forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Clean up kernel unit tests (vllm-project#938)
- Loading branch information
1 parent
04a531d
commit 346f1e3
Showing
6 changed files
with
364 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import List, Tuple | ||
|
||
import pytest | ||
import torch | ||
|
||
|
||
def create_kv_caches( | ||
num_blocks: int, | ||
block_size: int, | ||
num_layers: int, | ||
num_heads: int, | ||
head_size: int, | ||
dtype: torch.dtype, | ||
seed: int, | ||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
scale = head_size**-0.5 | ||
x = 16 // torch.tensor([], dtype=dtype).element_size() | ||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) | ||
key_caches = [] | ||
for _ in range(num_layers): | ||
key_cache = torch.empty(size=key_cache_shape, | ||
dtype=dtype, | ||
device='cuda') | ||
key_cache.uniform_(-scale, scale) | ||
key_caches.append(key_cache) | ||
|
||
value_cache_shape = (num_blocks, num_heads, head_size, block_size) | ||
value_caches = [] | ||
for _ in range(num_layers): | ||
value_cache = torch.empty(size=value_cache_shape, | ||
dtype=dtype, | ||
device='cuda') | ||
value_cache.uniform_(-scale, scale) | ||
value_caches.append(value_cache) | ||
return key_caches, value_caches | ||
|
||
|
||
@pytest.fixture() | ||
def kv_cache_factory(): | ||
return create_kv_caches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,75 @@ | ||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from transformers.activations import get_activation | ||
|
||
from vllm import activation_ops | ||
|
||
DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing | ||
D = [512, 4096, 5120, 13824] # Arbitrary values for testing | ||
SEEDS = [0] | ||
|
||
|
||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: | ||
x1, x2 = x.chunk(chunks=2, dim=1) | ||
return F.silu(x1) * x2 | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("d", D) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@torch.inference_mode() | ||
def run_silu_and_mul( | ||
def test_silu_and_mul( | ||
num_tokens: int, | ||
d: int, | ||
dtype: torch.dtype, | ||
seed: int, | ||
) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') | ||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') | ||
activation_ops.silu_and_mul(out, x) | ||
ref_out = ref_silu_and_mul(x) | ||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) | ||
|
||
|
||
def test_silu_and_mul() -> None: | ||
for dtype in [torch.half, torch.bfloat16, torch.float]: | ||
for num_tokens in [7, 83, 2048]: | ||
for d in [512, 4096, 5120, 13824]: | ||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') | ||
run_silu_and_mul(num_tokens, d, dtype) | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("d", D) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@torch.inference_mode() | ||
def run_gelu_new( | ||
def test_gelu_new( | ||
num_tokens: int, | ||
d: int, | ||
dtype: torch.dtype, | ||
seed: int, | ||
) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') | ||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') | ||
activation_ops.gelu_new(out, x) | ||
ref_out = get_activation("gelu_new")(x) | ||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) | ||
|
||
|
||
def test_gelu_new() -> None: | ||
for dtype in [torch.half, torch.bfloat16, torch.float]: | ||
for num_tokens in [7, 83, 2048]: | ||
for d in [512, 4096, 5120, 13824]: | ||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') | ||
run_gelu_new(num_tokens, d, dtype) | ||
|
||
|
||
@torch.inference_mode() | ||
def run_gelu_fast( | ||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("d", D) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
def test_gelu_fast( | ||
num_tokens: int, | ||
d: int, | ||
dtype: torch.dtype, | ||
seed: int, | ||
) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') | ||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') | ||
activation_ops.gelu_fast(out, x) | ||
ref_out = get_activation("gelu_fast")(x) | ||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) | ||
|
||
|
||
def test_gelu_fast() -> None: | ||
for dtype in [torch.half, torch.bfloat16, torch.float]: | ||
for num_tokens in [7, 83, 2048]: | ||
for d in [512, 4096, 5120, 13824]: | ||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') | ||
run_gelu_fast(num_tokens, d, dtype) |
Oops, something went wrong.