Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion for introducing "shift_labels" argument for Trainer #17960

Closed
seungeunrho opened this issue Jun 30, 2022 · 6 comments · Fixed by #17987
Closed

Suggestion for introducing "shift_labels" argument for Trainer #17960

seungeunrho opened this issue Jun 30, 2022 · 6 comments · Fixed by #17987

Comments

@seungeunrho
Copy link
Contributor

seungeunrho commented Jun 30, 2022

Feature request

Add an argument to determine shifting the labels or not.

In TrainingArguments class, an argument named shift_labels should be added.

During training, at here and here, model must check both labels is not None and self.shift_labels is True

e.g.

if labels is not None and self.shift_labels:          # changed
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

Default values for shift_labels is False, except for causal language models such as GPT2PreTrainedModel

Related to gpt2 : @patil-suraj and trainer @sgugger

Motivation

In the current state of the code, the shifting of labels for training GPT2LMHeadModel is changing under the use of label_smoothing, which I assume is unintended.

Specifically, training a GPT2LMHeadModel with args.label_smoothing_factor==0 (which is default), the code shifts the labels and computes the loss inside the model.forward().
This assumes that labels have not been shifted to be properly aligned with corresponding input_ids.

However, if I train GPT2LMHeadModel with args.label_smoothing_factor > 0, then the loss is computed here, inside the compute_loss() function of the Trainer.
This part assumes labels are already shifted, and does not proceed to shift the labels.

I believe whether to shift labels or not should be explicitly determined by its own argument, not by another argument like label_smoothing_factor. In my case, our team was very frustrated that our training results were totally different by only changing the label_smoothing with same given labels and input_ids.
The reason was due to the misalignment of labels and input_ids when turning on the label_smoothing.

Your contribution

I'm willing to make PR after your confirmation.

@sgugger
Copy link
Collaborator

sgugger commented Jun 30, 2022

I don't think a new TrainingArgument is the right answer here. Some models shift the labels internally, I think it's all the models for causal LM (not jsut GPT-2), so I think instead of a flag, there should be a check when the loss is computed by the Trainer for label smoothing to see if the model class name is inside the MODEL_FOR_CAUSAL_LM_MAPPING_NAMES (to import from the auto module) and then shift the labels.

Let me know if you'd like to proceed with a PR for this fix!

@seungeunrho
Copy link
Contributor Author

Thanks for quick reply. Your approach seems plausible and I'd like to proceed it.
I've read the document for contribution guide thoroughly. Can I just start now? or is there anything I should know before begin?

@sgugger
Copy link
Collaborator

sgugger commented Jun 30, 2022

You can start, good luck! :-)

seungeunrho added a commit to seungeunrho/transformers that referenced this issue Jul 1, 2022
When training CausalLM, loss is computed within model's foward() function and
labels are shifted internally. However, if label smoothing is applied, loss is
computed in trainer's compute_loss function and labels are not shifted.
This causes unintended confusion during the alignment of labels and corresponding
inputs. This commit is for resolving this confusion.

Resolves huggingface#17960

On branch shift_labels_for_causalLM
Changes to be committed:
	modified:   src/transformers/trainer.py
	modified:   src/transformers/trainer_pt_utils.py
sgugger added a commit that referenced this issue Jul 1, 2022
* Shifting labels for causal LM when using label smoother

When training CausalLM, loss is computed within model's foward() function and
labels are shifted internally. However, if label smoothing is applied, loss is
computed in trainer's compute_loss function and labels are not shifted.
This causes unintended confusion during the alignment of labels and corresponding
inputs. This commit is for resolving this confusion.

Resolves #17960

On branch shift_labels_for_causalLM
Changes to be committed:
	modified:   src/transformers/trainer.py
	modified:   src/transformers/trainer_pt_utils.py

* Update trainer.py

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
viclzhu pushed a commit to viclzhu/transformers that referenced this issue Jul 18, 2022
…17987)

* Shifting labels for causal LM when using label smoother

When training CausalLM, loss is computed within model's foward() function and
labels are shifted internally. However, if label smoothing is applied, loss is
computed in trainer's compute_loss function and labels are not shifted.
This causes unintended confusion during the alignment of labels and corresponding
inputs. This commit is for resolving this confusion.

Resolves huggingface#17960

On branch shift_labels_for_causalLM
Changes to be committed:
	modified:   src/transformers/trainer.py
	modified:   src/transformers/trainer_pt_utils.py

* Update trainer.py

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
@Junpliu
Copy link

Junpliu commented Dec 12, 2023

Feature request

Add an argument to determine shifting the labels or not.

In TrainingArguments class, an argument named shift_labels should be added.

During training, at here and here, model must check both labels is not None and self.shift_labels is True

e.g.

if labels is not None and self.shift_labels:          # changed
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

Default values for shift_labels is False, except for causal language models such as GPT2PreTrainedModel

Related to gpt2 : @patil-suraj and trainer @sgugger

Motivation

In the current state of the code, the shifting of labels for training GPT2LMHeadModel is changing under the use of label_smoothing, which I assume is unintended.

Specifically, training a GPT2LMHeadModel with args.label_smoothing_factor==0 (which is default), the code shifts the labels and computes the loss inside the model.forward(). This assumes that labels have not been shifted to be properly aligned with corresponding input_ids.

However, if I train GPT2LMHeadModel with args.label_smoothing_factor > 0, then the loss is computed here, inside the compute_loss() function of the Trainer. This part assumes labels are already shifted, and does not proceed to shift the labels.

I believe whether to shift labels or not should be explicitly determined by its own argument, not by another argument like label_smoothing_factor. In my case, our team was very frustrated that our training results were totally different by only changing the label_smoothing with same given labels and input_ids. The reason was due to the misalignment of labels and input_ids when turning on the label_smoothing.

Your contribution

I'm willing to make PR after your confirmation.

I want to know more about how the prediction text looks like under the label-smoothing case before the bug-fix. Does the model learn an indentity transformation and always predict the last input token repeatedly? I am curious about this.

@zzaebok
Copy link

zzaebok commented Aug 9, 2024

@sgugger May I ask how to shift labels for the custom CausalLM model?
Let's assume I made a CustomCausalLM model, which is not mapped in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.
As other models, my CausalLM model (which is modeled based on modeling_llama.py) shifts labels in forward function.
When I use hf Trainer, my CaausalLM, and label smoother, how to shift labels in this case?

The current Trainer code is

if labels is not None:
    if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
        loss = self.label_smoother(outputs, labels, shift_labels=True)
    else:
        loss = self.label_smoother(outputs, labels)

I think the workaround is to add shift_labels=True in else statement.
Is there a correct, or better way to shift labels for my custom causal lm model when using label smoother??

@ArthurZucker
Copy link
Collaborator

By default, if you copy the end of the forward pass of the LlamaForConditionalGeneration you will see that there is a part where the labels are shifted!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants