diff --git a/CHANGELOG.md b/CHANGELOG.md index 5306861b7f0af..190cd1a070845 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610)) - Fixed an issue with collecting logged test results with multiple dataloaders ([#10522](https://github.com/PyTorchLightning/pytorch-lightning/pull/10522)) +- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) + + - Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 6b8e44b610352..c13800cb842d6 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -46,7 +46,7 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_LiteOptimizer - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._accelerator = accelerator @@ -55,6 +55,9 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: def optimizer(self) -> Optimizer: return self._optimizer + def state_dict(self) -> Dict[str, Tensor]: + return self._accelerator.optimizer_state(self.optimizer) + def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure self._accelerator.optimizer_step( diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index c271d3b3163ed..a732390e1d00a 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -142,6 +142,15 @@ def test_lite_optimizer_wraps(): assert isinstance(lite_optimizer, optimizer_cls) +def test_lite_optimizer_state_dict(): + """Test that the LiteOptimizer calls into the accelerator/strategy to collect the state.""" + optimizer = Mock() + accelerator = Mock() + lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + lite_optimizer.state_dict() + accelerator.optimizer_state.assert_called_with(optimizer) + + def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock()