Cytnx v0.9.6
Loading...
Searching...
No Matches
MPS.hpp
Go to the documentation of this file.
1#ifndef _H_MPS_
2#define _H_MPS_
3
4#include "cytnx_error.hpp"
5#include "Device.hpp"
7#include "UniTensor.hpp"
8#include <iostream>
9#include <fstream>
10
11#include "utils/vec_clone.hpp"
12#include "Accessor.hpp"
13#include <vector>
14#include <initializer_list>
15#include <string>
16
17#ifdef BACKEND_TORCH
18#else
19
20 #include "backend/Scalar.hpp"
21
22namespace cytnx {
23 namespace tn_algo {
24
26 class MPSType_class {
27 public:
28 enum : int {
29 Void = -99,
30 RegularMPS = 0,
31 iMPS = 1,
32 };
33 std::string getname(const int &mps_type);
34 };
35 extern MPSType_class MPSType;
36
37 class MPS_impl : public intrusive_ptr_base<MPS_impl> {
38 private:
39 public:
40 friend class MPS;
41
42 // std::vector<cytnx_int64> phys_dim;
43 cytnx_int64 virt_dim; // maximum
44 cytnx_int64 S_loc;
45
46 int mps_type_id;
47
48 MPS_impl() : mps_type_id(MPSType.Void) {}
49
50 // place holder for the tensors:
51 std::vector<UniTensor> _TNs;
52
53 std::vector<UniTensor> &get_data() { return this->_TNs; }
54
55 virtual Scalar norm() const;
56 virtual boost::intrusive_ptr<MPS_impl> clone() const;
57 virtual std::ostream &Print(std::ostream &os);
58 virtual cytnx_uint64 size() { return 0; };
59 virtual void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
60 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
61 virtual void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
62 const cytnx_uint64 &virt_dim,
63 const std::vector<cytnx_int64> &select, const cytnx_int64 &dtype);
64 // virtual void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const
65 // std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
66
67 // for finite MPS:
68
69 // virtual void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64
70 // &virt_dim, const std::vector<std::vector<cytnx_int64> > &state_qnums, const cytnx_int64
71 // &dtype);
72
73 virtual void Into_Lortho();
74 virtual void S_mvleft();
75 virtual void S_mvright();
76
77 virtual void _save_dispatch(std::fstream &f);
78 virtual void _load_dispatch(std::fstream &f);
79 };
80
81 // finite size:
82 class RegularMPS : public MPS_impl {
83 public:
84 // only for this:
85 RegularMPS() {
86 this->mps_type_id = MPSType.RegularMPS;
87 this->S_loc = 0;
88 this->virt_dim = -1;
89 };
90
91 // specialization:
92 std::ostream &Print(std::ostream &os);
93 cytnx_uint64 size() { return this->_TNs.size(); };
94 void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
95 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
96 void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
97 const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select,
98 const cytnx_int64 &dtype);
99 // void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const
100 // std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
101
102 // void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64
103 // &virt_dim, const std::vector<std::vector<cytnx_int64> >&state_qnums, const cytnx_int64
104 // &dtype);
105
106 void Into_Lortho();
107 void S_mvleft();
108 void S_mvright();
109
110 Scalar norm() const;
111 boost::intrusive_ptr<MPS_impl> clone() const {
112 boost::intrusive_ptr<MPS_impl> out(new RegularMPS());
113 out->S_loc = this->S_loc;
114 out->virt_dim = this->virt_dim;
115 out->_TNs = vec_clone(this->_TNs);
116 return out;
117 }
118
119 void _save_dispatch(std::fstream &f);
120 void _load_dispatch(std::fstream &f);
121 };
122
123 // infinite size:
124 class iMPS : public MPS_impl {
125 public:
126 // only for this:
127 iMPS() {
128 this->mps_type_id = MPSType.iMPS;
129 this->virt_dim = -1;
130 };
131
132 // specialization:
133 std::ostream &Print(std::ostream &os);
134 cytnx_uint64 size() { return this->_TNs.size(); };
135 void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
136 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
137 void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
138 const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select,
139 const cytnx_int64 &dtype) {
140 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call Init_Msector%s", "\n");
141 }
142 // void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const
143 // std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
144 // cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call prodstate%s","\n");
145 // }
146 void Into_Lortho() {
147 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call Into_Lortho%s", "\n");
148 }
149 void S_mvleft() {
150 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call S_mvleft%s", "\n");
151 }
152 void S_mvright() {
153 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call S_mvright%s", "\n");
154 }
155 boost::intrusive_ptr<MPS_impl> clone() const {
156 boost::intrusive_ptr<MPS_impl> out(new RegularMPS());
157 out->S_loc = this->S_loc;
158 out->virt_dim = this->virt_dim;
159 out->_TNs = vec_clone(this->_TNs);
160 return out;
161 }
162 Scalar norm() const;
163 void _save_dispatch(std::fstream &f);
164 void _load_dispatch(std::fstream &f);
165 };
167
168 // API
169 class MPS {
170 private:
171 public:
173 boost::intrusive_ptr<MPS_impl> _impl;
174 MPS()
175 : _impl(new MPS_impl()){
176 // currently default init is RegularMPS;:
177 };
178
180 const cytnx_int64 &dtype = Type.Double, const cytnx_int64 &mps_type = 0)
181 : _impl(new MPS_impl()) {
182 this->Init(N, phys_dim, virt_dim, dtype, mps_type);
183 };
184
185 MPS(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim,
186 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype = Type.Double,
187 const cytnx_int64 &mps_type = 0)
188 : _impl(new MPS_impl()) {
189 this->Init(N, vphys_dim, virt_dim, dtype, mps_type);
190 };
191
192 MPS(const MPS &rhs) { _impl = rhs._impl; }
193
194 MPS &operator=(const MPS &rhs) {
195 _impl = rhs._impl;
196 return *this;
197 }
199
200 // Initialization API:
201 //-----------------------
202 MPS &Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim,
203 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype = Type.Double,
204 const cytnx_int64 &mps_type = 0) {
205 if (mps_type == 0) {
206 this->_impl = boost::intrusive_ptr<MPS_impl>(new RegularMPS());
207 } else if (mps_type == 1) {
208 this->_impl = boost::intrusive_ptr<MPS_impl>(new iMPS());
209 } else {
210 cytnx_error_msg(true, "[ERROR] invalid MPS type.%s", "\n");
211 }
212 this->_impl->Init(N, vphys_dim, virt_dim, dtype);
213 return *this;
214 }
216 const cytnx_int64 &dtype = Type.Double, const cytnx_int64 &mps_type = 0) {
217 std::vector<cytnx_uint64> vphys_dim(N, phys_dim);
218
219 this->Init(N, vphys_dim, virt_dim, dtype);
220 return *this;
221 }
222 //-----------------------
223
224 MPS &Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim,
225 const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select,
226 const cytnx_int64 &dtype = Type.Double, const cytnx_int64 &mps_type = 0) {
227 // only the select phys index will have non-zero element.
228 if (mps_type == 0) {
229 this->_impl = boost::intrusive_ptr<MPS_impl>(new RegularMPS());
230 } else if (mps_type == 1) {
231 this->_impl = boost::intrusive_ptr<MPS_impl>(new iMPS());
232 } else {
233 cytnx_error_msg(true, "[ERROR] invalid MPS type.%s", "\n");
234 }
235 this->_impl->Init_Msector(N, vphys_dim, virt_dim, select, dtype);
236 return *this;
237 }
238
239 /*
240 MPS& Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const std::vector<cytnx_uint64>
241 cstate, const cytnx_int64 &dtype){
242 // only the select phys index will have non-zero element.
243 if(mps_type==0){
244 this->_impl =boost::intrusive_ptr<MPS_impl>(new RegularMPS());
245 }else if(mps_type==1){
246 this->_impl =boost::intrusive_ptr<MPS_impl>(new iMPS());
247 }else{
248 cytnx_error_msg(true,"[ERROR] invalid MPS type.%s","\n");
249 }
250 this->_impl->Init_prodstate(phys_dim, cstate, dtype);
251 return *this;
252 }
253 */
254
255 cytnx_uint64 size() { return this->_impl->size(); }
256
257 int mps_type() const { return this->_impl->mps_type_id; }
258 std::string mps_type_str() const { return MPSType.getname(this->_impl->mps_type_id); }
259
260 MPS clone() const {
261 MPS out;
262 out._impl = this->_impl->clone();
263 return out;
264 }
265
266 std::vector<UniTensor> &data() { return this->_impl->get_data(); };
267
269 this->_impl->Into_Lortho();
270 return *this;
271 }
273 this->_impl->S_mvleft();
274 return *this;
275 }
277 this->_impl->S_mvright();
278 return *this;
279 }
280
281 Scalar norm() const { return this->_impl->norm(); }
282
283 cytnx_int64 phys_dim(const cytnx_int64 &idx) { return this->_impl->_TNs[idx].shape()[1]; }
284
285 cytnx_int64 &virt_dim() { return this->_impl->virt_dim; }
286
287 cytnx_int64 &S_loc() { return this->_impl->S_loc; }
288
290 void _Save(std::fstream &f) const;
291 void _Load(std::fstream &f);
293
294 void Save(const std::string &fname) const;
295 void Save(const char *fname) const;
296
297 static MPS Load(const std::string &fname);
298 static MPS Load(const char *fname);
299 };
300
301 std::ostream &operator<<(std::ostream &os, const MPS &in);
302
303 } // namespace tn_algo
304} // namespace cytnx
305
306#endif // BACKEND_TORCH
307
308#endif
Definition MPS.hpp:169
std::string mps_type_str() const
Definition MPS.hpp:258
cytnx_uint64 size()
Definition MPS.hpp:255
void Save(const std::string &fname) const
void Save(const char *fname) const
std::vector< UniTensor > & data()
Definition MPS.hpp:266
MPS & S_mvleft()
Definition MPS.hpp:272
cytnx_int64 phys_dim(const cytnx_int64 &idx)
Definition MPS.hpp:283
MPS & Init(const cytnx_uint64 &N, const cytnx_uint64 &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:215
int mps_type() const
Definition MPS.hpp:257
static MPS Load(const char *fname)
MPS & Init(const cytnx_uint64 &N, const std::vector< cytnx_uint64 > &vphys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:202
MPS clone() const
Definition MPS.hpp:260
MPS & Init_Msector(const cytnx_uint64 &N, const std::vector< cytnx_uint64 > &vphys_dim, const cytnx_uint64 &virt_dim, const std::vector< cytnx_int64 > &select, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:224
MPS & S_mvright()
Definition MPS.hpp:276
Scalar norm() const
Definition MPS.hpp:281
cytnx_int64 & S_loc()
Definition MPS.hpp:287
cytnx_int64 & virt_dim()
Definition MPS.hpp:285
static MPS Load(const std::string &fname)
MPS & Into_Lortho()
Definition MPS.hpp:268
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:16
std::ostream & operator<<(std::ostream &os, const MPO &in)
Helper function to print vector with ODT:
Definition Accessor.hpp:12
uint64_t cytnx_uint64
Definition Type.hpp:55
int64_t cytnx_int64
Definition Type.hpp:58
Type_class Type
data type