Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] expanding support and fix allgather compilation #9637

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,20 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import GPTBigCodeConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -187,6 +188,7 @@ def forward(
return hidden_states


@support_torch_compile
class GPTBigCodeModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gpt_j.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import GPTJConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -174,6 +175,7 @@ def forward(
return hidden_states


@support_torch_compile
class GPTJModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import GPTNeoXConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -187,6 +188,7 @@ def forward(
return hidden_states


@support_torch_compile
class GPTNeoXModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import GraniteConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -254,6 +255,7 @@ def forward(
return hidden_states


@support_torch_compile
class GraniteModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -230,6 +231,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class InternLM2Model(nn.Module):

def __init__(
Expand Down