diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index c0654712b71b5..27cdf5f339ede 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_decode(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -288,6 +289,7 @@ def test_prepare_decode(batch_size): Arguments: * batch_size + * multiple_seqs_per_seq_group * backend_name: The attention backend under test * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) ''' @@ -305,22 +307,29 @@ def test_prepare_decode(batch_size): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -328,6 +337,10 @@ def test_prepare_decode(batch_size): ) assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) # Build # * Decoder model inputs @@ -398,19 +411,24 @@ def test_prepare_decode(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention - expected = torch.tensor( - [block_tables[0] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + flattened_block_tables = [ + block_table for block_table in block_tables.values() + ] + expected = torch.tensor(flattened_block_tables * + len(seq_group_metadata_list), + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention - expected = torch.tensor( - [cross_block_table for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + expected = torch.tensor([ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ], + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.cross_block_tables, expected, @@ -474,7 +492,8 @@ def test_prepare_decode(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): """ Tests that for encoder-decoder models with CUDA Graph capture and replay enabled, the tensors used during the decode phase are correctly padded @@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, enforce_eager=False, ) - + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + cross_block_table = [2] + expanded_batch_size = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == 1 + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) + expanded_batch_size = expanded_batch_size + len( + seq_group_metadata.seq_data) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) @@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(batch_size) - cuda_graph_pad_size = graph_batch_size - batch_size + graph_batch_size = _get_graph_batch_size(expanded_batch_size) + cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( itertools.repeat(1, cuda_graph_pad_size)) @@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention. Pad the block tables as expected. - expected = [block_tables[0] for _ in range(batch_size)] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) + flattened_block_tables = [ + block_table for _ in range(len(seq_group_metadata_list)) + for block_table in block_tables.values() + ] + flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( - expected, + flattened_block_tables, max_len=64, pad=0, dtype=torch.int32, @@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size): ) # - Encoder/decoder cross-attention. Pad the cross-attention block tables # as expected. - expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected = [ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ] expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( expected, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 09dab0135f390..709efdc8b9d57 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -435,18 +435,18 @@ def _prepare_encoder_model_input_tensors( encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & # seq len from each sequence group metadata. # Cross-attention block tables are empty # during vLLM memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) if (model_input.attn_metadata is not None and model_input.attn_metadata.use_cuda_graph):