forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GridSampler.cu
715 lines (649 loc) · 31.8 KB
/
GridSampler.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
#include <ATen/ATen.h>
#include <ATen/native/cuda/GridSampler.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/macros/Macros.h>
namespace at { namespace native {
using namespace at::cuda::detail;
using at::native::detail::GridSamplerInterpolation;
using at::native::detail::GridSamplerPadding;
namespace {
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_2d_kernel(
const int nthreads,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> output,
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_H = input.sizes[2];
int inp_W = input.sizes[3];
int out_H = grid.sizes[1];
int out_W = grid.sizes[2];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sH = input.strides[2];
int inp_sW = input.strides[3];
int grid_sN = grid.strides[0];
int grid_sH = grid.strides[1];
int grid_sW = grid.strides[2];
int grid_sCoor = grid.strides[3];
int out_sN = output.strides[0];
int out_sC = output.strides[1];
int out_sH = output.strides[2];
int out_sW = output.strides[3];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int n = index / (out_H * out_W);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get NE, NW, SE, SW pixel values from (x, y)
int ix_nw = static_cast<int>(::floor(ix));
int iy_nw = static_cast<int>(::floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
// calculate bilinear weighted pixel value and set output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<scalar_t>(0);
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
// assign nearest neighor pixel value to output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
*out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<scalar_t>(0);
}
}
}
}
}
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_3d_kernel(
const int nthreads,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> output,
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_D = input.sizes[2];
int inp_H = input.sizes[3];
int inp_W = input.sizes[4];
int out_D = grid.sizes[1];
int out_H = grid.sizes[2];
int out_W = grid.sizes[3];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sD = input.strides[2];
int inp_sH = input.strides[3];
int inp_sW = input.strides[4];
int grid_sN = grid.strides[0];
int grid_sD = grid.strides[1];
int grid_sH = grid.strides[2];
int grid_sW = grid.strides[3];
int grid_sCoor = grid.strides[4];
int out_sN = output.strides[0];
int out_sC = output.strides[1];
int out_sD = output.strides[2];
int out_sH = output.strides[3];
int out_sW = output.strides[4];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int d = (index / (out_H * out_W)) % out_D;
const int n = index / (out_D * out_H * out_W);
const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
// get the corresponding input x, y, z co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int ix_tnw = static_cast<int>(::floor(ix));
int iy_tnw = static_cast<int>(::floor(iy));
int iz_tnw = static_cast<int>(::floor(iz));
int ix_tne = ix_tnw + 1;
int iy_tne = iy_tnw;
int iz_tne = iz_tnw;
int ix_tsw = ix_tnw;
int iy_tsw = iy_tnw + 1;
int iz_tsw = iz_tnw;
int ix_tse = ix_tnw + 1;
int iy_tse = iy_tnw + 1;
int iz_tse = iz_tnw;
int ix_bnw = ix_tnw;
int iy_bnw = iy_tnw;
int iz_bnw = iz_tnw + 1;
int ix_bne = ix_tnw + 1;
int iy_bne = iy_tnw;
int iz_bne = iz_tnw + 1;
int ix_bsw = ix_tnw;
int iy_bsw = iy_tnw + 1;
int iz_bsw = iz_tnw + 1;
int ix_bse = ix_tnw + 1;
int iy_bse = iy_tnw + 1;
int iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
*out_ptr_NCDHW = static_cast<scalar_t>(0);
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int iz_nearest = static_cast<int>(::round(iz));
// assign nearest neighor pixel value to output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCDHW = static_cast<scalar_t>(0);
}
}
}
}
}
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_2d_backward_kernel(
const int nthreads,
TensorInfo<scalar_t, int> grad_output,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> grad_input, // initialized to zeros
TensorInfo<scalar_t, int> grad_grid, // initialized to empty
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_H = input.sizes[2];
int inp_W = input.sizes[3];
int out_H = grid.sizes[1];
int out_W = grid.sizes[2];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sH = input.strides[2];
int inp_sW = input.strides[3];
int grid_sN = grid.strides[0];
int grid_sH = grid.strides[1];
int grid_sW = grid.strides[2];
int grid_sCoor = grid.strides[3];
int gOut_sN = grad_output.strides[0];
int gOut_sC = grad_output.strides[1];
int gOut_sH = grad_output.strides[2];
int gOut_sW = grad_output.strides[3];
int gInp_sN = grad_input.strides[0];
int gInp_sC = grad_input.strides[1];
int gInp_sH = grad_input.strides[2];
int gInp_sW = grad_input.strides[3];
int gGrid_sW = grad_grid.strides[2];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int n = index / (out_H * out_W);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
// multipliers for gradients on ix and iy
scalar_t gix_mult, giy_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get NE, NW, SE, SW pixel values from (x, y)
int ix_nw = static_cast<int>(::floor(ix));
int iy_nw = static_cast<int>(::floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
scalar_t gOut = *gOut_ptr_NCHW;
// calculate and set grad_input
safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
// calculate grad_grid
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to diectly compute gGrid_ptr_NHW
// 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
for (int c = 0; c < C; ++c, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
// calculate and set grad_input
safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW);
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to diectly compute gGrid_ptr_NHW
// 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NHW[0] = static_cast<scalar_t>(0);
gGrid_ptr_NHW[1] = static_cast<scalar_t>(0);
}
}
}
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_3d_backward_kernel(
const int nthreads,
TensorInfo<scalar_t, int> grad_output,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> grad_input, // initialized to zeros
TensorInfo<scalar_t, int> grad_grid, // initialized to empty
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_D = input.sizes[2];
int inp_H = input.sizes[3];
int inp_W = input.sizes[4];
int out_D = grid.sizes[1];
int out_H = grid.sizes[2];
int out_W = grid.sizes[3];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sD = input.strides[2];
int inp_sH = input.strides[3];
int inp_sW = input.strides[4];
int grid_sN = grid.strides[0];
int grid_sD = grid.strides[1];
int grid_sH = grid.strides[2];
int grid_sW = grid.strides[3];
int grid_sCoor = grid.strides[4];
int gOut_sN = grad_output.strides[0];
int gOut_sC = grad_output.strides[1];
int gOut_sD = grad_output.strides[2];
int gOut_sH = grad_output.strides[3];
int gOut_sW = grad_output.strides[4];
int gInp_sN = grad_input.strides[0];
int gInp_sC = grad_input.strides[1];
int gInp_sD = grad_input.strides[2];
int gInp_sH = grad_input.strides[3];
int gInp_sW = grad_input.strides[4];
int gGrid_sW = grad_grid.strides[3];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int d = (index / (out_H * out_W)) % out_D;
const int n = index / (out_D * out_H * out_W);
const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
// get the corresponding input x, y, z co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
// multipliers for gradients on ix, iy, and iz
scalar_t gix_mult, giy_mult, giz_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int ix_tnw = static_cast<int>(::floor(ix));
int iy_tnw = static_cast<int>(::floor(iy));
int iz_tnw = static_cast<int>(::floor(iz));
int ix_tne = ix_tnw + 1;
int iy_tne = iy_tnw;
int iz_tne = iz_tnw;
int ix_tsw = ix_tnw;
int iy_tsw = iy_tnw + 1;
int iz_tsw = iz_tnw;
int ix_tse = ix_tnw + 1;
int iy_tse = iy_tnw + 1;
int iz_tse = iz_tnw;
int ix_bnw = ix_tnw;
int iy_bnw = iy_tnw;
int iz_bnw = iz_tnw + 1;
int ix_bne = ix_tnw + 1;
int iy_bne = iy_tnw;
int iz_bne = iz_tnw + 1;
int ix_bsw = ix_tnw;
int iy_bsw = iy_tnw + 1;
int iz_bsw = iz_tnw + 1;
int ix_bse = ix_tnw + 1;
int iy_bse = iy_tnw + 1;
int iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
// calculate bilinear weighted pixel value and set output pixel
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
scalar_t gOut = *gOut_ptr_NCDHW;
// calculate and set grad_input
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
// calculate grad_grid
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
}
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to diectly compute gGrid_ptr_NDHW
// 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NDHW[0] = gix_mult * gix;
gGrid_ptr_NDHW[1] = giy_mult * giy;
gGrid_ptr_NDHW[2] = giz_mult * giz;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int iz_nearest = static_cast<int>(::round(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
// calculate and set grad_input
safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to diectly compute gGrid_ptr_NDHW
// 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NDHW[0] = static_cast<scalar_t>(0);
gGrid_ptr_NDHW[1] = static_cast<scalar_t>(0);
gGrid_ptr_NDHW[2] = static_cast<scalar_t>(0);
}
}
}
} // namespace
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto N = input.size(0);
auto H = grid.size(1);
auto W = grid.size(2);
auto output = at::empty({N, input.size(1), H, W}, input.options());
int count = static_cast<int>(N * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] {
grid_sampler_2d_kernel<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(output),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return output;
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto N = input.size(0);
auto D = grid.size(1);
auto H = grid.size(2);
auto W = grid.size(3);
auto output = at::empty({N, input.size(1), D, H, W}, input.options());
int count = static_cast<int>(N * D * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] {
grid_sampler_3d_kernel<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(output),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return output;
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode,
int64_t padding_mode, bool align_corners) {
auto N = input.size(0);
auto H = grid.size(1);
auto W = grid.size(2);
auto grad_input = at::zeros_like(input);
auto grad_grid = at::empty_like(grid);
int count = static_cast<int>(N * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
grid_sampler_2d_backward_kernel<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(grad_output),
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(grad_input),
getTensorInfo<scalar_t, int>(grad_grid),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return std::make_tuple(grad_input, grad_grid);
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto N = input.size(0);
auto D = grid.size(1);
auto H = grid.size(2);
auto W = grid.size(3);
auto grad_input = at::zeros_like(input);
auto grad_grid = at::empty_like(grid);
int count = static_cast<int>(N * D * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
grid_sampler_3d_backward_kernel<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(grad_output),
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(grad_input),
getTensorInfo<scalar_t, int>(grad_grid),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return std::make_tuple(grad_input, grad_grid);
}
}} // namespace at::native