Skip to content

Commit

Permalink
don't pass ga loss kwargs to flash_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 6, 2024
1 parent 56811fd commit c605d68
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,9 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch")

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand All @@ -1180,7 +1183,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit c605d68

Please sign in to comment.