From 97b65c7d43f4ce28366d17b907051b1ef3d9f643 Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:12:00 +0800 Subject: [PATCH] fix fused_rope diff (#60217) (#60593) --- paddle/phi/kernels/fusion/gpu/fused_rope_utils.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 972f5ee633bbb..0db16ffb7e20b 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -125,10 +125,18 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); - result[pr_index] = - cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1; - result[ls_index] = - cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0; + if (sign == 1) { + result[pr_index] = cos_value[pr_index] * p0; + result[pr_index] -= sin_value[pr_index] * p1; + + result[ls_index] = sin_value[ls_index] * p0; + result[ls_index] += cos_value[ls_index] * p1; + } else if (sign == -1) { + result[pr_index] = + cos_value[pr_index] * p0 + sin_value[ls_index] * p1; + result[ls_index] = + cos_value[ls_index] * p1 - sin_value[pr_index] * p0; + } store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]);