Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Aug 25, 2022
1 parent 9f6cc32 commit c06912e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
Expand Down Expand Up @@ -144,6 +145,9 @@ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)
Expand All @@ -155,9 +159,6 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
[
pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
pytest.param("fsdp_native", marks=RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="1.12")),
],
)
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
Expand Down

0 comments on commit c06912e

Please sign in to comment.