Cytnx v0.7.6
Loading...
Searching...
No Matches
MPS.hpp
Go to the documentation of this file.
1#ifndef _H_MPS_
2#define _H_MPS_
3
4#include "cytnx_error.hpp"
5#include "Device.hpp"
7#include "UniTensor.hpp"
8#include <iostream>
9#include <fstream>
10
11#include "utils/vec_clone.hpp"
12//#include "utils/dynamic_arg_resolver.hpp"
13//#include "linalg.hpp"
14#include "Accessor.hpp"
15#include <vector>
16#include <initializer_list>
17#include <string>
18#include "Scalar.hpp"
19
20namespace cytnx{
21 namespace tn_algo{
22
24 class MPSType_class{
25 public:
26 enum : int{
27 Void=-99,
28 RegularMPS=0,
29 iMPS=1,
30 };
31 std::string getname(const int &mps_type);
32 };
33 extern MPSType_class MPSType;
34
35
36
37 class MPS_impl: public intrusive_ptr_base<MPS_impl>{
38 private:
39
40 public:
41 friend class MPS;
42
43 //std::vector<cytnx_int64> phys_dim;
44 cytnx_int64 virt_dim; // maximum
45 cytnx_int64 S_loc;
46
47 int mps_type_id;
48
49 MPS_impl(): mps_type_id(MPSType.Void){}
50
51
52 // place holder for the tensors:
53 std::vector<UniTensor> _TNs;
54
55
56 std::vector<UniTensor>& get_data(){
57 return this->_TNs;
58 }
59
60 virtual Scalar norm() const;
61 virtual boost::intrusive_ptr<MPS_impl> clone() const;
62 virtual std::ostream& Print(std::ostream &os);
63 virtual cytnx_uint64 size(){return 0;};
64 virtual void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
65 virtual void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select, const cytnx_int64 &dtype);
66 //virtual void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
67
68 // for finite MPS:
69
70 //virtual void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const std::vector<std::vector<cytnx_int64> > &state_qnums, const cytnx_int64 &dtype);
71
72 virtual void Into_Lortho();
73 virtual void S_mvleft();
74 virtual void S_mvright();
75
76 virtual void _save_dispatch(std::fstream &f);
77 virtual void _load_dispatch(std::fstream &f);
78
79
80 };
81
82 // finite size:
83 class RegularMPS: public MPS_impl{
84 public:
85
86
87
88 // only for this:
89 RegularMPS(){
90 this->mps_type_id = MPSType.RegularMPS;
91 this->S_loc = 0;
92 this->virt_dim = -1;
93 };
94
95
96 // specialization:
97 std::ostream& Print(std::ostream &os);
98 cytnx_uint64 size(){return this->_TNs.size();};
99 void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
100 void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select, const cytnx_int64 &dtype);
101 //void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
102
103 //void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const std::vector<std::vector<cytnx_int64> >&state_qnums, const cytnx_int64 &dtype);
104
105 void Into_Lortho();
106 void S_mvleft();
107 void S_mvright();
108
109 Scalar norm() const;
110 boost::intrusive_ptr<MPS_impl> clone() const{
111 boost::intrusive_ptr<MPS_impl> out(new RegularMPS());
112 out->S_loc = this->S_loc;
113 out->virt_dim = this->virt_dim;
114 out->_TNs = vec_clone(this->_TNs);
115 return out;
116 }
117
118 void _save_dispatch(std::fstream &f);
119 void _load_dispatch(std::fstream &f);
120 };
121
122 // infinite size:
123 class iMPS: public MPS_impl{
124 public:
125
126
127
128 // only for this:
129 iMPS(){
130 this->mps_type_id = MPSType.iMPS;
131 this->virt_dim = -1;
132 };
133
134
135 // specialization:
136 std::ostream& Print(std::ostream &os);
137 cytnx_uint64 size(){return this->_TNs.size();};
138 void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
139 void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select, const cytnx_int64 &dtype){
140 cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call Init_Msector%s","\n");
141 }
142 //void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
143 // cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call prodstate%s","\n");
144 //}
145 void Into_Lortho(){
146 cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call Into_Lortho%s","\n");
147 }
148 void S_mvleft(){
149 cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call S_mvleft%s","\n");
150 }
151 void S_mvright(){
152 cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call S_mvright%s","\n");
153 }
154 boost::intrusive_ptr<MPS_impl> clone() const{
155 boost::intrusive_ptr<MPS_impl> out(new RegularMPS());
156 out->S_loc = this->S_loc;
157 out->virt_dim = this->virt_dim;
158 out->_TNs = vec_clone(this->_TNs);
159 return out;
160 }
161 Scalar norm() const;
162 void _save_dispatch(std::fstream &f);
163 void _load_dispatch(std::fstream &f);
164
165 };
167
168
169 // API
170 class MPS{
171 private:
172
173 public:
174
176 boost::intrusive_ptr<MPS_impl> _impl;
177 MPS(): _impl(new MPS_impl()){
178 // currently default init is RegularMPS;:
179 };
180
181 MPS(const cytnx_uint64 &N, const cytnx_uint64 &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0): _impl(new MPS_impl()){
182 this->Init(N,phys_dim, virt_dim,dtype,mps_type);
183 };
184
185 MPS(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0): _impl(new MPS_impl()){
186 this->Init(N,vphys_dim, virt_dim,dtype,mps_type);
187 };
188
189
190 MPS(const MPS &rhs){
191 _impl = rhs._impl;
192 }
193
194 MPS& operator=(const MPS &rhs){
195 _impl = rhs._impl;
196 return *this;
197 }
199
200
201
202
203 // Initialization API:
204 //-----------------------
205 MPS& Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0){
206
207 if(mps_type==0){
208 this->_impl =boost::intrusive_ptr<MPS_impl>(new RegularMPS());
209 }else if(mps_type==1){
210 this->_impl =boost::intrusive_ptr<MPS_impl>(new iMPS());
211 }else{
212 cytnx_error_msg(true,"[ERROR] invalid MPS type.%s","\n");
213 }
214 this->_impl->Init(N, vphys_dim, virt_dim,dtype);
215 return *this;
216 }
217 MPS& Init(const cytnx_uint64 &N, const cytnx_uint64 &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0){
218
219 std::vector<cytnx_uint64> vphys_dim(N,phys_dim);
220
221 this->Init(N, vphys_dim, virt_dim,dtype);
222 return *this;
223 }
224 //-----------------------
225
226
227 MPS& Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim, const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0){
228 // only the select phys index will have non-zero element.
229 if(mps_type==0){
230 this->_impl =boost::intrusive_ptr<MPS_impl>(new RegularMPS());
231 }else if(mps_type==1){
232 this->_impl =boost::intrusive_ptr<MPS_impl>(new iMPS());
233 }else{
234 cytnx_error_msg(true,"[ERROR] invalid MPS type.%s","\n");
235 }
236 this->_impl->Init_Msector(N, vphys_dim, virt_dim, select, dtype);
237 return *this;
238 }
239
240 /*
241 MPS& Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype){
242 // only the select phys index will have non-zero element.
243 if(mps_type==0){
244 this->_impl =boost::intrusive_ptr<MPS_impl>(new RegularMPS());
245 }else if(mps_type==1){
246 this->_impl =boost::intrusive_ptr<MPS_impl>(new iMPS());
247 }else{
248 cytnx_error_msg(true,"[ERROR] invalid MPS type.%s","\n");
249 }
250 this->_impl->Init_prodstate(phys_dim, cstate, dtype);
251 return *this;
252 }
253 */
254
256 return this->_impl->size();
257 }
258
259 int mps_type() const{
260 return this->_impl->mps_type_id;
261 }
262 std::string mps_type_str() const{
263 return MPSType.getname(this->_impl->mps_type_id);
264 }
265
266 MPS clone() const{
267 MPS out;
268 out._impl = this->_impl->clone();
269 return out;
270 }
271
272 std::vector<UniTensor> &data(){return this->_impl->get_data();};
273
275 this->_impl->Into_Lortho();
276 return *this;
277 }
279 this->_impl->S_mvleft();
280 return *this;
281 }
283 this->_impl->S_mvright();
284 return *this;
285 }
286
287 Scalar norm() const{
288 return this->_impl->norm();
289 }
290
292 return this->_impl->_TNs[idx].shape()[1];
293 }
294
296 return this->_impl->virt_dim;
297 }
298
300 return this->_impl->S_loc;
301 }
302
304 void _Save(std::fstream &f) const;
305 void _Load(std::fstream &f);
307
308 void Save(const std::string &fname) const;
309 void Save(const char* fname) const;
310
311 static MPS Load(const std::string &fname);
312 static MPS Load(const char* fname);
313
314
315
316 };
317
318 std::ostream& operator<<(std::ostream& os, const MPS &in);
319
320
321
322
323
324 }
325}
326
327
328#endif
329
Definition Scalar.hpp:1751
Definition MPS.hpp:170
std::string mps_type_str() const
Definition MPS.hpp:262
cytnx_uint64 size()
Definition MPS.hpp:255
void Save(const std::string &fname) const
void Save(const char *fname) const
std::vector< UniTensor > & data()
Definition MPS.hpp:272
MPS & S_mvleft()
Definition MPS.hpp:278
cytnx_int64 phys_dim(const cytnx_int64 &idx)
Definition MPS.hpp:291
MPS & Init(const cytnx_uint64 &N, const cytnx_uint64 &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:217
int mps_type() const
Definition MPS.hpp:259
static MPS Load(const char *fname)
MPS & Init(const cytnx_uint64 &N, const std::vector< cytnx_uint64 > &vphys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:205
MPS clone() const
Definition MPS.hpp:266
MPS & Init_Msector(const cytnx_uint64 &N, const std::vector< cytnx_uint64 > &vphys_dim, const cytnx_uint64 &virt_dim, const std::vector< cytnx_int64 > &select, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:227
MPS & S_mvright()
Definition MPS.hpp:282
Scalar norm() const
Definition MPS.hpp:287
cytnx_int64 & S_loc()
Definition MPS.hpp:299
cytnx_int64 & virt_dim()
Definition MPS.hpp:295
static MPS Load(const std::string &fname)
MPS & Into_Lortho()
Definition MPS.hpp:274
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:18
std::ostream & operator<<(std::ostream &os, const MPO &in)
Definition Accessor.hpp:12
uint64_t cytnx_uint64
Definition Type.hpp:22
int64_t cytnx_int64
Definition Type.hpp:25
Type_class Type
Definition Type.cpp:143