Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] directly register custom op #9896

Merged
merged 25 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/compile/piecewise/piecewise_compilation_config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"]
"non_cudagraph_ops": ["vllm.toy_attention"]
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
}
20 changes: 14 additions & 6 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.utils import direct_register_custom_op

os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

global_counter = 0


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
Expand All @@ -27,12 +27,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out[0] += 1


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
library_name="vllm",
op_name="toy_attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
)


@support_torch_compile
class SillyModel(nn.Module):

Expand All @@ -49,12 +57,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
torch.ops.vllm.toy_attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
torch.ops.vllm.toy_attention(x, x, x, out)
x = out
x = x + 1
return x
Expand Down
20 changes: 14 additions & 6 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,30 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op


@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
library_name="vllm",
op_name="toy_attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
)


@dataclass
class LlamaConfig:
hidden_size: int = 128
Expand Down Expand Up @@ -95,7 +103,7 @@ def forward(
k = k + positions.unsqueeze(1)

attn_output = torch.empty_like(q)
torch.ops.silly.attention(q, k, v, attn_output)
torch.ops.vllm.toy_attention(q, k, v, attn_output)

output = self.output_projection(attn_output)
return output
Expand Down Expand Up @@ -171,7 +179,7 @@ def run_model(llama_config,
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
non_cudagraph_ops=["vllm.toy_attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
Expand Down
17 changes: 12 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -595,8 +596,6 @@ def forward(
return output


@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -755,8 +754,7 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)


@unified_flash_attention.register_fake
def _(
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -773,3 +771,12 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)


direct_register_custom_op(
library_name="vllm",
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)
18 changes: 12 additions & 6 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
get_kv_cache_torch_dtype, make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -785,8 +785,6 @@ def forward(
)


@torch.library.custom_op("vllm::unified_flash_infer",
mutates_args=["kv_cache"])
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -906,8 +904,7 @@ def unified_flash_infer(
return output.view(num_tokens, hidden_size)


@unified_flash_infer.register_fake
def _(
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -924,3 +921,12 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()


direct_register_custom_op(
library_name="vllm",
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)
36 changes: 25 additions & 11 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from vllm.utils import direct_register_custom_op, supports_custom_op


@dataclass
Expand Down Expand Up @@ -99,20 +99,24 @@ def _register_group(group: "GroupCoordinator") -> None:

if supports_custom_op():

@torch.library.custom_op("vllm::inplace_all_reduce",
mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)

@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return

@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
direct_register_custom_op(
library_name="vllm",
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)

def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
Expand All @@ -121,10 +125,18 @@ def outplace_all_reduce(tensor: torch.Tensor,
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)

@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

direct_register_custom_op(
library_name="vllm",
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)


class GroupCoordinator:
"""
Expand Down Expand Up @@ -338,6 +350,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.world_size == 1:
return input_

if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_

if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_
Expand Down Expand Up @@ -369,9 +386,6 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)

Expand Down
27 changes: 21 additions & 6 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op


def get_scalar_type(num_bits: int, has_zp: bool):
Expand All @@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
Expand Down Expand Up @@ -119,8 +119,7 @@ def single_marlin_moe(
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


@single_marlin_moe.register_fake
def _(
def single_marlin_moe_fake(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
Expand All @@ -136,7 +135,15 @@ def _(
return torch.empty_like(hidden_states)


@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
direct_register_custom_op(
library_name="vllm",
op_name="single_marlin_moe",
op_func=single_marlin_moe,
mutates_args=[],
fake_impl=single_marlin_moe_fake,
)


def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -324,8 +331,7 @@ def fused_marlin_moe(
dim=1)


@fused_marlin_moe.register_fake
def _(
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -344,3 +350,12 @@ def _(
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


direct_register_custom_op(
library_name="vllm",
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake,
)
Loading