From 4a523ad4a50ad6b95f0d52d3d9d95fa2f5374614 Mon Sep 17 00:00:00 2001 From: Sandeep Subramanian Date: Tue, 22 Nov 2022 10:08:01 -0800 Subject: [PATCH] Change to kwargs (#5475) Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy --- .../nlp/modules/common/megatron/transformer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 12e18c853bce..6a0f0e5be6f4 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -2185,13 +2185,13 @@ def custom_forward(*inputs): for index in range(start, end): layer = self._get_layer(index) hidden_states = layer( - hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - rotary_pos_emb, - self_attention_relative_position_bias, - cross_attention_relative_position_bias, + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=cross_attention_relative_position_bias, ) if isinstance(hidden_states, tuple): pass