Skip to content

Commit

Permalink
register meta functions to the kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
mzusman committed Aug 25, 2024
1 parent 42d9c59 commit 308c922
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,13 +489,41 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
silu_activation)


try:
torch.ops._C.causal_conv1d_fwd # noqa B018

@torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like((x))
except Exception:
pass


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation)


try:
torch.ops._C.causal_conv1d_update # noqa B018

@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like((x))
except Exception:
pass


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
Expand All @@ -507,6 +535,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
x)


try:
torch.ops._C.selective_scan_fwd # noqa B018

@torch.library.register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
return [
torch.empty_like(u),
torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
]
except Exception:
pass


# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
Expand Down

0 comments on commit 308c922

Please sign in to comment.