-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split-K matmul is slow even for few partitions #1316
Comments
This PR introduces kernel IR nodes for the syncs that need to occur before and after a loop containing a serial grid reduction. That grid reduction can be inlined with other computation in a loop nest, and the sync nodes will be placed around the outer loop in the generated kernel. The `kir::GridReduction` node itself is modified to have an attribute available via `bool kir::GridReduction::isSerial() const` indicating whether this is a serial grid reduction. This PR tests that codegen is correct. Default CUDA kernel for the included test: ``` __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T4, Tensor<int64_t, 1, 1> T5) { alignas(16) extern __shared__ char array[]; void* shared_mem = array; NVFUSER_DEFINE_MAGIC_ZERO; nvfuser_index_t i0; i0 = 32LL * ((nvfuser_index_t)threadIdx.y); nvfuser_index_t i1; i1 = 32768LL * ((nvfuser_index_t)blockIdx.y); nvfuser_index_t i2; i2 = 262144LL * ((nvfuser_index_t)blockIdx.x); nvfuser_index_t i3; i3 = ((((2097152LL * ((nvfuser_index_t)blockIdx.z)) + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2; nvfuser_index_t i4; i4 = ((((nvfuser_index_t)threadIdx.x) + i0) + i1) + i2; nvfuser_index_t i5; i5 = (((-2097152LL + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2; bool b6; b6 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1LL); // Allocate global tensor T4 // Allocate global tensor T5 float T2[128LL]; #pragma unroll for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) { #pragma unroll for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) { T2[(i7 + (32LL * i8))] = 0.000000000e+00f; } } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) { nvfuser_index_t i9; i9 = 256LL * i7; nvfuser_index_t i10; i10 = i3 + i9; nvfuser_index_t i11; i11 = -i9; #pragma unroll for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) { nvfuser_index_t i12; i12 = 8192LL * (i8 + nvfuser_zero); if ((i5 < (i11 - i12))) { T2[(i7 + (32LL * i8))] = T0[(i10 + i12)]; } } } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) { nvfuser_index_t i14; i14 = 32LL * i13; nvfuser_index_t i15; i15 = 8192LL * i13; nvfuser_index_t i16; i16 = i4 + i15; nvfuser_index_t i17; i17 = -i15; #pragma unroll for(nvfuser_index_t i18 = 0; i18 < 32LL; ++i18) { nvfuser_index_t i19; i19 = 256LL * (i18 + nvfuser_zero); float T3[1LL]; T3[0LL] = 0.000000000e+00f; reduction::gridReduce<false, false, true, false, false, false, false, true>( T3[0LL], T2[(i14 + i18)], [](float &a, float b) { a = a + b; }, &T4[0], &T5[0], static_cast<float*>(shared_mem), true, true, float(0.000000000e+00f), ((i13 * 32LL) + i18), 128LL); if ((b6 && (i5 < (i17 - i19)))) { T1[(i16 + i19)] = T3[0LL]; } } } NVFUSER_UPDATE_MAGIC_ZERO; } ``` The serial reduction kernel looks like this: ``` __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T6, Tensor<int64_t, 1, 1> T5) { alignas(16) extern __shared__ char array[]; void* shared_mem = array; NVFUSER_DEFINE_MAGIC_ZERO; nvfuser_index_t i0; i0 = 32LL * ((nvfuser_index_t)threadIdx.y); nvfuser_index_t i1; i1 = 32768LL * ((nvfuser_index_t)blockIdx.y); nvfuser_index_t i2; i2 = 262144LL * ((nvfuser_index_t)blockIdx.x); nvfuser_index_t i3; i3 = ((((2097152LL * ((nvfuser_index_t)blockIdx.z)) + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2; nvfuser_index_t i4; i4 = ((((nvfuser_index_t)threadIdx.x) + i0) + i1) + i2; nvfuser_index_t i5; i5 = (((-2097152LL + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2; bool b6; b6 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1LL); // Allocate global tensor T6 // Allocate global tensor T5 float T2[128LL]; #pragma unroll for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) { #pragma unroll for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) { T2[(i7 + (32LL * i8))] = 0.000000000e+00f; } } NVFUSER_UPDATE_MAGIC_ZERO; #pragma unroll for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) { nvfuser_index_t i9; i9 = 256LL * i7; nvfuser_index_t i10; i10 = i3 + i9; nvfuser_index_t i11; i11 = -i9; #pragma unroll for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) { nvfuser_index_t i12; i12 = 8192LL * (i8 + nvfuser_zero); if ((i5 < (i11 - i12))) { T2[(i7 + (32LL * i8))] = T0[(i10 + i12)]; } } } NVFUSER_UPDATE_MAGIC_ZERO; grid_sync::blockSerializeWait<false, false, true, false>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); #pragma unroll for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) { nvfuser_index_t i14; i14 = 32LL * i13; nvfuser_index_t i15; i15 = 8192LL * i13; nvfuser_index_t i16; i16 = i4 + i15; nvfuser_index_t i17; i17 = -i15; #pragma unroll for(nvfuser_index_t i18 = 0; i18 < 32LL; ++i18) { nvfuser_index_t i19; i19 = 256LL * (i18 + nvfuser_zero); float T3[1LL]; T3[0LL] = 0.000000000e+00f; reduction::serialReductionStep( T3[0LL], T2[(i14 + i18)], 0.000000000e+00f, T6[(i16 + i19)], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1, true, true); if ((b6 && (i5 < (i17 - i19)))) { T1[(i16 + i19)] = T3[0LL]; } } } grid_sync::blockSerializeRelease<false, false, true, false>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); NVFUSER_UPDATE_MAGIC_ZERO; } ``` #### What is not included in this PR There is no automatic scheduling or lowering to serial reductions in this PR. The included test works via a post-lowering hook in `FusionExecutor` to simply test that we can codegen the nodes properly once they are manually placed. There is also no re-use of global buffers currently, so this is not yet an "in-place" reduction. I.e. we must manually allocate a work buffer that is the full size of the grid reduction output at this time. In the future, we can avoid the need for that workspace by aliasing an output buffer. The work buffer must currently be the same dtype as the reduction element. In the future, we could relax this in order to cast to lower precision in the work buffer. This would enable us to re-use the global memory allocated for TST and HSH matmul output, at the expense of a small loss in precision. Related to #1316 and #991.
Re |
Update: I did some benchmarking today now that we have some functional PR branches. This is comparing current main i.e. 61b0ab6 against #1510 which enables serial grid reductions for split-K matmuls, for two problem sizes. |
Although #1534 is not yet finished, I also compared it. This is unswizzling using shared memory before the grid reduction. NOTE: I suspect the unswizzling is incorrect, as noted on that PR. I don't yet see any improvement on that PR but I will keep an eye on it: |
Our current split-K kernels are quite slow. For example, these are the two split-K problems I encountered with a bfloat16 fwd+bwd nanogpt run (both are TST, NN, no epilogue), measured on an A100:
Note that M and N match between these cases: the only difference is K, which is much larger in the first plot. Larger K shows that quantization has a large effect on runtime in the first case, which for the smaller problem, the reduction needed for our split-K seems to drastically increase runtime.
Cublas uses a single kernel with in-place reduction in this case, and in most of the important GEMMs I've observed in real networks so far (please reply with an example if there are important two-kernel GEMMs I missed).
Possible approach: grouped reduction
In the case of allreduce and welford+broadcast, we have the ability to do a "cross-iteration grouped reduction". This is a notion that is related to but distinct from "horizontally grouped reduction", which refers to executing multiple reduction operations simultaneously (e.g. a sum and a max) which enables sharing synchronization. Instead, cross-iteration grouping refers to the use of
ParallelType::Group
to indicate some unparallelized non-reduction axes in a single reduction which should be converted to tuple and reduced as one unit. This enables us to reduce the amount of synching required by a factor equal to the number of elements in the grouped loop nest. For example, consider this pseudo-code which is doing 8 separate reductions:Each grid reduction will perform a separate block and grid sync, and each element will be reduced separately. Incidentally, placing this heavy inline function inside an unrolled loop nest slows down compilation a lot (#1242).
Instead, grouping the i1 and i2 axes results in something like
However, this is currently implemented with a type-level tuple and only supports grouping sizes of 1 thru 8, or 16. Most of our matmul kernels have 128 reductions to group, which will be cumbersome to support in the tuple approach.
Eliminating the workspace
Reducing the amount of synchronization means that we typically would need to use a larger workspace. To see this, consider that our current approach is such that only a single element in each reduction segment is reduced at once, so we only need a global workspace buffer whose size is
num_segments*blocks_per_segment
. Synchronizing at the outer loop level means we would neednum_segments*blocks_per_segment*reductions_per_segment
to avoid data hazards. The factorblocks_per_segment
comes from our current approach where we asynchronously write from every block and then the last block in each segment waits to perform the final sum. Instead, we can serialize the blocks so that each block watches a semaphore for its turn to read and write its individual work buffer element. That reduces the required workspace tonum_segments * reductions_per_segment
which is equal to the total number of reductions being performed; in the case of a matmul, this is the size of the output tensor. We can then enable our aliasing analysis to re-use the output buffer for this purpose instead of creating a new intermediate tensor.Tasks
The text was updated successfully, but these errors were encountered: