forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tracer.h
212 lines (179 loc) · 6.02 KB
/
tracer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#pragma once
#include <ATen/Backtrace.h>
#include <ATen/core/functional.h>
#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/tracing_state.h>
#include <torch/csrc/utils/variadic.h>
#include <cstdint>
#include <iostream>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
namespace script {
struct Module;
}
namespace tracer {
using ::c10::ivalue::List;
using ::c10::ivalue::Shared;
using ::c10::IValue;
using ::c10::ivalue::Future;
using ::c10::ivalue::Tuple;
using ::c10::ivalue::BoolList;
using ::c10::ivalue::DoubleList;
using ::c10::ivalue::GenericList;
using ::c10::ivalue::IntList;
using ::c10::ivalue::TensorList;
using ::c10::ivalue::ConstantString;
using torch::autograd::Variable;
using variable_list = std::vector<Variable>;
TORCH_API void recordSourceLocation(Node* n);
TORCH_API void setRecordSourceLocation(void (*v)(Node*));
// Having finished adding a new 'node' to the graph IR 'setValueTrace'
// associates this node with an output variable, so that further operations
// involving this variable know which node in the IR to reference.
TORCH_API void setValueTrace(const IValue& v, Value* value);
TORCH_API void delValueTrace(const Variable& var);
TORCH_API std::function<void()> pauseTracing();
TORCH_API Value* getValueTrace(const IValue& var);
TORCH_API Value* getNestedValueTrace(const IValue& v);
TORCH_API Value* getOutputTrace(
const std::shared_ptr<TracingState>& state,
const Variable& var);
TORCH_API Value* getNestedOutputTrace(
const std::shared_ptr<TracingState>& state,
const IValue& iv);
struct TypedStack : public std::pair<Stack, TupleTypePtr>
{
using pair::pair;
// NB: The inherited default constructor gives nullptr for |type|,
// so we provide a saner one.
TypedStack()
: pair({}, TupleType::create({}))
{}
Stack& stack() {
return this->first;
}
TupleTypePtr& types() {
return this->second;
}
size_t size() {
auto s = stack().size();
AT_ASSERT(s == types()->elements().size());
return s;
}
};
TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs, const std::shared_ptr<script::Module>& self=nullptr);
TORCH_API void exit(const Stack& outputs);
TORCH_API void abandon();
// NB: those serve both as an intermediate steps in addInputs below,
// as well as the overloads that terminate template recursion
TORCH_API void addInputs(Node* n, const char* name, int64_t value);
TORCH_API void addInputs(
Node* n,
const char* name,
c10::optional<int64_t> value);
TORCH_API void addInputs(Node* n, const char* name, bool value);
TORCH_API void addInputs(
Node* n,
const char* name,
const c10::optional<bool>& value);
TORCH_API void addInputs(Node* n, const char* name, double value);
TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
TORCH_API void addInputs(
Node* n,
const char* name,
const c10::optional<at::Scalar>& value);
TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
TORCH_API void addInputs(Node* n, const char* name, at::IntArrayRef value);
TORCH_API void addInputs(
Node* n,
const char* name,
at::TensorList value,
bool allow_undefined = false);
TORCH_API void addInputs(
Node* n,
const char* name,
const ArrayRef<double>& value);
TORCH_API void addInputs(
Node* n,
const char* name,
const std::vector<double>& value);
TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
TORCH_API void addInputs(
Node* n,
const char* name,
const at::SparseTensorRef& value);
TORCH_API void addInputs(
Node* n,
const char* name,
const at::TensorOptions& value);
TORCH_API void addInputs(Node* n, const char* name, at::Device value);
TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
TORCH_API void addInputs(
Node* n,
const char* name,
const c10::optional<at::ScalarType>& value);
TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
template<typename T>
TORCH_API void addInputs(
Node* n,
const char* name,
const std::vector<T>& value);
template<typename K, typename V>
TORCH_API void addInputs(
Node* n,
const char* name,
const std::unordered_map<K, V>& value);
template<typename T>
void addInputs(
Node* n,
const char* name,
const std::vector<T>& value) {
AT_ERROR("Tracing a list of arbitrary type is currently not supported!");
}
template<typename K, typename V>
void addInputs(
Node* n,
const char* name,
const std::unordered_map<K, V>& value) {
AT_ERROR("Tracing a dict of arbitrary types is currently not supported!");
}
template <size_t N>
void addInputs(Node* n, const char* name, std::array<bool, N> value) {
throw std::runtime_error(
"Found an unsupported argument type in the JIT tracer. File a bug report.");
}
TORCH_API void ensureUniqueIfOutOfPlaced(
const char* name,
const at::Tensor& tensor);
template <
typename T,
typename = torch::enable_if_t<
(!std::is_convertible<torch::decay_t<T>, at::TensorList>::value &&
!std::is_convertible<torch::decay_t<T>, at::Tensor>::value)>>
void addOutput(Node* node, T&&) {
AT_ERROR(
"Found an unsupported argument type ",
c10::demangle_type<T>(),
" in the JIT tracer. File a bug report.");
}
TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
TORCH_API void setOutput(Value* value, const at::Tensor& output);
TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
TORCH_API autograd::Variable getSizeOf(
const autograd::Variable& var,
int64_t dim);
} // namespace tracer
} // namespace jit
} // namespace torch