Skip to content

Commit

Permalink
[torch.compile] auto infer dynamic_arg_dims from type annotation (vll…
Browse files Browse the repository at this point in the history
…m-project#9589)

Signed-off-by: Erkin Sagiroglu <[email protected]>
  • Loading branch information
youkaichao authored and Erkin Sagiroglu committed Oct 26, 2024
1 parent 6a709e3 commit 6b9d4ac
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
68 changes: 63 additions & 5 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,58 @@
import inspect
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import torch

import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo

logger = init_logger(__name__)

def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):

def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
"""
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
Depending on the value of arguments:
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
Expand All @@ -38,11 +72,35 @@ def cls_decorator_helper(cls: type):
if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims:
inferred_dynamic_arg_dims = dynamic_arg_dims
if inferred_dynamic_arg_dims is None:
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [
torch.Tensor, Optional[torch.Tensor],
IntermediateTensors, Optional[IntermediateTensors]
]:
inferred_dynamic_arg_dims[k] = 0

logger.debug(("Inferred dynamic dimensions for "
"forward method of %s: %s"), cls,
list(inferred_dynamic_arg_dims.keys()))

if len(inferred_dynamic_arg_dims) == 0:
raise ValueError(
"No dynamic dimensions found in the forward method of "
f"{cls}. Please provide dynamic_arg_dims explicitly.")

for k in inferred_dynamic_arg_dims:
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, dynamic_arg_dims)
return _support_torch_compile(cls, inferred_dynamic_arg_dims)

if cls is not None:
# use `support_torch_compile` as a decorator without arguments
assert isinstance(cls, type)
return cls_decorator_helper(cls)

return cls_decorator_helper

Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,7 @@ def forward(
return hidden_states, residual


@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class Gemma2Model(nn.Module):

def __init__(
Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,7 @@ def forward(
return hidden_states, residual


@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class LlamaModel(nn.Module):

def __init__(
Expand Down

0 comments on commit 6b9d4ac

Please sign in to comment.