From 31bff69151606220e9db7ed37603e41b3f2e3230 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 19 Dec 2023 16:52:46 -0800 Subject: [PATCH] Make _prepare_sample non-blocking and use pinned memory for input buffers (#2207) --- vllm/worker/model_runner.py | 55 +++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5623d27df3a36..fb7a0c17d6f9f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,6 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import in_wsl logger = init_logger(__name__) @@ -52,6 +53,8 @@ def __init__( # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() def load_model(self) -> None: self.model = get_model(self.model_config) @@ -203,24 +206,29 @@ def _prepare_decode( # When using CUDA graph, we don't need to make the tensors on the GPU # because they will be eventually copied to the designated GPU buffer. device = "cpu" if use_captured_graph else "cuda" + pin_memory = use_captured_graph and not self.in_wsl input_tokens = _make_tensor_with_pad(input_tokens, max_len=1, pad=0, dtype=torch.long, - device=device) + device=device, + pin_memory=pin_memory) input_positions = _make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, - device=device) + device=device, + pin_memory=pin_memory) slot_mapping = _make_tensor_with_pad(slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, - device=device) + device=device, + pin_memory=pin_memory) context_lens = torch.tensor(context_lens, dtype=torch.int, - device=device) + device=device, + pin_memory=pin_memory) if use_captured_graph: # The shape of graph_block_tables is @@ -229,7 +237,7 @@ def _prepare_decode( for i, block_table in enumerate(block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.from_numpy(input_block_tables).to(device) + block_tables = torch.tensor(input_block_tables, device=device) else: block_tables = _make_tensor_with_pad( block_tables, @@ -297,11 +305,11 @@ def _prepare_sample( categorized_sample_indices_start_idx + num_seqs)) categorized_sample_indices_start_idx += num_seqs - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long, - device="cuda") + selected_token_indices = _async_h2d(selected_token_indices, + dtype=torch.long, + pin_memory=not self.in_wsl) categorized_sample_indices = { - t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() } @@ -334,8 +342,6 @@ def execute_model( else: inputs = self._prepare_decode(seq_group_metadata_list) input_tokens, input_positions, input_metadata = inputs - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - input_metadata.prompt_lens) # Execute the model. if input_metadata.use_cuda_graph: @@ -350,6 +356,9 @@ def execute_model( input_metadata=input_metadata, ) + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + input_metadata.prompt_lens) + # Sample the next token. output = self.model.sample( hidden_states=hidden_states, @@ -502,11 +511,14 @@ def forward( del kv_caches # Copy the input tensors to the input buffers. - self.input_buffers["input_ids"].copy_(input_ids) - self.input_buffers["positions"].copy_(positions) - self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping) - self.input_buffers["context_lens"].copy_(input_metadata.context_lens) - self.input_buffers["block_tables"].copy_(input_metadata.block_tables) + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, + non_blocking=True) + self.input_buffers["context_lens"].copy_(input_metadata.context_lens, + non_blocking=True) + self.input_buffers["block_tables"].copy_(input_metadata.block_tables, + non_blocking=True) # Run the graph. self.graph.replay() @@ -529,9 +541,13 @@ def _make_tensor_with_pad( pad: int, dtype: torch.dtype, device: Union[str, torch.device] = "cuda", + pin_memory: bool = False, ) -> torch.Tensor: padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, dtype=dtype, device=device) + return torch.tensor(padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu") def _get_graph_batch_size(batch_size: int) -> int: @@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int: return 4 else: return (batch_size + 7) // 8 * 8 + + +def _async_h2d(data: list, dtype, pin_memory): + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) + return t.to(device="cuda", non_blocking=True)