Skip to content

Commit

Permalink
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder m…
Browse files Browse the repository at this point in the history
…odels (#9559)
  • Loading branch information
sroy745 authored Nov 2, 2024
1 parent d522034 commit a78dd33
Show file tree
Hide file tree
Showing 11 changed files with 716 additions and 317 deletions.
88 changes: 51 additions & 37 deletions tests/encoder_decoder/test_e2e_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs

from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close

LIST_ENC_DEC_SUPPORTED_BACKENDS = [
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
]


def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
Expand All @@ -29,7 +35,8 @@ def vllm_to_hf_output(


@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
Expand All @@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
num_logprobs: int,
decoder_prompt_type: DecoderPromptType,
enforce_eager: bool,
attn_backend: _Backend,
) -> None:
'''
End-to-End (E2E) test for the encoder-decoder framework.
Expand All @@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
implementations to ensure that both implementations produce consistent
and correct results.
'''
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
test_case_prompts = example_encoder_decoder_prompts[
decoder_prompt_type]

# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}

with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)

hf_skip_tokens = (1
if decoder_prompt_type == DecoderPromptType.NONE else 0)
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
Loading

0 comments on commit a78dd33

Please sign in to comment.