Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Dec 10, 2024
2 parents b633228 + 601c060 commit 9546a22
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 1 deletion.
24 changes: 23 additions & 1 deletion docs/source-pytorch/common/checkpointing_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ PyTorch Lightning checkpoints are fully usable in plain PyTorch.

----

.. important::

**Important Update: Deprecated Method**

Starting from PyTorch Lightning v1.0.0, the `resume_from_checkpoint` argument has been deprecated. To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method.
Please update your code accordingly to avoid potential compatibility issues.

************************
Contents of a checkpoint
************************
Expand Down Expand Up @@ -197,16 +204,31 @@ You can disable checkpointing by passing:

----


*********************
Resume training state
*********************

If you don't just want to load weights, but instead restore the full training, do the following:

Correct usage:

.. code-block:: python
model = LitModel()
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt")
.. warning::

The argument `resume_from_checkpoint` has been deprecated in versions of PyTorch Lightning >= 1.0.0.
To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method instead.

Incorrect (deprecated) usage:

.. code-block:: python
trainer = Trainer(resume_from_checkpoint="path/to/your/checkpoint.ckpt")
trainer.fit(model)
6 changes: 6 additions & 0 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
}
self._desired_input_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_input_dtype)
return module

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))


## [2.3.0] - 2024-06-13
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import get_args, override

import lightning.pytorch as pl
Expand Down Expand Up @@ -73,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
}
self._desired_input_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_input_dtype)
return module

@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision():

with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"):
FSDPPrecision(precision="64-true")


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = FSDPPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype
18 changes: 18 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected):
assert config.reduce_dtype == expected[2]


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = FSDPPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype


def test_fsdp_precision_default_scaler():
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down

0 comments on commit 9546a22

Please sign in to comment.