Skip to content

Commit

Permalink
[Model] H2O Danube3-4b (vllm-project#6451)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
g-eoj authored and Alvant committed Oct 26, 2024
1 parent d0afac3 commit c6c5cc2
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest Pillow protobuf
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

# online inference
docker exec cpu-test bash -c "
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def benchmark_rope_kernels_multi_lora(
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",
Expand Down
6 changes: 6 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
Expand Down Expand Up @@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128]

BLOCK_SIZES = [16, 32]
Expand Down Expand Up @@ -134,6 +134,8 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down
8 changes: 7 additions & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [8, 16, 32]

# Arbitrary values for testing
Expand Down Expand Up @@ -52,6 +52,8 @@ def test_copy_blocks(
kv_cache_dtype: str,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down Expand Up @@ -124,6 +126,8 @@ def test_reshape_and_cache(
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down Expand Up @@ -325,6 +329,8 @@ def test_swap_blocks(
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
Expand Down
52 changes: 52 additions & 0 deletions tests/models/test_danube3_4b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This tests danube3 separately because its head size isn't supported on CPU yet.
Run `pytest tests/models/test_danube3_4b.py`.
"""
import pytest

from .utils import check_outputs_equal

MODELS = ["h2oai/h2o-danube3-4b-base"]

target_dtype = "half"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 192, 256]
return [64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
6 changes: 6 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@ def create_kv_caches_with_random(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:

if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)

torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
Expand Down

0 comments on commit c6c5cc2

Please sign in to comment.