From 2941d3f46ae7bb18ece919801e5afbb5f2cf0e30 Mon Sep 17 00:00:00 2001 From: Ankit Gupta Date: Fri, 22 Mar 2024 19:22:27 -0400 Subject: [PATCH] Update long_conv_lm.py --- src/models/sequence/long_conv_lm.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/models/sequence/long_conv_lm.py b/src/models/sequence/long_conv_lm.py index 53c8b10..ace9359 100644 --- a/src/models/sequence/long_conv_lm.py +++ b/src/models/sequence/long_conv_lm.py @@ -28,9 +28,9 @@ ColumnParallelLinear = None try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm + from flash_attn.ops.triton.layer_norm import layer_norm_fn except ImportError: - dropout_add_layer_norm = None + layer_norm_fn = None from src.utils import instantiate import src.utils.registry as registry @@ -301,8 +301,8 @@ def __init__( # nn.Dropout probabilities are changed. # This is for performance reason: we can fuse dropout + add + layer_norm. self.fused_dropout_add_ln = fused_dropout_add_ln - if self.fused_dropout_add_ln and dropout_add_layer_norm is None: - raise ImportError("dropout_add_layer_norm is not installed") + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") self.layers = nn.ModuleList( [ @@ -384,15 +384,14 @@ def forward(self, input_ids, position_ids=None, inference_params=None): hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual - hidden_states = dropout_add_layer_norm( + hidden_states = layer_norm_fn( hidden_states, - residual, self.ln_f.weight, self.ln_f.bias, - self.drop_f.p if self.training else 0.0, - self.ln_f.eps, + residual=residual, + eps=self.ln_f.eps, + dropout_p=self.drop_f.p if self.training else 0.0, prenorm=False, - residual_in_fp32=self.residual_in_fp32, ) return hidden_states @@ -687,4 +686,4 @@ def shard_qkv_headdim(state_dict, key): for name in ["kernel.kernel.C", "ssm_k_kernel.kernel.C"]: if f"backbone.layers.{i}.mixer.{name}" in state_dict: shard_dim(state_dict, f"backbone.layers.{i}.mixer.{name}", 1) - return state_dict \ No newline at end of file + return state_dict