Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Why did I get a wrong result from GemmGrouped? #1924

Open
WangNorthSea opened this issue Nov 7, 2024 · 3 comments
Open

[QST] Why did I get a wrong result from GemmGrouped? #1924

WangNorthSea opened this issue Nov 7, 2024 · 3 comments

Comments

@WangNorthSea
Copy link

I'm using GemmGrouped in this way:

using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass_t,                                      // Element A
    cutlass::layout::RowMajor,                      // Layout A
    cutlass::ComplexTransform::kNone,               //
    1,                                              // Granularity A
    cutlass_t,                                      // Element B
    cutlass::layout::RowMajor,                      // Layout B
    cutlass::ComplexTransform::kNone,               //
    1,                                              // Granularity B
    cutlass_t,                                      // Element C&D
    cutlass::layout::RowMajor,                      // Layout C&D
    float,                                          // Element Accumulator
    cutlass::arch::OpClassTensorOp,                 // Operator Class Tag
    cutlass::arch::Sm80,                            // Architecture
    cutlass::gemm::GemmShape<16, 64, 64>,           // Thread Block Shape
    cutlass::gemm::GemmShape<16, 16, 64>,           // Warp Shape
    cutlass::gemm::GemmShape<16, 8, 16>,            // Instruction Shape
    LinearCombination<cutlass_t, 1, float, float>,  // Epilogue
    GemmIdentityThreadblockSwizzle<>,              // Swizzling Operator
    2                                               // Stages
    >::GemmKernel;

using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(1.0, 0.0);

using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
typename GemmGrouped::Arguments args_(
    all_problems, num_problems, 512, epilogue_op, ptr_X, ptr_W, ptr_Y,
    ptr_Y, ld_X, ld_W, ld_Y, ld_Y);

GemmGrouped gemm;

gemm.initialize(args, nullptr, stream);
gemm.run(stream);

With num_problems = 1, M = 2035, N = 48, K = 3584
all_problems[0] = cutlass::gemm::GemmCoord(2035, 48, 3584)
ptr_X[0] = matrix A, ptr_W[0] = matrix B and ptr_Y[0] = matrix C
ld_X[0] = K, ld_W[0] = N, ld_Y[0] = N
cutlass_t = cutlass::half_t
and the GPU I used is A100
cutlass version: 3.5.0
nvcc version: 12.4

The calculation I'm expecting is C = matmul(A, B) so I set alpha = 1.0 and beta = 0.0 for epilogue_op.
The shape of C is (2035, 48), however only elements in the first 16 columns were correct, all other elements of C were incorrect.

I spent a lot of time on tracing the execution procedure with cuda-gdb, and I found something is wrong in loading warp fragments of matrix B.
The loading procedure is done by code below:
https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h#L395-L415

for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) {
  for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) {
    int access_idx = c + s * Policy::LdsmIterations::kContiguous;

    AccessType const *source_ptr =
        pointer_[c % kPointerCount] +
        Layout::TileShape::kContiguous * (c / kPointerCount) +
        Policy::kLdsmOpInner * Policy::LdsmShape::kStrided * s * stride_;

    char const *source_byte_ptr = reinterpret_cast<char const *>(source_ptr) + byte_offset + byte_offset_;

    cutlass::arch::ldsm<layout::ColumnMajor, Policy::LdsmShape::kCount>(
      fetch_ptr[access_idx],
      source_byte_ptr
    );
  }
}

Before the 'ldsm' here, I printed the values in global memory from address 'source_byte_ptr' of block 0 thread 0, and the values and 'source_byte_ptr' of some threads are shown below:
Image
Looks like cutlass has automatically done padding and swizzle for matrix B, and I think the memory layout of B looks correct. The ''source_byte_ptr' of 32 threads in Warp 0(thread 0 - 31) were all pointed to B[threadIdx.x][0:8], however for Warp 1, thread 34 got the address of 8 zeros...and many threads in Warp 1-3 got wrong memory addresses in my point of view.
As a result, after the 'ldsm' here, thread 0 got 4 32-bit values consisted of {B[0][0], B[1][0], B[8][0], B[9][0], B[0][8], B[1][8], B[8][8], B[9][8]}.
Image
According to the figure above showing the element layout of an m16n8k16 mma instruction, threads in Warp 0 got correct fragments. But threads in other Warps would get wrong fragments, for example, thread 33 got all zeros.
I think that's the reason why only the part belonged to Warp 0 has correct values in matrix C, but I still don't know why the fragment loading procedure would go wrong.

@WangNorthSea
Copy link
Author

WangNorthSea commented Nov 7, 2024

This is a simple test case to reproduce the problem:

#include <fstream>
#include <iostream>
#include <stdexcept>
#include <vector>

#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"

using cutlass_t = typename cutlass::half_t;

