Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning doesn't initialize microsoft/resnet classifier weights with _fast_init #31841

Closed
2 of 4 tasks
williford opened this issue Jul 8, 2024 · 5 comments · Fixed by #31851
Closed
2 of 4 tasks
Assignees

Comments

@williford
Copy link

williford commented Jul 8, 2024

System Info

It seems that the changes with #11471 broke fine-tuning of ResNet
(when the number of classes is being changed).

It seems like most models handle this by adding Linear to the following:

def _init_weights(self, module):

However, it seems like it would be better to handle it when the mismatch size is detected in modeling_utils.py:

mismatched_keys += _find_mismatched_keys(

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

E.g.

> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10, ignore_mismatched_sizes=True).classifier[1].
weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(1.7014e+38, grad_fn=<MaxBackward1>)

# Sometimes the same command gives NaN:
> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10, ignore_mismatched_sizes=True).classifier[1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(nan, grad_fn=<MaxBackward1>)


# no change in the number of labels
> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=1000, ignore_mismatched_sizes=True).classifier[1
].weight.absolute().max()
tensor(4.7245, grad_fn=<MaxBackward1>)

# change weights
> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=1001, ignore_mismatched_sizes=True ).classifier[1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1001]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([1001, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(1.8520e-40, grad_fn=<MaxBackward1>)

> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10000, ignore_mismatched_sizes=True).classifier[
1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10000]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10000, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(0., grad_fn=<MaxBackward1>)

Disabling the _fast_init fixes the issue:

> AutoModelForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=10000, ignore_mismatched_sizes=True, _fast_init=False).classifier[1].weight.absolute().max()
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10000]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10000, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tensor(0.0221, grad_fn=<MaxBackward1>)

Expected behavior

The statistics of the initialized weights should be similar with and without the _fast_init - importantly, it shouldn't contain NaN's and the maximum absolute values shouldn't be 0 or really large (e.g. > 1e20).

@NielsRogge
Copy link
Contributor

cc @ydshieh who worked on a similar issue which was fixed by #28122

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 8, 2024

Hi @williford

Could you share your system info with us? You can run the command transformers-cli env and copy-paste its output below.

@williford
Copy link
Author

For the reproduction I installed transformers with pip install git+https://github.com/huggingface/transformers:

  • transformers version: 4.43.0.dev0
  • Platform: Linux
  • Python version: 3.12.4
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1.post300 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: no
  • GPU type: NVIDIA GeForce RTX 3090

@ydshieh ydshieh self-assigned this Jul 8, 2024
@williford
Copy link
Author

williford commented Jul 8, 2024

@ydshieh If I'm understanding the code correctly, your change makes sure the model._initialize_weights is called. ResNetForImageClassification inherits from ResNetPreTrainedModel, which overloads _init_weights. However, ResNetPreTrainedModel doesn't do anything when the module is a torch.nn.module.linear.Linear.

When fast_init is not set, then the Linear module initializes the weights via the "reset_parameters" method.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 9, 2024

@williford Thank you for diving into this issue. Yes, you are correct! I opened a PR to fix it and it works now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants