Cytnx v0.9.6
Loading...
Searching...
No Matches
contraction_tree.hpp
Go to the documentation of this file.
1#ifndef _H_contraction_tree_
2#define _H_contraction_tree_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include "UniTensor.hpp"
7#include "utils/utils.hpp"
8#include <vector>
9#include <map>
10#include <string>
11
12#ifdef BACKEND_TORCH
13#else
14namespace cytnx {
16 class Node {
17 public:
18 UniTensor utensor; // don't worry about copy, because everything are references in cytnx!
19 bool is_assigned;
20 Node *left;
21 Node *right;
22 std::string name;
23 Node *root;
24
25 Node() : is_assigned(false), left(nullptr), right(nullptr), root(nullptr){};
26 Node(const Node &rhs) {
27 this->left = rhs.left;
28 this->right = rhs.right;
29 this->root = rhs.root;
30 this->utensor = rhs.utensor;
31 this->is_assigned = rhs.is_assigned;
32 }
33 Node &operator==(const Node &rhs) {
34 this->left = rhs.left;
35 this->right = rhs.right;
36 this->root = rhs.root;
37 this->utensor = rhs.utensor;
38 this->is_assigned = rhs.is_assigned;
39 return *this;
40 }
41 Node(Node *in_left, Node *in_right, const UniTensor &in_uten = UniTensor())
42 : is_assigned(false), left(nullptr), right(nullptr), root(nullptr) {
43 this->left = in_left;
44 this->right = in_right;
45 in_left->root = this;
46 in_right->root = this;
47 if (in_uten.uten_type() != UTenType.Void) this->utensor = in_uten;
48 }
49 void assign_utensor(const UniTensor &in_uten) {
50 this->utensor = in_uten;
51 this->is_assigned = true;
52 }
53 void clear_utensor() {
54 this->is_assigned = false;
55 this->utensor = UniTensor();
56 }
57 };
58
59 class ContractionTree {
60 public:
61 std::vector<Node> nodes_container; // this contains intermediate layer.
62 std::vector<Node> base_nodes; // this is the button layer.
63
64 ContractionTree(){};
65 ContractionTree(const ContractionTree &rhs) {
66 this->nodes_container = rhs.nodes_container;
67 this->base_nodes = rhs.base_nodes;
68 }
69 ContractionTree &operator==(const ContractionTree &rhs) {
70 this->nodes_container = rhs.nodes_container;
71 this->base_nodes = rhs.base_nodes;
72 return *this;
73 }
74
75 // clear all the elements in the whole tree.
76 void clear() {
77 nodes_container.clear();
78 base_nodes.clear();
79 // nodes_container.reserve(1024);
80 }
81 // clear all the intermediate layer, leave all the base_nodes intact.
82 // and reset the root pointer on the base ondes
83 void reset_contraction_order() {
84 nodes_container.clear();
85 for (cytnx_uint64 i = 0; i < base_nodes.size(); i++) {
86 base_nodes[i].root = nullptr;
87 }
88 // nodes_container.reserve(1024);
89 }
90 void reset_nodes() {
91 // reset all nodes but keep the skeleton
92 for (cytnx_uint64 i = 0; i < this->nodes_container.size(); i++) {
93 this->nodes_container[i].clear_utensor();
94 }
95 for (cytnx_uint64 i = 0; i < this->base_nodes.size(); i++) {
96 this->base_nodes[i].clear_utensor();
97 }
98 }
99 void build_default_contraction_tree();
100 void build_contraction_tree_by_tokens(const std::map<std::string, cytnx_uint64> &name2pos,
101 const std::vector<std::string> &tokens);
102 };
104} // namespace cytnx
105#endif // BACKEND_TORCH
106
107#endif
Helper function to print vector with ODT:
Definition Accessor.hpp:12
UniTensorType_class UTenType
UniTensor type.
Tensor operator==(const Tensor &Lt, const Tensor &Rt)
The comparison operator for Tensor.