1#ifndef CYTNX_NETWORK_H_
2#define CYTNX_NETWORK_H_
6#include <initializer_list>
11#include "utils/utils.hpp"
20 #include <cutensornet.h>
27 enum __nttype {
Void = -1, Regular = 0, Fermion = 1 };
29 class NetworkType_class {
31 enum :
int {
Void = -1, Regular = 0, Fermion = 1 };
32 std::string getname(
const int &nwrktype_id);
34 extern NetworkType_class NtType;
38 class Network_base :
public intrusive_ptr_base<Network_base> {
43 std::vector<UniTensor> tensors;
44 std::vector<std::string> TOUT_labels;
46 cytnx_uint64 TOUT_iBondNum;
51 ContractionTree CtTree;
52 std::vector<std::string> ORDER_tokens;
54 std::string order_line =
"";
56 std::vector<std::pair<cytnx_int64, cytnx_int64>> einsum_path;
59 std::vector<std::vector<std::string>> label_arr;
60 std::vector<cytnx_int64> iBondNums;
63 std::vector<std::string> names;
64 std::map<std::string, cytnx_uint64> name2pos;
67 std::vector<std::pair<int, int>> TOUT_pos;
70 std::vector<std::vector<cytnx_int64>> int_modes;
71 std::vector<cytnx_int64> int_out_mode;
81 cutensornetNetworkDescriptor_t descNet;
83 cutensornetContractionOptimizerInfo_t optimizerInfo;
86 friend class FermionNetwork;
87 friend class RegularNetwork;
89 Network_base() : nwrktype_id(NtType.
Void){};
91 bool HasPutAllUniTensor() {
92 for (cytnx_uint64 i = 0; i < this->tensors.size(); i++) {
93 if (this->tensors[i].uten_type() ==
UTenType.Void)
return false;
109 virtual void PutUniTensor(
const std::string &name,
const UniTensor &utensor);
110 virtual void PutUniTensor(
const cytnx_uint64 &idx,
const UniTensor &utensor);
111 virtual void PutUniTensors(
const std::vector<std::string> &name,
112 const std::vector<UniTensor> &utensors);
114 virtual void RmUniTensor(
const cytnx_uint64 &idx);
115 virtual void RmUniTensor(
const std::string &name);
116 virtual void RmUniTensors(
const std::vector<std::string> &name);
118 virtual void Contract_plan(
const std::vector<UniTensor> &utensors,
const std::string &Tout,
119 const std::vector<std::string> &alias,
120 const std::string &contract_order);
122 virtual void Fromfile(
const std::string &fname);
123 virtual void FromString(
const std::vector<std::string> &content);
124 virtual void clear();
125 virtual std::string getOptimalOrder();
129 virtual std::string getOrder();
130 virtual void setOrder(
const bool &optimal =
false,
const std::string &contract_order =
"");
132 virtual UniTensor Launch();
134 virtual void construct(
const std::vector<std::string> &alias,
135 const std::vector<std::vector<std::string>> &labels,
136 const std::vector<std::string> &outlabel,
const cytnx_int64 &outrk,
137 const std::string &order,
const bool optim);
138 virtual void PrintNet(std::ostream &os);
139 virtual boost::intrusive_ptr<Network_base> clone();
140 virtual void Savefile(
const std::string &fname);
141 virtual ~Network_base(){};
145 class RegularNetwork :
public Network_base {
147 RegularNetwork() { this->nwrktype_id = NtType.Regular; };
148 void Fromfile(
const std::string &fname);
149 void FromString(
const std::vector<std::string> &contents);
150 void PutUniTensor(
const std::string &name,
const UniTensor &utensor);
151 void PutUniTensor(
const cytnx_uint64 &idx,
const UniTensor &utensor);
152 void PutUniTensors(
const std::vector<std::string> &name,
153 const std::vector<UniTensor> &utensors);
155 void RmUniTensor(
const cytnx_uint64 &idx);
156 void RmUniTensor(
const std::string &name);
157 void RmUniTensors(
const std::vector<std::string> &name);
159 void Contract_plan(
const std::vector<UniTensor> &utensors,
const std::string &Tout,
160 const std::vector<std::string> &alias = {},
161 const std::string &contract_order =
"");
163 this->tensors.clear();
164 this->name2pos.clear();
165 this->CtTree.clear();
167 this->iBondNums.clear();
168 this->label_arr.clear();
169 this->TOUT_labels.clear();
170 this->TOUT_iBondNum = 0;
171 this->ORDER_tokens.clear();
173 std::string getOptimalOrder();
176 std::string getOrder();
177 void setOrder(
const bool &optimal =
false,
const std::string &contract_order =
"");
180 void construct(
const std::vector<std::string> &alias,
181 const std::vector<std::vector<std::string>> &labels,
182 const std::vector<std::string> &outlabel,
const cytnx_int64 &outrk,
183 const std::string &order,
const bool optim);
184 boost::intrusive_ptr<Network_base> clone() {
185 RegularNetwork *tmp =
new RegularNetwork();
186 tmp->name2pos = this->name2pos;
187 tmp->CtTree = this->CtTree;
188 tmp->names = this->names;
189 tmp->iBondNums = this->iBondNums;
190 tmp->label_arr = this->label_arr;
191 tmp->TOUT_labels = this->TOUT_labels;
192 tmp->TOUT_iBondNum = this->TOUT_iBondNum;
193 tmp->ORDER_tokens = this->ORDER_tokens;
194 boost::intrusive_ptr<Network_base> out(tmp);
197 void PrintNet(std::ostream &os);
198 void Savefile(
const std::string &fname);
203 class FermionNetwork :
public Network_base {
208 FermionNetwork() { this->nwrktype_id = NtType.Fermion; };
209 void Fromfile(
const std::string &fname){};
210 void FromString(
const std::vector<std::string> &contents){};
211 void RmUniTensor(
const cytnx_uint64 &idx){};
212 void RmUniTensor(
const std::string &name){};
213 void RmUniTensors(
const std::vector<std::string> &name){};
215 void PutUniTensor(
const std::string &name,
const UniTensor &utensor){};
216 void PutUniTensor(
const cytnx_uint64 &idx,
const UniTensor &utensor){};
217 void PutUniTensors(
const std::vector<std::string> &name,
218 const std::vector<UniTensor> &utensors){};
219 void Contract_plan(
const std::vector<UniTensor> &utensors,
const std::string &Tout,
220 const std::vector<std::string> &alias = {},
221 const std::string &contract_order =
""){};
223 this->name2pos.clear();
224 this->CtTree.clear();
226 this->iBondNums.clear();
227 this->label_arr.clear();
228 this->TOUT_labels.clear();
229 this->TOUT_iBondNum = 0;
230 this->ORDER_tokens.clear();
232 UniTensor Launch(
const bool &optimal =
false,
const std::string &contract_order =
"") {
235 boost::intrusive_ptr<Network_base> clone() {
236 FermionNetwork *tmp =
new FermionNetwork();
237 tmp->name2pos = this->name2pos;
238 tmp->CtTree = this->CtTree;
239 tmp->names = this->names;
240 tmp->iBondNums = this->iBondNums;
241 tmp->label_arr = this->label_arr;
242 tmp->TOUT_labels = this->TOUT_labels;
243 tmp->TOUT_iBondNum = this->TOUT_iBondNum;
244 tmp->ORDER_tokens = this->ORDER_tokens;
245 boost::intrusive_ptr<Network_base> out(tmp);
248 void PrintNet(std::ostream &os){};
249 void Savefile(
const std::string &fname){};
264 boost::intrusive_ptr<Network_base> _impl;
265 Network() : _impl(
new Network_base()){};
268 this->_impl = rhs._impl;
320 void Fromfile(
const std::string &fname,
const int &network_type = NtType.Regular) {
321 if (network_type == NtType.Regular) {
322 boost::intrusive_ptr<Network_base> tmp(
new RegularNetwork());
325 cytnx_error_msg(
true,
"[Developing] currently only support regular type network.%s",
"\n");
327 this->_impl->Fromfile(fname);
358 const int &network_type = NtType.Regular) {
359 if (network_type == NtType.Regular) {
360 boost::intrusive_ptr<Network_base> tmp(
new RegularNetwork());
363 cytnx_error_msg(
true,
"[Developing] currently only support regular type network.%s",
"\n");
365 this->_impl->FromString(contents);
369 static Network Contract(
const std::vector<UniTensor> &tensors,
const std::string &Tout,
370 const std::vector<std::string> &alias = {},
371 const std::string &contract_order =
"") {
372 boost::intrusive_ptr<Network_base> tmp(
new RegularNetwork());
375 out._impl->Contract_plan(tensors, Tout, alias, contract_order);
379 Network(
const std::string &fname,
const int &network_type = NtType.Regular) {
380 this->
Fromfile(fname, network_type);
383 void RmUniTensor(
const std::string &name) { this->_impl->RmUniTensor(name); }
384 void RmUniTensor(
const cytnx_uint64 &idx) { this->_impl->RmUniTensor(idx); }
385 void RmUniTensors(
const std::vector<std::string> &names) { this->_impl->RmUniTensors(names); }
387 const std::vector<std::string> &label_order = {}) {
388 if (label_order.size()) {
389 auto tmpu = utensor.
permute(label_order);
390 this->_impl->PutUniTensor(name, tmpu);
392 this->_impl->PutUniTensor(name, utensor);
395 const std::vector<std::string> &label_order = {}) {
396 if (label_order.size()) {
397 auto tmpu = utensor.
permute(label_order);
398 this->_impl->PutUniTensor(idx, tmpu);
400 this->_impl->PutUniTensor(idx, utensor);
404 const std::vector<UniTensor> &utensors) {
405 this->_impl->PutUniTensors(name, utensors);
408 if (network_type == NtType.Regular) {
409 return this->_impl->getOptimalOrder();
411 cytnx_error_msg(
true,
"[Developing] currently only support regular type network.%s",
"\n");
415 std::string
getOrder() {
return this->_impl->getOrder(); }
417 void setOrder(
const bool &optimal,
const std::string &contract_order ) {
418 return this->_impl->setOrder(optimal, contract_order);
422 if (network_type == NtType.Regular) {
423 return this->_impl->Launch();
425 cytnx_error_msg(
true,
"[Developing] currently only support regular type network.%s",
"\n");
430 const std::vector<std::vector<std::string>> &labels,
431 const std::vector<std::string> &outlabel = std::vector<std::string>(),
432 const cytnx_int64 &outrk = 0,
const std::string &order =
"",
433 const bool optim =
false,
const int &network_type = NtType.Regular) {
434 if (network_type == NtType.Regular) {
435 boost::intrusive_ptr<Network_base> tmp(
new RegularNetwork());
438 cytnx_error_msg(
true,
"[Developing] currently only support regular type network.%s",
"\n");
440 this->_impl->construct(alias, labels, outlabel, outrk, order, optim);
445 this->_impl->clear();
450 out._impl = this->_impl->
clone();
453 void PrintNet() { this->_impl->PrintNet(std::cout); }
455 void Savefile(
const std::string &fname) { this->_impl->Savefile(fname); }
459 std::ostream &operator<<(std::ostream &os,
const Network &bin);
Definition Network.hpp:261
void Fromfile(const std::string &fname, const int &network_type=NtType.Regular)
Construct Network from network file.
Definition Network.hpp:320
Network(const std::string &fname, const int &network_type=NtType.Regular)
Definition Network.hpp:379
Network clone()
Definition Network.hpp:448
void setOrder(const bool &optimal, const std::string &contract_order)
Definition Network.hpp:417
void construct(const std::vector< std::string > &alias, const std::vector< std::vector< std::string > > &labels, const std::vector< std::string > &outlabel=std::vector< std::string >(), const cytnx_int64 &outrk=0, const std::string &order="", const bool optim=false, const int &network_type=NtType.Regular)
Definition Network.hpp:429
void PrintNet()
Definition Network.hpp:453
std::string getOptimalOrder(const int &network_type=NtType.Regular)
Definition Network.hpp:407
void RmUniTensors(const std::vector< std::string > &names)
Definition Network.hpp:385
UniTensor Launch(const int &network_type=NtType.Regular)
Definition Network.hpp:421
void RmUniTensor(const cytnx_uint64 &idx)
Definition Network.hpp:384
static Network Contract(const std::vector< UniTensor > &tensors, const std::string &Tout, const std::vector< std::string > &alias={}, const std::string &contract_order="")
Definition Network.hpp:369
void FromString(const std::vector< std::string > &contents, const int &network_type=NtType.Regular)
Construct Network from a list of strings, where each string is the same as each line in network file.
Definition Network.hpp:357
void clear()
Definition Network.hpp:443
void RmUniTensor(const std::string &name)
Definition Network.hpp:383
std::string getOrder()
Definition Network.hpp:415
void PutUniTensors(const std::vector< std::string > &name, const std::vector< UniTensor > &utensors)
Definition Network.hpp:403
void PutUniTensor(const cytnx_uint64 &idx, const UniTensor &utensor, const std::vector< std::string > &label_order={})
Definition Network.hpp:394
void PutUniTensor(const std::string &name, const UniTensor &utensor, const std::vector< std::string > &label_order={})
Definition Network.hpp:386
void Savefile(const std::string &fname)
Definition Network.hpp:455
An Enhanced tensor specifically designed for physical Tensor network simulation.
Definition UniTensor.hpp:2599
UniTensor permute(const std::vector< cytnx_int64 > &mapper, const cytnx_int64 &rowrank=-1) const
permute the legs of the UniTensor
Definition UniTensor.hpp:3521
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:18
Definition Accessor.hpp:12
UniTensorType_class UTenType
UniTensor type.
@ Void
Definition Symmetry.hpp:32