Skip to content

Commit

Permalink
[torch.compile] avoid Dynamo guard evaluation overhead (vllm-project#…
Browse files Browse the repository at this point in the history
…7898)

Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
youkaichao and WoosukKwon authored Aug 28, 2024
1 parent 3cdfe1f commit ce6bf3a
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ steps:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_wrapper.py


- label: Vision Language Models Test # 42min
Expand Down
59 changes: 59 additions & 0 deletions tests/compile/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional

import torch

from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther


class MyMod(torch.nn.Module):

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
if cache is not None:
return x + cache
return x * 2


class MyWrapper(TorchCompileWrapperWithCustomDispacther):

def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
return self.model(x, cache)

def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# let torch.compile compile twice
if len(self.compiled_codes) == 2:
dispatch_id = 0 if cache is None else 1
with self.dispatch_to_code(dispatch_id):
return self.forward(x, cache)
else:
return self.compiled_callable(x, cache)


def test_torch_compile_wrapper():
mod = MyMod()
wrappers = []
for i in range(3):
torch._dynamo.reset()
wrapper = MyWrapper(mod)
wrappers.append(wrapper)
x = torch.tensor([1])
wrapper(x, None) # profile run, compile
# create a cache tensor
cache = torch.tensor([2])
wrapper(x, cache) # warm up with cache, recompile

# for new input, dispatch to the compiled code directly
new_x = torch.tensor([3])
assert wrapper(new_x,
None).item() == 6 # dispatch to the first compiled code
assert wrapper(
new_x, cache).item() == 5 # dispatch to the second compiled code

for wrapper in wrappers:
# make sure they have independent compiled codes
assert len(wrapper.compiled_codes) == 2
Empty file added tests/tpu/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ..utils import compare_two_settings


def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
env2={})
Empty file added vllm/compilation/__init__.py
Empty file.
81 changes: 81 additions & 0 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List

import torch

import vllm.envs as envs


class TorchCompileWrapperWithCustomDispacther:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
1. Implement the forward method
2. Implement the dispatch logic in the __call__ method
It can use `self.compiled_codes` to access the compiled bytecode,
and `with self.dispatch_to_code(index):` to dispatch to
the compiled code.
3. Implement the `__init__` method to determine how to call
`torch.compile` over the forward method.
"""

def __init__(self, compiled_callable: Callable):
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

# read the env var to determine whether to use the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER

def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)

@abstractmethod
def forward(self, *args, **kwargs):
...

def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while True:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code

if frame.f_locals["self"] is not self:
return

self.compiled_codes.append(new_code)

@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object
4 changes: 4 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def get_default_config_root():
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),

# local rank of the process in the distributed setting, used to determine
# the GPU device id
Expand Down
45 changes: 35 additions & 10 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch_xla.runtime as xr

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
Expand Down Expand Up @@ -144,11 +145,7 @@ def load_model(self) -> None:
)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
self.model = ModelWrapper(model)

def _dummy_run(
self,
Expand Down Expand Up @@ -235,8 +232,15 @@ def _dummy_run(
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
self.model(token_ids,
position_ids,
attn_metadata,
input_lens,
t,
p,
num_samples,
kv_caches,
is_prompt=is_prompt)

def warmup_model(
self,
Expand Down Expand Up @@ -530,7 +534,7 @@ def _execute_model(*args):
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg)
return self.model(*new_args)
return self.model(*new_args, is_prompt=is_prompt)

num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0
Expand Down Expand Up @@ -601,11 +605,32 @@ def _execute_model(*args):
return [SamplerOutput(sampler_outputs)]


class ModelWrapper(nn.Module):
class ModelWrapper(TorchCompileWrapperWithCustomDispacther):

def __init__(self, model: nn.Module):
super().__init__()
self.model = model
compiled_callable = torch.compile(self.forward,
backend="openxla",
fullgraph=True,
dynamic=False)
super().__init__(compiled_callable)

def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return self.compiled_callable(*args, **kwargs)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if is_prompt:
with self.dispatch_to_code(1):
return self.forward(*args, **kwargs)
else:
with self.dispatch_to_code(2):
return self.forward(*args, **kwargs)

def forward(
self,
Expand Down

0 comments on commit ce6bf3a

Please sign in to comment.