diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 2e6f10d356fe0..ad3b75b59ff46 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -41,10 +41,10 @@ ) from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import has_iterable_dataset +from pytorch_lightning.utilities.device_parser import _parse_devices from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -86,7 +86,7 @@ def __init__( ) -> None: self._check_accelerator_support(accelerator) self._check_strategy_support(strategy) - gpu_ids, tpu_cores = Trainer._parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores) + gpu_ids, tpu_cores = _parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores) self._accelerator_connector = AcceleratorConnector( num_processes=1, devices=devices, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e83efea7a6e89..b6dfcbfee8bc6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -60,7 +60,6 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus -from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( @@ -1488,13 +1487,7 @@ def _parse_devices( auto_select_gpus: bool, tpu_cores: Optional[Union[List[int], str, int]], ) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]: - if auto_select_gpus and isinstance(gpus, int): - gpus = pick_multiple_gpus(gpus) - - # TODO (@seannaren, @kaushikb11): Include IPU parsing logic here - gpu_ids = device_parser.parse_gpu_ids(gpus) - tpu_cores = device_parser.parse_tpu_cores(tpu_cores) - return gpu_ids, tpu_cores + return device_parser._parse_devices(gpus, auto_select_gpus, tpu_cores) @staticmethod def _log_api_event(event: str) -> None: diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index aadfd28a510e8..c0913633b5717 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -48,6 +49,20 @@ def determine_root_gpu_device(gpus: List[int]) -> Optional[int]: return root_gpu +def _parse_devices( + gpus: Optional[Union[List[int], str, int]], + auto_select_gpus: bool, + tpu_cores: Optional[Union[List[int], str, int]], +) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]: + if auto_select_gpus and isinstance(gpus, int): + gpus = pick_multiple_gpus(gpus) + + # TODO (@seannaren, @kaushikb11): Include IPU parsing logic here + gpu_ids = parse_gpu_ids(gpus) + tpu_cores = parse_tpu_cores(tpu_cores) + return gpu_ids, tpu_cores + + def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]: """ Parses the GPU ids given in the format as accepted by the @@ -89,7 +104,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i return _sanitize_gpu_ids(gpus) -def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[int, List[int]]]: +def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]: """ Parses the tpu_cores given in the format as accepted by the :class:`~pytorch_lightning.trainer.Trainer`.