diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 494367b813bf3e..f0cad5f338d7e8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2086,6 +2086,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) if trust_remote_code is True: logger.warning( @@ -2222,14 +2223,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ): # Load from a Flax checkpoint in priority if from_flax archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - elif is_safetensors_available() and os.path.isfile( + elif use_safetensors is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) ): # Load from a safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) ) - elif is_safetensors_available() and os.path.isfile( + elif use_safetensors is not False and os.path.isfile( os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) ) @@ -2295,7 +2296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P filename = TF2_WEIGHTS_NAME elif from_flax: filename = FLAX_WEIGHTS_NAME - elif is_safetensors_available(): + elif use_safetensors is not False: filename = _add_variant(SAFE_WEIGHTS_NAME, variant) else: filename = _add_variant(WEIGHTS_NAME, variant) @@ -2328,6 +2329,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if resolved_archive_file is not None: is_sharded = True + elif use_safetensors: + raise EnvironmentError( + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) else: # This repo has no safetensors file of any kind, we switch to PyTorch. filename = _add_variant(WEIGHTS_NAME, variant) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cb06400e9a777a..f71366d2183829 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -15,6 +15,7 @@ import copy import gc +import glob import inspect import json import os @@ -119,6 +120,7 @@ AutoTokenizer, BertConfig, BertModel, + CLIPTextModel, PreTrainedModel, T5Config, T5ForConditionalGeneration, @@ -3327,6 +3329,49 @@ def test_legacy_load_from_url(self): "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config ) + @require_safetensors + def test_use_safetensors(self): + # test nice error message if no safetensor files available + with self.assertRaises(OSError) as env_error: + AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True) + + self.assertTrue( + "model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`" + in str(env_error.exception) + ) + + # test that error if only safetensors is available + with self.assertRaises(OSError) as env_error: + BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False) + + self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception)) + + # test that only safetensors if both available and use_safetensors=False + with tempfile.TemporaryDirectory() as tmp_dir: + CLIPTextModel.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", + subfolder="text_encoder", + use_safetensors=False, + cache_dir=tmp_dir, + ) + + all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*")) + self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files)) + self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files)) + + # test that no safetensors if both available and use_safetensors=True + with tempfile.TemporaryDirectory() as tmp_dir: + CLIPTextModel.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", + subfolder="text_encoder", + use_safetensors=True, + cache_dir=tmp_dir, + ) + + all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*")) + self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files)) + self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files)) + @require_safetensors def test_safetensors_save_and_load(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")