Skip to content

Commit

Permalink
Make _prepare_sample non-blocking and use pinned memory for input buf…
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored Dec 20, 2023
1 parent ba4f826 commit 31bff69
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import in_wsl

logger = init_logger(__name__)

Expand Down Expand Up @@ -52,6 +53,8 @@ def __init__(
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()

def load_model(self) -> None:
self.model = get_model(self.model_config)
Expand Down Expand Up @@ -203,24 +206,29 @@ def _prepare_decode(
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
pin_memory = use_captured_graph and not self.in_wsl
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=device)
device=device,
pin_memory=pin_memory)

if use_captured_graph:
# The shape of graph_block_tables is
Expand All @@ -229,7 +237,7 @@ def _prepare_decode(
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(device)
block_tables = torch.tensor(input_block_tables, device=device)
else:
block_tables = _make_tensor_with_pad(
block_tables,
Expand Down Expand Up @@ -297,11 +305,11 @@ def _prepare_sample(
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs

selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items()
}

Expand Down Expand Up @@ -334,8 +342,6 @@ def execute_model(
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)

# Execute the model.
if input_metadata.use_cuda_graph:
Expand All @@ -350,6 +356,9 @@ def execute_model(
input_metadata=input_metadata,
)

sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)

# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
Expand Down Expand Up @@ -502,11 +511,14 @@ def forward(
del kv_caches

# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids)
self.input_buffers["positions"].copy_(positions)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
non_blocking=True)

# Run the graph.
self.graph.replay()
Expand All @@ -529,9 +541,13 @@ def _make_tensor_with_pad(
pad: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
pin_memory: bool = False,
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
return torch.tensor(padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")


def _get_graph_batch_size(batch_size: int) -> int:
Expand All @@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
return 4
else:
return (batch_size + 7) // 8 * 8


def _async_h2d(data: list, dtype, pin_memory):
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
return t.to(device="cuda", non_blocking=True)

0 comments on commit 31bff69

Please sign in to comment.