Skip to content

Commit

Permalink
Fix mypy typing errors in pytorch_lightning/strategies/tpu_spawn.py (#…
Browse files Browse the repository at this point in the history
…13813)

Co-authored-by: awaelchli <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: otaj <[email protected]>
  • Loading branch information
4 people authored Aug 2, 2022
1 parent 0fbfbf9 commit d8e5e7f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ module = [
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.tpu_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
Expand Down
48 changes: 30 additions & 18 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import io
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -29,15 +29,17 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand All @@ -58,7 +60,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[int]] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False,
Expand All @@ -72,6 +74,7 @@ def __init__(
precision_plugin=precision_plugin,
start_method="fork",
)
self._checkpoint_io: Optional[CheckpointIO]
self.debug = debug
self._launched = False

Expand All @@ -95,17 +98,16 @@ def root_device(self) -> torch.device:
return xm.xla_device()

@staticmethod
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

for dataloader in dataloaders:
def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
def check_has_len(dataloader: DataLoader) -> None:
if not has_len(dataloader):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len)

@staticmethod
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
Expand All @@ -118,32 +120,37 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
)
for source in sources:
if not source.is_module():
assert source.instance is not None
assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule))
TPUSpawnStrategy._validate_dataloader(source.instance)

def connect(self, model: "pl.LightningModule") -> None:
def connect(self, model: "pl.LightningModule") -> None: # type: ignore
TPUSpawnStrategy._validate_patched_dataloaders(model)
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)

def _configure_launcher(self):
def _configure_launcher(self) -> None:
self._launcher = _XLALauncher(self)

def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator
self.accelerator.setup(trainer)

if self.debug:
os.environ["PT_XLA_DEBUG"] = "1"

assert self.model
shared_params = find_shared_parameters(self.model)
self.model_to_device()
assert isinstance(self.model.module, Module)
set_shared_parameters(self.model.module, shared_params)
self.setup_precision_plugin()

if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

def _setup_model(self, model: Module) -> Module:
def _setup_model(self, model: Module) -> Module: # type: ignore
return model

@property
Expand All @@ -168,11 +175,11 @@ def configure_ddp(self) -> None:
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

def barrier(self, name: Optional[str] = None) -> None:
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
if self.is_distributed:
rendezvous(name)

def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not self.is_distributed:
return obj
buffer = io.BytesIO()
Expand All @@ -184,7 +191,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
def reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
if not isinstance(output, Tensor):
output = torch.tensor(output, device=self.root_device)

Expand All @@ -203,20 +212,23 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def _worker_setup(self, process_idx: int):
def _worker_setup(self, process_idx: int) -> None:
self._launched = True
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def is_defined(self) -> bool:
return not self.is_module() or is_overridden(self.name, self.instance)

def is_module(self) -> bool:
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
"""Returns whether the DataLoader source is a LightningModule or a LightningDataModule.
It does not check whether ``*_dataloader`` methods are actually overridden.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def apply_to_collection(
dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None,
include_none: bool = True,
**kwargs: Any,
) -> Any:
Expand Down

0 comments on commit d8e5e7f

Please sign in to comment.