-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suggestion for introducing "shift_labels" argument for Trainer #17960
Comments
I don't think a new TrainingArgument is the right answer here. Some models shift the labels internally, I think it's all the models for causal LM (not jsut GPT-2), so I think instead of a flag, there should be a check when the loss is computed by the Let me know if you'd like to proceed with a PR for this fix! |
Thanks for quick reply. Your approach seems plausible and I'd like to proceed it. |
You can start, good luck! :-) |
When training CausalLM, loss is computed within model's foward() function and labels are shifted internally. However, if label smoothing is applied, loss is computed in trainer's compute_loss function and labels are not shifted. This causes unintended confusion during the alignment of labels and corresponding inputs. This commit is for resolving this confusion. Resolves huggingface#17960 On branch shift_labels_for_causalLM Changes to be committed: modified: src/transformers/trainer.py modified: src/transformers/trainer_pt_utils.py
* Shifting labels for causal LM when using label smoother When training CausalLM, loss is computed within model's foward() function and labels are shifted internally. However, if label smoothing is applied, loss is computed in trainer's compute_loss function and labels are not shifted. This causes unintended confusion during the alignment of labels and corresponding inputs. This commit is for resolving this confusion. Resolves #17960 On branch shift_labels_for_causalLM Changes to be committed: modified: src/transformers/trainer.py modified: src/transformers/trainer_pt_utils.py * Update trainer.py * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
…17987) * Shifting labels for causal LM when using label smoother When training CausalLM, loss is computed within model's foward() function and labels are shifted internally. However, if label smoothing is applied, loss is computed in trainer's compute_loss function and labels are not shifted. This causes unintended confusion during the alignment of labels and corresponding inputs. This commit is for resolving this confusion. Resolves huggingface#17960 On branch shift_labels_for_causalLM Changes to be committed: modified: src/transformers/trainer.py modified: src/transformers/trainer_pt_utils.py * Update trainer.py * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
I want to know more about how the prediction text looks like under the label-smoothing case before the bug-fix. Does the model learn an indentity transformation and always predict the last input token repeatedly? I am curious about this. |
@sgugger May I ask how to shift labels for the custom CausalLM model? The current Trainer code is if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels) I think the workaround is to add shift_labels=True in else statement. |
By default, if you copy the end of the forward pass of the |
Feature request
Add an argument to determine shifting the
labels
or not.In TrainingArguments class, an argument named
shift_labels
should be added.During training, at here and here,
model
must check bothlabels is not None
andself.shift_labels is True
e.g.
Default values for
shift_labels
isFalse
, except for causal language models such asGPT2PreTrainedModel
Related to gpt2 : @patil-suraj and trainer @sgugger
Motivation
In the current state of the code, the shifting of
labels
for training GPT2LMHeadModel is changing under the use oflabel_smoothing
, which I assume is unintended.Specifically, training a GPT2LMHeadModel with
args.label_smoothing_factor==0
(which is default), the code shifts thelabels
and computes the loss inside themodel.forward()
.This assumes that
labels
have not been shifted to be properly aligned with correspondinginput_ids
.However, if I train GPT2LMHeadModel with
args.label_smoothing_factor > 0
, then the loss is computed here, inside thecompute_loss()
function of theTrainer
.This part assumes
labels
are already shifted, and does not proceed to shift the labels.I believe whether to shift
labels
or not should be explicitly determined by its own argument, not by another argument likelabel_smoothing_factor
. In my case, our team was very frustrated that our training results were totally different by only changing thelabel_smoothing
with same givenlabels
andinput_ids
.The reason was due to the misalignment of
labels
andinput_ids
when turning on thelabel_smoothing
.Your contribution
I'm willing to make PR after your confirmation.
The text was updated successfully, but these errors were encountered: