Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split-K matmul is slow even for few partitions #1316

Open
5 of 6 tasks
jacobhinkle opened this issue Nov 16, 2023 · 3 comments
Open
5 of 6 tasks

Split-K matmul is slow even for few partitions #1316

jacobhinkle opened this issue Nov 16, 2023 · 3 comments
Assignees

Comments

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 16, 2023

Our current split-K kernels are quite slow. For example, these are the two split-K problems I encountered with a bfloat16 fwd+bwd nanogpt run (both are TST, NN, no epilogue), measured on an A100:
image
image
Note that M and N match between these cases: the only difference is K, which is much larger in the first plot. Larger K shows that quantization has a large effect on runtime in the first case, which for the smaller problem, the reduction needed for our split-K seems to drastically increase runtime.

Cublas uses a single kernel with in-place reduction in this case, and in most of the important GEMMs I've observed in real networks so far (please reply with an example if there are important two-kernel GEMMs I missed).

Possible approach: grouped reduction

In the case of allreduce and welford+broadcast, we have the ability to do a "cross-iteration grouped reduction". This is a notion that is related to but distinct from "horizontally grouped reduction", which refers to executing multiple reduction operations simultaneously (e.g. a sum and a max) which enables sharing synchronization. Instead, cross-iteration grouping refers to the use of ParallelType::Group to indicate some unparallelized non-reduction axes in a single reduction which should be converted to tuple and reduced as one unit. This enables us to reduce the amount of synching required by a factor equal to the number of elements in the grouped loop nest. For example, consider this pseudo-code which is doing 8 separate reductions:

#pragma unroll
for (auto i1: irange(2)) {
  #pragma unroll
  for (auto i2: irange(4)) {
    out[i1*4+ i2] = gridReduce(in[i1* 4 + i2], [](float a, float b) { return a+b; });
  }
}

