Cytnx v0.9.1
Loading...
Searching...
No Matches
LinOp.hpp
Go to the documentation of this file.
1#ifndef _H_LinOp_
2#define _H_LinOp_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include "Tensor.hpp"
7#include "Scalar.hpp"
8#include "UniTensor.hpp"
9#include <vector>
10#include <fstream>
11#include <functional>
12#include <map>
13#include <utility>
14#include <algorithm>
16#include "utils/vec_clone.hpp"
17
18namespace cytnx {
19
20 class LinOp {
21 private:
22 // type:
23 std::string _type;
24
25 // nx
26 cytnx_uint64 _nx;
27
28 // device
29 int _device;
30 int _dtype;
31
32 // pre-storage data:
33 std::map<cytnx_uint64, std::pair<std::vector<cytnx_uint64>, Tensor>>
34 _elems; // map[i] -> pair[<js>,<Storage>]
35 std::map<cytnx_uint64, std::pair<std::vector<cytnx_uint64>, Tensor>>::iterator _elems_it;
36
37 Tensor _mv_elemfunc(const Tensor &);
38
39 public:
41 // we need driver of void f(nx,vin,vout)
43
74 LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype = Type.Double,
75 const int &device = Device.cpu) {
76 if (type == "mv") {
77 } else if (type == "mv_elem") {
78 } else
79 cytnx_error_msg(type != "mv",
80 "[ERROR][LinOp] currently only type=\"mv\" (matvec) can be used.%s", "\n");
81
82 this->_type = type;
83 this->_nx = nx;
84 cytnx_error_msg(device < -1 || device >= Device.Ngpus, "[ERROR] invalid device.%s", "\n");
85 this->_device = device;
86 cytnx_error_msg(dtype < 1 || dtype >= N_Type, "[ERROR] invalid dtype.%s", "\n");
87 this->_dtype = dtype;
88 };
89 /*
90 void set_func(std::function<Tensor(const Tensor&)> custom_f, const int &dtype, const int
91 &device){ if(this->_type=="mv"){ this->_mvfunc = custom_f; cytnx_error_msg(device<-1 || device
92 >=Device.Ngpus,"[ERROR] invalid device.%s","\n"); this->_device = device;
93 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
94 this->_dtype = dtype;
95 }else{
96 cytnx_error_msg(true,"[ERROR] cannot specify func with type=mv_elem%s. use set_elem
97 instead.","\n");
98 }
99 };
100 */
101 template <class T>
102 void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem,
103 const bool check_exists = true) {
104 this->_elems_it = this->_elems.find(i);
105 if (this->_elems_it == this->_elems.end()) {
106 // not exists:
107 Tensor x({1}, this->_dtype);
108 x(0) = elem;
109 this->_elems[i] = std::pair<std::vector<cytnx_uint64>, Tensor>({j}, x);
110
111 } else {
112 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
113 Tensor &ie = this->_elems_it->second.second;
114 if (check_exists) {
115 cytnx_error_msg(std::find(vi.begin(), vi.end(), j) != vi.end(),
116 "[ERROR] the element is set%s", "\n");
117 }
118 vi.push_back(j);
119 ie.append(elem);
120 }
121 };
122 Tensor::Tproxy operator()(const cytnx_uint64 &i, const cytnx_uint64 &j) {
123 //[Note that this can only call by mv_elem]
124 // if the element is not exists, it will create one.
125 this->_elems_it = this->_elems.find(i);
126 if (this->_elems_it == this->_elems.end()) {
127 // not exists:
128 Tensor x({1}, this->_dtype);
129 x(0) = 0;
130 this->_elems[i] = std::pair<std::vector<cytnx_uint64>, Tensor>({j}, x);
131 return this->_elems[i].second(0);
132 } else {
133 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
134 Tensor &ie = this->_elems_it->second.second;
135 auto tmp_it = std::find(vi.begin(), vi.end(), j);
136
137 // if(check_exists){
138 // cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is
139 // set%s","\n");
140 // }
141 if (tmp_it == vi.end()) {
142 vi.push_back(j);
143 ie.append(0);
144 return ie(vi.size() - 1);
145 } else {
146 return ie(std::distance(vi.begin(), tmp_it));
147 }
148 }
149 }
150
151 void set_device(const int &device) {
152 cytnx_error_msg(device < -1 || device >= Device.Ngpus, "[ERROR] invalid device.%s", "\n");
153 this->_device = device;
154 };
155 void set_dtype(const int &dtype) {
156 cytnx_error_msg(dtype < 1 || dtype >= N_Type, "[ERROR] invalid dtype.%s", "\n");
157 this->_dtype = dtype;
158 };
159 int device() const { return this->_device; };
160 int dtype() const { return this->_dtype; };
161 cytnx_uint64 nx() const { return this->_nx; };
162
163 void _print();
164
166 // this expose to interitance:
167 // need user to check the output to be Tensor
169 virtual Tensor matvec(const Tensor &Tin);
170
172 // this expose to interface:
173 virtual UniTensor matvec(const UniTensor &Tin);
174 // virtual std::vector<UniTensor> matvec(const std::vector<UniTensor> &Tin);
176 };
177
178} // namespace cytnx
179
180#endif
Definition LinOp.hpp:20
int device() const
Definition LinOp.hpp:159
LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu)
Linear Operator class for iterative solvers.
Definition LinOp.hpp:74
int dtype() const
Definition LinOp.hpp:160
cytnx_uint64 nx() const
Definition LinOp.hpp:161
virtual Tensor matvec(const Tensor &Tin)
Definition LinOp.cpp:64
void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true)
Definition LinOp.hpp:102
void set_device(const int &device)
Definition LinOp.hpp:151
void _print()
Definition LinOp.cpp:10
void set_dtype(const int &dtype)
Definition LinOp.hpp:155
Tensor::Tproxy operator()(const cytnx_uint64 &i, const cytnx_uint64 &j)
Definition LinOp.hpp:122
an tensor (multi-dimensional array)
Definition Tensor.hpp:345
void append(const Tensor &rhs)
the append function.
Definition Tensor.hpp:1584
An Enhanced tensor specifically designed for physical Tensor network simulation.
Definition UniTensor.hpp:2449
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:16
Definition Accessor.hpp:12
Device_class Device
data on which devices.
Definition Device.cpp:140
uint64_t cytnx_uint64
Definition Type.hpp:45
Type_class Type
data type
Definition Type.cpp:23