Skip to content

Commit

Permalink
[Bugfix] Marlin 2:4 temp fix for large M dim (>256) (#10464)
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson authored Nov 20, 2024
1 parent d5b68ab commit d200972
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
15 changes: 11 additions & 4 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization
thread_k = 128;
thread_m = 128;
} else if (prob_n <= 256) {
} else {
thread_k = 64;
thread_m = 256;
} else {
thread_k = 32;
thread_m = 512;
}
// Also had
// if prob_n > 256
// thread_k = 32;
// thread_m = 512;
// but this is broken,
// TODO(Lucas, Alex M): figure out why
}

int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
Expand Down Expand Up @@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(a.dtype() == torch::kFloat16,
"A is not float16, currently only float16 is supported");

// Verify B device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
Expand All @@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify scales device and strides
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
"A is not float16, currently only float16 is supported");

// Alloc C matrix
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]

DTYPES = [torch.float16, torch.bfloat16]
Expand Down

0 comments on commit d200972

Please sign in to comment.