Skip to content

Commit

Permalink
enhance 3d-party devices in mix-precision
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Oct 19, 2024
1 parent 8ad3e29 commit baf3e5c
Show file tree
Hide file tree
Showing 15 changed files with 53 additions and 15 deletions.
6 changes: 5 additions & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __init__(
self._accelerator_flag = self._choose_auto_accelerator()
elif self._accelerator_flag == "gpu":
self._accelerator_flag = self._choose_gpu_accelerator_backend()
elif isinstance(self._accelerator_flag, Accelerator):
pass # do nothing

self._set_parallel_devices_and_init_accelerator()

Expand Down Expand Up @@ -461,7 +463,7 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) # type: ignore[arg-type]
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
raise ValueError(
Expand Down Expand Up @@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision:
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
if isinstance(self._accelerator_flag, Accelerator):
device = self._accelerator_flag.get_device()
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ class FSDPPrecision(Precision):
"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
self.device = device if device is not None else "cuda"

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down Expand Up @@ -110,7 +111,7 @@ def module_init_context(self) -> ContextManager:
@override
def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return self.tensor_init_context()

@override
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def load_checkpoint(

optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())

torch.cuda.empty_cache()
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
_, client_state = engine.load_checkpoint(
path,
tag="checkpoint",
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def load_checkpoint(
given, the full checkpoint will be returned.
"""
torch.cuda.empty_cache()
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
checkpoint = self.checkpoint_io.load_checkpoint(path)
if not state:
return checkpoint
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""
raise NotImplementedError

@staticmethod
def get_device() -> str:
"""Get the device for the current process."""
raise NotImplementedError
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device() -> str:
return "cpu"


# CPU device metrics
_CPU_VM_PERCENT = "cpu_vm_percent"
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device() -> str:
return "cuda"


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device() -> str:
return "mps"


# device metrics
_VM_PERCENT = "M1_vm_percent"
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ class FSDPPrecision(Precision):
"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
self.device = device if device is not None else "cuda"

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down Expand Up @@ -119,7 +120,7 @@ def module_init_context(self) -> ContextManager:
@override
def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return _DtypeContextManager(self._desired_input_dtype)

@override
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
return self._lightning_module

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
return self.checkpoint_io.load_checkpoint(checkpoint_path)

def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
Expand Down
18 changes: 15 additions & 3 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __init__(
self._accelerator_flag = self._choose_auto_accelerator()
elif self._accelerator_flag == "gpu":
self._accelerator_flag = self._choose_gpu_accelerator_backend()
elif isinstance(self._accelerator_flag, Accelerator):
pass # do nothing

self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
self._set_parallel_devices_and_init_accelerator()
Expand Down Expand Up @@ -301,15 +303,18 @@ def _check_config_and_set_final_flags(
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "cpu"
if self._strategy_flag.parallel_devices[0].type == "cuda":
elif self._strategy_flag.parallel_devices[0].type == "cuda":
if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"):
raise MisconfigurationException(
f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "cuda"
else:
pass # 3rd party accelerator
self._parallel_devices = self._strategy_flag.parallel_devices


def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
if not isinstance(num_nodes, int) or num_nodes < 1:
raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.")
Expand Down Expand Up @@ -458,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:

if (
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
) and self._accelerator_flag not in ("cuda", "gpu"):
) and self._accelerator_flag not in ("cuda", "gpu") and isinstance(self._accelerator_flag, str):
raise ValueError(
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
f" {self._accelerator_flag}"
)
elif isinstance(self._accelerator_flag, Accelerator):
Warning(
f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`."
f" Please ensure it is compatible with the selected strategy `{strategy_flag}`."
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
Expand Down Expand Up @@ -496,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type]
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(self._precision_flag) # type: ignore[arg-type]
return FSDPPrecision(precision=self._precision_flag, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None)
if self._precision_flag in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_flag) # type: ignore
if self._precision_flag == "32-true":
Expand All @@ -520,6 +530,8 @@ def _check_and_init_precision(self) -> Precision:
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
if isinstance(self._accelerator_flag, Accelerator):
device = self._accelerator_flag.get_device()
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down

0 comments on commit baf3e5c

Please sign in to comment.