template <typename T>
__global__ void precompute_sgmv_args(cutlass::gemm::GemmCoord *all_problems,
                                     T **ptr_y, T **ptr_x, T **ptr_w,
                                     int64_t *ld_y, int64_t *ld_x,
                                     int64_t *ld_w, T *y, T *x, T *w) {
  int m = 2035;
  int k = 3584;
  int n = 48;
  all_problems[0] = cutlass::gemm::GemmCoord(m, n, k);
  ptr_w[0] = w;
  ptr_x[0] = x;
  ptr_y[0] = y;
  ld_x[0] = k;
  ld_w[0] = n;
  ld_y[0] = n;
}

size_t sgmv_tmp_size(int num_problems) {
  constexpr auto sz = sizeof(void *) * 3 + sizeof(int64_t) * 3 +
                      sizeof(cutlass::gemm::GemmCoord);
  return sz * num_problems;
}

template <typename T> inline T *alloc_from_buf(void **buf, int n) {
  auto *p = (T *)*buf;
  *buf = (void *)(p + n);
  return p;
}

bool SgmvCutlass(cutlass_t *y, const cutlass_t *x, const cutlass_t *w,
                 void *tmp_d, int num_problems, cudaStream_t stream) {
  auto ptr_Y = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
  auto ptr_X = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
  auto ptr_W = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
  auto ld_Y = alloc_from_buf<int64_t>(&tmp_d, num_problems);
  auto ld_X = alloc_from_buf<int64_t>(&tmp_d, num_problems);
  auto ld_W = alloc_from_buf<int64_t>(&tmp_d, num_problems);
  auto all_problems =
      alloc_from_buf<cutlass::gemm::GemmCoord>(&tmp_d, num_problems);

  precompute_sgmv_args<<<num_problems, 1, 0, stream>>>(
      all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y,
      (cutlass_t *)x, (cutlass_t *)w);

  using cutlass::epilogue::thread::LinearCombination;
  using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;

  using GemmKernelSm80 = typename cutlass::gemm::kernel::DefaultGemmGrouped<
      cutlass_t,                                     // Element A
      cutlass::layout::RowMajor,                     // Layout A
      cutlass::ComplexTransform::kNone,              //
      1,                                             // Granularity A
      cutlass_t,                                     // Element B
      cutlass::layout::RowMajor,                     // Layout B
      cutlass::ComplexTransform::kNone,              //
      1,                                             // Granularity B
      cutlass_t,                                     // Element C&D
      cutlass::layout::RowMajor,                     // Layout C&D
      float,                                         // Element Accumulator
      cutlass::arch::OpClassTensorOp,                // Operator Class Tag
      cutlass::arch::Sm80,                           // Architecture
      cutlass::gemm::GemmShape<16, 64, 64>,          // Thread Block Shape
      cutlass::gemm::GemmShape<16, 16, 64>,          // Warp Shape
      cutlass::gemm::GemmShape<16, 8, 16>,           // Instruction Shape
      LinearCombination<cutlass_t, 1, float, float>, // Epilogue
      GemmIdentityThreadblockSwizzle<>,              // Swizzling Operator
      2                                              // Stages
      >::GemmKernel;

  using EpilogueOutputOpSm80 = typename GemmKernelSm80::Epilogue::OutputOp;
  typename EpilogueOutputOpSm80::Params epilogue_op_sm80(1.0, 0.0);

  using GemmGroupedSm80 = cutlass::gemm::device::GemmGrouped<GemmKernelSm80>;
  typename GemmGroupedSm80::Arguments args_sm80(
      all_problems, num_problems, 512, epilogue_op_sm80, ptr_X, ptr_W, ptr_Y,
      ptr_Y, ld_X, ld_W, ld_Y, ld_Y);

  GemmGroupedSm80 gemm_sm80;

  auto status = gemm_sm80.initialize(args_sm80, nullptr, stream);
  if (status != cutlass::Status::kSuccess) {
    throw std::runtime_error("SgmvCutlass gemm.initialize failed: " +
                             std::string(cutlassGetStatusString(status)) +
                             "\n");
  }
  status = gemm_sm80.run(stream);
  if (status != cutlass::Status::kSuccess) {
    throw std::runtime_error("SgmvCutlass gemm.run failed: " +
                             std::string(cutlassGetStatusString(status)) +
                             "\n");
  }

  return true;
}

bool read_binary_file(const char *filename, void *&data, size_t &size) {
  std::ifstream file(filename, std::ios::binary | std::ios::ate);
  if (!file.is_open()) {
    std::cerr << "Failed to open file: " << filename << std::endl;
    return false;
  }

  size = file.tellg();
  file.seekg(0, std::ios::beg);

  data = malloc(size);
  if (!data) {
    std::cerr << "Failed to allocate memory for data" << std::endl;
    return false;
  }

  file.read(static_cast<char *>(data), size);
  file.close();

  return true;
}

bool write_binary_file(const char *filename, const void *data, size_t size) {
  std::ofstream file(filename, std::ios::binary);
  if (!file.is_open()) {
    std::cerr << "Failed to open file: " << filename << std::endl;
    return false;
  }

  file.write(static_cast<const char *>(data), size);
  file.close();

  return true;
}

