forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
boolean_mask_ops.cc
572 lines (484 loc) · 16.1 KB
/
boolean_mask_ops.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
#include "caffe2/operators/boolean_mask_ops.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
namespace caffe2 {
namespace {
template <class Context>
class BooleanMaskLengthsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BooleanMaskLengthsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
auto& lengths = Input(0);
auto& mask = Input(1);
CAFFE_ENFORCE(lengths.dim() == 1);
CAFFE_ENFORCE(mask.dim() == 1);
const auto* lengthsPtr = lengths.template data<T>();
const auto* maskPtr = mask.template data<bool>();
auto totalLength =
std::accumulate(lengthsPtr, lengthsPtr + lengths.numel(), 0);
CAFFE_ENFORCE(mask.numel() == totalLength);
auto* lengthsOut = Output(0, lengths.sizes(), at::dtype<T>());
auto* lengthsOutPtr = lengthsOut->template mutable_data<T>();
int p = 0;
for (int i = 0; i < lengths.numel(); ++i) {
T lengthOut = 0;
for (int j = 0; j < lengthsPtr[i]; ++j) {
if (maskPtr[p++]) {
++lengthOut;
}
}
lengthsOutPtr[i] = lengthOut;
}
return true;
}
};
} // namespace
template <>
bool BooleanMaskOp<CPUContext>::RunOnDevice() {
auto& data = Input(0);
auto& mask = Input(1);
auto* dataOut = Output(0);
CAFFE_ENFORCE(data.dim() >= 1);
CAFFE_ENFORCE_EQ(mask.dim(), 1);
CAFFE_ENFORCE(data.size(0) == mask.size(0));
const auto* maskPtr = mask.template data<bool>();
int numOutputs = 0;
int outerSize = mask.numel();
for (int i = 0; i < outerSize; ++i) {
if (maskPtr[i]) {
++numOutputs;
}
}
std::vector<int64_t> outShape;
outShape.push_back(numOutputs);
outShape.insert(outShape.end(), data.sizes().begin() + 1, data.sizes().end());
dataOut->Resize(outShape);
auto* outPtr = (char*)dataOut->raw_mutable_data(data.dtype());
int64_t* out_vec = nullptr;
if (OutputSize() == 2) {
auto* indicesOut = Output(1, {numOutputs}, at::dtype<int64_t>());
out_vec = indicesOut->template mutable_data<int64_t>();
}
if (numOutputs == 0) {
return true;
}
const auto innerSize = data.size_from_dim(1);
const auto innerSizeBytes = innerSize * data.dtype().itemsize();
int64_t lastStart = -1;
const auto* inPtr = (char*)data.raw_data();
int64_t outStart = 0;
for (int64_t i = 0;; ++i) {
// mask was true and either a) became false, or b) sequence finished
if (lastStart != -1 && ((i >= outerSize) || !maskPtr[i])) {
const auto* src = inPtr + lastStart * innerSizeBytes;
auto* dst = outPtr + outStart * innerSizeBytes;
int numItems = i - lastStart;
context_.CopyItemsSameDevice(
data.dtype(), numItems * innerSize, src, dst);
outStart += numItems;
lastStart = -1;
}
if (i >= outerSize) {
break;
}
// mask was false and became true
if (lastStart == -1 && maskPtr[i]) {
lastStart = i;
}
if (maskPtr[i] && OutputSize() == 2) {
*(out_vec++) = i;
}
}
return true;
}
REGISTER_CPU_OPERATOR(BooleanMask, BooleanMaskOp<CPUContext>);
REGISTER_CPU_OPERATOR(BooleanMaskLengths, BooleanMaskLengthsOp<CPUContext>);
OPERATOR_SCHEMA(BooleanMask)
.NumInputs(2)
.NumOutputs(1, 2)
.SetDoc(R"DOC(
Given a 1D `data` tensor and a boolean `mask` tensor of the same shape, returns a `masked_data` tensor containing only the elements corresponding to positions where the `mask` is True, and a `masked_indices` tensor containing the indices of the True elements.
Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/boolean_mask_ops.cc
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"BooleanMask",
["data", "mask"],
["masked_data", "masked_indices"]
)
workspace.FeedBlob("data", np.array([1,2,3,4,5,6]))
workspace.FeedBlob("mask", np.array([True,False,False,True,True,False]))
print("data:", workspace.FetchBlob("data"))
print("mask:", workspace.FetchBlob("mask"))
workspace.RunOperatorOnce(op)
print("masked_data:", workspace.FetchBlob("masked_data"))
print("masked_indices:", workspace.FetchBlob("masked_indices"))
```
**Result**
```
data: [1 2 3 4 5 6]
mask: [ True False False True True False]
masked_data: [1 4 5]
masked_indices: [0 3 4]
```
</details>
)DOC")
.Input(0, "data", "(*Tensor*): 1D input tensor")
.Input(1, "mask", "(*Tensor`<bool>`*): tensor of bools which determines the input elements that will be left in the `masked_data` output tensor; same shape as `data`")
.Output(0, "masked_data", "(*Tensor*): 1D tensor of same type as `data` input that contains the masked input tensor")
.Output(1, "masked_indices", "(*Tensor`<int>`*): 1D tensor of indices of the True elements in the `mask` tensor");
OPERATOR_SCHEMA(BooleanMaskLengths)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Given a tensor of int32 `lengths` tensor representing segment lengths and a `mask` (boolean) tensor, return the segment lengths of the corresponding segmented tensor after **BooleanMask** is applied.
If `lengths` tensor is $[a_1, a_2, ..., a_n]$, then length of `mask` tensor must be $a_1 + a_2 + ... + a_n$.
Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/boolean_mask_ops.cc
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"BooleanMaskLengths",
["lengths", "mask"],
["masked_lengths"]
)
workspace.FeedBlob("lengths", np.array([1,3,2], dtype=np.int32))
workspace.FeedBlob("mask", np.array([False,True,True,False,True,True]))
print("lengths:", workspace.FetchBlob("lengths"))
print("mask:", workspace.FetchBlob("mask"))
workspace.RunOperatorOnce(op)
print("masked_lengths:", workspace.FetchBlob("masked_lengths"))
```
**Result**
```
lengths: [1 3 2]
mask: [False True True False True True]
masked_lengths: [0 2 2]
```
</details>
)DOC")
.Input(0, "lengths", "(*Tensor`<int>`*): input tensor containing segment lengths")
.Input(1, "mask", "(*Tensor`<bool>`*): A 1D bool tensor of values to keep.")
.Output(0, "masked_lengths", "(*Tensor`<int>`*): 1D tensor of same type as inputs that contains the sequence");
NO_GRADIENT(BooleanMask)
NO_GRADIENT(BooleanMaskLengths);
const float minf = -1.0f * std::numeric_limits<float>::infinity();
// Template this on a functor object so we can generate different
// implementations at compile time and have a better chance of inlining
template <typename Functor>
void MaskWithFunctor(
size_t N,
size_t M,
int B,
const float* in,
Functor fn,
float fill_val,
float* out) {
if (B >= 0) { // with batching
// collapse tensor to 3-dim view [B, N, M] where:
// B is product of dims up to and including batch
// N is product of dims between batch and axis, exclusive
// M is product of dimensions at/after axis
// then mask each batch [i, :, :] (note that this is N x M matrix)
for (int i = 0; i < B; ++i) {
for (int j = 0; j < N; ++j) {
for (int k = 0; k < M; ++k) {
// when [i, :, :] is laid out in row major order
// N * M * i + M * j + k is index of entry in N x M matrix
// with coordinates (row = j, col = k)
auto val = in[N * M * i + M * j + k];
out[N * M * i + M * j + k] = (fn(j, k, val) ? fill_val : val);
}
}
}
} else { // without batching
// TODO(T20952436): vector implementation
// collapse tensor to 2-dim view [N, M], where
// N is product of dimensions before axis
// M is product of dimensions at/after axis
// and mask N by M matrix
for (int i = 0; i < N; ++i) {
for (int j = 0; j < M; ++j) {
auto val = in[M * i + j];
out[M * i + j] = (fn(i, j, val) ? fill_val : val);
}
}
}
}
// Repeat masking along continuous segments (right axes) of size D
template <typename Functor>
void RepeatedMaskWithFunctor(
size_t N,
size_t M,
int D,
const float* in,
Functor fn,
float fill_val,
float* out) {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < M; ++j) {
for (int k = 0; k < D; ++k) {
auto val = in[M * D * i + D * j + k];
out[M * D * i + D * j + k] = (fn(i, j, val) ? fill_val : val);
}
}
}
}
namespace {
class SequenceFunctor {
public:
explicit SequenceFunctor(const int* sl, const size_t len)
: sl_(sl), len_(len) {}
bool operator()(int i, int j, float /* val*/) {
CAFFE_ENFORCE(i < len_, "Out of bound.");
return j >= sl_[i];
}
private:
const int* sl_;
const size_t len_;
};
class WindowFunctor {
public:
explicit WindowFunctor(const int* c, int r) : c(c), r(r) {}
bool operator()(int i, int j, float /* val*/) {
return j > c[i] + r || j < c[i] - r;
}
private:
const int* c;
const int r;
};
class UpperFunctor {
public:
bool operator()(int i, int j, float /* val */) {
return j > i;
}
};
class LowerFunctor {
public:
bool operator()(int i, int j, float /* val */) {
return j < i;
}
};
class UpperDiagFunctor {
public:
bool operator()(int i, int j, float /* val */) {
return j >= i;
}
};
class LowerDiagFunctor {
public:
bool operator()(int i, int j, float /* val */) {
return j <= i;
}
};
} // namespace
template <>
bool SequenceMaskOp<CPUContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
}
template <>
template <class T>
bool SequenceMaskOp<CPUContext>::DoRunWithType() {
const Tensor* input = &Input(0);
const Tensor* sequence_lengths = nullptr;
const Tensor* window_centers = nullptr;
if (mode_ == "sequence") {
sequence_lengths = &Input(1);
} else if (mode_ == "window") {
window_centers = &Input(1);
}
auto* output = Output(0, input->sizes(), at::dtype<T>());
const auto canonical_axis = input->canonical_axis_index(axis_);
// canonical_batch is non-negative if batching, -1 otherwise
int canonical_batch = -1;
if ((HasArgument("batch"))) {
canonical_batch = input->canonical_axis_index(batch_);
}
// make sure batch < axis
if (canonical_batch >= 0) {
CAFFE_ENFORCE_LT(canonical_batch, canonical_axis);
}
// if no batch, then left is product of dims up to axis
// otherwise, left is product of dims between batch and axis
const int left =
(canonical_batch >= 0
? input->size_between_dim(canonical_batch, canonical_axis)
: input->size_to_dim(canonical_axis));
const int right = input->size_from_dim(canonical_axis);
// product of dims from 1 to batch
const int batch_dim =
(canonical_batch >= 0
? input->size_to_dim(canonical_batch) * input->size(canonical_batch)
: -1);
T fill_val = convert::To<float, T>(grad_ ? 0.0f : fill_val_);
if (mode_ == "sequence") {
CAFFE_ENFORCE(
sequence_lengths, "Sequence length not provided for mode 'sequence'!");
if (HasArgument("repeat_from_axis")) {
const int canonical_repeat_from =
input->canonical_axis_index(repeat_from_);
const int repeated_dims = input->size_from_dim(canonical_repeat_from);
const int masked_dims = right / repeated_dims;
RepeatedMaskWithFunctor(
left,
masked_dims,
repeated_dims,
input->data<T>(),
SequenceFunctor(
sequence_lengths->data<int>(), sequence_lengths->numel()),
fill_val,
output->template mutable_data<T>());
} else {
MaskWithFunctor(
left,
right,
batch_dim,
input->data<T>(),
SequenceFunctor(
sequence_lengths->data<int>(), sequence_lengths->numel()),
fill_val,
output->template mutable_data<T>());
}
} else if (mode_ == "window") {
MaskWithFunctor(
left,
right,
batch_dim,
input->data<T>(),
WindowFunctor(window_centers->data<int>(), radius_),
fill_val,
output->template mutable_data<T>());
} else if (mode_ == "upper") {
MaskWithFunctor(
left,
right,
batch_dim,
input->data<T>(),
UpperFunctor(),
fill_val,
output->template mutable_data<T>());
} else if (mode_ == "lower") {
MaskWithFunctor(
left,
right,
batch_dim,
input->data<T>(),
LowerFunctor(),
fill_val,
output->template mutable_data<T>());
} else if (mode_ == "upperdiag") {
MaskWithFunctor(
left,
right,
batch_dim,
input->data<T>(),
UpperDiagFunctor(),
fill_val,
output->template mutable_data<T>());
} else if (mode_ == "lowerdiag") {
MaskWithFunctor(
left,
right,
batch_dim,
input->data<T>(),
LowerDiagFunctor(),
fill_val,
output->template mutable_data<T>());
} else {
CAFFE_ENFORCE(false, "Unsupported mode for SequenceMaskOp!");
return false;
}
return true;
}
REGISTER_CPU_OPERATOR(SequenceMask, SequenceMaskOp<CPUContext>);
OPERATOR_SCHEMA(SequenceMask)
.NumInputs(1, 2)
.NumOutputs(1)
.SetDoc(R"DOC(
Mask op designed for use in attention mechanisms for sequence modeling tasks.
Supports batching: given batch_dim, collapses dims 0 through batch_dim into a
single dimension, e.g. if tensor dims are [4,2,1,3,4] and batch_dim=2, first
collapse tensor to [4*2*1,3,4], then mask each batch [i,:,:].
Two current operating modes:
1) Given a 2D input tensor and 1D tensor of sequence lengths, for each row i in
the input tensor, set elements in that row to -inf if their column index
j >= sequence_lengths[i]. This mode takes two inputs and argument mode =
'sequence'
2) Triangular mask. Given row index i and column index j, set elements to -inf
given the following conditions:
mode='upper', x_ij = -inf if j < i
mode='lower', x_ij = -inf if j > i
mode='upperdiag', x_ij = -inf if j <= i
mode='lowerdiag', x_ij = -inf if j >= i
This mode takes one input.
3) Window Mask. Given a 2D input tensor and 1D tensor of window centers,
for each row i in the input tensor, set elements in that row to -inf
if their column index j outside [center - radius, center + radius].
This mode takes two inputs and argument mode = 'sequence'.
Argument 'radius' should be provided.
)DOC")
.Input(0, "input", "Tensor to apply masking to")
.Input(1, "sequence_lengths", "1D Tensor of sequence lengths for mode #1")
.Output(0, "masked_tensor", "Input tensor with masking applied")
.Arg(
"mode",
"(string) Mode selection. Possible values: "
"'sequence', 'upper', 'lower', 'upperdiag', 'lowerdiag'")
.Arg(
"axis",
"(int) Beginning axis of row elements. All dimensions to the left "
"will be treated as row indices and those to the right (inclusive) "
"will be treated as column indices in the 2D mask")
.Arg("grad", "(bool) operate in gradient mode")
.Arg("radius", "(int) radius of windows in window mode")
.Arg("batch", "(int) batch dimension of tensor (optional)")
.Arg(
"repeat_from_axis",
"(int) used when mask should be repeated for "
"one or more data dimensions (beginning at this axis). "
"(currently only supported for sequence mode without batch argument)");
class GetSequenceMaskGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
vector<Argument> args;
args.reserve(Def().arg().size());
for (const auto& x : Def().arg()) {
args.push_back(x);
}
args.push_back(MakeArgument<bool>("grad", true));
if (def_.input_size() == 1) {
return SingleGradientDef(
"SequenceMask",
"",
vector<string>{GO(0)},
vector<string>{GI(0)},
args);
} else {
return SingleGradientDef(
"SequenceMask",
"",
vector<string>{GO(0), I(1)},
vector<string>{GI(0)},
args);
}
}
bool CopyArguments() const override {
return false;
}
};
REGISTER_GRADIENT(SequenceMask, GetSequenceMaskGradient);
} // namespace caffe2