Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Fix some kernels failed unit tests #2498

Merged
merged 10 commits into from
Feb 5, 2024
18 changes: 18 additions & 0 deletions tests/kernels/allclose_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: 1.6e-2,
torch.float: 1.3e-6
}


def get_default_atol(output) -> float:
return default_atol[output.dtype]


def get_default_rtol(output) -> float:
return default_rtol[output.dtype]
16 changes: 13 additions & 3 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
from allclose_default import get_default_atol, get_default_rtol

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
Expand Down Expand Up @@ -30,7 +31,10 @@ def test_silu_and_mul(
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -53,7 +57,10 @@ def test_gelu_new(
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -75,4 +82,7 @@ def test_gelu_fast(
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
24 changes: 19 additions & 5 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,27 @@

from vllm._C import ops
from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import is_hip
from allclose_default import get_default_atol, get_default_rtol

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 12000 # Arbitrary values for testing
PARTITION_SIZE = 512

DTYPES = [torch.half, torch.bfloat16, torch.float]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 256
] if not is_hip() else [64, 80, 96, 112, 128]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
SEEDS = [0]
Expand Down Expand Up @@ -230,7 +238,10 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
assert torch.allclose(output,
ref_output,
atol=get_default_atol(output),
rtol=get_default_rtol(output))


def ref_multi_query_kv_attention(
Expand Down Expand Up @@ -331,4 +342,7 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
assert torch.allclose(output,
ref_output,
atol=get_default_atol(output),
rtol=get_default_rtol(output))
6 changes: 5 additions & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
import torch

from vllm._C import cache_ops
from vllm.utils import is_hip

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
# reduce the size for ROCm test to avoid HIP OOM
NUM_BLOCKS = [1024, 36000] if not is_hip else [
1024, 10000
] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
Expand Down
12 changes: 9 additions & 3 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
import torch

from allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope

IS_NEOX_STYLE = [True, False]
Expand Down Expand Up @@ -64,5 +64,11 @@ def test_rotary_embedding(
ref_query, ref_key = rope._forward(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
Loading