Cytnx v0.7.4
Loading...
Searching...
No Matches
LinOp.hpp
Go to the documentation of this file.
1#ifndef _H_LinOp_
2#define _H_LinOp_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include "Tensor.hpp"
7#include "Scalar.hpp"
8#include "UniTensor.hpp"
9#include <vector>
10#include <fstream>
11#include <functional>
12#include <map>
13#include <utility>
14#include <algorithm>
16#include "utils/vec_clone.hpp"
17
18namespace cytnx{
19
20 class LinOp{
21 private:
22
23
24 // type:
25 std::string _type;
26
27 // nx
28 cytnx_uint64 _nx;
29
30 // device
31 int _device;
32 int _dtype;
33
34 // pre-storage data:
35 std::map<cytnx_uint64,std::pair<std::vector<cytnx_uint64>,Tensor> > _elems; //map[i] -> pair[<js>,<Storage>]
36 std::map<cytnx_uint64,std::pair<std::vector<cytnx_uint64>,Tensor> >::iterator _elems_it;
37
38 Tensor _mv_elemfunc(const Tensor &);
39
40
41 public:
42
44 // we need driver of void f(nx,vin,vout)
46
74 LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu){
75
76 if(type=="mv"){
77
78 }else if(type=="mv_elem"){
79
80
81 }else
82 cytnx_error_msg(type!="mv","[ERROR][LinOp] currently only type=\"mv\" (matvec) can be used.%s","\n");
83
84 this->_type = type;
85 this->_nx = nx;
86 cytnx_error_msg(device<-1 || device >=Device.Ngpus,"[ERROR] invalid device.%s","\n");
87 this->_device = device;
88 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
89 this->_dtype = dtype;
90 };
91 /*
92 void set_func(std::function<Tensor(const Tensor&)> custom_f, const int &dtype, const int &device){
93 if(this->_type=="mv"){
94 this->_mvfunc = custom_f;
95 cytnx_error_msg(device<-1 || device >=Device.Ngpus,"[ERROR] invalid device.%s","\n");
96 this->_device = device;
97 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
98 this->_dtype = dtype;
99 }else{
100 cytnx_error_msg(true,"[ERROR] cannot specify func with type=mv_elem%s. use set_elem instead.","\n");
101 }
102 };
103 */
104 template<class T>
105 void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true){
106 this->_elems_it = this->_elems.find(i);
107 if(this->_elems_it == this->_elems.end()){
108 //not exists:
109 Tensor x({1},this->_dtype);
110 x(0) = elem;
111 this->_elems[i] = std::pair<std::vector<cytnx_uint64>,Tensor>({j},x);
112
113 }else{
114 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
115 Tensor &ie = this->_elems_it->second.second;
116 if(check_exists){
117 cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is set%s","\n");
118 }
119 vi.push_back(j);
120 ie.append(elem);
121 }
122
123 };
125 //[Note that this can only call by mv_elem]
126 //if the element is not exists, it will create one.
127 this->_elems_it = this->_elems.find(i);
128 if(this->_elems_it == this->_elems.end()){
129 //not exists:
130 Tensor x({1},this->_dtype);
131 x(0) = 0;
132 this->_elems[i] = std::pair<std::vector<cytnx_uint64>,Tensor>({j},x);
133 return this->_elems[i].second(0);
134 }else{
135 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
136 Tensor &ie = this->_elems_it->second.second;
137 auto tmp_it = std::find(vi.begin(), vi.end(), j);
138
139 //if(check_exists){
140 // cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is set%s","\n");
141 //}
142 if(tmp_it==vi.end()){
143 vi.push_back(j);
144 ie.append(0);
145 return ie(vi.size()-1);
146 }else{
147 return ie(std::distance(vi.begin(), tmp_it));
148 }
149 }
150 }
151
152 void set_device(const int &device){
153 cytnx_error_msg(device<-1 || device >=Device.Ngpus,"[ERROR] invalid device.%s","\n");
154 this->_device = device;
155 };
156 void set_dtype(const int &dtype){
157 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
158 this->_dtype = dtype;
159 };
160 int device() const{
161 return this->_device;
162 };
163 int dtype() const{
164 return this->_dtype;
165 };
167 return this->_nx;
168 };
169
170 void _print();
171
173 // this expose to interitance:
174 // need user to check the output to be Tensor
176 virtual Tensor matvec(const Tensor &Tin);
177
179 // this expose to interface:
180 virtual UniTensor matvec(const UniTensor &Tin);
181 //virtual std::vector<UniTensor> matvec(const std::vector<UniTensor> &Tin);
183
184
185 };
186
187}
188
189
190#endif
Definition LinOp.hpp:20
int device() const
Definition LinOp.hpp:160
LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu)
Linear Operator class for iterative solvers.
Definition LinOp.hpp:74
int dtype() const
Definition LinOp.hpp:163
cytnx_uint64 nx() const
Definition LinOp.hpp:166
virtual Tensor matvec(const Tensor &Tin)
Definition LinOp.cpp:67
void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true)
Definition LinOp.hpp:105
void set_device(const int &device)
Definition LinOp.hpp:152
void _print()
Definition LinOp.cpp:11
void set_dtype(const int &dtype)
Definition LinOp.hpp:156
Tensor::Tproxy operator()(const cytnx_uint64 &i, const cytnx_uint64 &j)
Definition LinOp.hpp:124
an tensor (multi-dimensional array)
Definition Tensor.hpp:333
int device() const
the device-id of the Tensor
Definition Tensor.hpp:719
An Enhanced tensor specifically designed for physical Tensor network simulation.
Definition UniTensor.hpp:1123
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:18
Definition Accessor.hpp:12
Device_class Device
Definition Device.cpp:105
uint64_t cytnx_uint64
Definition Type.hpp:22
Type_class Type
Definition Type.cpp:137