Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make _prepare_sample non blocking and pin memory of CPU input buffers #2207

Merged
merged 9 commits into from
Dec 20, 2023
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import in_wsl

logger = init_logger(__name__)

Expand Down Expand Up @@ -52,6 +53,8 @@ def __init__(
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()

def load_model(self) -> None:
self.model = get_model(self.model_config)
Expand Down Expand Up @@ -203,24 +206,29 @@ def _prepare_decode(
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
pin_memory = use_captured_graph and not self.in_wsl
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=device)
device=device,
pin_memory=pin_memory)

if use_captured_graph:
# The shape of graph_block_tables is
Expand All @@ -229,7 +237,7 @@ def _prepare_decode(
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(device)
block_tables = torch.tensor(input_block_tables, device=device)
else:
block_tables = _make_tensor_with_pad(
block_tables,
Expand Down Expand Up @@ -297,11 +305,11 @@ def _prepare_sample(
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs

selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
selected_token_indices = async_h2d(selected_token_indices,
dtype=torch.long,
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
t: async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items()
}

Expand Down Expand Up @@ -334,8 +342,6 @@ def execute_model(
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)

# Execute the model.
if input_metadata.use_cuda_graph:
Expand All @@ -350,6 +356,9 @@ def execute_model(
input_metadata=input_metadata,
)

sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)

# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
Expand Down Expand Up @@ -502,11 +511,14 @@ def forward(
del kv_caches

# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids)
self.input_buffers["positions"].copy_(positions)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
non_blocking=True)

# Run the graph.
self.graph.replay()
Expand All @@ -529,9 +541,13 @@ def _make_tensor_with_pad(
pad: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
pin_memory=False,
hanzhi713 marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
return torch.tensor(padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")


def _get_graph_batch_size(batch_size: int) -> int:
Expand All @@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
return 4
else:
return (batch_size + 7) // 8 * 8


def async_h2d(data: list, dtype, pin_memory):
hanzhi713 marked this conversation as resolved.
Show resolved Hide resolved
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
return t.to(device="cuda", non_blocking=True)
Loading