Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 9, 2024
1 parent 78c4ff8 commit bda8e68
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 53 deletions.
23 changes: 11 additions & 12 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from dataclasses import dataclass

import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import torch

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
Expand All @@ -31,7 +31,6 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)

from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)
Expand Down Expand Up @@ -411,15 +410,15 @@ def _get_cached_sampled_token_ids_for_multi_step(
def _cache_output_for_multi_step(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1):
if len(output) > 0 and output[0] is not None:
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_numpy is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_numpy is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
Expand Down
11 changes: 7 additions & 4 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union, cast)

import torch
import numpy
import torch

from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -1147,21 +1147,24 @@ class ExecuteModelRequest:

@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of steps
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
return first_seq_group.state.current_step == 0

@property
def is_last_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of steps
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
return first_seq_group.state.remaining_steps == 1

@property
def current_step(self) -> int:
# TODO(will) make this be able to handle batches with variable number of steps
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step

Expand Down
53 changes: 30 additions & 23 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import dataclasses
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)

from ..model_executor.model_loader.tensorizer import TensorizerConfig
import torch

from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
BroadcastableModelInput, _init_frozen_model_input_from_tensor_dict,
_init_attn_metadata_from_tensor_dict,
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
_init_frozen_model_input_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
GPUModelRunnerBase)
from vllm.logger import init_logger
from vllm.distributed import get_pp_group
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata, SequenceOutput,
CompletionSequenceGroupOutput, Logprob)
from vllm import _custom_ops as ops

import torch
from ..model_executor.model_loader.tensorizer import TensorizerConfig

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
Expand All @@ -43,7 +44,8 @@ class ModelOutput:
There are two scenarios:
1. The output tensors are ready and we can pythonize them immediately.
2. The output tensors are not ready and we need to wait for the event to be ready.
2. The output tensors are not ready and we need to wait for the event to be
ready.
"""
sampler_output: SamplerOutput
sampler_output_ready_event: torch.cuda.Event
Expand Down Expand Up @@ -217,10 +219,11 @@ class MultiStepModelRunner(MultiStepModelRunnerBase):
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]
) -> MutableModelInputForGPUWithMultiStepMetadata:
model_input = MutableModelInputForGPUWithMultiStepMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
model_input = (MutableModelInputForGPUWithMultiStepMetadata.
from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
return model_input

def prepare_model_input(
Expand Down Expand Up @@ -271,9 +274,11 @@ def execute_model(
device="cpu",
pin_memory=True)

self._base_model_runner.model.sampler.include_gpu_probs_tensor = True
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
True)
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = True
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True)
# TODO(will) Will need to benchmark and look at torch profiler for
# the exact location we should do this. If the CPU is very ahead, it
# does not matter if we call this before executable or after, as the
Expand All @@ -296,7 +301,8 @@ def execute_model(
# changing batch sizes, will remove afterwards and potentially leave
# comment for future optimization
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.reuse_sampling_tensors = False
frozen_model_input.sampling_metadata.reuse_sampling_tensors = (
False)
else:
# This is not needed for flashattn backend, but for other attn
# backends such as flashinfer that performs we may need to
Expand All @@ -309,7 +315,8 @@ def execute_model(
# changing batch sizes, will remove afterwards and potentially leave
# comment for future optimization
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.reuse_sampling_tensors = False
frozen_model_input.sampling_metadata.reuse_sampling_tensors = (
False)

# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
Expand Down
21 changes: 10 additions & 11 deletions vllm/worker/multi_step_worker.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from vllm.worker.worker import Worker
from dataclasses import dataclass
from vllm.worker.worker import WorkerInput
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from typing import Tuple, Optional, List
from dataclasses import field
from typing import List, Optional, Tuple

from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (
MutableModelInputForGPUWithMultiStepMetadata)
from vllm.worker.worker import Worker, WorkerInput


@dataclass
Expand Down Expand Up @@ -70,8 +68,8 @@ def _get_driver_input_and_broadcast(
# otherwise we need to get the cached sampled token ids from the
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = execute_model_req.last_sampled_token_ids.cuda(
)
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.cuda())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
Expand Down Expand Up @@ -143,8 +141,9 @@ def prepare_input(

assert isinstance(
model_input, MutableModelInputForGPUWithMultiStepMetadata)
# we need to update the last sampled token ids in the model input
# for the workers so that they can run inplace advance_step
# we need to update the last sampled token ids in the model
# input for the workers so that they can run inplace
# advance_step
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
Expand Down
6 changes: 3 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
BroadcastableModelInput)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)

logger = init_logger(__name__)

Expand Down

0 comments on commit bda8e68

Please sign in to comment.