forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBinaryOps.cpp
312 lines (263 loc) · 13.8 KB
/
BinaryOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#include <ATen/native/BinaryOps.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
namespace at {
namespace native {
DEFINE_DISPATCH(add_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(logical_xor_stub);
DEFINE_DISPATCH(lt_stub);
DEFINE_DISPATCH(le_stub);
DEFINE_DISPATCH(gt_stub);
DEFINE_DISPATCH(ge_stub);
DEFINE_DISPATCH(eq_stub);
DEFINE_DISPATCH(ne_stub);
static inline void alpha_check(const TensorIterator& iter, Scalar alpha) {
TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool,
"Boolean alpha only supported for Boolean results.");
TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(true),
"For integral input tensors, argument alpha must not be a floating point number.");
}
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
alpha_check(iter, alpha);
add_stub(iter.device_type(), iter, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype());
return result;
}
Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter, alpha);
add_stub(iter.device_type(), iter, alpha);
return iter.output();
}
Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::add_out(self, self, other, alpha);
}
Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
div_stub(iter.device_type(), iter);
return result;
}
Tensor div(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& div_(Tensor& self, const Tensor& other) {
return native::div_out(self, self, other);
}
Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
mul_stub(iter.device_type(), iter);
return result;
}
Tensor mul(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& mul_(Tensor& self, const Tensor& other) {
return native::mul_out(self, self, other);
}
// Basic checking for all sub functions.
static inline void sub_check(const Tensor& self, const Tensor& other) {
TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
"Subtraction, the `-` operator, with two bool tensors is not supported. "
"Use the `^` or `logical_xor()` operator instead.")
TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
"Subtraction, the `-` operator, with a bool tensor is not supported. "
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
}
Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
sub_check(self, other);
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
alpha_check(iter, alpha);
sub_stub(iter.device_type(), iter, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype());
return result;
}
Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) {
sub_check(self, other);
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter, alpha);
sub_stub(iter.device_type(), iter, alpha);
return iter.output();
}
Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::sub_out(self, self, other, alpha);
}
Tensor rsub(const Tensor& self, const Tensor& other, Scalar alpha) {
return native::sub(other, self, alpha);
}
Tensor& atan2_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
atan2_stub(iter.device_type(), iter);
return result;
}
Tensor atan2(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
return native::atan2_out(result, self, other);
}
Tensor& atan2_(Tensor& self, const Tensor& other) {
return native::atan2_out(self, self, other);
}
// These are still needed because we don't have C++ conversions from number
// types (int, float, etc.) to Tensor (only to Scalar). They're not exposed
// to Python.
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
static void check_convert(Scalar scalar, ScalarType scalarType) {
// Validate that is possible to convert scalar to tensor dtype without overflow
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
scalar.to<scalar_t>();
});
}
static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tensor) {
check_convert(scalar, tensor.scalar_type());
return wrapped_scalar_tensor(scalar);
}
Tensor add(const Tensor& self, Scalar other, Scalar alpha) {
return native::add(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& add_(Tensor& self, Scalar other, Scalar alpha) {
return native::add_(self, wrapped_scalar_tensor(other), alpha);
}
// WARNING: There doesn't appear to be any testing for this function
// with sparse self input.
Tensor div(const Tensor& self, Scalar other) {
return self.div(wrapped_scalar_tensor(other)); // redispatch!
}
// WARNING: This function, with a sparse self, is currently only
// exercised by DistributedDataParallelTest.test_sparse_gradients
// (you need to exercise it from C++, because this overload is never
// used for Python)
Tensor& div_(Tensor& self, Scalar other) {
return self.div_(wrapped_scalar_tensor(other)); // redispatch!
}
Tensor mul(const Tensor& self, Scalar other) {
return native::mul(self, wrapped_scalar_tensor(other));
}
Tensor& mul_(Tensor& self, Scalar other) {
return native::mul_(self, wrapped_scalar_tensor(other));
}
Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
return native::sub(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
return native::sub_(self, wrapped_scalar_tensor(other), alpha);
}
Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
}
template <typename Stub>
static inline Tensor& comparison_op_impl_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
auto iter = TensorIterator::comparison_op(result, self, other,
/*check_mem_overlap=*/true);
stub(iter.device_type(), iter);
return result;
}
template <typename Stub>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
TORCH_CHECK(result.scalar_type() == kBool,
"The output tensor of a comparison or logical op must be a bool, but was ", result.scalar_type());
// Validate that is possible to convert zero-dim tensor's dtype to other dtype without overflow
if (self.scalar_type() != other.scalar_type()) {
if (self.dim() != 0 && other.dim() == 0) {
check_convert(other.item(), self.scalar_type());
} else if (self.dim() == 0 && other.dim() != 0) {
check_convert(self.item(), other.scalar_type());
}
}
return native::comparison_op_impl_out(result, self, other, stub);
}
template <typename Stub>
Tensor comparison_op(const Tensor& self, const Tensor& other, Stub& stub) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return native::comparison_op_out(result, self, other, stub);
}
// To avoid overflow during type promotion we will check that both dtypes of self and other are same
template <typename Stub>
Tensor& comparison_op_(Tensor& self, const Tensor& other, Stub& stub) {
TORCH_CHECK(self.dtype() == other.dtype(),
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
other.dtype(), " for argument 'other'");
return native::comparison_op_impl_out(self, self, other, stub);
}
// validates that is possible to convert Scalar other to self's dtype without overflow.
// This behavior is unique to comparison ops; arithmetic operations don't do this.
// In the future, we should reconsider this inconsistency and decide if we want to add the same check to arithmetic ops.
template <typename Stub>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, Scalar other, Stub& stub) {
return native::comparison_op_out(result, self, wrapped_scalar_tensor_and_check_convert(other, self), stub);
}
template <typename Stub>
Tensor comparison_op(const Tensor& self, Scalar other, Stub& stub) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return native::comparison_op_out(result, self, other, stub);
}
template <typename Stub>
Tensor& comparison_op_(Tensor& self, Scalar other, Stub& stub) {
return native::comparison_op_impl_out(self, self, wrapped_scalar_tensor_and_check_convert(other, self), stub);
}
Tensor& lt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, lt_stub); }
Tensor lt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, lt_stub); }
Tensor& lt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, lt_stub); }
Tensor& lt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, lt_stub); }
Tensor lt(const Tensor& self, Scalar other) { return comparison_op(self, other, lt_stub); }
Tensor& lt_(Tensor& self, Scalar other) { return comparison_op_(self, other, lt_stub); }
Tensor& le_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, le_stub); }
Tensor le(const Tensor& self, const Tensor& other) { return comparison_op(self, other, le_stub); }
Tensor& le_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, le_stub); }
Tensor& le_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, le_stub); }
Tensor le(const Tensor& self, Scalar other) { return comparison_op(self, other, le_stub); }
Tensor& le_(Tensor& self, Scalar other) { return comparison_op_(self, other, le_stub); }
Tensor& gt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, gt_stub); }
Tensor gt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, gt_stub); }
Tensor& gt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, gt_stub); }
Tensor& gt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, gt_stub); }
Tensor gt(const Tensor& self, Scalar other) { return comparison_op(self, other, gt_stub); }
Tensor& gt_(Tensor& self, Scalar other) { return comparison_op_(self, other, gt_stub); }
Tensor& ge_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, ge_stub); }
Tensor ge(const Tensor& self, const Tensor& other) { return comparison_op(self, other, ge_stub); }
Tensor& ge_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, ge_stub); }
Tensor& ge_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, ge_stub); }
Tensor ge(const Tensor& self, Scalar other) { return comparison_op(self, other, ge_stub); }
Tensor& ge_(Tensor& self, Scalar other) { return comparison_op_(self, other, ge_stub); }
Tensor& eq_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, eq_stub); }
Tensor eq(const Tensor& self, const Tensor& other) { return comparison_op(self, other, eq_stub); }
Tensor& eq_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, eq_stub); }
Tensor& eq_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, eq_stub); }
Tensor eq(const Tensor& self, Scalar other) { return comparison_op(self, other, eq_stub); }
Tensor& eq_(Tensor& self, Scalar other) { return comparison_op_(self, other, eq_stub); }
Tensor& ne_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, ne_stub); }
Tensor ne(const Tensor& self, const Tensor& other) { return comparison_op(self, other, ne_stub); }
Tensor& ne_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, ne_stub); }
Tensor& ne_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, ne_stub); }
Tensor ne(const Tensor& self, Scalar other) { return comparison_op(self, other, ne_stub); }
Tensor& ne_(Tensor& self, Scalar other) { return comparison_op_(self, other, ne_stub); }
Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_xor_stub); }
Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, logical_xor_stub); }
Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, logical_xor_stub); }
Tensor& logical_xor_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, logical_xor_stub); }
Tensor logical_xor(const Tensor& self, Scalar other) { return comparison_op(self, other, logical_xor_stub); }
Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, other, logical_xor_stub); }
}
} // namespace at