diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index aaac24a146..079c370873 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -70,10 +70,9 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) { // LayerNorm implementation adapted from ggml, accumulation is made using f32. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477 template -__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) { +__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const int block_size = blockDim.x; float2 mean_var = make_float2(0.f, 0.f); @@ -134,10 +133,9 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, // RmsNorm implementation adapted from ggml, accumulation is made using f32. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 template -__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) { +__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const int block_size = blockDim.x; float tmp = 0.0f; // partial sum for thread in warp @@ -530,15 +528,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, #define RMSNORM_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ - const int n_cols, const float eps) { \ - rmsnorm(src, dst, alpha, n_cols, eps); \ + const int n_cols, const int block_size, const float eps) { \ + rmsnorm(src, dst, alpha, n_cols, block_size, eps); \ } \ #define LAYERNORM_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ - const TYPENAME *beta, const int n_cols, const float eps) { \ - layernorm(src, dst, alpha, beta, n_cols, eps); \ + const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \ + layernorm(src, dst, alpha, beta, n_cols, block_size, eps); \ } \ #define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 9a360c472c..8a3c19fe38 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -543,15 +543,23 @@ impl candle::CustomOp2 for RmsNorm { let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); + let block_size = if n_cols < 1024 { 32 } else { 1024 }; let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), - block_dim: (1024, 1, 1), + block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (&src, &dst, &alpha, n_cols as i32, self.eps); + let params = ( + &src, + &dst, + &alpha, + n_cols as i32, + block_size as i32, + self.eps, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) @@ -776,15 +784,24 @@ impl candle::CustomOp3 for LayerNorm { let dim_m1 = dims[dims.len() - 1]; let (n_rows, n_cols) = (el / dim_m1, dim_m1); + let block_size = if n_cols < 1024 { 32 } else { 1024 }; let cfg = LaunchConfig { grid_dim: (n_rows as u32, 1, 1), - block_dim: (1024, 1, 1), + block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; let func = dev.get_or_load_func(&kernel_name::("layernorm"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps); + let params = ( + &src, + &dst, + &alpha, + &beta, + n_cols as i32, + block_size as i32, + self.eps, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 65a8fbf289..3a8a0bb915 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -77,6 +77,27 @@ fn rms_norm(device: &Device) -> Result<()> { Ok(()) } +fn rms_norml(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, seq_len, head_dim) = (24, 70, 64); + let el_count = b_size * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?; + let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?; + let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + let diff = (t - t2)? + .abs()? + .flatten_all()? + .max(0)? + .reshape(())? + .to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + fn layer_norm(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, device)?; @@ -103,6 +124,28 @@ fn layer_norm(device: &Device) -> Result<()> { Ok(()) } +fn layer_norml(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, seq_len, head_dim) = (24, 70, 64); + let el_count = b_size * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?; + let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?; + let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?; + let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?; + let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?; + let diff = (t - t2)? + .abs()? + .flatten_all()? + .max(0)? + .reshape(())? + .to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + #[test] fn softmax_numerical_stability() -> Result<()> { let dev = &Device::Cpu; @@ -211,5 +254,7 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); +test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal); test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal); +test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal); test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);