Cytnx v0.7.3
Loading...
Searching...
No Matches
LinOp.hpp
Go to the documentation of this file.
1#ifndef _H_LinOp_
2#define _H_LinOp_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include "Tensor.hpp"
7#include "Scalar.hpp"
8#include <vector>
9#include <fstream>
10#include <functional>
11#include <map>
12#include <utility>
13#include <algorithm>
15#include "utils/vec_clone.hpp"
16
17namespace cytnx{
18
19 class LinOp{
20 private:
21
22 // function pointer:
23 std::function<Tensor(const Tensor&)> _mvfunc;
24
25 // type:
26 std::string _type;
27
28 // nx
29 cytnx_uint64 _nx;
30
31 // device
32 int _device;
33 int _dtype;
34
35 // pre-storage data:
36 std::map<cytnx_uint64,std::pair<std::vector<cytnx_uint64>,Tensor> > _elems; //map[i] -> pair[<js>,<Storage>]
37 std::map<cytnx_uint64,std::pair<std::vector<cytnx_uint64>,Tensor> >::iterator _elems_it;
38
39 Tensor _mv_elemfunc(const Tensor &);
40
41
42 public:
43
45 // we need driver of void f(nx,vin,vout)
47
77 LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu, std::function<Tensor(const Tensor&)> custom_f = nullptr){
78
79 if(type=="mv"){
80 this->_mvfunc=custom_f;
81
82 }else if(type=="mv_elem"){
83 cytnx_error_msg(custom_f != nullptr,"[ERROR][LinOp] with type=mv_elem cannot accept any function. use set_elem/set_elemts instead.%s","\n");
84
85
86 }else
87 cytnx_error_msg(type!="mv","[ERROR][LinOp] currently only type=\"mv\" (matvec) can be used.%s","\n");
88
89 this->_type = type;
90 this->_nx = nx;
91 cytnx_error_msg(device<-1 || device >=Device.Ngpus,"[ERROR] invalid device.%s","\n");
92 this->_device = device;
93 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
94 this->_dtype = dtype;
95 };
96 void set_func(std::function<Tensor(const Tensor&)> custom_f, const int &dtype, const int &device){
97 if(this->_type=="mv"){
98 this->_mvfunc = custom_f;
99 cytnx_error_msg(device<-1 || device >=Device.Ngpus,"[ERROR] invalid device.%s","\n");
100 this->_device = device;
101 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
102 this->_dtype = dtype;
103 }else{
104 cytnx_error_msg(true,"[ERROR] cannot specify func with type=mv_elem%s. use set_elem instead.","\n");
105 }
106 };
107
108 template<class T>
109 void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true){
110 this->_elems_it = this->_elems.find(i);
111 if(this->_elems_it == this->_elems.end()){
112 //not exists:
113 Tensor x({1},this->_dtype);
114 x(0) = elem;
115 this->_elems[i] = std::pair<std::vector<cytnx_uint64>,Tensor>({j},x);
116
117 }else{
118 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
119 Tensor &ie = this->_elems_it->second.second;
120 if(check_exists){
121 cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is set%s","\n");
122 }
123 vi.push_back(j);
124 ie.append(elem);
125 }
126
127 };
129 //[Note that this can only call by mv_elem]
130 //if the element is not exists, it will create one.
131 this->_elems_it = this->_elems.find(i);
132 if(this->_elems_it == this->_elems.end()){
133 //not exists:
134 Tensor x({1},this->_dtype);
135 x(0) = 0;
136 this->_elems[i] = std::pair<std::vector<cytnx_uint64>,Tensor>({j},x);
137 return this->_elems[i].second(0);
138 }else{
139 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
140 Tensor &ie = this->_elems_it->second.second;
141 auto tmp_it = std::find(vi.begin(), vi.end(), j);
142
143 //if(check_exists){
144 // cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is set%s","\n");
145 //}
146 if(tmp_it==vi.end()){
147 vi.push_back(j);
148 ie.append(0);
149 return ie(vi.size()-1);
150 }else{
151 return ie(std::distance(vi.begin(), tmp_it));
152 }
153 }
154 }
155
156 void set_device(const int &device){
157 cytnx_error_msg(device<-1 || device >=Device.Ngpus,"[ERROR] invalid device.%s","\n");
158 this->_device = device;
159 };
160 void set_dtype(const int &dtype){
161 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
162 this->_dtype = dtype;
163 };
164 int device() const{
165 return this->_device;
166 };
167 int dtype() const{
168 return this->_dtype;
169 };
171 return this->_nx;
172 };
173
174 void _print();
175
177 // this expose to interitance:
178 // need user to check the output to be Tensor
180 virtual Tensor matvec(const Tensor &Tin);
181
182
183
184 };
185
186}
187
188
189#endif
Definition LinOp.hpp:19
LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu, std::function< Tensor(const Tensor &)> custom_f=nullptr)
Linear Operator class for iterative solvers.
Definition LinOp.hpp:77
virtual Tensor matvec(const Tensor &Tin)
Definition LinOp.cpp:67
void set_func(std::function< Tensor(const Tensor &)> custom_f, const int &dtype, const int &device)
Definition LinOp.hpp:96
int device() const
Definition LinOp.hpp:164
int dtype() const
Definition LinOp.hpp:167
cytnx_uint64 nx() const
Definition LinOp.hpp:170
void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true)
Definition LinOp.hpp:109
void set_device(const int &device)
Definition LinOp.hpp:156
void _print()
Definition LinOp.cpp:11
void set_dtype(const int &dtype)
Definition LinOp.hpp:160
Tensor::Tproxy operator()(const cytnx_uint64 &i, const cytnx_uint64 &j)
Definition LinOp.hpp:128
an tensor (multi-dimensional array)
Definition Tensor.hpp:289
int device() const
the device-id of the Tensor
Definition Tensor.hpp:668
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:18
Definition Accessor.hpp:12
Device_class Device
Definition Device.cpp:105
uint64_t cytnx_uint64
Definition Type.hpp:22
Type_class Type
Definition Type.cpp:137