From 5b7e0d524f701d7a4023dd8cc92272d2f73f7712 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 16 Nov 2023 17:58:00 +0800 Subject: [PATCH] fix a bug in fused rope (#1750) Signed-off-by: Xin Yao --- csrc/megatron/fused_rotary_positional_embedding.h | 5 ++--- tests/L0/run_transformer/test_fused_rope.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h index 7ac13932d..28dca70a5 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.h +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -52,8 +52,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - int offset_src_dst = offset_head + hn_id; - dst[offset_src_dst] = src[offset_src_dst]; + dst[offset_head + hn_id] = src[offset_head + hn_id]; } } } @@ -89,7 +88,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - dst[offset_head + hn_id] = 1.0; + dst[offset_head + hn_id] = src[offset_head + hn_id]; } } } diff --git a/tests/L0/run_transformer/test_fused_rope.py b/tests/L0/run_transformer/test_fused_rope.py index be557054e..5e5167119 100644 --- a/tests/L0/run_transformer/test_fused_rope.py +++ b/tests/L0/run_transformer/test_fused_rope.py @@ -84,13 +84,15 @@ def test_forward_backward(self): # unfused output_unfused = apply_rotary_pos_emb(t, emb) - output_unfused.sum().backward() + loss_unfused = output_unfused.sum() * 2 + loss_unfused.backward() grad_unfused = t.grad.detach().clone() t.grad = None # fused output_fused = fused_apply_rotary_pos_emb(t, emb) - output_fused.sum().backward() + loss_fused = output_fused.sum() * 2 + loss_fused.backward() grad_fused = t.grad.detach().clone() self.assertEqual(