Cytnx v0.9.3
Loading...
Searching...
No Matches
LinOp.hpp
Go to the documentation of this file.
1#ifndef _H_LinOp_
2#define _H_LinOp_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include <vector>
7#include <fstream>
8#include <functional>
9#include <map>
10#include <utility>
11#include <algorithm>
13#include "Tensor.hpp"
14#include "UniTensor.hpp"
15#include "utils/vec_clone.hpp"
16
17#ifdef BACKEND_TORCH
18#else
19namespace cytnx {
20
21 class LinOp {
22 private:
23 // type:
24 std::string _type;
25
26 // nx
27 cytnx_uint64 _nx;
28
29 // device
30 int _device;
31 int _dtype;
32
33 // pre-storage data:
34 std::map<cytnx_uint64, std::pair<std::vector<cytnx_uint64>, Tensor>>
35 _elems; // map[i] -> pair[<js>,<Storage>]
36 std::map<cytnx_uint64, std::pair<std::vector<cytnx_uint64>, Tensor>>::iterator _elems_it;
37
38 Tensor _mv_elemfunc(const Tensor &);
39
40 public:
42 // we need driver of void f(nx,vin,vout)
44
75 LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype = Type.Double,
76 const int &device = Device.cpu) {
77 if (type == "mv") {
78 } else if (type == "mv_elem") {
79 } else
80 cytnx_error_msg(type != "mv",
81 "[ERROR][LinOp] currently only type=\"mv\" (matvec) can be used.%s", "\n");
82
83 this->_type = type;
84 this->_nx = nx;
85 cytnx_error_msg(device < -1 || device >= Device.Ngpus, "[ERROR] invalid device.%s", "\n");
86 this->_device = device;
87 cytnx_error_msg(dtype < 1 || dtype >= N_Type, "[ERROR] invalid dtype.%s", "\n");
88 this->_dtype = dtype;
89 };
90 /*
91 void set_func(std::function<Tensor(const Tensor&)> custom_f, const int &dtype, const int
92 &device){ if(this->_type=="mv"){ this->_mvfunc = custom_f; cytnx_error_msg(device<-1 || device
93 >=Device.Ngpus,"[ERROR] invalid device.%s","\n"); this->_device = device;
94 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
95 this->_dtype = dtype;
96 }else{
97 cytnx_error_msg(true,"[ERROR] cannot specify func with type=mv_elem%s. use set_elem
98 instead.","\n");
99 }
100 };
101 */
102 template <class T>
103 void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem,
104 const bool check_exists = true) {
105 this->_elems_it = this->_elems.find(i);
106 if (this->_elems_it == this->_elems.end()) {
107 // not exists:
108 Tensor x({1}, this->_dtype);
109 x(0) = elem;
110 this->_elems[i] = std::pair<std::vector<cytnx_uint64>, Tensor>({j}, x);
111
112 } else {
113 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
114 Tensor &ie = this->_elems_it->second.second;
115 if (check_exists) {
116 cytnx_error_msg(std::find(vi.begin(), vi.end(), j) != vi.end(),
117 "[ERROR] the element is set%s", "\n");
118 }
119 vi.push_back(j);
120 ie.append(elem);
121 }
122 };
124 //[Note that this can only call by mv_elem]
125 // if the element is not exists, it will create one.
126 this->_elems_it = this->_elems.find(i);
127 if (this->_elems_it == this->_elems.end()) {
128 // not exists:
129 Tensor x({1}, this->_dtype);
130 x(0) = 0;
131 this->_elems[i] = std::pair<std::vector<cytnx_uint64>, Tensor>({j}, x);
132 return this->_elems[i].second(0);
133 } else {
134 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
135 Tensor &ie = this->_elems_it->second.second;
136 auto tmp_it = std::find(vi.begin(), vi.end(), j);
137
138 // if(check_exists){
139 // cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is
140 // set%s","\n");
141 // }
142 if (tmp_it == vi.end()) {
143 vi.push_back(j);
144 ie.append(0);
145 return ie(vi.size() - 1);
146 } else {
147 return ie(std::distance(vi.begin(), tmp_it));
148 }
149 }
150 }
151
152 void set_device(const int &device) {
153 cytnx_error_msg(device < -1 || device >= Device.Ngpus, "[ERROR] invalid device.%s", "\n");
154 this->_device = device;
155 };
156 void set_dtype(const int &dtype) {
157 cytnx_error_msg(dtype < 1 || dtype >= N_Type, "[ERROR] invalid dtype.%s", "\n");
158 this->_dtype = dtype;
159 };
160 int device() const { return this->_device; };
161 int dtype() const { return this->_dtype; };
162 cytnx_uint64 nx() const { return this->_nx; };
163
164 void _print();
165
167 // this expose to interitance:
168 // need user to check the output to be Tensor
170 virtual Tensor matvec(const Tensor &Tin);
171
173 // this expose to interface:
174 virtual UniTensor matvec(const UniTensor &Tin);
175 // virtual std::vector<UniTensor> matvec(const std::vector<UniTensor> &Tin);
177 };
178
179} // namespace cytnx
180
181#endif // BACKEND_TORCH
182
183#endif
Definition LinOp.hpp:21
virtual Tensor matvec(const Tensor &Tin)
int device() const
Definition LinOp.hpp:160
LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu)
Linear Operator class for iterative solvers.
Definition LinOp.hpp:75
int dtype() const
Definition LinOp.hpp:161
cytnx_uint64 nx() const
Definition LinOp.hpp:162
void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true)
Definition LinOp.hpp:103
void set_device(const int &device)
Definition LinOp.hpp:152
void set_dtype(const int &dtype)
Definition LinOp.hpp:156
Tensor::Tproxy operator()(const cytnx_uint64 &i, const cytnx_uint64 &j)
Definition LinOp.hpp:123
an tensor (multi-dimensional array)
Definition Tensor.hpp:41
int device() const
the device-id of the Tensor
Definition Tensor.hpp:521
An Enhanced tensor specifically designed for physical Tensor network simulation.
Definition UniTensor.hpp:1706
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:16
Helper function to print vector with ODT:
Definition Accessor.hpp:12
Device_class Device
data on which devices.
uint64_t cytnx_uint64
Definition Type.hpp:55
Type_class Type
data type