Cytnx v1.0.0
Loading...
Searching...
No Matches
LinOp.hpp
Go to the documentation of this file.
1#ifndef CYTNX_LINOP_H_
2#define CYTNX_LINOP_H_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include <vector>
7#include <fstream>
8#include <functional>
9#include <map>
10#include <utility>
11#include <algorithm>
13#include "Tensor.hpp"
14#include "UniTensor.hpp"
15
16#ifdef BACKEND_TORCH
17#else
18namespace cytnx {
19
20 class LinOp {
21 private:
22 // type:
23 std::string _type;
24
25 // nx
26 cytnx_uint64 _nx;
27
28 // device
29 int _device;
30 int _dtype;
31
32 // pre-storage data:
33 std::map<cytnx_uint64, std::pair<std::vector<cytnx_uint64>, Tensor>>
34 _elems; // map[i] -> pair[<js>,<Storage>]
35 std::map<cytnx_uint64, std::pair<std::vector<cytnx_uint64>, Tensor>>::iterator _elems_it;
36
37 Tensor _mv_elemfunc(const Tensor &);
38
39 public:
41 // we need driver of void f(nx,vin,vout)
43
70 LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype = Type.Double,
71 const int &device = Device.cpu) {
72 if (type == "mv") {
73 } else if (type == "mv_elem") {
74 } else
75 cytnx_error_msg(type != "mv",
76 "[ERROR][LinOp] currently only type=\"mv\" (matvec) can be used.%s", "\n");
77
78 this->_type = type;
79 this->_nx = nx;
80 cytnx_error_msg(device < -1 || device >= Device.Ngpus, "[ERROR] invalid device.%s", "\n");
81 this->_device = device;
82 cytnx_error_msg(dtype < 1 || dtype >= N_Type, "[ERROR] invalid dtype.%s", "\n");
83 this->_dtype = dtype;
84 };
85 /*
86 void set_func(std::function<Tensor(const Tensor&)> custom_f, const int &dtype, const int
87 &device){ if(this->_type=="mv"){ this->_mvfunc = custom_f; cytnx_error_msg(device<-1 || device
88 >=Device.Ngpus,"[ERROR] invalid device.%s","\n"); this->_device = device;
89 cytnx_error_msg(dtype<1 || dtype >= N_Type,"[ERROR] invalid dtype.%s","\n");
90 this->_dtype = dtype;
91 }else{
92 cytnx_error_msg(true,"[ERROR] cannot specify func with type=mv_elem%s. use set_elem
93 instead.","\n");
94 }
95 };
96 */
97 template <class T>
98 void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem,
99 const bool check_exists = true) {
100 this->_elems_it = this->_elems.find(i);
101 if (this->_elems_it == this->_elems.end()) {
102 // not exists:
103 Tensor x({1}, this->_dtype);
104 x(0) = elem;
105 this->_elems[i] = std::pair<std::vector<cytnx_uint64>, Tensor>({j}, x);
106
107 } else {
108 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
109 Tensor &ie = this->_elems_it->second.second;
110 if (check_exists) {
111 cytnx_error_msg(std::find(vi.begin(), vi.end(), j) != vi.end(),
112 "[ERROR] the element is set%s", "\n");
113 }
114 vi.push_back(j);
115 ie.append(elem);
116 }
117 };
119 //[Note that this can only call by mv_elem]
120 // if the element is not exists, it will create one.
121 this->_elems_it = this->_elems.find(i);
122 if (this->_elems_it == this->_elems.end()) {
123 // not exists:
124 Tensor x({1}, this->_dtype);
125 x(0) = 0;
126 this->_elems[i] = std::pair<std::vector<cytnx_uint64>, Tensor>({j}, x);
127 return this->_elems[i].second(0);
128 } else {
129 std::vector<cytnx_uint64> &vi = this->_elems_it->second.first; // pair:
130 Tensor &ie = this->_elems_it->second.second;
131 auto tmp_it = std::find(vi.begin(), vi.end(), j);
132
133 // if(check_exists){
134 // cytnx_error_msg(std::find(vi.begin(), vi.end(), j)!=vi.end(),"[ERROR] the element is
135 // set%s","\n");
136 // }
137 if (tmp_it == vi.end()) {
138 vi.push_back(j);
139 ie.append(0);
140 return ie(vi.size() - 1);
141 } else {
142 return ie(std::distance(vi.begin(), tmp_it));
143 }
144 }
145 }
146
147 void set_device(const int &device) {
148 cytnx_error_msg(device < -1 || device >= Device.Ngpus, "[ERROR] invalid device.%s", "\n");
149 this->_device = device;
150 };
151 void set_dtype(const int &dtype) {
152 cytnx_error_msg(dtype < 1 || dtype >= N_Type, "[ERROR] invalid dtype.%s", "\n");
153 this->_dtype = dtype;
154 };
155 int device() const { return this->_device; };
156 int dtype() const { return this->_dtype; };
157 cytnx_uint64 nx() const { return this->_nx; };
158
159 void _print();
160
162 // this expose to interitance:
163 // need user to check the output to be Tensor
165 virtual Tensor matvec(const Tensor &Tin);
166
168 // this expose to interface:
169 virtual UniTensor matvec(const UniTensor &Tin);
170 // virtual std::vector<UniTensor> matvec(const std::vector<UniTensor> &Tin);
172 };
173
174} // namespace cytnx
175
176#endif // BACKEND_TORCH
177
178#endif // CYTNX_LINOP_H_
constexpr Type_class Type
data type
Definition Type.hpp:426
Definition LinOp.hpp:20
virtual Tensor matvec(const Tensor &Tin)
int device() const
Definition LinOp.hpp:155
LinOp(const std::string &type, const cytnx_uint64 &nx, const int &dtype=Type.Double, const int &device=Device.cpu)
Linear Operator class for iterative solvers.
Definition LinOp.hpp:70
int dtype() const
Definition LinOp.hpp:156
cytnx_uint64 nx() const
Definition LinOp.hpp:157
void set_elem(const cytnx_uint64 &i, const cytnx_uint64 &j, const T &elem, const bool check_exists=true)
Definition LinOp.hpp:98
void set_device(const int &device)
Definition LinOp.hpp:147
void set_dtype(const int &dtype)
Definition LinOp.hpp:151
Tensor::Tproxy operator()(const cytnx_uint64 &i, const cytnx_uint64 &j)
Definition LinOp.hpp:118
an tensor (multi-dimensional array)
Definition Tensor.hpp:41
int device() const
the device-id of the Tensor
Definition Tensor.hpp:581
An Enhanced tensor specifically designed for physical Tensor network simulation.
Definition UniTensor.hpp:2599
#define cytnx_error_msg(is_true, format,...)
Definition cytnx_error.hpp:18
Definition Accessor.hpp:12
Device_class Device
data on which devices.