Cytnx v0.9.1
Loading...
Searching...
No Matches
MPS.hpp
Go to the documentation of this file.
1#ifndef _H_MPS_
2#define _H_MPS_
3
4#include "cytnx_error.hpp"
5#include "Device.hpp"
7#include "UniTensor.hpp"
8#include <iostream>
9#include <fstream>
10
11#include "utils/vec_clone.hpp"
12//#include "utils/dynamic_arg_resolver.hpp"
13//#include "linalg.hpp"
14#include "Accessor.hpp"
15#include <vector>
16#include <initializer_list>
17#include <string>
18#include "Scalar.hpp"
19
20namespace cytnx {
21 namespace tn_algo {
22
24 class MPSType_class {
25 public:
26 enum : int {
27 Void = -99,
28 RegularMPS = 0,
29 iMPS = 1,
30 };
31 std::string getname(const int &mps_type);
32 };
33 extern MPSType_class MPSType;
34
35 class MPS_impl : public intrusive_ptr_base<MPS_impl> {
36 private:
37 public:
38 friend class MPS;
39
40 // std::vector<cytnx_int64> phys_dim;
41 cytnx_int64 virt_dim; // maximum
42 cytnx_int64 S_loc;
43
44 int mps_type_id;
45
46 MPS_impl() : mps_type_id(MPSType.Void) {}
47
48 // place holder for the tensors:
49 std::vector<UniTensor> _TNs;
50
51 std::vector<UniTensor> &get_data() { return this->_TNs; }
52
53 virtual Scalar norm() const;
54 virtual boost::intrusive_ptr<MPS_impl> clone() const;
55 virtual std::ostream &Print(std::ostream &os);
56 virtual cytnx_uint64 size() { return 0; };
57 virtual void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
58 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
59 virtual void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
60 const cytnx_uint64 &virt_dim,
61 const std::vector<cytnx_int64> &select, const cytnx_int64 &dtype);
62 // virtual void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const
63 // std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
64
65 // for finite MPS:
66
67 // virtual void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64
68 // &virt_dim, const std::vector<std::vector<cytnx_int64> > &state_qnums, const cytnx_int64
69 // &dtype);
70
71 virtual void Into_Lortho();
72 virtual void S_mvleft();
73 virtual void S_mvright();
74
75 virtual void _save_dispatch(std::fstream &f);
76 virtual void _load_dispatch(std::fstream &f);
77 };
78
79 // finite size:
80 class RegularMPS : public MPS_impl {
81 public:
82 // only for this:
83 RegularMPS() {
84 this->mps_type_id = MPSType.RegularMPS;
85 this->S_loc = 0;
86 this->virt_dim = -1;
87 };
88
89 // specialization:
90 std::ostream &Print(std::ostream &os);
91 cytnx_uint64 size() { return this->_TNs.size(); };
92 void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
93 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
94 void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
95 const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select,
96 const cytnx_int64 &dtype);
97 // void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const
98 // std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
99
100 // void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const cytnx_uint64
101 // &virt_dim, const std::vector<std::vector<cytnx_int64> >&state_qnums, const cytnx_int64
102 // &dtype);
103
104 void Into_Lortho();
105 void S_mvleft();
106 void S_mvright();
107
108 Scalar norm() const;
109 boost::intrusive_ptr<MPS_impl> clone() const {
110 boost::intrusive_ptr<MPS_impl> out(new RegularMPS());
111 out->S_loc = this->S_loc;
112 out->virt_dim = this->virt_dim;
113 out->_TNs = vec_clone(this->_TNs);
114 return out;
115 }
116
117 void _save_dispatch(std::fstream &f);
118 void _load_dispatch(std::fstream &f);
119 };
120
121 // infinite size:
122 class iMPS : public MPS_impl {
123 public:
124 // only for this:
125 iMPS() {
126 this->mps_type_id = MPSType.iMPS;
127 this->virt_dim = -1;
128 };
129
130 // specialization:
131 std::ostream &Print(std::ostream &os);
132 cytnx_uint64 size() { return this->_TNs.size(); };
133 void Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
134 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype);
135 void Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &phys_dim,
136 const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select,
137 const cytnx_int64 &dtype) {
138 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call Init_Msector%s", "\n");
139 }
140 // void Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const
141 // std::vector<cytnx_uint64> cstate, const cytnx_int64 &dtype);
142 // cytnx_error_msg(true,"[ERROR][MPS][type=iMPS] cannot call prodstate%s","\n");
143 // }
144 void Into_Lortho() {
145 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call Into_Lortho%s", "\n");
146 }
147 void S_mvleft() {
148 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call S_mvleft%s", "\n");
149 }
150 void S_mvright() {
151 cytnx_error_msg(true, "[ERROR][MPS][type=iMPS] cannot call S_mvright%s", "\n");
152 }
153 boost::intrusive_ptr<MPS_impl> clone() const {
154 boost::intrusive_ptr<MPS_impl> out(new RegularMPS());
155 out->S_loc = this->S_loc;
156 out->virt_dim = this->virt_dim;
157 out->_TNs = vec_clone(this->_TNs);
158 return out;
159 }
160 Scalar norm() const;
161 void _save_dispatch(std::fstream &f);
162 void _load_dispatch(std::fstream &f);
163 };
165
166 // API
167 class MPS {
168 private:
169 public:
171 boost::intrusive_ptr<MPS_impl> _impl;
172 MPS()
173 : _impl(new MPS_impl()){
174 // currently default init is RegularMPS;:
175 };
176
178 const cytnx_int64 &dtype = Type.Double, const cytnx_int64 &mps_type = 0)
179 : _impl(new MPS_impl()) {
180 this->Init(N, phys_dim, virt_dim, dtype, mps_type);
181 };
182
183 MPS(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim,
184 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype = Type.Double,
185 const cytnx_int64 &mps_type = 0)
186 : _impl(new MPS_impl()) {
187 this->Init(N, vphys_dim, virt_dim, dtype, mps_type);
188 };
189
190 MPS(const MPS &rhs) { _impl = rhs._impl; }
191
192 MPS &operator=(const MPS &rhs) {
193 _impl = rhs._impl;
194 return *this;
195 }
197
198 // Initialization API:
199 //-----------------------
200 MPS &Init(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim,
201 const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype = Type.Double,
202 const cytnx_int64 &mps_type = 0) {
203 if (mps_type == 0) {
204 this->_impl = boost::intrusive_ptr<MPS_impl>(new RegularMPS());
205 } else if (mps_type == 1) {
206 this->_impl = boost::intrusive_ptr<MPS_impl>(new iMPS());
207 } else {
208 cytnx_error_msg(true, "[ERROR] invalid MPS type.%s", "\n");
209 }
210 this->_impl->Init(N, vphys_dim, virt_dim, dtype);
211 return *this;
212 }
214 const cytnx_int64 &dtype = Type.Double, const cytnx_int64 &mps_type = 0) {
215 std::vector<cytnx_uint64> vphys_dim(N, phys_dim);
216
217 this->Init(N, vphys_dim, virt_dim, dtype);
218 return *this;
219 }
220 //-----------------------
221
222 MPS &Init_Msector(const cytnx_uint64 &N, const std::vector<cytnx_uint64> &vphys_dim,
223 const cytnx_uint64 &virt_dim, const std::vector<cytnx_int64> &select,
224 const cytnx_int64 &dtype = Type.Double, const cytnx_int64 &mps_type = 0) {
225 // only the select phys index will have non-zero element.
226 if (mps_type == 0) {
227 this->_impl = boost::intrusive_ptr<MPS_impl>(new RegularMPS());
228 } else if (mps_type == 1) {
229 this->_impl = boost::intrusive_ptr<MPS_impl>(new iMPS());
230 } else {
231 cytnx_error_msg(true, "[ERROR] invalid MPS type.%s", "\n");
232 }
233 this->_impl->Init_Msector(N, vphys_dim, virt_dim, select, dtype);
234 return *this;
235 }
236
237 /*
238 MPS& Init_prodstate(const std::vector<cytnx_uint64> &phys_dim, const std::vector<cytnx_uint64>
239 cstate, const cytnx_int64 &dtype){
240 // only the select phys index will have non-zero element.
241 if(mps_type==0){
242 this->_impl =boost::intrusive_ptr<MPS_impl>(new RegularMPS());
243 }else if(mps_type==1){
244 this->_impl =boost::intrusive_ptr<MPS_impl>(new iMPS());
245 }else{
246 cytnx_error_msg(true,"[ERROR] invalid MPS type.%s","\n");
247 }
248 this->_impl->Init_prodstate(phys_dim, cstate, dtype);
249 return *this;
250 }
251 */
252
253 cytnx_uint64 size() { return this->_impl->size(); }
254
255 int mps_type() const { return this->_impl->mps_type_id; }
256 std::string mps_type_str() const { return MPSType.getname(this->_impl->mps_type_id); }
257
258 MPS clone() const {
259 MPS out;
260 out._impl = this->_impl->clone();
261 return out;
262 }
263
264 std::vector<UniTensor> &data() { return this->_impl->get_data(); };
265
267 this->_impl->Into_Lortho();
268 return *this;
269 }
271 this->_impl->S_mvleft();
272 return *this;
273 }
275 this->_impl->S_mvright();
276 return *this;
277 }
278
279 Scalar norm() const { return this->_impl->norm(); }
280
281 cytnx_int64 phys_dim(const cytnx_int64 &idx) { return this->_impl->_TNs[idx].shape()[1]; }
282
283 cytnx_int64 &virt_dim() { return this->_impl->virt_dim; }
284
285 cytnx_int64 &S_loc() { return this->_impl->S_loc; }
286
288 void _Save(std::fstream &f) const;
289 void _Load(std::fstream &f);
291
292 void Save(const std::string &fname) const;
293 void Save(const char *fname) const;
294
295 static MPS Load(const std::string &fname);
296 static MPS Load(const char *fname);
297 };
298
299 std::ostream &operator<<(std::ostream &os, const MPS &in);
300
301 } // namespace tn_algo
302} // namespace cytnx
303
304#endif
A class to represent a scalar.
Definition Scalar.hpp:2470
Definition MPS.hpp:167
std::string mps_type_str() const
Definition MPS.hpp:256
cytnx_uint64 size()
Definition MPS.hpp:253
void Save(const std::string &fname) const
void Save(const char *fname) const
std::vector< UniTensor > & data()
Definition MPS.hpp:264
MPS & S_mvleft()
Definition MPS.hpp:270
cytnx_int64 phys_dim(const cytnx_int64 &idx)
Definition MPS.hpp:281
MPS & Init(const cytnx_uint64 &N, const cytnx_uint64 &phys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:213
int mps_type() const
Definition MPS.hpp:255
static MPS Load(const char *fname)
MPS & Init(const cytnx_uint64 &N, const std::vector< cytnx_uint64 > &vphys_dim, const cytnx_uint64 &virt_dim, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:200
MPS clone() const
Definition MPS.hpp:258
MPS & Init_Msector(const cytnx_uint64 &N, const std::vector< cytnx_uint64 > &vphys_dim, const cytnx_uint64 &virt_dim, const std::vector< cytnx_int64 > &select, const cytnx_int64 &dtype=Type.Double, const cytnx_int64 &mps_type=0)
Definition MPS.hpp:222
MPS & S_mvright()
Definition MPS.hpp:274
Scalar norm() const
Definition MPS.hpp:279
cytnx_int64 & S_loc()
Definition MPS.hpp:285
cytnx_int64 & virt_dim()
Definition MPS.hpp:283
static MPS Load(const std::string &fname)
MPS & Into_Lortho()
Definition MPS.hpp:266
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:16
std::ostream & operator<<(std::ostream &os, const MPO &in)
Definition Accessor.hpp:12
uint64_t cytnx_uint64
Definition Type.hpp:45
int64_t cytnx_int64
Definition Type.hpp:48
Type_class Type
data type