Each grid reduction will perform a separate block and grid sync, and each element will be reduced separately. Incidentally, placing this heavy inline function inside an unrolled loop nest slows down compilation a lot (#1242).

Instead, grouping the i1 and i2 axes results in something like

groupedGridReduce(  // actually uses reduction::ParallelReduce
  /*out*/makeTuple<
      out[0], out[1], out[2], out[3],
      out[4 + 0], out[4 + 1], out[4 + 2], out[4 + 3]>,
  /*in*/makeTuple<
      in[0], in[1], in[2], in[3],
      in[4 + 0], in[4 + 1], in[4 + 2], in[4 + 3]>,
   [](float a, float b) { return a+b; }
);

However, this is currently implemented with a type-level tuple and only supports grouping sizes of 1 thru 8, or 16. Most of our matmul kernels have 128 reductions to group, which will be cumbersome to support in the tuple approach.

Eliminating the workspace

Reducing the amount of synchronization means that we typically would need to use a larger workspace. To see this, consider that our current approach is such that only a single element in each reduction segment is reduced at once, so we only need a global workspace buffer whose size is num_segments*blocks_per_segment. Synchronizing at the outer loop level means we would need num_segments*blocks_per_segment*reductions_per_segment to avoid data hazards. The factor blocks_per_segment comes from our current approach where we asynchronously write from every block and then the last block in each segment waits to perform the final sum. Instead, we can serialize the blocks so that each block watches a semaphore for its turn to read and write its individual work buffer element. That reduces the required workspace to num_segments * reductions_per_segment which is equal to the total number of reductions being performed; in the case of a matmul, this is the size of the output tensor. We can then enable our aliasing analysis to re-use the output buffer for this purpose instead of creating a new intermediate tensor.

Tasks

@jacobhinkle jacobhinkle self-assigned this Nov 16, 2023
@jacobhinkle jacobhinkle mentioned this issue Nov 16, 2023
jacobhinkle added a commit that referenced this issue Dec 8, 2023
This PR introduces kernel IR nodes for the syncs that need to occur
before and after a loop containing a serial grid reduction. That grid
reduction can be inlined with other computation in a loop nest, and the
sync nodes will be placed around the outer loop in the generated kernel.
The `kir::GridReduction` node itself is modified to have an attribute
available via `bool kir::GridReduction::isSerial() const` indicating
whether this is a serial grid reduction. This PR tests that codegen is
correct.

Default CUDA kernel for the included test:
```
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T4, Tensor<int64_t, 1, 1> T5) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = 32LL * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i1;
  i1 = 32768LL * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i2;
  i2 = 262144LL * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i3;
  i3 = ((((2097152LL * ((nvfuser_index_t)blockIdx.z)) + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  nvfuser_index_t i4;
  i4 = ((((nvfuser_index_t)threadIdx.x) + i0) + i1) + i2;
  nvfuser_index_t i5;
  i5 = (((-2097152LL + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  bool b6;
  b6 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1LL);
  // Allocate global tensor T4
  // Allocate global tensor T5
  float T2[128LL];
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      T2[(i7 + (32LL * i8))] = 0.000000000e+00f;
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    nvfuser_index_t i9;
    i9 = 256LL * i7;
    nvfuser_index_t i10;
    i10 = i3 + i9;
    nvfuser_index_t i11;
    i11 = -i9;    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      nvfuser_index_t i12;
      i12 = 8192LL * (i8 + nvfuser_zero);
      if ((i5 < (i11 - i12))) {
        T2[(i7 + (32LL * i8))]
           = T0[(i10 + i12)];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
    nvfuser_index_t i14;
    i14 = 32LL * i13;
    nvfuser_index_t i15;
    i15 = 8192LL * i13;
    nvfuser_index_t i16;
    i16 = i4 + i15;
    nvfuser_index_t i17;
    i17 = -i15;
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 32LL; ++i18) {
      nvfuser_index_t i19;
      i19 = 256LL * (i18 + nvfuser_zero);
      float T3[1LL];
      T3[0LL] = 0.000000000e+00f;
      reduction::gridReduce<false, false, true, false, false, false, false, true>(
        T3[0LL],
        T2[(i14 + i18)],
        [](float &a, float b) { a = a + b; },
        &T4[0],
        &T5[0],
        static_cast<float*>(shared_mem),
        true,
        true,
        float(0.000000000e+00f),
        ((i13 * 32LL) + i18),
        128LL);
      if ((b6 && (i5 < (i17 - i19)))) {
        T1[(i16 + i19)]
           = T3[0LL];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
}
```
The serial reduction kernel looks like this:
```
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T6, Tensor<int64_t, 1, 1> T5) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = 32LL * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i1;
  i1 = 32768LL * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i2;
  i2 = 262144LL * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i3;
  i3 = ((((2097152LL * ((nvfuser_index_t)blockIdx.z)) + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  nvfuser_index_t i4;
  i4 = ((((nvfuser_index_t)threadIdx.x) + i0) + i1) + i2;
  nvfuser_index_t i5;
  i5 = (((-2097152LL + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  bool b6;
  b6 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1LL);
  // Allocate global tensor T6
  // Allocate global tensor T5
  float T2[128LL];
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      T2[(i7 + (32LL * i8))] = 0.000000000e+00f;
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    nvfuser_index_t i9;
    i9 = 256LL * i7;
    nvfuser_index_t i10;
    i10 = i3 + i9;
    nvfuser_index_t i11;
    i11 = -i9;
    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      nvfuser_index_t i12;
      i12 = 8192LL * (i8 + nvfuser_zero);
      if ((i5 < (i11 - i12))) {
        T2[(i7 + (32LL * i8))]
           = T0[(i10 + i12)];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeWait<false, false, true, false>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll
  for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
    nvfuser_index_t i14;
    i14 = 32LL * i13;
    nvfuser_index_t i15;
    i15 = 8192LL * i13;
    nvfuser_index_t i16;
    i16 = i4 + i15;
    nvfuser_index_t i17;
    i17 = -i15;
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 32LL; ++i18) {
      nvfuser_index_t i19;
      i19 = 256LL * (i18 + nvfuser_zero);
      float T3[1LL];
      T3[0LL] = 0.000000000e+00f;
      reduction::serialReductionStep(
        T3[0LL],
        T2[(i14 + i18)],
        0.000000000e+00f,
        T6[(i16 + i19)],
        [](float &a, float b) { a = a + b; },
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,
        true);
      if ((b6 && (i5 < (i17 - i19)))) {
        T1[(i16 + i19)]
           = T3[0LL];
      }
    }
  }
  grid_sync::blockSerializeRelease<false, false, true, false>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  NVFUSER_UPDATE_MAGIC_ZERO;
}
```

#### What is not included in this PR

There is no automatic scheduling or lowering to serial reductions in
this PR. The included test works via a post-lowering hook in
`FusionExecutor` to simply test that we can codegen the nodes properly
once they are manually placed.

There is also no re-use of global buffers currently, so this is not yet
an "in-place" reduction. I.e. we must manually allocate a work buffer
that is the full size of the grid reduction output at this time. In the
future, we can avoid the need for that workspace by aliasing an output
buffer.

The work buffer must currently be the same dtype as the reduction
element. In the future, we could relax this in order to cast to lower
precision in the work buffer. This would enable us to re-use the global
memory allocated for TST and HSH matmul output, at the expense of a
small loss in precision.

Related to #1316 and #991.
@jacobhinkle
Copy link
Collaborator Author

Re Enable inner aliasing of global buffers. This would be useful in matmuls like TSS where we want to output a float tensor, and the accumulator is also float. In these cases, we could use the output buffer as the reduction work buffer, saving us from allocating a separate work buffer. However, we expect our most common use cases to look like TST or be TSS but with conversion to half precision at the end of the epilogue. In those cases the output buffer is half precision, so we cannot use it as the work buffer anyway. There is a possibility in such cases to use a reduced-precision work buffer, casting to and from bfloat16 when doing the serial reduction; however, since that technically reduces precision slightly we are not considering that at this point, but we can revisit that in the future if there is an important memory-constrained use case that would benefit from it.

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Jan 10, 2024

Update: I did some benchmarking today now that we have some functional PR branches. This is comparing current main i.e. 61b0ab6 against #1510 which enables serial grid reductions for split-K matmuls, for two problem sizes.
image
I had previously made some estimates on speedups based on manually tweaking a kernel for splitk_factor=2 and comparing to main. I expected to get to around 3x the perf of cublas with this PR but this shows that at splitk_factor=2 we are around 1.4x for the small problem and 1.09x for the large problem. If we use a higher split we can get better perf for the large problem, but that may be true for cublas also.

@jacobhinkle
Copy link
Collaborator Author

Although #1534 is not yet finished, I also compared it. This is unswizzling using shared memory before the grid reduction. NOTE: I suspect the unswizzling is incorrect, as noted on that PR. I don't yet see any improvement on that PR but I will keep an eye on it:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant