Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models #9559

Merged
merged 104 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 Sep 2, 2024
25cef3d
Merge branch 'vllm-project:main' into main
sroy745 Sep 3, 2024
9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 Sep 4, 2024
c48cacb
Merge branch 'vllm-project:main' into main
sroy745 Sep 5, 2024
c42c399
Merge branch 'vllm-project:main' into main
sroy745 Sep 7, 2024
3d13e43
Merge branch 'vllm-project:main' into main
sroy745 Sep 9, 2024
7479775
Merge branch 'vllm-project:main' into main
sroy745 Sep 11, 2024
df9b966
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
9a7ed92
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
118e838
Merge branch 'vllm-project:main' into main
sroy745 Sep 19, 2024
e640c69
Merge branch 'vllm-project:main' into main
sroy745 Sep 20, 2024
89fb6cd
Merge branch 'vllm-project:main' into main
sroy745 Sep 23, 2024
5d886cc
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
56f2065
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
28e103e
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
2fc1490
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
8805750
Merge branch 'vllm-project:main' into main
sroy745 Sep 26, 2024
b30e5af
Merge branch 'vllm-project:main' into main
sroy745 Sep 28, 2024
92322f1
Merge branch 'vllm-project:main' into main
sroy745 Sep 30, 2024
85e9001
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
cd4ff89
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
0dd96ed
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
9d4d969
Merge branch 'vllm-project:main' into main
sroy745 Oct 3, 2024
7d223b5
Merge branch 'vllm-project:main' into main
sroy745 Oct 5, 2024
f327d91
Merge branch 'vllm-project:main' into main
sroy745 Oct 5, 2024
b5adf28
Merge branch 'vllm-project:main' into main
sroy745 Oct 6, 2024
caf0d12
Merge branch 'vllm-project:main' into main
sroy745 Oct 7, 2024
28e77b1
Merge branch 'vllm-project:main' into main
sroy745 Oct 8, 2024
db7e46d
Merge branch 'vllm-project:main' into main
sroy745 Oct 9, 2024
59b35f0
Merge branch 'vllm-project:main' into main
sroy745 Oct 17, 2024
dd9affa
Merge branch 'vllm-project:main' into main
sroy745 Oct 17, 2024
f61a15d
Merge branch 'vllm-project:main' into main
sroy745 Oct 21, 2024
b3f42ed
Add flash attn kernel support for encoder-decoder models
sroy745 Oct 21, 2024
551008e
Run with flash attn
sroy745 Oct 22, 2024
99cfaf1
Flash Attn support
sroy745 Oct 23, 2024
1e1fb57
Some more fixes
sroy745 Oct 25, 2024
e16cbcb
More fixes
sroy745 Oct 25, 2024
1b2c060
More fixes to model_runner
sroy745 Oct 25, 2024
b619e32
Fixes
sroy745 Oct 26, 2024
0569773
Merge branch 'vllm-project:main' into main
sroy745 Oct 27, 2024
ffd82c0
commits
sroy745 Oct 27, 2024
d5f478e
Merge branch 'main' into sroy-vllm-encdec-flash
sroy745 Oct 27, 2024
a1c2e7c
Merge remote-tracking branch 'origin/main' into sroy-vllm-encdec-flash
sroy745 Oct 27, 2024
2adb7fd
Merge branch 'sroy-vllm-encdec-flash' of https://github.com/sroy745/v…
sroy745 Oct 27, 2024
d995221
Fixes
sroy745 Oct 28, 2024
d1f8140
Format
sroy745 Oct 28, 2024
b438166
Merge branch 'main' into sroy-vllm-encdec-flash
sroy745 Oct 28, 2024
d99370c
Remove unused import
sroy745 Oct 28, 2024
11bda4f
Reverting layer changes
sroy745 Oct 28, 2024
040f61e
Fixes
sroy745 Oct 28, 2024
1bc6fe1
Fixes
sroy745 Oct 29, 2024
4573f8f
Format
sroy745 Oct 29, 2024
ed587cb
Fix test reset logic
sroy745 Oct 29, 2024
18e8a97
Fixes
sroy745 Oct 29, 2024
9f7dc04
Dummu
sroy745 Oct 29, 2024
fce7f62
Fix
sroy745 Oct 29, 2024
d596c23
Dummy
sroy745 Oct 29, 2024
7bed5e6
Format
sroy745 Oct 29, 2024
3a7d05e
Format
sroy745 Oct 29, 2024
a2090e0
Merge branch 'vllm-project:main' into main
sroy745 Oct 30, 2024
a1e8c98
Merge remote-tracking branch 'origin/main' into sroy-vllm-encdec-flash
sroy745 Oct 30, 2024
0604c0a
Comments
sroy745 Oct 30, 2024
77ee5e2
Format
sroy745 Oct 30, 2024
b147fb9
Comments
sroy745 Oct 30, 2024
7284de5
Comment
sroy745 Oct 30, 2024
282a918
Comments
sroy745 Nov 1, 2024
c39d4c9
Comments
sroy745 Nov 1, 2024
cc58ebe
Comments
sroy745 Nov 1, 2024
5785714
Merge branch 'main' into sroy-vllm-encdec-flash
sroy745 Nov 1, 2024
c9a3f00
Merge branch 'vllm-project:main' into main
sroy745 Nov 1, 2024
834572f
Comments
sroy745 Nov 1, 2024
15dc714
Merge remote-tracking branch 'origin/main' into sroy-vllm-encdec-flash
sroy745 Nov 1, 2024
21946be
Merge branch 'main' into sroy-vllm-encdec-flash
sroy745 Nov 1, 2024
2264a62
Dummy
sroy745 Nov 2, 2024
7ca0ab7
Format
sroy745 Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 51 additions & 37 deletions tests/encoder_decoder/test_e2e_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs

from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close

LIST_ENC_DEC_SUPPORTED_BACKENDS = [
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
]


def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
Expand All @@ -29,7 +35,8 @@ def vllm_to_hf_output(


@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
sroy745 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
Expand All @@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
num_logprobs: int,
decoder_prompt_type: DecoderPromptType,
enforce_eager: bool,
attn_backend: _Backend,
) -> None:
'''
End-to-End (E2E) test for the encoder-decoder framework.
Expand All @@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
implementations to ensure that both implementations produce consistent
and correct results.
'''
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
test_case_prompts = example_encoder_decoder_prompts[
decoder_prompt_type]

# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}

with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)

hf_skip_tokens = (1
if decoder_prompt_type == DecoderPromptType.NONE else 0)
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
Loading