-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable serial grid reduction for split-K in matmul scheduler #1510
Conversation
Will revisit once sync pass is done, when we have a TensorIndex
Still missing allocation/indexing of work buffer
I need to replay leaf transforms, then get index.
Codegen is now like ```c++ // Allocate global tensor T5 reduction::serialReductionStep( T3[0LL], T2[(i14 + i18)], 0.000000000e+00f, T5[((((((((((((nvfuser_index_t)blockIdx.x) * 8LL) + ((nvfuser_index_t)blockIdx.y)) * 4LL) + i13) * 8LL) + (i18 + nvfuser_zero)) * 4LL) + ((nvfuser_index_t)threadIdx.y)) * 32LL) + ((nvfuser_index_t)threadIdx.x))], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1, true, true); ``` This looks OK, although it will get a little better with hoisting. This compiles, but I get an error in `runFusion`: ``` C++ exception with description "Expected T5_g[ iblockIdx.x59{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(262144, 32) ), 4) ), 8) ), 4) ), 8) )}, iblockIdx.y60{8}, ithreadIdx.y54{4}, ithreadIdx.x52{32}, iS58{4}, iS56{8}, rblockIdx.z49{5} ] to be bound to a tensor of rank 1, but got a tensor of rank 6 Exception raised from validateValWithConcreteValue at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:38 (most recent call first): ``` This is happening when binding inputs I believe.
Fixes execution error. Test passes!
Generated kernel now looks like ```c++ // Allocate global tensor T4 grid_sync::blockSerializeWait<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); #pragma unroll for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) { nvfuser_index_t i14; i14 = 8LL * i13; nvfuser_index_t i15; i15 = 2048LL * i13; nvfuser_index_t i16; i16 = i4 + i15; nvfuser_index_t i17; i17 = -i15; #pragma unroll for(nvfuser_index_t i18 = 0; i18 < 8LL; ++i18) { nvfuser_index_t i19; i19 = 256LL * (i18 + nvfuser_zero); nvfuser_index_t i20; i20 = i16 + i19; float T3[1LL]; T3[0LL] = 0.000000000e+00f; // Allocate global tensor T5 reduction::serialReductionStep( T3[0LL], T2[(i14 + i18)], 0.000000000e+00f, T5[i20], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1, true, true); if ((b6 && (i5 < (i17 - i19)))) { T1[i20] = T3[0LL]; } } } NVFUSER_UPDATE_MAGIC_ZERO; grid_sync::blockSerializeRelease<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); ``` Note that the index `i20` matches the output `T1`. This is what we need to reclaim `T1` in a later PR; it will still be a challenge in that work to exact map between `T5` and `T3` in order to get `T1` and `T5` exact mapped...
Also sort expected output by line to give clearer error messages.
6db7cbf
to
41f00b8
Compare
csrc/scheduler/mma_utils.cpp
Outdated
for (auto prop : props) { | ||
auto* init = IrBuilder::create<Val>(0.0, prop.out->getDataType().value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment about init dtype.
I had to change some surrounding code to fix the following error, which was caused by using a // Allocate global tensor T13
grid_sync::blockSerializeWait<false, false, true>(&T13[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
#pragma unroll
for(nvfuser_index_t i93 = 0; i93 < 4LL; ++i93) {
nvfuser_index_t i94;
i94 = 32LL * i93;
nvfuser_index_t i95;
i95 = i35 + (i30 * i93);
nvfuser_index_t i96;
i96 = -(16LL * i93);
#pragma unroll
for(nvfuser_index_t i97 = 0; i97 < 8LL; ++i97) {
nvfuser_index_t i98;
i98 = i94 + (4LL * i97);
nvfuser_index_t i99;
i99 = i95 + (8LL * i97);
bool b100;
b100 = i41 < (-(8LL * (i97 + nvfuser_zero)));
bool b101;
b101 = b43 && b100;
#pragma unroll
for(nvfuser_index_t i102 = 0; i102 < 2LL; ++i102) {
nvfuser_index_t i103;
i103 = i98 + (2LL * i102);
nvfuser_index_t i104;
i104 = i99 + (i36 * i102);
nvfuser_index_t i105;
i105 = i102 + nvfuser_zero;
bool b106;
b106 = i42 < (i96 - (8LL * i105));
bool b107;
b107 = b100 && b106;
Array<float, 2LL, 2> T7;
#pragma unroll
for(nvfuser_index_t i108 = 0; i108 < 2LL; ++i108) {
T7[i108] = 0.00000000000000000e+00;
}
#pragma unroll
for(nvfuser_index_t i108 = 0; i108 < 2LL; ++i108) {
// Allocate global tensor T14
reduction::serialReductionStep(
T7[i108],
T12[(i103 + i108)],
0.00000000000000000e+00,
T14[(i104 + (i108 + nvfuser_zero))],
[](float &a, float b) { a = a + b; },
index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
b107,
b107);
}
if ((b101 && b106)) {
loadLocalToGlobal<float, /*vec_size=*/2, /*is_volatile=*/false>( &T6[(i99 + (i36 * i105))], &T7[0LL]);
}
}
}
}
NVFUSER_UPDATE_MAGIC_ZERO;
grid_sync::blockSerializeRelease<false, false, true>(&T13[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
}
}
/*
CUDA NVRTC compile error: __tmp_kernel_none_f0_c0_r0_g0.cu(9791): error: no instance of function template "<unnamed>::reduction::serialReductionStep" matches the argument list
argument types are: (float, float, double, float, lambda [](float &, float)->void, __nv_bool, __nv_bool, __nv_bool, __nv_bool)
reduction::serialReductionStep(
^
1 error detected in the compilation of "__tmp_kernel_none_f0_c0_r0_g0.cu".
*/ |
5dc198c
to
0bacbea
Compare
These were disabled in #1545 because of slow compilation with gridReduce
0bacbea
to
fc07a9a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Congrats on enabling split-K!
Serial grid reductions are used in split-K matmuls as of #1510. This means we load and store elements in the reduction tensor according to the indexing of the work buffer. This is unlike ordinary grid reductions that use `gridReduce`, which reduces individual elements using a scheme that ensures coalescing by indexing into the work buffer based on `threadIdx` and `blockIdx`. Currently these split-K accesses are inefficient due to this lack of coalescing. We currently already ensure coalesced output stores in matmuls when possible by using smem for the epilogue (#387). A shared memory buffer is used to communicate elements between threads so that the resulting tensor will have a proper global access pattern when it is written out to global memory as a tile of the output. Before this PR if we used split-K with `use_smem_epilogue = true`, the store to global memory will be coalesced but there will be uncoalesced accesses during the split-K reduction. This PR modifies scheduling so that in those cases, the smem epilogue tensor is placed before the split-K sum, so that unswizzling happens before completing the reduction. The result is that the reduction accesses are coalesced. This is a generated kernel from `NVFuserTest.FusionAmpereMatmulSplitKBias_CUDA`: ```c++ // ... (main loop) ... #pragma unroll for(nvfuser_index_t i59 = 0; i59 < 4LL; ++i59) { nvfuser_index_t i104; i104 = 8LL * i59; nvfuser_index_t i105; i105 = 32LL * i59; #pragma unroll for(nvfuser_index_t i61 = 0; i61 < 8LL; ++i61) { nvfuser_index_t i106; i106 = 4LL * i61; asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" :"=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]), "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]), "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]), "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3]) :"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[0]), "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[1]), "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[2]), "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[3]), "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[0]), "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[1]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3]) ); } } } NVFUSER_UPDATE_MAGIC_ZERO; __syncthreads(); } __syncthreads(); #pragma unroll for(nvfuser_index_t i107 = 0; i107 < 4LL; ++i107) { nvfuser_index_t i108; i108 = 32LL * i107; nvfuser_index_t i109; i109 = i38 + (2048LL * i107); #pragma unroll for(nvfuser_index_t i110 = 0; i110 < 8LL; ++i110) { nvfuser_index_t i111; i111 = i108 + (4LL * i110); nvfuser_index_t i112; i112 = i11 + i110; nvfuser_index_t i113; i113 = (i109 + (32LL * (i112 / 4LL))) + (8LL * (i39 ^ (i112 % 4LL))); #pragma unroll for(nvfuser_index_t i114 = 0; i114 < 2LL; ++i114) { loadGeneric<float, 2>( &T17[(i113 + (1024LL * i114))], &T16[(i111 + (2LL * i114))]); } } } NVFUSER_UPDATE_MAGIC_ZERO; // Allocate global tensor T19 grid_sync::blockSerializeWait<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); __syncthreads(); #pragma unroll for(nvfuser_index_t i115 = 0; i115 < 32LL; ++i115) { nvfuser_index_t i116; i116 = i115 + nvfuser_zero; nvfuser_index_t i117; i117 = i44 + (i45 * i116); nvfuser_index_t i118; i118 = i47 + (4LL * i115); bool b119; b119 = i55 < (-(4LL * i116)); bool b120; b120 = b54 && b119; Array<float, 4LL, 4> T6; T6.set(float(0.000000000e+00f)); // Allocate global tensor T20 reduction::serialReductionStep</*vec_size=*/4>( &T6[0LL], &T17[(i42 + (512LL * i115))], 0.000000000e+00f, &T20[i117], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1, b120, b120); Array<float, 4LL, 4> T10; #pragma unroll for(nvfuser_index_t i121 = 0; i121 < 4LL; ++i121) { __half T18[1LL]; T18[0LL] = 0LL; if (b119) { T18[0LL] = T2[(i118 + ((i48 + (i121 + nvfuser_zero)) / 128LL))]; } __half T7[1LL]; T7[0LL] = T18[0LL]; float T8[1LL]; T8[0LL] = __half2float(T7[0LL]); T10[i121] = T6[i121] + T8[0LL]; } if ((b56 && b119)) { loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T9[i117], &T10[0LL]); } } NVFUSER_UPDATE_MAGIC_ZERO; grid_sync::blockSerializeRelease<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); } ``` Note that the `i135` loop will be smaller once we have #1528 at which point it would more clearly show reduction followed by the loop for the predicated bias epilogue. (Diff should be viewed hiding whitespace changes as many changes are to indentation).
Stacked on #1456.
This simply enables serial reduction in the matmul scheduler when split-K is used.