Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Aug 15, 2022
1 parent ea270e2 commit 3ab0c6a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
from pytorch_lightning.utilities.types import (
_PATH,
LRSchedulerConfig,
Expand Down Expand Up @@ -181,7 +181,7 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
from fairscale.optim import OSS

if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or (
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)
):
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}
Expand Down
53 changes: 53 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
from tests_pytorch.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import ZeroRedundancyOptimizer


class BoringModelGPU(BoringModel):
def on_train_start(self) -> None:
Expand Down Expand Up @@ -252,3 +258,50 @@ def test_ddp_strategy_set_timeout(mock_init_process_group):
mock_init_process_group.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn"))
def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using faircale optimizer."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_epochs=2)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)


class BoringZeroRedundancyOptimizerModel(BoringModel):
def configure_optimizers(self):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)


@RunIf(min_gpus=2, skip_windows=True, min_torch="1.10")
@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn"))
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
model = BoringZeroRedundancyOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_epochs=2)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)
25 changes: 25 additions & 0 deletions tests/tests_pytorch/strategies/test_sharded_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS


@pytest.mark.parametrize("clip_val", [0, 10])
Expand Down Expand Up @@ -314,3 +315,27 @@ def test_block_backward_sync():
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
trainer = Trainer(strategy=strategy_name)
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", ("ddp_sharded", "ddp_sharded_spawn"))
def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using fairscale optimizers."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_epochs=2)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)

0 comments on commit 3ab0c6a

Please sign in to comment.