diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index de9273c710bfd8..9d59f03e279ddf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2271,11 +2271,13 @@ def _load_pretrained_model( offload_state_dict=False, dtype=None, ): - if device_map is not None and "disk" in device_map.values() and offload_folder is None: - raise ValueError( - "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for" - " them." - ) + if device_map is not None and "disk" in device_map.values(): + if offload_folder is None: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them." + ) + os.makedirs(offload_folder, exist_ok=True) # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -2449,6 +2451,15 @@ def _find_mismatched_keys( gc.collect() if offload_index is not None and len(offload_index) > 0: + if model != model_to_load: + # We need to add the prefix of the base model + prefix = cls.base_model_prefix + for weight_name in offload_index: + shutil.move( + os.path.join(offload_folder, f"{weight_name}.dat"), + os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), + ) + offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} save_offload_index(offload_index, offload_folder) if offload_state_dict: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d83454d4e63dc2..e1ff5851e1a8e5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2811,6 +2811,48 @@ def test_model_parallelism_gpt2(self): text_output = tokenizer.decode(output[0].tolist()) self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm") + @require_accelerate + @require_torch_gpu + def test_from_pretrained_disk_offload_task_model(self): + model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2") + device_map = { + "transformer.wte": 0, + "transformer.wpe": 0, + "transformer.h.0": "cpu", + "transformer.h.1": "cpu", + "transformer.h.2": "cpu", + "transformer.h.3": "disk", + "transformer.h.4": "disk", + "transformer.ln_f": 0, + "lm_head": 0, + } + with tempfile.TemporaryDirectory() as tmp_dir: + inputs = torch.tensor([[1, 2, 3]]).to(0) + + model.save_pretrained(tmp_dir) + new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0) + outputs1 = new_model.to(0)(inputs) + + offload_folder = os.path.join(tmp_dir, "offload") + new_model_with_offload = AutoModelForCausalLM.from_pretrained( + tmp_dir, device_map=device_map, offload_folder=offload_folder + ) + outputs2 = new_model_with_offload(inputs) + + self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu())) + + # With state dict temp offload + offload_folder = os.path.join(tmp_dir, "offload") + new_model_with_offload = AutoModelForCausalLM.from_pretrained( + tmp_dir, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=True, + ) + outputs2 = new_model_with_offload(inputs) + + self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu())) + def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock()