From baf3e5c12f7fe3e8e7a4c2a1bd67fbe2e4af7db5 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sat, 19 Oct 2024 02:00:03 +1100 Subject: [PATCH] enhance 3d-party devices in mix-precision --- src/lightning/fabric/connector.py | 6 +++++- src/lightning/fabric/plugins/precision/amp.py | 2 +- src/lightning/fabric/plugins/precision/fsdp.py | 5 +++-- src/lightning/fabric/strategies/ddp.py | 2 +- src/lightning/fabric/strategies/deepspeed.py | 2 +- src/lightning/fabric/strategies/strategy.py | 2 +- .../pytorch/accelerators/accelerator.py | 5 +++++ src/lightning/pytorch/accelerators/cpu.py | 5 +++++ src/lightning/pytorch/accelerators/cuda.py | 5 +++++ src/lightning/pytorch/accelerators/mps.py | 5 +++++ src/lightning/pytorch/plugins/precision/amp.py | 2 +- .../pytorch/plugins/precision/fsdp.py | 5 +++-- src/lightning/pytorch/strategies/ddp.py | 2 +- src/lightning/pytorch/strategies/strategy.py | 2 +- .../connectors/accelerator_connector.py | 18 +++++++++++++++--- 15 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9fb66255830c61..d654e84e040cc3 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -141,6 +141,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # do nothing self._set_parallel_devices_and_init_accelerator() @@ -461,7 +463,7 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type] + return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) # type: ignore[arg-type] mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: raise ValueError( @@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision: else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device() return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index c624e821af28cb..96046386834688 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -50,7 +50,7 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler() if scaler is not None and self.precision == "bf16-mixed": raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 179fc21cdd90de..aa8d17017f2d51 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -48,13 +48,14 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device = device if device is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -110,7 +111,7 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return self.tensor_init_context() @override diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c38780655ce6ea..01fe81b51b694a 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,7 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext() with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 93a17f10c8998b..c792ee3405440c 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -506,7 +506,7 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - torch.cuda.empty_cache() + getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None _, client_state = engine.load_checkpoint( path, tag="checkpoint", diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b683..96f4edbc27373f 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -325,7 +325,7 @@ def load_checkpoint( given, the full checkpoint will be returned. """ - torch.cuda.empty_cache() + getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None checkpoint = self.checkpoint_io.load_checkpoint(path) if not state: return checkpoint diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 0490c2d86431c1..d65ad3a7cd3955 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """ raise NotImplementedError + + @staticmethod + def get_device() -> str: + """Get the device for the current process.""" + raise NotImplementedError diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 735312b363d111..ab6304053f314e 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "cpu" + # CPU device metrics _CPU_VM_PERCENT = "cpu_vm_percent" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 6df3bc6b468eea..cfb85cb2c29903 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "cuda" + def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 6efe6292de624c..d8bda9dae80871 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "mps" + # device metrics _VM_PERCENT = "M1_vm_percent" diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index e63ccd6912b63e..e639d21165387c 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -50,7 +50,7 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler() if scaler is not None and self.precision == "bf16-mixed": raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index e6c684967ed404..978edfeed55716 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -47,13 +47,14 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device = device if device is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -119,7 +120,7 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return _DtypeContextManager(self._desired_input_dtype) @override diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3b..a8dd1a9bfe6b61 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,7 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext() with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 314007f497f597..7114fe1407ec70 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - torch.cuda.empty_cache() + getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 06f3ee366bcaa5..8ef39e1edd8e3a 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -141,6 +141,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # do nothing self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) self._set_parallel_devices_and_init_accelerator() @@ -301,15 +303,18 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" - if self._strategy_flag.parallel_devices[0].type == "cuda": + elif self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + else: + pass # 3rd party accelerator self._parallel_devices = self._strategy_flag.parallel_devices + def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") @@ -458,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None: if ( strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and self._accelerator_flag not in ("cuda", "gpu"): + ) and self._accelerator_flag not in ("cuda", "gpu") and isinstance(self._accelerator_flag, str): raise ValueError( f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:" f" {self._accelerator_flag}" ) + elif isinstance(self._accelerator_flag, Accelerator): + Warning( + f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`." + f" Please ensure it is compatible with the selected strategy `{strategy_flag}`." + ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" @@ -496,7 +506,7 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] + return FSDPPrecision(precision=self._precision_flag, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": @@ -520,6 +530,8 @@ def _check_and_init_precision(self) -> Precision: f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device() return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set")