Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Core][Model] torch.compile for layernorm in commandr (vllm-project#3985
Browse files Browse the repository at this point in the history
)

[Core][Model] Use torch.compile to accelerate layernorm in commandr (vllm-project#3985)
  • Loading branch information
youkaichao authored and SageMoore committed Apr 11, 2024
1 parent 09e9e8a commit 782c1cf
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@
from vllm.sequence import SamplerOutput


@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)


class LayerNorm(nn.Module):

def __init__(self, param_shape=None, eps=1e-5):
Expand All @@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5):
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})

def forward(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype), residuals
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
Expand Down

0 comments on commit 782c1cf

Please sign in to comment.