Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid initializing optimizers during deepspeed evaluation #14944

Merged
merged 8 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Do not update on-plateau schedulers when reloading from an end-of-epoch checkpoint ([#14702](https://github.com/Lightning-AI/lightning/pull/14702))
- Fixed `Trainer` support for PyTorch built without distributed support ([#14971](https://github.com/Lightning-AI/lightning/pull/14971))
- Fixed batch normalization statistics calculation in `StochasticWeightAveraging` callback ([#14866](https://github.com/Lightning-AI/lightning/pull/14866))
- Avoided initializing optimizers during deepspeed inference ([#14944](https://github.com/Lightning-AI/lightning/pull/14944))
- Fixed `LightningCLI` parse_env and description in subcommands ([#15138](https://github.com/Lightning-AI/lightning/pull/15138))


rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
15 changes: 3 additions & 12 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,17 +559,8 @@ def _set_deepspeed_activation_checkpointing(self) -> None:
)

def _initialize_deepspeed_inference(self, model: Module) -> None:
# todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
assert isinstance(self.config, dict)
optimizer, scheduler = None, None
if "optimizer" not in self.config:
rank_zero_info(
"You have not specified an optimizer or scheduler within the DeepSpeed config."
" Using `configure_optimizers` to define optimizer and scheduler."
)
optimizer, lr_scheduler, _ = self._init_optimizers()
if lr_scheduler is not None:
scheduler = lr_scheduler.scheduler

# todo: this is required for DeepSpeed throughput timers
inference_config = {"train_micro_batch_size_per_gpu": 1}
if "fp16" in self.config:
Expand All @@ -587,8 +578,8 @@ def _initialize_deepspeed_inference(self, model: Module) -> None:
args=argparse.Namespace(device_rank=self.root_device.index),
config=inference_config,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
optimizer=None,
lr_scheduler=None,
model_parameters=[],
dist_init_required=False,
)
Expand Down
12 changes: 9 additions & 3 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,16 @@ def test_deepspeed_multigpu(tmpdir):
enable_progress_bar=False,
enable_model_summary=False,
)

with mock.patch.object(
model, "configure_optimizers", wraps=model.configure_optimizers
) as mock_configure_optimizers:
trainer.test(model)
assert mock_configure_optimizers.call_count == 0

with mock.patch("deepspeed.init_distributed", wraps=deepspeed.init_distributed) as mock_deepspeed_distributed:
trainer.fit(model)
mock_deepspeed_distributed.assert_called_once()
trainer.test(model)

_assert_save_model_is_equal(model, tmpdir, trainer)

Expand Down Expand Up @@ -655,8 +661,8 @@ def test_deepspeed_multigpu_stage_3(tmpdir):
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
trainer.test(model)
trainer.fit(model)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

_assert_save_model_is_equal(model, tmpdir, trainer)

Expand All @@ -676,8 +682,8 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmpdir, trainer)

Expand Down