Skip to content

Commit

Permalink
Move device parser utility function (#10230)
Browse files Browse the repository at this point in the history
* move parser function to utils

* fix types

* keep static method

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Nov 5, 2021
1 parent 9c4112c commit 348fc4b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
)
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import has_iterable_dataset
from pytorch_lightning.utilities.device_parser import _parse_devices
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
) -> None:
self._check_accelerator_support(accelerator)
self._check_strategy_support(strategy)
gpu_ids, tpu_cores = Trainer._parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores)
gpu_ids, tpu_cores = _parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores)
self._accelerator_connector = AcceleratorConnector(
num_processes=1,
devices=devices,
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import (
Expand Down Expand Up @@ -1488,13 +1487,7 @@ def _parse_devices(
auto_select_gpus: bool,
tpu_cores: Optional[Union[List[int], str, int]],
) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]:
if auto_select_gpus and isinstance(gpus, int):
gpus = pick_multiple_gpus(gpus)

# TODO (@seannaren, @kaushikb11): Include IPU parsing logic here
gpu_ids = device_parser.parse_gpu_ids(gpus)
tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
return gpu_ids, tpu_cores
return device_parser._parse_devices(gpus, auto_select_gpus, tpu_cores)

@staticmethod
def _log_api_event(event: str) -> None:
Expand Down
17 changes: 16 additions & 1 deletion pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -48,6 +49,20 @@ def determine_root_gpu_device(gpus: List[int]) -> Optional[int]:
return root_gpu


def _parse_devices(
gpus: Optional[Union[List[int], str, int]],
auto_select_gpus: bool,
tpu_cores: Optional[Union[List[int], str, int]],
) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]:
if auto_select_gpus and isinstance(gpus, int):
gpus = pick_multiple_gpus(gpus)

# TODO (@seannaren, @kaushikb11): Include IPU parsing logic here
gpu_ids = parse_gpu_ids(gpus)
tpu_cores = parse_tpu_cores(tpu_cores)
return gpu_ids, tpu_cores


def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
"""
Parses the GPU ids given in the format as accepted by the
Expand Down Expand Up @@ -89,7 +104,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
return _sanitize_gpu_ids(gpus)


def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[int, List[int]]]:
def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
"""
Parses the tpu_cores given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Expand Down

0 comments on commit 348fc4b

Please sign in to comment.