Skip to content

Commit

Permalink
fix transformers>=4.46 loss (#2365)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Oct 31, 2024
1 parent 59f58c3 commit 44302ab
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def prediction_step(

return loss, generated_tokens, labels

def compute_loss(self, model, inputs, return_outputs=None, **kwargs):
def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None):
if not hasattr(self, '_custom_metrics'):
self._custom_metrics = {}

Expand All @@ -159,6 +159,9 @@ def compute_loss(self, model, inputs, return_outputs=None, **kwargs):

loss_kwargs['labels'] = labels
outputs = model(**inputs)
# fix https://github.com/huggingface/transformers/issues/34263
if outputs.loss is not None and num_items_in_batch is not None:
outputs.loss = outputs.loss * (inputs['labels'][:, 1:] != -100).sum() / num_items_in_batch
if loss_name is not None:
loss_func = get_loss_func(loss_name)
outputs['loss'] = loss_func(outputs, **loss_kwargs)
Expand Down

0 comments on commit 44302ab

Please sign in to comment.