From a60037de22838d9ce135920dbd3560d59fbcfaae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 22:37:27 +0100 Subject: [PATCH] typing hell --- pytorch_lightning/lite/lite.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index ba545c38a4917..e554a3e61495a 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union, cast import torch import torch.nn as nn @@ -188,7 +188,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[_LiteDataLoader, List[_LiteDataLoader]]: + ) -> Union[DataLoader, List[DataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -213,7 +213,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> _LiteDataLoader: + ) -> DataLoader: """Setup a single dataloader for accelerated training. Args: @@ -246,7 +246,9 @@ def _setup_dataloader( dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None - return _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = cast(DataLoader, lite_dataloader) + return lite_dataloader def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.