forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PythonFallbackKernel.cpp
208 lines (172 loc) · 8.09 KB
/
PythonFallbackKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <ATen/core/PythonFallbackKernel.h>
#include <c10/core/SafePyObject.h>
namespace {
// This TLS is used to track the state of the dispatcher to be able to restore
// it when calling back into python.
// It has the following invariant:
// - It must be empty while python code is executed.
// - It should only be set once even for multiple dispatcher calls that do not come
// back to python.
// To achieve this, we ensure that the tls is empty by default and emptied again both when
// we call into user torch_dispatch or returning back to python after this call.
thread_local std::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
c10::impl::LocalDispatchKeySet safe_get_tls_on_entry() {
TORCH_CHECK(tls_on_entry.has_value(), "Accessing torch dispatch state outside of '__torch_dispatch__' "
"is not allowed.");
return tls_on_entry.value();
}
// All the keys below the Python key
constexpr c10::DispatchKeySet after_Python_keyset = c10::DispatchKeySet(c10::DispatchKeySet::FULL) ^
(c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python) |
c10::DispatchKeySet(c10::DispatchKey::Python));
// This guard assumes that tls_on_entry has a value.
struct StashTLSOnEntryGuard {
public:
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
StashTLSOnEntryGuard(): saved_(tls_on_entry.value()) {
tls_on_entry = std::nullopt;
}
StashTLSOnEntryGuard(const StashTLSOnEntryGuard&) = delete;
StashTLSOnEntryGuard(StashTLSOnEntryGuard&&) = delete;
StashTLSOnEntryGuard& operator=(const StashTLSOnEntryGuard&) = delete;
StashTLSOnEntryGuard& operator=(StashTLSOnEntryGuard&&) = delete;
~StashTLSOnEntryGuard() {
TORCH_INTERNAL_ASSERT(!tls_on_entry.has_value());
tls_on_entry = saved_;
}
private:
c10::impl::LocalDispatchKeySet saved_;
};
void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
// StashTLSOnEntryGuard stash_guard;
c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
if (mode_stack_len > 0) {
const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
return;
}
// Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
// without checking the interpreters of any of the arguments, because when
// we actually run dispatch(), we will take out PyObjects in the context
// of that interpreter, and this will ensure that everyone is on the same
// interpreter.
bool tensors_with_python_key_present = false;
c10::impl::PyInterpreter* interpreter = nullptr;
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
if (ivalue.isTensor()) {
auto* t = ivalue.unsafeToTensorImpl();
if (t->key_set().has(c10::DispatchKey::Python)) {
tensors_with_python_key_present = true;
}
if (!interpreter) {
auto* t_interpreter = t->pyobj_slot()->pyobj_interpreter();
if (t_interpreter) {
interpreter = t_interpreter;
}
}
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
// NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
// is not a thing)
for (const auto& nv : ivalue.toListRef()) {
if (nv.isNone()) {
continue;
}
auto* t = nv.unsafeToTensorImpl();
if (t->key_set().has(c10::DispatchKey::Python)) {
tensors_with_python_key_present = true;
}
if (!interpreter) {
auto* t_interpreter = t->pyobj_slot()->pyobj_interpreter();
if (t_interpreter) {
interpreter = t_interpreter;
}
}
}
}
}
if (interpreter) {
if (tensors_with_python_key_present) {
(*interpreter)->dispatch(op, stack);
} else {
// At this point, there are no modes in the stack and no tensors with the python key.
// so disable the python key before redispatching.
// See https://github.com/pytorch/pytorch/issues/136565
c10::DispatchKeySet keyset = dispatch_keys.remove(c10::DispatchKey::Python);
// Remove Python key from the included set as well (modes add it there).
c10::impl::LocalDispatchKeySet local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::impl::ForceDispatchKeyGuard no_python_guard(
local_keyset.included_.remove(c10::DispatchKey::Python),
local_keyset.excluded_
);
op.redispatchBoxed(keyset, stack);
}
return;
}
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
}
void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
auto* state = c10::impl::PythonDispatcherTLS::get_state();
TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set");
(*state)->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack);
}
void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// It is ok for the tls to be already set here.
// It means that there are multiple calls into the dispatcher not originating from python code.
// The guard below will properly ignore such calls.
at::impl::MaybeSetTLSOnEntryGuard guard;
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
}
// The PreDispatch key gets a no-op fallback that just redispatches.
// The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
// Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
// Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
// Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
// Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
}
} // anonymous namespace
namespace at::impl {
RestorePythonTLSSnapshot::RestorePythonTLSSnapshot() : saved_(safe_get_tls_on_entry()), guard_(safe_get_tls_on_entry()) {
tls_on_entry = std::nullopt;
}
RestorePythonTLSSnapshot::~RestorePythonTLSSnapshot() {
TORCH_INTERNAL_ASSERT(!tls_on_entry.has_value());
tls_on_entry = saved_;
}
MaybeSetTLSOnEntryGuard::MaybeSetTLSOnEntryGuard() {
if (tls_on_entry.has_value()) {
value_set_ = false;
} else {
value_set_ = true;
tls_on_entry = c10::impl::tls_local_dispatch_key_set();
}
}
MaybeSetTLSOnEntryGuard::~MaybeSetTLSOnEntryGuard() {
if (value_set_) {
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
tls_on_entry = std::nullopt;
}
}
} // namespace at::impl
TORCH_LIBRARY_IMPL(_, Python, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
}
TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>());
}
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
}
TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
}