diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 24790386d..9df1f5c16 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -687,6 +687,8 @@ def build_layer(layer_number): if args.position_embedding_type == PositionEmbeddingType.alibi: self.alibi = self._build_alibi_tensor(args.seq_length, args.num_attention_heads, args.micro_batch_size).to(torch.cuda.current_device()) if args.params_dtype == torch.float16: + self.alibi = self.alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: self.alibi = self.alibi.to(torch.bfloat16) else: self.alibi = None