From 033f46987a602982a8a89bbf8efd6db411486900 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 18 Jul 2022 06:45:47 -0400 Subject: [PATCH 1/9] initial changes to variable --- pyproject.toml | 1 - .../core/mixins/device_dtype_mixin.py | 24 +++++++------------ 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 989e63122f640..ad5feed0c24a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ module = [ "pytorch_lightning.callbacks.stochastic_weight_avg", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", - "pytorch_lightning.core.mixins.device_dtype_mixin", "pytorch_lightning.core.module", "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 5f6397e4562e5..2b2a4fd2072ce 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TypeVar import torch from torch.nn import Module -try: - from typing_extensions import Self -except ImportError: - # workaround for Python 3.7. - # see https://www.python.org/dev/peps/pep-0673/ - from typing import TypeVar - Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") +TDeviceDtypeModuleMixin = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") import pytorch_lightning as pl @@ -57,7 +51,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> Self: + def to(self, *args: Any, **kwargs: Any) -> TDeviceDtypeModuleMixin: """Moves and/or casts the parameters and buffers. This can be called as @@ -121,7 +115,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self: self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> DeviceDtypeModuleMixin: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -138,7 +132,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: self.__update_properties(device=device) return super().cuda(device=device) - def cpu(self) -> Self: + def cpu(self) -> TDeviceDtypeModuleMixin: """Moves all model parameters and buffers to the CPU. Returns: @@ -147,7 +141,7 @@ def cpu(self) -> Self: self.__update_properties(device=torch.device("cpu")) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> Self: + def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -159,7 +153,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Self: self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> Self: + def float(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``float`` datatype. Returns: @@ -168,7 +162,7 @@ def float(self) -> Self: self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> Self: + def double(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -177,7 +171,7 @@ def double(self) -> Self: self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> Self: + def half(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``half`` datatype. Returns: From ff08ba5099681b16ea95eb9ad97411a5661d857e Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 18 Jul 2022 06:48:37 -0400 Subject: [PATCH 2/9] add assert --- src/pytorch_lightning/core/mixins/device_dtype_mixin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 2b2a4fd2072ce..c08a246ca8d72 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -115,7 +115,7 @@ def to(self, *args: Any, **kwargs: Any) -> TDeviceDtypeModuleMixin: self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> DeviceDtypeModuleMixin: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtypeModuleMixin: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -128,6 +128,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> DeviceDtype Module: self """ if device is None or isinstance(device, int): + assert isinstance(device, int) device = torch.device("cuda", index=device) self.__update_properties(device=device) return super().cuda(device=device) From 3cb210d11d9ea8419fa87b19b60b21ae74786755 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Jul 2022 10:54:22 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/core/mixins/device_dtype_mixin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index c08a246ca8d72..c605789681299 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, TypeVar +from typing import Any, Optional, TypeVar, Union import torch from torch.nn import Module - TDeviceDtypeModuleMixin = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") From 275d7bbf9e4bacce1a39afa51d1cad3c2930824e Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 19 Jul 2022 15:29:23 -0400 Subject: [PATCH 4/9] add ignore --- .../core/mixins/device_dtype_mixin.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index c08a246ca8d72..8041b6b1c410d 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -113,7 +113,7 @@ def to(self, *args: Any, **kwargs: Any) -> TDeviceDtypeModuleMixin: # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs) self.__update_properties(device=out[0], dtype=out[1]) - return super().to(*args, **kwargs) + return super().to(*args, **kwargs) # type: ignore def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtypeModuleMixin: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers @@ -131,7 +131,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtyp assert isinstance(device, int) device = torch.device("cuda", index=device) self.__update_properties(device=device) - return super().cuda(device=device) + return super().cuda(device=device) # type: ignore def cpu(self) -> TDeviceDtypeModuleMixin: """Moves all model parameters and buffers to the CPU. @@ -140,7 +140,7 @@ def cpu(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(device=torch.device("cpu")) - return super().cpu() + return super().cpu() # type: ignore def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: """Casts all parameters and buffers to :attr:`dst_type`. @@ -152,7 +152,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=dst_type) - return super().type(dst_type=dst_type) + return super().type(dst_type=dst_type) # type: ignore def float(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``float`` datatype. @@ -161,7 +161,7 @@ def float(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=torch.float) - return super().float() + return super().float() # type: ignore def double(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``double`` datatype. @@ -170,7 +170,7 @@ def double(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=torch.double) - return super().double() + return super().double() # type: ignore def half(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``half`` datatype. @@ -179,7 +179,7 @@ def half(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=torch.half) - return super().half() + return super().half() # type: ignore def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None From 20db3e3be3fbed7640c44866063a0d4cc0dc3930 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Jul 2022 19:33:03 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/mixins/device_dtype_mixin.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index e8cebea8132eb..2eeab6da997be 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -112,7 +112,7 @@ def to(self, *args: Any, **kwargs: Any) -> TDeviceDtypeModuleMixin: # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs) self.__update_properties(device=out[0], dtype=out[1]) - return super().to(*args, **kwargs) # type: ignore + return super().to(*args, **kwargs) # type: ignore def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtypeModuleMixin: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers @@ -130,7 +130,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtyp assert isinstance(device, int) device = torch.device("cuda", index=device) self.__update_properties(device=device) - return super().cuda(device=device) # type: ignore + return super().cuda(device=device) # type: ignore def cpu(self) -> TDeviceDtypeModuleMixin: """Moves all model parameters and buffers to the CPU. @@ -139,7 +139,7 @@ def cpu(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(device=torch.device("cpu")) - return super().cpu() # type: ignore + return super().cpu() # type: ignore def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: """Casts all parameters and buffers to :attr:`dst_type`. @@ -151,7 +151,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=dst_type) - return super().type(dst_type=dst_type) # type: ignore + return super().type(dst_type=dst_type) # type: ignore def float(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``float`` datatype. @@ -160,7 +160,7 @@ def float(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=torch.float) - return super().float() # type: ignore + return super().float() # type: ignore def double(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``double`` datatype. @@ -169,7 +169,7 @@ def double(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=torch.double) - return super().double() # type: ignore + return super().double() # type: ignore def half(self) -> TDeviceDtypeModuleMixin: """Casts all floating point parameters and buffers to ``half`` datatype. @@ -178,7 +178,7 @@ def half(self) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=torch.half) - return super().half() # type: ignore + return super().half() # type: ignore def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None From bc1509cfe2b1cf5286b3340dd8c93b668260ea38 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 19 Jul 2022 15:44:20 -0400 Subject: [PATCH 6/9] add validation for Nonr --- src/pytorch_lightning/core/mixins/device_dtype_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index e8cebea8132eb..e3d80605ab040 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -127,7 +127,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtyp Module: self """ if device is None or isinstance(device, int): - assert isinstance(device, int) + assert isinstance(device, int) or device is None device = torch.device("cuda", index=device) self.__update_properties(device=device) return super().cuda(device=device) # type: ignore From 46e2498b0f248c44e503e652c18b71023b7d434d Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 20 Jul 2022 17:31:51 -0400 Subject: [PATCH 7/9] type changes based on reviews --- .../core/mixins/device_dtype_mixin.py | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 0e5eef12861f6..aefc3e4879287 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, Union import torch from torch.nn import Module - -TDeviceDtypeModuleMixin = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") - - +from typing_extensions import Self import pytorch_lightning as pl @@ -50,7 +47,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> TDeviceDtypeModuleMixin: + def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] """Moves and/or casts the parameters and buffers. This can be called as @@ -112,9 +109,9 @@ def to(self, *args: Any, **kwargs: Any) -> TDeviceDtypeModuleMixin: # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs) self.__update_properties(device=out[0], dtype=out[1]) - return super().to(*args, **kwargs) # type: ignore + return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtypeModuleMixin: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -127,21 +124,21 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> TDeviceDtyp Module: self """ if device is None or isinstance(device, int): - assert isinstance(device, int) or device is None + assert isinstance(device, (int, type(None))) device = torch.device("cuda", index=device) self.__update_properties(device=device) - return super().cuda(device=device) # type: ignore + return super().cuda(device=device) - def cpu(self) -> TDeviceDtypeModuleMixin: + def cpu(self) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the CPU. Returns: Module: self """ self.__update_properties(device=torch.device("cpu")) - return super().cpu() # type: ignore + return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: + def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type] """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -151,34 +148,34 @@ def type(self, dst_type: Union[str, torch.dtype]) -> TDeviceDtypeModuleMixin: Module: self """ self.__update_properties(dtype=dst_type) - return super().type(dst_type=dst_type) # type: ignore + return super().type(dst_type=dst_type) - def float(self) -> TDeviceDtypeModuleMixin: + def float(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``float`` datatype. Returns: Module: self """ self.__update_properties(dtype=torch.float) - return super().float() # type: ignore + return super().float() - def double(self) -> TDeviceDtypeModuleMixin: + def double(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``double`` datatype. Returns: Module: self """ self.__update_properties(dtype=torch.double) - return super().double() # type: ignore + return super().double() - def half(self) -> TDeviceDtypeModuleMixin: + def half(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``half`` datatype. Returns: Module: self """ self.__update_properties(dtype=torch.half) - return super().half() # type: ignore + return super().half() def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None @@ -193,4 +190,4 @@ def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: if dtype is not None: module._dtype = dtype - self.apply(apply_fn) + self.apply(apply_fn) \ No newline at end of file From be7a95f644d6c65b04d6cb291525e5127bee8fb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 21:33:32 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/core/mixins/device_dtype_mixin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index aefc3e4879287..f430dd7f8e139 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,7 @@ import torch from torch.nn import Module from typing_extensions import Self + import pytorch_lightning as pl @@ -47,7 +48,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] + def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] """Moves and/or casts the parameters and buffers. This can be called as @@ -190,4 +191,4 @@ def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: if dtype is not None: module._dtype = dtype - self.apply(apply_fn) \ No newline at end of file + self.apply(apply_fn) From 451cfd360b68ff0bd504ad71a39c260b94f6335c Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 25 Jul 2022 09:28:23 +0200 Subject: [PATCH 9/9] suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/core/mixins/device_dtype_mixin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index f430dd7f8e139..b12e1cf042a1f 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -125,8 +125,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty Module: self """ if device is None or isinstance(device, int): - assert isinstance(device, (int, type(None))) - device = torch.device("cuda", index=device) + device = torch.device("cuda", index=(device or 0)) self.__update_properties(device=device) return super().cuda(device=device)