Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Fix some kernels failed unit tests #2498

Merged
merged 10 commits into from
Feb 5, 2024
18 changes: 18 additions & 0 deletions tests/kernels/allclose_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: 1.6e-2,
torch.float: 1.3e-6
}


def get_default_atol(output) -> float:
return default_atol[output.dtype]


def get_default_rtol(output) -> float:
return default_rtol[output.dtype]
16 changes: 13 additions & 3 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
from allclose_default import get_default_atol, get_default_rtol

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
Expand Down Expand Up @@ -33,7 +34,10 @@ def test_silu_and_mul(
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -57,7 +61,10 @@ def test_gelu_new(
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -80,4 +87,7 @@ def test_gelu_fast(
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
22 changes: 17 additions & 5 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from vllm._C import ops, cache_ops
from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import is_hip
from allclose_default import get_default_atol, get_default_rtol

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
Expand All @@ -17,12 +19,18 @@
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512

DTYPES = [torch.half, torch.bfloat16, torch.float]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 256
] if not is_hip() else [64, 80, 96, 112, 128]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
Expand Down Expand Up @@ -251,9 +259,11 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5

# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8_e5m2":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
Expand Down Expand Up @@ -357,4 +367,6 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
6 changes: 5 additions & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Tuple

from vllm._C import cache_ops
from vllm.utils import is_hip

COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand All @@ -14,7 +15,10 @@
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
# reduce the size for ROCm test to avoid HIP OOM
NUM_BLOCKS = [1024, 36000] if not is_hip else [
1024, 10000
] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
Expand Down
12 changes: 9 additions & 3 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
import torch

from allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope

IS_NEOX_STYLE = [True, False]
Expand Down Expand Up @@ -64,5 +64,11 @@ def test_rotary_embedding(
ref_query, ref_key = rope._forward(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
Loading