Skip to content

Commit

Permalink
[torch.compile] directly register custom op (vllm-project#9896)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Richard Liu <[email protected]>
  • Loading branch information
youkaichao authored and richardsliu committed Nov 4, 2024
1 parent 9efb5fe commit 0458bc6
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 67 deletions.
20 changes: 16 additions & 4 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@

import torch
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.utils import direct_register_custom_op

os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

global_counter = 0

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
Expand All @@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out[0] += 1


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@support_torch_compile
class SillyModel(nn.Module):

Expand Down
20 changes: 16 additions & 4 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,41 @@

import torch
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@dataclass
class LlamaConfig:
hidden_size: int = 128
Expand Down
16 changes: 11 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -595,8 +596,6 @@ def forward(
return output


@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -755,8 +754,7 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)


@unified_flash_attention.register_fake
def _(
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -773,3 +771,11 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)


direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)
17 changes: 11 additions & 6 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
get_kv_cache_torch_dtype, make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -785,8 +785,6 @@ def forward(
)


@torch.library.custom_op("vllm::unified_flash_infer",
mutates_args=["kv_cache"])
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -906,8 +904,7 @@ def unified_flash_infer(
return output.view(num_tokens, hidden_size)


@unified_flash_infer.register_fake
def _(
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -924,3 +921,11 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()


direct_register_custom_op(
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)
34 changes: 23 additions & 11 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from vllm.utils import direct_register_custom_op, supports_custom_op


@dataclass
Expand Down Expand Up @@ -99,20 +99,23 @@ def _register_group(group: "GroupCoordinator") -> None:

if supports_custom_op():

@torch.library.custom_op("vllm::inplace_all_reduce",
mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)

@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return

@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
direct_register_custom_op(
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)

def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
Expand All @@ -121,10 +124,17 @@ def outplace_all_reduce(tensor: torch.Tensor,
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)

@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)


class GroupCoordinator:
"""
Expand Down Expand Up @@ -338,6 +348,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.world_size == 1:
return input_

if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_

if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_
Expand Down Expand Up @@ -369,9 +384,6 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)

Expand Down
25 changes: 19 additions & 6 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op


def get_scalar_type(num_bits: int, has_zp: bool):
Expand All @@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
Expand Down Expand Up @@ -119,8 +119,7 @@ def single_marlin_moe(
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


@single_marlin_moe.register_fake
def _(
def single_marlin_moe_fake(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
Expand All @@ -136,7 +135,14 @@ def _(
return torch.empty_like(hidden_states)


@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
direct_register_custom_op(
op_name="single_marlin_moe",
op_func=single_marlin_moe,
mutates_args=[],
fake_impl=single_marlin_moe_fake,
)


def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -324,8 +330,7 @@ def fused_marlin_moe(
dim=1)


@fused_marlin_moe.register_fake
def _(
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -344,3 +349,11 @@ def _(
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake,
)
Loading

0 comments on commit 0458bc6

Please sign in to comment.