Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU][PHI Kernels] refine bf16 test for fused_rope #60439

Merged
merged 2 commits into from
Dec 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 71 additions & 23 deletions test/xpu/test_fused_rotary_position_embedding_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,17 @@ def get_forward_backward(
fw.append(out_q)
fw.append(out_k)
fw.append(out_v)
paddle.seed(seed + 1)
out_gq = paddle.randn(out_q.shape, self.dtype)
out_gk = paddle.randn(out_q.shape, self.dtype)
out_gv = paddle.randn(out_q.shape, self.dtype)

paddle.autograd.backward(
[out_q, out_k, out_v], [out_gq, out_gk, out_gv], True
)
bw.append(tensor_q)
bw.append(tensor_k)
bw.append(tensor_v)
bw.append(tensor_q.grad)
bw.append(tensor_k.grad)
bw.append(tensor_v.grad)

return fw, bw

Expand Down Expand Up @@ -368,42 +369,89 @@ def setUp(self):
self.shape = [2, 8, 2, 16]

def test_api(self):
q_fp32 = paddle.rand(self.shape, dtype="float32")
k_fp32 = paddle.rand(self.shape, dtype="float32")
v_fp32 = paddle.rand(self.shape, dtype="float32")
sin_fp32 = paddle.rand(
[1, self.shape[1], 1, self.shape[3]], dtype="float32"
paddle.disable_static()
q_bf16 = paddle.randn(self.shape, dtype="bfloat16")
k_bf16 = paddle.randn(self.shape, dtype="bfloat16")
v_bf16 = paddle.randn(self.shape, dtype="bfloat16")
sin_bf16 = paddle.randn(
[1, self.shape[1], 1, self.shape[3]], dtype="bfloat16"
)
cos_fp32 = paddle.rand(
[1, self.shape[1], 1, self.shape[3]], dtype="float32"
cos_bf16 = paddle.randn(
[1, self.shape[1], 1, self.shape[3]], dtype="bfloat16"
)
q_bf16 = paddle.to_tensor(q_fp32, dtype="bfloat16")
k_bf16 = paddle.to_tensor(k_fp32, dtype="bfloat16")
v_bf16 = paddle.to_tensor(v_fp32, dtype="bfloat16")
sin_bf16 = paddle.to_tensor(sin_fp32, dtype="bfloat16")
cos_bf16 = paddle.to_tensor(cos_fp32, dtype="bfloat16")

out_fp32 = fused_rotary_position_embedding(
q_fp32,
k_fp32,
v_fp32,
sin_fp32,
cos_fp32,
use_neox_rotary_style=False,
q_bf16.stop_gradient = False
k_bf16.stop_gradient = False
v_bf16.stop_gradient = False
q_fp32 = paddle.to_tensor(q_bf16, dtype="float32", stop_gradient=False)
k_fp32 = paddle.to_tensor(k_bf16, dtype="float32", stop_gradient=False)
v_fp32 = paddle.to_tensor(v_bf16, dtype="float32", stop_gradient=False)
sin_fp32 = paddle.to_tensor(sin_bf16, dtype="float32")
cos_fp32 = paddle.to_tensor(cos_bf16, dtype="float32")

position_ids = paddle.arange(0, self.shape[1], dtype="int64")
position_ids = paddle.stack(
[position_ids for _ in range(self.shape[0])], axis=0
)
out_bf16 = fused_rotary_position_embedding(
q_bf16,
k_bf16,
v_bf16,
sin_bf16,
cos_bf16,
position_ids=position_ids,
use_neox_rotary_style=False,
)

grad_out_q_bf16 = paddle.randn(self.shape, dtype="bfloat16")
grad_out_k_bf16 = paddle.randn(self.shape, dtype="bfloat16")
grad_out_v_bf16 = paddle.randn(self.shape, dtype="bfloat16")

paddle.autograd.backward(
out_bf16, [grad_out_q_bf16, grad_out_k_bf16, grad_out_v_bf16], True
)
grad_bf16 = [q_bf16.grad, k_bf16.grad, v_bf16.grad]

out_fp32 = paddle_fused_rotary_position_embedding(
q_fp32,
k_fp32,
v_fp32,
sin_fp32,
cos_fp32,
position_ids=position_ids,
use_neox_rotary_style=False,
)

grad_out_q_fp32 = paddle.to_tensor(grad_out_q_bf16, dtype="float32")
grad_out_k_fp32 = paddle.to_tensor(grad_out_k_bf16, dtype="float32")
grad_out_v_fp32 = paddle.to_tensor(grad_out_v_bf16, dtype="float32")
paddle.autograd.backward(
out_fp32, [grad_out_q_fp32, grad_out_k_fp32, grad_out_v_fp32], True
)
grad_fp32 = [q_fp32.grad, k_fp32.grad, v_fp32.grad]

for fp32_val, bf16_val in zip(out_fp32, out_bf16):
bf16_val = convert_uint16_to_float(bf16_val.numpy())
np.testing.assert_allclose(
fp32_val.numpy(), bf16_val, rtol=1e-2, atol=1e-2
)
for grad_fp32_val, grad_bf16_val in zip(grad_fp32, grad_bf16):
grad_bf16_val = convert_uint16_to_float(grad_bf16_val.numpy())
np.testing.assert_allclose(
grad_fp32_val.numpy(), grad_bf16_val, rtol=1e-2, atol=1e-2
)


class XPUTestFusedRotaryPositionEmbeddingBf16_2(
XPUTestFusedRotaryPositionEmbeddingBf16_1
):
def setUp(self):
self.shape = [2, 2048, 16, 128]


# too long for CI
# class XPUTestFusedRotaryPositionEmbeddingBf16_3(XPUTestFusedRotaryPositionEmbeddingBf16_1):
# def setUp(self):
# self.shape = [2, 8192, 8, 128]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLaMA 65B的规模耗时过长(超过两分钟),CI里先暂时不跑,本地能通过



if __name__ == '__main__':
Expand Down