Skip to content

Commit

Permalink
Remove the requirement for FSDPStrategy subclasses to only support GPU (
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 16, 2024
1 parent 58ad56a commit c235f20
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,11 @@ def _check_strategy_and_fallback(self) -> None:
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag

if (
strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy)
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
) and self._accelerator_flag not in ("cuda", "gpu"):
raise MisconfigurationException(
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used."
raise ValueError(
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
f" {self._accelerator_flag}"
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
Expand Down
6 changes: 1 addition & 5 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
from torchmetrics import Accuracy
Expand Down Expand Up @@ -216,10 +215,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):

def test_invalid_on_cpu(tmp_path, cuda_count_0):
"""Test to ensure that we raise Misconfiguration for FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
):
with pytest.raises(ValueError, match="The strategy `fsdp` requires a GPU accelerator"):
trainer = Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy="fsdp")
assert isinstance(trainer.strategy, FSDPStrategy)
trainer.strategy.setup_environment()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,19 @@ def test_strategy_choice_ddp_cpu_slurm(cuda_count_0, strategy):


def test_check_fsdp_strategy_and_fallback():
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
):
with pytest.raises(ValueError, match="The strategy `fsdp` requires a GPU accelerator"):
Trainer(accelerator="cpu", strategy="fsdp")

class FSDPStrategySubclass(FSDPStrategy):
pass

class AcceleratorSubclass(CPUAccelerator):
pass

# we allow subclasses of FSDPStrategy to be used with other accelerators
Trainer(accelerator="cpu", strategy=FSDPStrategySubclass())
Trainer(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass())


@mock.patch.dict(os.environ, {}, clear=True)
def test_unsupported_tpu_choice(xla_available, tpu_available):
Expand Down

0 comments on commit c235f20

Please sign in to comment.