Skip to content

Commit

Permalink
Fix infinite recursion error in precision plugin graveyard (#19542)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Feb 27, 2024
1 parent 7880c11 commit a6c0a31
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130))


-
- Fixed infinite recursion error in precision plugin graveyard ([#19542](https://github.com/Lightning-AI/pytorch-lightning/pull/19542))



## [2.2.0] - 2024-02-08
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/_graveyard/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def init(self: type, *args: Any, **kwargs: Any) -> None:
f"The `{deprecated_name}` is deprecated."
f" Use `lightning.pytorch.plugins.precision.{new_class.__name__}` instead."
)
super(type(self), self).__init__(*args, **kwargs)
new_class.__init__(self, *args, **kwargs) # type: ignore[misc]

return type(deprecated_name, (new_class,), {"__init__": init})

Expand Down
23 changes: 23 additions & 0 deletions tests/tests_pytorch/graveyard/test_precision.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_precision_plugin_renamed_imports():
# base class
from lightning.pytorch.plugins import PrecisionPlugin as PrecisionPlugin2
Expand All @@ -9,6 +12,10 @@ def test_precision_plugin_renamed_imports():
assert issubclass(PrecisionPlugin1, Precision)
assert issubclass(PrecisionPlugin2, Precision)

for plugin_cls in [PrecisionPlugin0, PrecisionPlugin1, PrecisionPlugin2]:
with pytest.warns(DeprecationWarning, match="The `PrecisionPlugin` is deprecated"):
plugin_cls()

# bitsandbytes
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin as BnbPlugin2
from lightning.pytorch.plugins.precision import BitsandbytesPrecisionPlugin as BnbPlugin1
Expand Down Expand Up @@ -39,6 +46,10 @@ def test_precision_plugin_renamed_imports():
assert issubclass(DoublePlugin1, DoublePrecision)
assert issubclass(DoublePlugin2, DoublePrecision)

for plugin_cls in [DoublePlugin0, DoublePlugin1, DoublePlugin2]:
with pytest.warns(DeprecationWarning, match="The `DoublePrecisionPlugin` is deprecated"):
plugin_cls()

# fsdp
from lightning.pytorch.plugins import FSDPPrecisionPlugin as FSDPPlugin2
from lightning.pytorch.plugins.precision import FSDPPrecisionPlugin as FSDPPlugin1
Expand All @@ -49,6 +60,10 @@ def test_precision_plugin_renamed_imports():
assert issubclass(FSDPPlugin1, FSDPPrecision)
assert issubclass(FSDPPlugin2, FSDPPrecision)

for plugin_cls in [FSDPPlugin0, FSDPPlugin1, FSDPPlugin2]:
with pytest.warns(DeprecationWarning, match="The `FSDPPrecisionPlugin` is deprecated"):
plugin_cls(precision="16-mixed")

# half
from lightning.pytorch.plugins import HalfPrecisionPlugin as HalfPlugin2
from lightning.pytorch.plugins.precision import HalfPrecisionPlugin as HalfPlugin1
Expand All @@ -59,6 +74,10 @@ def test_precision_plugin_renamed_imports():
assert issubclass(HalfPlugin1, HalfPrecision)
assert issubclass(HalfPlugin2, HalfPrecision)

for plugin_cls in [HalfPlugin0, HalfPlugin1, HalfPlugin2]:
with pytest.warns(DeprecationWarning, match="The `HalfPrecisionPlugin` is deprecated"):
plugin_cls()

# mixed
from lightning.pytorch.plugins import MixedPrecisionPlugin as MixedPlugin2
from lightning.pytorch.plugins.precision import MixedPrecisionPlugin as MixedPlugin1
Expand All @@ -69,6 +88,10 @@ def test_precision_plugin_renamed_imports():
assert issubclass(MixedPlugin1, MixedPrecision)
assert issubclass(MixedPlugin2, MixedPrecision)

for plugin_cls in [MixedPlugin0, MixedPlugin1, MixedPlugin2]:
with pytest.warns(DeprecationWarning, match="The `MixedPrecisionPlugin` is deprecated"):
plugin_cls(precision="bf16-mixed", device="cuda:0")

# transformer_engine
from lightning.pytorch.plugins import TransformerEnginePrecisionPlugin as TEPlugin2
from lightning.pytorch.plugins.precision import TransformerEnginePrecisionPlugin as TEPlugin1
Expand Down

0 comments on commit a6c0a31

Please sign in to comment.