Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP checkpointing uses deprecated APIs with PyTorch 2.2 #19462

Open
carmocca opened this issue Feb 13, 2024 · 6 comments
Open

FSDP checkpointing uses deprecated APIs with PyTorch 2.2 #19462

carmocca opened this issue Feb 13, 2024 · 6 comments
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: fsdp Fully Sharded Data Parallel
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented Feb 13, 2024

Bug description

See added deprecation warnings in pytorch/pytorch#113867

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Originated from

save_state_dict(converted_state, writer)

We already use the newer API for loading

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load # deprecated

Error messages and logs

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py:31: UserWarning: 'save_state_dict' is deprecated and will be removed in future versions.Please use 'save' instead.
  warnings.warn(

Environment

No response

More info

No response

cc @awaelchli @carmocca

@carmocca carmocca added bug Something isn't working strategy: fsdp Fully Sharded Data Parallel labels Feb 13, 2024
@carmocca carmocca added this to the 2.2.x milestone Feb 13, 2024
@carmocca
Copy link
Contributor Author

carmocca commented Feb 13, 2024

Two more which probably need to be fixed in PyTorch

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py:1132: UserWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  warnings.warn(DEPRECATE_MSG)

From (print_stack added by me):

  File "/home/carlos/stuff.py", line 29, in <module>
    fabric.save(f"{compile}_before_fwd", {"model": fmodel})
  File "/home/carlos/lightning/src/lightning/fabric/fabric.py", line 770, in save
    self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state), filter=filter)
  File "/home/carlos/lightning/src/lightning/fabric/strategies/fsdp.py", line 484, in save_checkpoint
    converted = obj.state_dict()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1922, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 737, in _post_state_dict_hook
    local_shape = tensor.shape
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in __torch_function__
    traceback.print_stack()

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py:151: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():

From (print_stack added by me):

  File "/home/carlos/stuff.py", line 29, in <module>
    fabric.save(f"{compile}_before_fwd", {"model": fmodel})
  File "/home/carlos/lightning/src/lightning/fabric/fabric.py", line 770, in save
    self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state), filter=filter)
  File "/home/carlos/lightning/src/lightning/fabric/strategies/fsdp.py", line 496, in save_checkpoint
    save_state_dict(converted_state, writer)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 40, in save_state_dict
    return _save_state_dict(
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 280, in _save_state_dict
    return distW.all_reduce("write", write_data, finish_checkpoint)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 210, in all_reduce
    local_data = map_fun()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 270, in write_data
    all_writes = storage_writer.write_data(final_local_plan, planner)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 470, in write_data
    _write_files_from_queue(
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 284, in _write_files_from_queue
    loader.start_loading()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 179, in start_loading
    self._refill()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 150, in _refill
    traceback.print_stack()

@carmocca
Copy link
Contributor Author

If the newer save is used, the argument order seems to have changed in pytorch/pytorch#117772

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py:409: UserWarning: The argument order of save has been changed. Please check the document to avoid future breakages.
  warnings.warn(

This probably applies to load too. I haven't tried it

@awaelchli
Copy link
Contributor

I agree we need to update these imports.
The change in argument order is only in nightly, but since lit-gpt relies on that, we should start incorporating this asap.

@carmocca
Copy link
Contributor Author

Technically lit-gpt doesn't rely on nightly since the 2.2 release.

I opened #19463

@carmocca
Copy link
Contributor Author

Also opened pytorch/pytorch#119802 upstream. We might want to silence these after this is resolved

@carmocca
Copy link
Contributor Author

pytorch/pytorch#119800 (comment) suggests that we should replace (in 2.2+) most of what we have with {get,set}_{model,optimizer}_state_dict functions in https://github.com/pytorch/pytorch/blob/v2.2.0/torch/distributed/checkpoint/state_dict.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants