This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
/
test_bart.py
54 lines (44 loc) · 1.62 KB
/
test_bart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Compare the outputs of HF and vLLM for BART models using greedy sampling.
Run `pytest tests/models/test_bart.py`.
"""
import pytest
from tests.kernels.utils import override_backend_env_variable
from vllm.utils import STR_XFORMERS_ATTN_VAL
from .utils import check_logprobs_close
MODELS = ["facebook/bart-large-cnn"]
# Backends under test
#
# Currently only XFormers is supported
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
def test_models(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
backend_name: str,
monkeypatch,
) -> None:
# TODO(sang): Sliding window should be tested separately.
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit(
example_encoder_decoder_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
example_encoder_decoder_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)