diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index aa27f0a96c745..aa9b28b676e0b 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,6 +48,18 @@ from vllm.sequence import SamplerOutput +@torch.compile +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + class LayerNorm(nn.Module): def __init__(self, param_shape=None, eps=1e-5): @@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5): set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - - mean) * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.to(torch.float32) * hidden_states - return hidden_states.to(input_dtype), residuals + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank()