-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Safetensors serialization by default #27064
Conversation
The documentation is not available anymore as the PR was closed or merged. |
4398da0
to
e4ebba7
Compare
f86a14f
to
66a896a
Compare
a680bd1
to
4c09b62
Compare
if ( | ||
is_safetensors_available() | ||
and isinstance(resolved_archive_file, str) | ||
and resolved_archive_file.endswith(".safetensors") | ||
): | ||
with safe_open(resolved_archive_file, framework="pt") as f: | ||
metadata = f.metadata() | ||
|
||
if metadata.get("format") == "pt": | ||
pass | ||
elif metadata.get("format") == "tf": | ||
from_tf = True | ||
logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") | ||
elif metadata.get("format") == "flax": | ||
from_flax = True | ||
logger.info("A Flax safetensors file is being loaded in a PyTorch model.") | ||
else: | ||
raise ValueError( | ||
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}" | ||
) | ||
|
||
from_pt = not (from_tf | from_flax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is necessary to enable loading safetensors
files saved from TensorFlow/Jax into PyTorch models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice !
Not that bad for such a big change.
For the testing part I see loading in PT from PT/TF/Flax, but not the other ways
TF -TF
TF - Flax
TF - Pt
Flax - Flax
Flax - TF
Flax - Pt.
From you initial comment I understand it's not possible, but it's not entirely clear for me as to why (you mention sharded weights, is it the only restriction? If yes, from what I read it should be okay-ish to be able to at least load for those, no ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this @LysandreJik!
Co-authored-by: Sanchit Gandhi <[email protected]>
@Narsil, this is what is currently supported and not supported:
I mention this in the PR description:
It should be pretty straightforward to enable it, but I suspect extremely little usage for a TF <> Flax conversion where no PyTorch conversion exists. I'm planning to add this to the documentation and IMO we can work on it afterwards if there are requests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! 🔥
Just some small nits and Qs for my own understanding
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Show resolved
Hide resolved
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self): | ||
# This should not raise even if there are two types of sharded weights | ||
# This should discard the safetensors weights in favor of the msgpack sharded weights | ||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure I've understood correctly: TF and Flax models can't load sharded weights from safetensors. So, if this passes, we know the model has successfully loaded the msgpack sharded weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's exactly right! This was raised by Sanchit as a previous version of the implementation priorized safetensors, realized they were sharded, and errored-out; but if sharded msgpack are also in the repo, we would want to load these first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining!
# init random models | ||
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) | ||
|
||
if from_pt: | ||
if from_pt or safetensors_from_pt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exposing my lack of knowledge about safe tensors here: if safetensors_from_pt
is True
here then is the reason we do load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
because the serialized weights are in "pytorch format" and therefore can't be loaded using cls.load_flax_weights
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's correct! This way we call load_pytorch_checkpoint_in_flax_state_dict
with the safetensors file, and that method checks if the file ends with .safetensors
to load it the pytorch-way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining!
@@ -986,15 +988,15 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat | |||
# Read the safetensors file | |||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: | |||
mismatched_layers = [] | |||
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights] | |||
weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As:
- The previous function
format_weight_name
added_prefix
, whereas the new functionstrip_model_name_and_prefix
removes_prefix
if it's present. weight_names
are compared to those insafetensors_archive
Can we load in previously saved safetensors (before this PR) into our TF models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts I believe the previous code was just completely incorrect! Essentially, the relevant workflow (saving TF composite encoder-decoder models as safetensors and then reloading the checkpoint in TF) was not being tested, and actually didn't work because of the name prefix bug.
As such, I don't think there's a backwards compatibility issue here, because previous checkpoints weren't working at all. My suspicion is that not many people were saving encoder-decoder models in TF, and not many TF users were saving safetensors, and so the intersection of that venn diagram was tiny enough that no-one noticed the bug for a long time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, just to clarify: _prefix
is almost always None
or ""
when this function is called, in which case the behaviour is unchanged after this bugfix. _prefix
is only defined when loading composite models like EncoderDecoder
, which is the workflow that was broken before this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining!
Co-authored-by: amyeroberts <[email protected]>
I will proceed to merge this and write a small explanatory doc tomorrow. I would like for the slow tests to run on this before the release. |
Awesome ! Thanks a LOT for this. |
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <[email protected]> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: Matt <[email protected]> Co-authored-by: Sanchit Gandhi <[email protected]> Co-authored-by: amyeroberts <[email protected]>
This PR aims to do one thing but is larger than expected. I'm happy to break it down into smaller PRs if it helps for reviewing.
This PR aims to switch safe serialization to
True
by default fortorch
models. In doing so, it revealed a few bugs in the existing implementation andsafetensors
support that this PR fixes.Additionally, support for
safetensors
for Flax models is added so that models saved from PyTorch after merging this PR can be used in both TensorFlow and Flax, and for models saved from TensorFlow/Flax to be loaded in PyTorch models.The following should be worked on shortly to enable switching to safetensors by default for TensorFlow and Flax as well:
Additionally, I'll contribute some documentation making the following clear:
Thanks, @Rocketknight1, for the help on TensorFlow's side.