Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safetensors serialization by default #27064

Merged
merged 18 commits into from
Oct 31, 2023
Merged

Safetensors serialization by default #27064

merged 18 commits into from
Oct 31, 2023

Conversation

LysandreJik
Copy link
Member

@LysandreJik LysandreJik commented Oct 25, 2023

This PR aims to do one thing but is larger than expected. I'm happy to break it down into smaller PRs if it helps for reviewing.

This PR aims to switch safe serialization to True by default for torch models. In doing so, it revealed a few bugs in the existing implementation and safetensors support that this PR fixes.

Additionally, support for safetensors for Flax models is added so that models saved from PyTorch after merging this PR can be used in both TensorFlow and Flax, and for models saved from TensorFlow/Flax to be loaded in PyTorch models.

The following should be worked on shortly to enable switching to safetensors by default for TensorFlow and Flax as well:

  • There is no support for sharded weights in TensorFlow
  • There is no support for sharded weights in Flax

Additionally, I'll contribute some documentation making the following clear:

  • TensorFlow models can load models in safetensors saved from PyTorch and TensorFlow, but it cannot load them from Flax. This can be eventually worked on; meanwhile, I'll write this in the docs with workarounds to get models saved in Flax to work in TensorFlow for those interested.
  • Same, but for Flax models loaded from TensorFlow

Thanks, @Rocketknight1, for the help on TensorFlow's side.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@LysandreJik LysandreJik force-pushed the safetensors-by-default branch from 4398da0 to e4ebba7 Compare October 25, 2023 14:07
@LysandreJik LysandreJik force-pushed the safetensors-by-default branch from f86a14f to 66a896a Compare October 30, 2023 10:45
Comment on lines +3060 to +3081
if (
is_safetensors_available()
and isinstance(resolved_archive_file, str)
and resolved_archive_file.endswith(".safetensors")
):
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()

if metadata.get("format") == "pt":
pass
elif metadata.get("format") == "tf":
from_tf = True
logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
else:
raise ValueError(
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
)

from_pt = not (from_tf | from_flax)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary to enable loading safetensors files saved from TensorFlow/Jax into PyTorch models

@LysandreJik LysandreJik marked this pull request as ready for review October 30, 2023 14:50
Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice !

Not that bad for such a big change.

For the testing part I see loading in PT from PT/TF/Flax, but not the other ways

TF -TF
TF - Flax
TF - Pt
Flax - Flax
Flax - TF
Flax - Pt.

From you initial comment I understand it's not possible, but it's not entirely clear for me as to why (you mention sharded weights, is it the only restriction? If yes, from what I read it should be okay-ish to be able to at least load for those, no ?)

tests/models/auto/test_modeling_tf_auto.py Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @LysandreJik!

src/transformers/modeling_flax_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Show resolved Hide resolved
tests/test_modeling_flax_utils.py Outdated Show resolved Hide resolved
tests/test_modeling_flax_utils.py Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
@LysandreJik
Copy link
Member Author

@Narsil, this is what is currently supported and not supported:

  • TF -TF - Supported, tested here:
    def test_safetensors_tf_from_tf(self):
    model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
    with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, safe_serialization=True)
    new_model = TFBertModel.from_pretrained(tmp_dir)
    for p1, p2 in zip(model.weights, new_model.weights):
    self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
  • TF - Flax - Not supported
  • TF - Pt - Supported, tested here:
    @require_safetensors
    @is_pt_tf_cross_test
    def test_safetensors_tf_from_torch(self):
    hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
    model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
    with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, safe_serialization=True)
    new_model = TFBertModel.from_pretrained(tmp_dir)
    for p1, p2 in zip(hub_model.weights, new_model.weights):
    self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
  • Flax - Flax - Supported, tested here:
    @require_safetensors
    def test_safetensors_flax_from_flax(self):
    model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
    with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, safe_serialization=True)
    new_model = FlaxBertModel.from_pretrained(tmp_dir)
    self.assertTrue(check_models_equal(model, new_model))
  • Flax - TF - Not supported
  • Flax - Pt - Supported, tested here:
    @require_safetensors
    @require_torch
    def test_safetensors_flax_from_torch(self):
    hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
    model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
    with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, safe_serialization=True)
    new_model = FlaxBertModel.from_pretrained(tmp_dir)
    self.assertTrue(check_models_equal(hub_model, new_model))

From you initial comment I understand it's not possible, but it's not entirely clear for me as to why (you mention sharded weights, is it the only restriction? If yes, from what I read it should be okay-ish to be able to at least load for those, no ?)

I mention this in the PR description:

TensorFlow models can load models in safetensors saved from PyTorch and TensorFlow, but it cannot load them from Flax. This can be eventually worked on; meanwhile, I'll write this in the docs with workarounds to get models saved in Flax to work in TensorFlow for those interested.

It should be pretty straightforward to enable it, but I suspect extremely little usage for a TF <> Flax conversion where no PyTorch conversion exists. I'm planning to add this to the documentation and IMO we can work on it afterwards if there are requests.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🔥

Just some small nits and Qs for my own understanding

src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the msgpack sharded weights
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I've understood correctly: TF and Flax models can't load sharded weights from safetensors. So, if this passes, we know the model has successfully loaded the msgpack sharded weights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's exactly right! This was raised by Sanchit as a previous version of the implementation priorized safetensors, realized they were sharded, and errored-out; but if sharded msgpack are also in the repo, we would want to load these first

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

tests/test_modeling_tf_utils.py Outdated Show resolved Hide resolved
tests/test_modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
# init random models
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

if from_pt:
if from_pt or safetensors_from_pt:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exposing my lack of knowledge about safe tensors here: if safetensors_from_pt is True here then is the reason we do load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) because the serialized weights are in "pytorch format" and therefore can't be loaded using cls.load_flax_weights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct! This way we call load_pytorch_checkpoint_in_flax_state_dict with the safetensors file, and that method checks if the file ends with .safetensors to load it the pytorch-way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

@@ -986,15 +988,15 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
# Read the safetensors file
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
mismatched_layers = []
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As:

  • The previous function format_weight_name added _prefix, whereas the new function strip_model_name_and_prefix removes _prefix if it's present.
  • weight_names are compared to those in safetensors_archive

Can we load in previously saved safetensors (before this PR) into our TF models?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts I believe the previous code was just completely incorrect! Essentially, the relevant workflow (saving TF composite encoder-decoder models as safetensors and then reloading the checkpoint in TF) was not being tested, and actually didn't work because of the name prefix bug.

As such, I don't think there's a backwards compatibility issue here, because previous checkpoints weren't working at all. My suspicion is that not many people were saving encoder-decoder models in TF, and not many TF users were saving safetensors, and so the intersection of that venn diagram was tiny enough that no-one noticed the bug for a long time!

Copy link
Member

@Rocketknight1 Rocketknight1 Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just to clarify: _prefix is almost always None or "" when this function is called, in which case the behaviour is unchanged after this bugfix. _prefix is only defined when loading composite models like EncoderDecoder, which is the workflow that was broken before this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
@LysandreJik
Copy link
Member Author

I will proceed to merge this and write a small explanatory doc tomorrow. I would like for the slow tests to run on this before the release.

@LysandreJik LysandreJik merged commit 113ebf8 into main Oct 31, 2023
3 checks passed
@LysandreJik LysandreJik deleted the safetensors-by-default branch October 31, 2023 18:16
@Narsil
Copy link
Contributor

Narsil commented Nov 1, 2023

Awesome ! Thanks a LOT for this.

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <[email protected]>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: Matt <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants