Skip to content

Commit

Permalink
typing hell
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Nov 19, 2021
1 parent 24a85ce commit a60037d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -188,7 +188,7 @@ def setup(

def setup_dataloaders(
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[_LiteDataLoader, List[_LiteDataLoader]]:
) -> Union[DataLoader, List[DataLoader]]:
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
dataloader, call this method individually for each one.
Expand All @@ -213,7 +213,7 @@ def setup_dataloaders(

def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> _LiteDataLoader:
) -> DataLoader:
"""Setup a single dataloader for accelerated training.
Args:
Expand Down Expand Up @@ -246,7 +246,9 @@ def _setup_dataloader(

dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
return _LiteDataLoader(dataloader=dataloader, device=device)
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device)
lite_dataloader = cast(DataLoader, lite_dataloader)
return lite_dataloader

def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
Expand Down

0 comments on commit a60037d

Please sign in to comment.