Skip to content

Commit

Permalink
Added multi-optimizer tests with hpu (#13217)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jerome-habana and pre-commit-ci[bot] authored Jun 21, 2022
1 parent 176ca1f commit cd44512
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deprecated `get_progress_bar_dict` property from `LightningModule` ([#12839](https://github.com/PyTorchLightning/pytorch-lightning/pull/12839))

- Removed sanity check for multi-optimizer support with habana backends ([#13217](https://github.com/PyTorchLightning/pytorch-lightning/pull/13217))


- Removed the need to explicitly load habana module ([#13338](https://github.com/PyTorchLightning/pytorch-lightning/pull/13338))


### Fixed

- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
Expand Down
1 change: 0 additions & 1 deletion docs/source-pytorch/accelerators/hpu_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ Check out the `Get Started Guide with AWS and Habana <https://docs.habana.ai/en/
Known limitations
-----------------

* Multiple optimizers are not supported.
* `Habana dataloader <https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html#habana-data-loader>`__ is not supported.
* :class:`~pytorch_lightning.callbacks.device_stats_monitor.DeviceStatsMonitor` is not supported.
* :func:`torch.inference_mode` is not supported
3 changes: 0 additions & 3 deletions src/pytorch_lightning/strategies/single_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
super().setup_optimizers(trainer)

if len(self.optimizers) > 1:
raise MisconfigurationException("HPUs currently support only one optimizer.")

def model_to_device(self) -> None:
self.model.to(self.root_device) # type: ignore

Expand Down
39 changes: 39 additions & 0 deletions tests/tests_pytorch/accelerators/test_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
from tests_pytorch.helpers.simple_models import ClassificationModel


class HPUTestModel(BoringModel):
def configure_optimizers(self):
opt_a = torch.optim.Adam(self.layer.parameters(), lr=0.001)
opt_b = torch.optim.SGD(self.layer.parameters(), lr=0.001)
return opt_a, opt_b


@RunIf(hpu=True)
def test_availability():
assert HPUAccelerator.is_available()
Expand Down Expand Up @@ -258,3 +265,35 @@ def test_strategy_params_with_hpu_parallel_strategy():
assert strategy._ddp_kwargs["gradient_as_bucket_view"] == gradient_as_bucket_view
assert strategy._ddp_kwargs["static_graph"] == static_graph
assert strategy._ddp_kwargs["find_unused_parameters"] == find_unused_parameters


@RunIf(hpu=True)
def test_multi_optimizers_with_hpu(tmpdir):
class TestModel(HPUTestModel):

optims = [False, False]

def training_step(self, batch, batch_idx, optimizer_idx):
self.optims[optimizer_idx] = True
return super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

model = TestModel()
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
accelerator="hpu",
devices=1,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
enable_model_summary=False,
)
trainer.fit(model)

assert all(model.optims)

0 comments on commit cd44512

Please sign in to comment.