Cytnx v1.0.0
Loading...
Searching...
No Matches
contraction_tree.hpp
Go to the documentation of this file.
1#ifndef CYTNX_CONTRACTION_TREE_H_
2#define CYTNX_CONTRACTION_TREE_H_
3
4#include "Type.hpp"
5#include "cytnx_error.hpp"
6#include "UniTensor.hpp"
7#include "utils/utils.hpp"
8#include <vector>
9#include <map>
10#include <string>
11#include <memory>
12
13#ifdef BACKEND_TORCH
14#else
15namespace cytnx {
17 class Node : public std::enable_shared_from_this<Node> {
18 public:
19 UniTensor utensor;
20 bool is_assigned;
21 std::shared_ptr<Node> left;
22 std::shared_ptr<Node> right;
23 std::weak_ptr<Node> root;
24 std::string name;
25
26 Node() : is_assigned(false) {}
27
28 Node(const Node& rhs)
29 : utensor(rhs.utensor),
30 is_assigned(rhs.is_assigned),
31 left(rhs.left),
32 right(rhs.right),
33 name(rhs.name) {
34 // Only copy root if it exists
35 if (auto r = rhs.root.lock()) {
36 root = r;
37 }
38 }
39
40 Node& operator=(const Node& rhs) {
41 if (this != &rhs) {
42 utensor = rhs.utensor;
43 is_assigned = rhs.is_assigned;
44 left = rhs.left;
45 right = rhs.right;
46 name = rhs.name;
47 if (auto r = rhs.root.lock()) {
48 root = r;
49 }
50 }
51 return *this;
52 }
53
54 Node(std::shared_ptr<Node> in_left, std::shared_ptr<Node> in_right,
55 const UniTensor& in_uten = UniTensor())
56 : is_assigned(false), left(in_left), right(in_right) {
57 // Set name based on children
58 if (left && right) {
59 name = "(" + left->name + "," + right->name + ")";
60 }
61
62 if (in_uten.uten_type() != UTenType.Void) {
63 utensor = in_uten;
64 }
65 }
66
67 void set_root_ptrs() {
68 try {
69 auto self = shared_from_this();
70
71 if (left) {
72 left->root = self;
73 left->set_root_ptrs();
74 }
75
76 if (right) {
77 right->root = self;
78 right->set_root_ptrs();
79 }
80 } catch (const std::bad_weak_ptr& e) {
81 std::cerr << "Failed to set root ptrs for node " << name << ": " << e.what() << std::endl;
82 throw;
83 }
84 }
85
86 void clear_utensor() {
87 if (left) {
88 left->clear_utensor();
89 left->root.reset();
90 }
91 if (right) {
92 right->clear_utensor();
93 right->root.reset();
94 }
95 is_assigned = false;
96 utensor = UniTensor();
97 }
98
99 void assign_utensor(const UniTensor& in_uten) {
100 utensor = in_uten;
101 is_assigned = true;
102 }
103 };
104
105 class ContractionTree {
106 public:
107 std::vector<std::shared_ptr<Node>> nodes_container; // intermediate layer
108 std::vector<std::shared_ptr<Node>> base_nodes; // bottom layer
109
110 ContractionTree() = default;
111 ContractionTree(const ContractionTree&) = default;
112 ContractionTree& operator=(const ContractionTree&) = default;
113
114 void clear() {
115 nodes_container.clear();
116 base_nodes.clear();
117 }
118
119 void reset_contraction_order() {
120 // First clear all root pointers
121 for (auto& node : base_nodes) {
122 if (node) node->root.reset();
123 }
124 // Then clear the container
125 nodes_container.clear();
126 }
127
128 void reset_nodes() {
129 // Clear from root down if we have nodes
130 if (!nodes_container.empty() && nodes_container.back()) {
131 nodes_container.back()->clear_utensor();
132 }
133 nodes_container.clear();
134
135 // Reset base nodes
136 for (auto& node : base_nodes) {
137 if (node) {
138 node->is_assigned = false;
139 node->utensor = UniTensor();
140 }
141 }
142 }
143
144 void build_default_contraction_tree();
145 void build_contraction_tree_by_tokens(const std::map<std::string, cytnx_uint64>& name2pos,
146 const std::vector<std::string>& tokens);
147 };
149} // namespace cytnx
150#endif // BACKEND_TORCH
151
152#endif // CYTNX_CONTRACTION_TREE_H_
Definition Accessor.hpp:12
UniTensorType_class UTenType
UniTensor type.