-
Notifications
You must be signed in to change notification settings - Fork 8
/
BlockSparse.cu
652 lines (532 loc) · 27.5 KB
/
BlockSparse.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
#include "utils.h"
#define BLOCKSPARSE_THREADS 32
#define BLOCKSPARSE_MAXOUTPUTBLOCKSIZE 512
#define BLOCKSPARSE_STREAMS 8
__global__ void cunnx_BlockSparse_updateOutput_kernel(
float *output, const float *input, const float *outputIndice,
const float *outputScale, const float *bias,
int outputSize, int nOutputBlock,
int inputWindowSize, int outputWindowSize)
{
__shared__ float buffer[BLOCKSPARSE_THREADS];
int tx = threadIdx.x;
int i_step = blockDim.x;
int k = blockIdx.x;
float *output_k = output + k*outputWindowSize*outputSize;
const float *input_k = input + k*inputWindowSize*outputWindowSize*outputSize;
const float *outputIndice_k = outputIndice + k*outputWindowSize;
const float *outputScale_k = outputScale + k*outputWindowSize;
for (int m=0; m<outputWindowSize; m++)
{
int outputIdx = (int)outputIndice_k[m] - 1;
float outputScale = outputScale_k[m];
for (int j=tx; j<outputSize; j+=i_step)
{
buffer[tx] = bias[outputIdx*outputSize + j];
for (int l=0; l<inputWindowSize; l++)
buffer[tx] += input_k[l*outputWindowSize*outputSize + m*outputSize + j];
output_k[m*outputSize + j] = outputScale*buffer[tx];
}
}
}
static int cunnx_BlockSparse_updateOutput(lua_State *L)
{
/* input, inputIndice, outputIndice, inputScale, outputScale, gradOutput*/
THCState *state = getCutorchState(L);
// batchSize x inputWindowSize x inputSize
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
// batchSize x inputWindowSize
THCudaLongTensor *inputIndice = (THCudaLongTensor*)luaT_checkudata(L, 3, "torch.CudaLongTensor");
THCudaTensor *inputScale = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor");
// batchSize x outputWindowSize
THCudaLongTensor *outputIndice = (THCudaLongTensor*)luaT_checkudata(L, 4, "torch.CudaLongTensor");
THCudaTensor *outputScale = (THCudaTensor*)luaT_checkudata(L, 6, "torch.CudaTensor");
int batchSize = luaT_getfieldcheckint(L, 1, "batchSize");
int inputSize = luaT_getfieldcheckint(L, 1, "inputSize");
int outputSize = luaT_getfieldcheckint(L, 1, "outputSize");
int inputWindowSize = luaT_getfieldcheckint(L, 1, "inputWindowSize");
int outputWindowSize = luaT_getfieldcheckint(L, 1, "outputWindowSize");
int nInputBlock = luaT_getfieldcheckint(L, 1, "nInputBlock");
int nOutputBlock = luaT_getfieldcheckint(L, 1, "nOutputBlock");
int batchedGemmMax = luaT_getfieldcheckint(L, 1, "batchedGemmMax");
long nBatched = batchSize*inputWindowSize*outputWindowSize;
THLongTensor *inputIndiceHost = (THLongTensor*)luaT_getfieldcheckudata(L, 1, "inputIndiceHost", "torch.LongTensor");
THLongTensor *outputIndiceHost = (THLongTensor*)luaT_getfieldcheckudata(L, 1, "outputIndiceHost", "torch.LongTensor");
// nOutputBlock x nInputBlock x outputSize x inputSize
THCudaTensor *weight = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "weight", "torch.CudaTensor");
// nOutputBlock x outputSize
THCudaTensor *bias = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "bias", "torch.CudaTensor");
// batchSize x inputWindowSize x outputWindowSize x outputSize
THCudaTensor *outputBatched = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "outputBatched", "torch.CudaTensor");
// batchSize x outputWindowSize x outputSize
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_output", "torch.CudaTensor");
cublasStatus_t stat;
cublasHandle_t handle;
float alpha = 1;
float beta = 0;
if (nInputBlock > 1)
{
luaL_argcheck(L, input->nDimension == 3, 2, "3D(batch mode) tensor expected");
luaL_argcheck(L, input->size[2] == inputSize, 2, "invalid input size");
}
else
{
luaL_argcheck(L, input->nDimension == 2, 2, "2D(batch mode) tensor expected");
luaL_argcheck(L, input->size[1] == inputSize, 2, "invalid input size");
}
luaL_argcheck(L, inputIndice->nDimension == 2, 3, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputIndice->nDimension == 2, 4, "2D(batch mode) tensor expected");
luaL_argcheck(L, inputScale->nDimension == 2, 5, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputScale->nDimension == 2, 6, "2D(batch mode) tensor expected");
luaL_argcheck(L, THCudaTensor_isContiguous(state, input), 2, "Expecting contiguous input");
THCudaTensor_resize4d(state, outputBatched, batchSize, inputWindowSize, outputWindowSize, outputSize);
THLongTensor_resize2d(inputIndiceHost, batchSize, inputWindowSize);
THLongTensor_resize2d(outputIndiceHost, batchSize, outputWindowSize);
THLongTensor_copyCuda(state, inputIndiceHost, inputIndice);
THLongTensor_copyCuda(state, outputIndiceHost, outputIndice);
stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS)
THError("CUBLAS initialization failed");
if ( nOutputBlock > 1 )
THCudaTensor_resize3d(state, output, batchSize, outputWindowSize, outputSize);
else
THCudaTensor_resize2d(state, output, batchSize, outputSize);
/* streamed or batched */
if (sqrt(inputSize*outputSize) > batchedGemmMax)
{
cudaStream_t streams[BLOCKSPARSE_STREAMS];
for (int i=0; i<BLOCKSPARSE_STREAMS; i++)
{
if (cudaStreamCreate(&streams[i]) != cudaSuccess)
THError("error initializing stream");
}
cudaDeviceSynchronize();
long batchedIdx = 0;
for (int i=0; i<batchSize; i++)
{
float *inputPtr = THCudaTensor_data(state, input)+i*input->stride[0];
float *outputPtr = THCudaTensor_data(state, outputBatched)+i*outputBatched->stride[0];
long *inputIdxPtr = THLongTensor_data(inputIndiceHost)+i*inputIndiceHost->stride[0];
long *outputIdxPtr = THLongTensor_data(outputIndiceHost)+i*outputIndiceHost->stride[0];
for (int l=0; l<inputWindowSize; l++)
{
for (int m=0; m<outputWindowSize; m++)
{
cublasSetStream(handle, streams[batchedIdx%BLOCKSPARSE_STREAMS]);
stat = cublasSgemv(handle, CUBLAS_OP_T, inputSize, outputSize,
&alpha, (const float*)THCudaTensor_data(state, weight)+(inputIdxPtr[l]-1)*weight->stride[1] + (outputIdxPtr[m]-1)*weight->stride[0], inputSize,
(const float*)inputPtr, 1,
&beta, outputPtr, 1);
if (stat != CUBLAS_STATUS_SUCCESS)
THError("cublasSgemv failed");
outputPtr += outputBatched->stride[2];
batchedIdx++;
}
inputPtr += input->stride[1];
}
}
cublasSetStream(handle, NULL);
cudaDeviceSynchronize();
for (int i=0; i<BLOCKSPARSE_STREAMS; i++)
{
if (cudaStreamDestroy(streams[i]) != cudaSuccess)
THError("error destroying stream");
}
}
else
{
THCharTensor *inputHost = (THCharTensor*)luaT_getfieldcheckudata(L, 1, "inputHost", "torch.CharTensor");
THCharTensor *weightHost = (THCharTensor*)luaT_getfieldcheckudata(L, 1, "weightHost", "torch.CharTensor");
THCharTensor *outputHost = (THCharTensor*)luaT_getfieldcheckudata(L, 1, "outputHost", "torch.CharTensor");
THCudaTensor *inputCuda = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "inputCuda", "torch.CudaTensor");
THCudaTensor *weightCuda = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "weightCuda", "torch.CudaTensor");
THCudaTensor *outputCuda = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "outputCuda", "torch.CudaTensor");
// put output back on top of the stack
output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_output", "torch.CudaTensor");
cublasSetStream(handle, NULL);
THCharTensor_resize1d(inputHost, nBatched*sizeof(float*));
THCharTensor_resize1d(weightHost, nBatched*sizeof(float*));
THCharTensor_resize1d(outputHost, nBatched*sizeof(float*));
THCudaTensor_resize1d(state, inputCuda, nBatched*sizeof(float*)/sizeof(float));
THCudaTensor_resize1d(state, weightCuda, nBatched*sizeof(float*)/sizeof(float));
THCudaTensor_resize1d(state, outputCuda, nBatched*sizeof(float*)/sizeof(float));
const float **inputB = (const float **)THCharTensor_data(inputHost);
const float **weightB = (const float **)THCharTensor_data(weightHost);
float **outputB = (float **)THCharTensor_data(outputHost);
const float **inputB_d = (const float **)THCudaTensor_data(state, inputCuda);
const float **weightB_d = (const float **)THCudaTensor_data(state, weightCuda);
float **outputB_d = (float **)THCudaTensor_data(state, outputCuda);
long batchedIdx = 0;
for (int i=0; i<batchSize; i++)
{
float *inputPtr = THCudaTensor_data(state, input)+i*input->stride[0];
float *outputPtr = THCudaTensor_data(state, outputBatched)+i*outputBatched->stride[0];
long *inputIdxPtr = THLongTensor_data(inputIndiceHost)+i*inputIndiceHost->stride[0];
long *outputIdxPtr = THLongTensor_data(outputIndiceHost)+i*outputIndiceHost->stride[0];
for (int l=0; l<inputWindowSize; l++)
{
for (int m=0; m<outputWindowSize; m++)
{
inputB[batchedIdx] = inputPtr;
weightB[batchedIdx] = THCudaTensor_data(state, weight) + (outputIdxPtr[m]-1)*weight->stride[0] + (inputIdxPtr[l]-1)*weight->stride[1];
outputB[batchedIdx] = outputPtr;
outputPtr += outputBatched->stride[2];
batchedIdx++;
}
inputPtr += input->stride[1];
}
}
if(cudaMemcpy(inputB_d, inputB, sizeof(float*) * nBatched, cudaMemcpyHostToDevice) != cudaSuccess)
THError("cudaMemcpy failed");
if(cudaMemcpy(weightB_d, weightB, sizeof(float*) * nBatched, cudaMemcpyHostToDevice) != cudaSuccess)
THError("cudaMemcpy failed");
if(cudaMemcpy(outputB_d, outputB, sizeof(float*) * nBatched, cudaMemcpyHostToDevice) != cudaSuccess)
THError("cudaMemcpy failed");
stat = cublasSgemmBatched(handle, CUBLAS_OP_T, CUBLAS_OP_N,
outputSize, 1, inputSize,
&alpha, weightB_d, inputSize,
inputB_d, inputSize,
&beta, outputB_d, outputSize,
nBatched);
if (stat != CUBLAS_STATUS_SUCCESS)
THError("cublasSgemmBatched failed");
}
/* call cudakernel */
dim3 blocks(input->size[0]); // each cuda-block is an example
dim3 threads(BLOCKSPARSE_THREADS);
cunnx_BlockSparse_updateOutput_kernel<<<blocks,threads>>>(
THCudaTensor_data(state, output), THCudaTensor_data(state, outputBatched),
(const float *)THCudaLongTensor_data(state, outputIndice), THCudaTensor_data(state, outputScale),
THCudaTensor_data(state, bias), outputSize, nOutputBlock,
inputWindowSize, outputWindowSize
);
cublasDestroy(handle);
cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
THError(cudaGetErrorString(errcode));
return 1;
}
__global__ void cunnx_BlockSparse_updateGradOutput_kernel(
float *_gradOutput, float* gradOutputScale, const float *gradOutput,
const float *output, const float *outputScale,
int outputWindowSize, int outputSize)
{
__shared__ float buffer[BLOCKSPARSE_THREADS];
int tx = threadIdx.x;
int i_step = blockDim.x;
int k = blockIdx.x;
float *_gradOutput_k = _gradOutput + k*outputWindowSize*outputSize;
float *gradOutputScale_k = gradOutputScale + k*outputWindowSize;
const float *gradOutput_k = gradOutput + k*outputWindowSize*outputSize;
const float *output_k = output + k*outputWindowSize*outputSize;
const float *outputScale_k = outputScale + k*outputWindowSize;
// get gradients for outputScale (to be backwarded to a Gater)
for (int m=0; m<outputWindowSize; m++)
{
float outputScale = outputScale_k[m];
float *_blockGradOutput = _gradOutput_k + m*outputSize;
const float *blockGradOutput = gradOutput_k + m*outputSize;
const float *blockOutput = output_k + m*outputSize;
buffer[tx] = 0;
for (int j=tx; j<outputSize; j+=i_step)
{
const float grad = blockGradOutput[j];
buffer[tx] += blockOutput[j]*grad;
_blockGradOutput[j] = grad*outputScale;
}
// add (reduce)
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
{
__syncthreads();
if (tx < stride)
buffer[tx] += buffer[tx+stride];
}
if (tx == 0)
gradOutputScale_k[m] = buffer[0]/(outputScale+0.00000001);
}
}
static int cunnx_BlockSparse_updateGradInput(lua_State *L)
{
/* input, inputIndice, outputIndice, inputScale, outputScale*/
THCState *state = getCutorchState(L);
// batchSize x inputWindowSize x inputSize
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
// batchSize x inputWindowSize
THCudaTensor *inputIndice = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
THCudaTensor *inputScale = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor");
// batchSize x outputWindowSize
THCudaTensor *outputIndice = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor");
THCudaTensor *outputScale = (THCudaTensor*)luaT_checkudata(L, 6, "torch.CudaTensor");
// batchSize x outputWindowSize x outputSize
THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 7, "torch.CudaTensor");
THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_output", "torch.CudaTensor");
int batchSize = luaT_getfieldcheckint(L, 1, "batchSize");
int inputSize = luaT_getfieldcheckint(L, 1, "inputSize");
int outputSize = luaT_getfieldcheckint(L, 1, "outputSize");
int inputWindowSize = luaT_getfieldcheckint(L, 1, "inputWindowSize");
int outputWindowSize = luaT_getfieldcheckint(L, 1, "outputWindowSize");
int nInputBlock = luaT_getfieldcheckint(L, 1, "nInputBlock");
int nOutputBlock = luaT_getfieldcheckint(L, 1, "nOutputBlock");
int batchedGemmMax = luaT_getfieldcheckint(L, 1, "batchedGemmMax");
long nBatched = batchSize*inputWindowSize*outputWindowSize;
THLongTensor *inputIndiceHost = (THLongTensor*)luaT_getfieldcheckudata(L, 1, "inputIndiceHost", "torch.LongTensor");
THLongTensor *outputIndiceHost = (THLongTensor*)luaT_getfieldcheckudata(L, 1, "outputIndiceHost", "torch.LongTensor");
THCudaTensor *weight = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "weight", "torch.CudaTensor");
THCudaTensor *gradInputBatched = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInputBatched", "torch.CudaTensor");
THCudaTensor *_gradOutput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_gradOutput", "torch.CudaTensor");
THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_gradInput", "torch.CudaTensor");
THCudaTensor *gradOutputScale = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradOutputScale", "torch.CudaTensor");
cublasStatus_t stat;
cublasHandle_t handle;
float alpha = 1;
float beta = 0;
if (nInputBlock > 1)
{
luaL_argcheck(L, input->nDimension == 3, 2, "3D(batch mode) tensor expected");
luaL_argcheck(L, input->size[2] == inputSize, 2, "invalid input size");
}
else
{
luaL_argcheck(L, input->nDimension == 2, 2, "2D(batch mode) tensor expected");
luaL_argcheck(L, input->size[1] == inputSize, 2, "invalid input size");
}
luaL_argcheck(L, inputIndice->nDimension == 2, 3, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputIndice->nDimension == 2, 4, "2D(batch mode) tensor expected");
luaL_argcheck(L, inputScale->nDimension == 2, 5, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputScale->nDimension == 2, 6, "2D(batch mode) tensor expected");
luaL_argcheck(L, THCudaTensor_isContiguous(state, input), 2, "Expecting contiguous input");
THCudaTensor_resizeAs(state, _gradOutput, gradOutput);
THCudaTensor_resizeAs(state, gradOutputScale, outputScale);
THCudaTensor_resize4d(state, gradInputBatched, batchSize, outputWindowSize, inputWindowSize, inputSize);
/* call cudakernel */
dim3 blocks(input->size[0]); // each cuda-block is an example
dim3 threads(BLOCKSPARSE_THREADS);
cunnx_BlockSparse_updateGradOutput_kernel<<<blocks,threads>>>(
THCudaTensor_data(state, _gradOutput), THCudaTensor_data(state, gradOutputScale),
THCudaTensor_data(state, gradOutput), THCudaTensor_data(state, output),
THCudaTensor_data(state, outputScale), outputWindowSize, outputSize
);
cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
THError(cudaGetErrorString(errcode));
stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS)
THError("CUBLAS initialization failed");
/* streamed or batched */
if (sqrt(inputSize*outputSize) > batchedGemmMax)
{
cudaStream_t streams[BLOCKSPARSE_STREAMS];
for (int i=0; i<BLOCKSPARSE_STREAMS; i++)
{
if (cudaStreamCreate(&streams[i]) != cudaSuccess)
THError("error initializing stream");
}
cudaDeviceSynchronize();
long batchedIdx = 0;
for (int i=0; i<batchSize; i++)
{
float *gradOutputPtr = THCudaTensor_data(state, _gradOutput)+i*_gradOutput->stride[0];
float *gradInputPtr = THCudaTensor_data(state, gradInputBatched)+i*gradInputBatched->stride[0];
long *inputIdxPtr = THLongTensor_data(inputIndiceHost)+i*inputIndiceHost->stride[0];
long *outputIdxPtr = THLongTensor_data(outputIndiceHost)+i*outputIndiceHost->stride[0];
for (int m=0; m<outputWindowSize; m++)
{
for (int l=0; l<inputWindowSize; l++)
{
cublasSetStream(handle, streams[batchedIdx%BLOCKSPARSE_STREAMS]);
stat = cublasSgemv(handle, CUBLAS_OP_N, inputSize, outputSize,
&alpha, (const float*)THCudaTensor_data(state, weight)+(outputIdxPtr[m]-1)*weight->stride[0]+(inputIdxPtr[l]-1)*weight->stride[1], inputSize,
(const float*)gradOutputPtr, 1,
&beta, gradInputPtr, 1);
if (stat != CUBLAS_STATUS_SUCCESS)
THError("cublasSgemv failed");
gradInputPtr += gradInputBatched->stride[2];
batchedIdx++;
}
gradOutputPtr += _gradOutput->stride[1];
}
}
cublasSetStream(handle, NULL);
cudaDeviceSynchronize();
for (int i=0; i<BLOCKSPARSE_STREAMS; i++)
{
if (cudaStreamDestroy(streams[i]) != cudaSuccess)
THError("error destroying stream");
}
}
else
{
THCharTensor *inputHost = (THCharTensor*)luaT_getfieldcheckudata(L, 1, "inputHost", "torch.CharTensor");
THCharTensor *weightHost = (THCharTensor*)luaT_getfieldcheckudata(L, 1, "weightHost", "torch.CharTensor");
THCharTensor *outputHost = (THCharTensor*)luaT_getfieldcheckudata(L, 1, "outputHost", "torch.CharTensor");
THCudaTensor *inputCuda = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "inputCuda", "torch.CudaTensor");
THCudaTensor *weightCuda = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "weightCuda", "torch.CudaTensor");
THCudaTensor *outputCuda = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "outputCuda", "torch.CudaTensor");
// put gradInput back on top of the stack
gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_gradInput", "torch.CudaTensor");
gradOutputScale = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradOutputScale", "torch.CudaTensor");
cublasSetStream(handle, NULL);
THCharTensor_resize1d(inputHost, nBatched*sizeof(float*));
THCharTensor_resize1d(weightHost, nBatched*sizeof(float*));
THCharTensor_resize1d(outputHost, nBatched*sizeof(float*));
THCudaTensor_resize1d(state, inputCuda, nBatched*sizeof(float*)/sizeof(float));
THCudaTensor_resize1d(state, weightCuda, nBatched*sizeof(float*)/sizeof(float));
THCudaTensor_resize1d(state, outputCuda, nBatched*sizeof(float*)/sizeof(float));
float **gradInputB = (float **)THCharTensor_data(inputHost);
const float **weightB = (const float **)THCharTensor_data(weightHost);
const float **gradOutputB = (const float **)THCharTensor_data(outputHost);
float **gradInputB_d = (float **)THCudaTensor_data(state, inputCuda);
const float **weightB_d = (const float **)THCudaTensor_data(state, weightCuda);
const float **gradOutputB_d = (const float **)THCudaTensor_data(state, outputCuda);
long batchedIdx = 0;
for (int i=0; i<batchSize; i++)
{
float *gradOutputPtr = THCudaTensor_data(state, _gradOutput)+i*_gradOutput->stride[0];
float *gradInputPtr = THCudaTensor_data(state, gradInputBatched)+i*gradInputBatched->stride[0];
long *inputIdxPtr = THLongTensor_data(inputIndiceHost)+i*inputIndiceHost->stride[0];
long *outputIdxPtr = THLongTensor_data(outputIndiceHost)+i*outputIndiceHost->stride[0];
for (int m=0; m<outputWindowSize; m++)
{
for (int l=0; l<inputWindowSize; l++)
{
gradInputB[batchedIdx] = gradInputPtr;
weightB[batchedIdx] = THCudaTensor_data(state, weight)+(outputIdxPtr[m]-1)*weight->stride[0]+(inputIdxPtr[l]-1)*weight->stride[1];
gradOutputB[batchedIdx] = gradOutputPtr;
gradInputPtr += gradInputBatched->stride[2];
batchedIdx++;
}
gradOutputPtr += _gradOutput->stride[1];
}
}
if(cudaMemcpy(gradInputB_d, gradInputB, sizeof(float*)*nBatched, cudaMemcpyHostToDevice) != cudaSuccess)
THError("cudaMemcpy failed");
if(cudaMemcpy(weightB_d, weightB, sizeof(float*)*nBatched, cudaMemcpyHostToDevice) != cudaSuccess)
THError("cudaMemcpy failed");
if(cudaMemcpy(gradOutputB_d, gradOutputB, sizeof(float*)*nBatched, cudaMemcpyHostToDevice) != cudaSuccess)
THError("cudaMemcpy failed");
stat = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N,
inputSize, 1, outputSize,
&alpha, weightB_d, inputSize,
gradOutputB_d, outputSize,
&beta, gradInputB_d, inputSize,
nBatched);
if (stat != CUBLAS_STATUS_SUCCESS)
THError("cublasSgemmBatched failed");
}
cublasDestroy(handle);
THCudaTensor_sum(state, gradInput, gradInputBatched, 0, 1);
THCudaTensor_resizeAs(state, gradInput, input);
errcode = cudaGetLastError();
if(errcode != cudaSuccess)
THError(cudaGetErrorString(errcode));
return 2;
}
__global__ void cunnx_BlockSparse_accGradParameters_kernel(
float *gradWeight, float* gradBias, float *gradOutput,
float *input, float *inputIndice, float *outputIndice,
int inputSize, int outputSize, int nInputBlock, int nOutputBlock,
int inputWindowSize, int outputWindowSize, float scale)
{
__shared__ float buffer[BLOCKSPARSE_THREADS];
__shared__ float gradOutputBuffer[BLOCKSPARSE_MAXOUTPUTBLOCKSIZE];
int tx = threadIdx.x;
int i_step = blockDim.x;
int k = blockIdx.x;
float *input_k = input + k*inputWindowSize*inputSize;
float *gradOutput_k = gradOutput + k*outputWindowSize*outputSize;
float *inputIndice_k = inputIndice + k*inputWindowSize;
float *outputIndice_k = outputIndice + k*outputWindowSize;
// loop through blocks
for (int m=0; m<outputWindowSize; m++)
{
int outputIdx = (int)outputIndice_k[m] - 1;
float *blockGradOutput = gradOutput_k + m*outputSize;
float *blockGradBias = gradBias + outputIdx*outputSize;
for (int j=tx; j<outputSize; j+=i_step)
gradOutputBuffer[j] = blockGradOutput[j]*scale;
__syncthreads(); // needed for some reason
for (int l=0; l<inputWindowSize; l++)
{
int inputIdx = (int)inputIndice_k[l] - 1;
float *blockInput = input_k + l*inputSize;
float *blockGradWeight = gradWeight + outputIdx*nInputBlock*outputSize*inputSize + inputIdx*outputSize*inputSize;
// addr weights (scalar-products)
for (int i=tx; i<inputSize; i+=i_step)
{
// copy input to buffer
buffer[tx] = blockInput[i];
// multiply accumulate weights
for (int j=0; j<outputSize; j++)
atomicAdd(&(blockGradWeight[j*inputSize + i]), gradOutputBuffer[j]*buffer[tx]);
}
}
__syncthreads(); // needed for some reason
// multiply accumulate biases
for (int j=tx; j<outputSize; j+=i_step)
atomicAdd(&(blockGradBias[j]), gradOutputBuffer[j]);
}
}
static int cunnx_BlockSparse_accGradParameters(lua_State *L)
{
/* input, inputIndice, outputIndice, inputScale, outputScale, gradOutput, scale */
THCState *state = getCutorchState(L);
// batchSize x inputWindowSize x inputSize
THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
// batchSize x inputWindowSize
THCudaTensor *inputIndice = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
THCudaTensor *inputScale = (THCudaTensor*)luaT_checkudata(L, 5, "torch.CudaTensor");
// batchSize x outputWindowSize
THCudaTensor *outputIndice = (THCudaTensor*)luaT_checkudata(L, 4, "torch.CudaTensor");
THCudaTensor *outputScale = (THCudaTensor*)luaT_checkudata(L, 6, "torch.CudaTensor");
float scale = luaL_optnumber(L, 8, 1);
int inputSize = luaT_getfieldcheckint(L, 1, "inputSize");
int outputSize = luaT_getfieldcheckint(L, 1, "outputSize");
int nInputBlock = luaT_getfieldcheckint(L, 1, "nInputBlock");
int nOutputBlock = luaT_getfieldcheckint(L, 1, "nOutputBlock");
THCudaTensor *gradWeight = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradWeight", "torch.CudaTensor");
THCudaTensor *gradBias = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradBias", "torch.CudaTensor");
THCudaTensor *_gradOutput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "_gradOutput", "torch.CudaTensor");
THLongTensor *inputIndiceHost = (THLongTensor*)luaT_getfieldcheckudata(L, 1, "inputIndiceHost", "torch.LongTensor");
THLongTensor *outputIndiceHost = (THLongTensor*)luaT_getfieldcheckudata(L, 1, "outputIndiceHost", "torch.LongTensor");
if (nInputBlock > 1)
{
luaL_argcheck(L, input->nDimension == 3, 2, "3D(batch mode) tensor expected");
luaL_argcheck(L, input->size[2] == inputSize, 2, "invalid input size");
}
else
{
luaL_argcheck(L, input->nDimension == 2, 2, "2D(batch mode) tensor expected");
luaL_argcheck(L, input->size[1] == inputSize, 2, "invalid input size");
}
luaL_argcheck(L, inputIndice->nDimension == 2, 3, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputIndice->nDimension == 2, 4, "2D(batch mode) tensor expected");
luaL_argcheck(L, inputScale->nDimension == 2, 5, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputScale->nDimension == 2, 6, "2D(batch mode) tensor expected");
luaL_argcheck(L, outputSize <= BLOCKSPARSE_MAXOUTPUTBLOCKSIZE, 1, "outputSize is too large");
/* call cudakernel */
dim3 blocks(input->size[0]); // each cuda-block is an example
dim3 threads(BLOCKSPARSE_THREADS);
cunnx_BlockSparse_accGradParameters_kernel<<<blocks,threads>>>(
THCudaTensor_data(state, gradWeight), THCudaTensor_data(state, gradBias),
THCudaTensor_data(state, _gradOutput), THCudaTensor_data(state, input),
THCudaTensor_data(state, inputIndice), THCudaTensor_data(state, outputIndice),
inputSize, outputSize, nInputBlock, nOutputBlock,
inputIndice->size[1], outputIndice->size[1], scale
);
cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
THError(cudaGetErrorString(errcode));
return 0;
}
static const struct luaL_Reg cunnx_BlockSparse__ [] = {
{"BlockSparse_updateOutput", cunnx_BlockSparse_updateOutput},
{"BlockSparse_updateGradInput", cunnx_BlockSparse_updateGradInput},
{"BlockSparse_accGradParameters", cunnx_BlockSparse_accGradParameters},
{NULL, NULL}
};
static void cunnx_BlockSparse_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunnx_BlockSparse__, "nn");
lua_pop(L,1);
}