Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors attributed to pytorch_lightning.core.mixins.device_dtype_mixin #13704

Merged
merged 12 commits into from
Jul 26, 2022
Merged
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ module = [
"pytorch_lightning.callbacks.stochastic_weight_avg",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.mixins.device_dtype_mixin",
"pytorch_lightning.core.module",
"pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",
Expand Down
26 changes: 9 additions & 17 deletions src/pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,7 @@

import torch
from torch.nn import Module

try:
from typing_extensions import Self
except ImportError:
# workaround for Python 3.7.
# see https://www.python.org/dev/peps/pep-0673/
from typing import TypeVar

Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")

from typing_extensions import Self

import pytorch_lightning as pl

Expand Down Expand Up @@ -57,7 +48,7 @@ def device(self) -> Union[str, torch.device]:

return device

def to(self, *args: Any, **kwargs: Any) -> Self:
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
"""Moves and/or casts the parameters and buffers.

This can be called as
Expand Down Expand Up @@ -121,7 +112,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
different objects. So it should be called before constructing optimizer if the module will live on GPU
while being optimized.
Expand All @@ -134,11 +125,12 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
Module: self
"""
if device is None or isinstance(device, int):
assert isinstance(device, (int, type(None)))
device = torch.device("cuda", index=device)
krishnakalyan3 marked this conversation as resolved.
Show resolved Hide resolved
self.__update_properties(device=device)
return super().cuda(device=device)

def cpu(self) -> Self:
def cpu(self) -> Self: # type: ignore[valid-type]
"""Moves all model parameters and buffers to the CPU.

Returns:
Expand All @@ -147,7 +139,7 @@ def cpu(self) -> Self:
self.__update_properties(device=torch.device("cpu"))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> Self:
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
"""Casts all parameters and buffers to :attr:`dst_type`.

Arguments:
Expand All @@ -159,7 +151,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Self:
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> Self:
def float(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``float`` datatype.

Returns:
Expand All @@ -168,7 +160,7 @@ def float(self) -> Self:
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> Self:
def double(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``double`` datatype.

Returns:
Expand All @@ -177,7 +169,7 @@ def double(self) -> Self:
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> Self:
def half(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``half`` datatype.

Returns:
Expand Down