diff --git a/pyproject.toml b/pyproject.toml index 177410cba79a6..5a710faf3544b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ module = [ "pytorch_lightning.callbacks.stochastic_weight_avg", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", - "pytorch_lightning.core.mixins.device_dtype_mixin", "pytorch_lightning.core.module", "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 5f6397e4562e5..b12e1cf042a1f 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -16,16 +16,7 @@ import torch from torch.nn import Module - -try: - from typing_extensions import Self -except ImportError: - # workaround for Python 3.7. - # see https://www.python.org/dev/peps/pep-0673/ - from typing import TypeVar - - Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") - +from typing_extensions import Self import pytorch_lightning as pl @@ -57,7 +48,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> Self: + def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] """Moves and/or casts the parameters and buffers. This can be called as @@ -121,7 +112,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self: self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -134,11 +125,11 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: Module: self """ if device is None or isinstance(device, int): - device = torch.device("cuda", index=device) + device = torch.device("cuda", index=(device or 0)) self.__update_properties(device=device) return super().cuda(device=device) - def cpu(self) -> Self: + def cpu(self) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the CPU. Returns: @@ -147,7 +138,7 @@ def cpu(self) -> Self: self.__update_properties(device=torch.device("cpu")) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> Self: + def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type] """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -159,7 +150,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Self: self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> Self: + def float(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``float`` datatype. Returns: @@ -168,7 +159,7 @@ def float(self) -> Self: self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> Self: + def double(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -177,7 +168,7 @@ def double(self) -> Self: self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> Self: + def half(self) -> Self: # type: ignore[valid-type] """Casts all floating point parameters and buffers to ``half`` datatype. Returns: