Skip to content

Commit

Permalink
[Serialization] Add is_main_process argument to `save_torch_state_d…
Browse files Browse the repository at this point in the history
…ict()` (#2648)

* Add is_main_process flag

* Update tests comments

* Fix failing test

* fix typos
  • Loading branch information
hanouticelina authored Nov 5, 2024
1 parent 0deb17f commit 0c98fbd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def save_torch_model(
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
):
"""
Saves a given torch model to disk, handling sharding and shared tensors issues.
Expand Down Expand Up @@ -88,6 +89,10 @@ def save_torch_model(
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
Example:
Expand All @@ -112,6 +117,7 @@ def save_torch_model(
metadata=metadata,
safe_serialization=safe_serialization,
save_directory=save_directory,
is_main_process=is_main_process,
)


Expand All @@ -124,6 +130,7 @@ def save_torch_state_dict(
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
) -> None:
"""
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
Expand Down Expand Up @@ -171,7 +178,10 @@ def save_torch_state_dict(
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
Example:
```py
Expand Down Expand Up @@ -222,15 +232,18 @@ def save_torch_state_dict(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)

# Clean the folder from previous save
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
for filename in os.listdir(save_directory):
if existing_files_regex.match(filename):
try:
logger.debug(f"Removing existing file '{filename}' from folder.")
os.remove(os.path.join(save_directory, filename))
except Exception as e:
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")
# Only main process should clean up existing files to avoid race conditions in distributed environment
if is_main_process:
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
for filename in os.listdir(save_directory):
if existing_files_regex.match(filename):
try:
logger.debug(f"Removing existing file '{filename}' from folder.")
os.remove(os.path.join(save_directory, filename))
except Exception as e:
logger.warning(
f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
)

# Save each shard
per_file_metadata = {"format": "pt"}
Expand Down Expand Up @@ -442,7 +455,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

if is_traceable_wrapper_subclass(tensor):
return _get_unique_id(tensor)
return _get_unique_id(tensor) # type: ignore
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass
Expand Down
26 changes: 26 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
max_shard_size="3GB",
metadata={"foo": "bar"},
safe_serialization=True,
is_main_process=True,
)
safe_state_dict_mock.assert_called_once_with(
state_dict=model_mock.state_dict.return_value,
Expand All @@ -273,6 +274,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
max_shard_size="3GB",
metadata={"foo": "bar"},
safe_serialization=True,
is_main_process=True,
)


Expand Down Expand Up @@ -472,3 +474,27 @@ def test_save_torch_state_dict_delete_existing_files(
assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file()
assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file()
assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file()


def test_save_torch_state_dict_not_main_process(
tmp_path: Path,
torch_state_dict: Dict[str, "torch.Tensor"],
) -> None:
"""
Test that previous files in the directory are not deleted when is_main_process=False.
When is_main_process=True, previous files should be deleted,
this is already tested in `test_save_torch_state_dict_delete_existing_files`.
"""
# Create some .safetensors files before saving a new state dict.
(tmp_path / "model.safetensors").touch()
(tmp_path / "model-00001-of-00002.safetensors").touch()
(tmp_path / "model-00002-of-00002.safetensors").touch()
(tmp_path / "model.safetensors.index.json").touch()
# Save with is_main_process=False
save_torch_state_dict(torch_state_dict, tmp_path, is_main_process=False)

# Previous files should still exist (not deleted)
assert (tmp_path / "model.safetensors").is_file()
assert (tmp_path / "model-00001-of-00002.safetensors").is_file()
assert (tmp_path / "model-00002-of-00002.safetensors").is_file()
assert (tmp_path / "model.safetensors.index.json").is_file()

0 comments on commit 0c98fbd

Please sign in to comment.