Skip to content

Commit

Permalink
avoid un-necessary unwrapping of the model on each train step
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 18, 2024
1 parent 9a94dfe commit 7185dd5
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,12 @@ def __init__(
forward_params = inspect.signature(model_forward).parameters
self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())

if _is_peft_model(unwrapped_model):
self.model_name = unwrapped_model.base_model.model._get_name()
else:
self.model_name = unwrapped_model._get_name()


self.neftune_noise_alpha = args.neftune_noise_alpha

self.compute_metrics = compute_metrics
Expand Down Expand Up @@ -3727,15 +3733,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
self._past = outputs[self.args.past_index]

if labels is not None:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
elif self.model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
Expand Down

0 comments on commit 7185dd5

Please sign in to comment.