int main(int argc, const char *argv[]) {
  cudaStream_t stream;

  cudaError_t err = cudaStreamCreate(&stream);
  if (err != cudaSuccess) {
    std::cerr << "Failed to create CUDA stream: " << cudaGetErrorString(err)
              << std::endl;
    return 1;
  }

  const char *input_name = "input.bin";
  const char *weight_name = "weight.bin";
  const char *output_name = "output.bin";
  void *input_data;
  void *weight_data;
  size_t input_size;
  size_t weight_size;

  if (!read_binary_file(input_name, input_data, input_size)) {
    return 1;
  }
  if (!read_binary_file(weight_name, weight_data, weight_size)) {
    return 1;
  }

  int M = 2035;
  int N = 48;
  int K = 3584;
  size_t element_size = sizeof(cutlass_t);
  size_t input_elements = M * K;
  size_t weight_elements = K * N;
  size_t output_elements = M * N;

  if (input_size != input_elements * element_size) {
    std::cerr << "\'input\' file size does not match expected tensor size"
              << std::endl;
    return 1;
  }
  if (weight_size != weight_elements * element_size) {
    std::cerr << "\'weight\' file size does not match expected tensor size"
              << std::endl;
    return 1;
  }

  cutlass_t *input_tensor;
  cutlass_t *weight_tensor;
  void *buffer;
  cutlass_t *output_tensor;
  cudaMalloc(&input_tensor, input_elements * element_size);
  cudaMalloc(&weight_tensor, weight_elements * element_size);
  cudaMalloc(&buffer, sgmv_tmp_size(1));
  cudaMalloc(&output_tensor, output_elements * element_size);
  void *output_data = malloc(output_elements * element_size);

  cudaMemcpyAsync(input_tensor, input_data, input_elements * element_size,
                  cudaMemcpyHostToDevice, stream);
  cudaMemcpyAsync(weight_tensor, weight_data, weight_elements * element_size,
                  cudaMemcpyHostToDevice, stream);

  try {
    SgmvCutlass(output_tensor, input_tensor, weight_tensor, buffer, 1, stream);
  } catch (const std::exception &e) {
    std::cerr << "Caught an exception: " << e.what() << std::endl;
  }

  cudaMemcpyAsync(output_data, output_tensor, output_elements * element_size,
                  cudaMemcpyDeviceToHost, stream);
  write_binary_file(output_name, output_data, output_elements * element_size);

  return 0;
}

The binary tensor files 'input.bin' and 'weight.bin' can be generated by using NumPy.
Simply compile this test case using nvcc and the executable can generate a binary tensor file 'output.bin'.
I found the values in 'output.bin' are mismatched with values directly computed by calling numpy.matmul().

@hwu36
Copy link
Collaborator

hwu36 commented Nov 7, 2024

your align1 can be the problem. Since you are doing A:row x B:row -> C:row, your leading dimension is A:k=3584, B:N=48, C:N=48. You can just use alignment = 8. Also your tile sizes are not common ones, you could start from this

using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass_t,                                      // Element A
    cutlass::layout::RowMajor,                      // Layout A
    cutlass::ComplexTransform::kNone,               //
    8,                                              // Granularity A
    cutlass_t,                                      // Element B
    cutlass::layout::RowMajor,                      // Layout B
    cutlass::ComplexTransform::kNone,               //
    8,                                              // Granularity B
    cutlass_t,                                      // Element C&D
    cutlass::layout::RowMajor,                      // Layout C&D
    float,                                          // Element Accumulator
    cutlass::arch::OpClassTensorOp,                 // Operator Class Tag
    cutlass::arch::Sm80,                            // Architecture
    cutlass::gemm::GemmShape<64, 64, 64>,           // Thread Block Shape
    cutlass::gemm::GemmShape<32, 32, 64>,           // Warp Shape
    cutlass::gemm::GemmShape<16, 8, 16>,            // Instruction Shape
    LinearCombination<cutlass_t, 1, float, float>,  // Epilogue
    GemmIdentityThreadblockSwizzle<>,              // Swizzling Operator
    6                                               // Stages
    >::GemmKernel;

If it works, then switch the tile size to smaller ones like what you tried.

@WangNorthSea
Copy link
Author

WangNorthSea commented Nov 8, 2024

@hwu36
The thread block and warp shapes you provided can help me get the correct result.
I also tried this combination:

cutlass::gemm::GemmShape<32, 64, 64>,          // Thread Block Shape
cutlass::gemm::GemmShape<16, 32, 64>,          // Warp Shape
cutlass::gemm::GemmShape<16, 8, 16>,           // Instruction Shape

It also works, but I can't make the tile size to be any smaller.
Like:

cutlass::gemm::GemmShape<32, 32, 64>,          // Thread Block Shape
cutlass::gemm::GemmShape<16, 16, 64>,          // Warp Shape
cutlass::gemm::GemmShape<16, 8, 16>,           // Instruction Shape

This will cause incorrect result. The 16 columns in middle of the output matrix are incorrect.

Anyway, thanks very much for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants