From 807435885ea265580fee9f4e69c063eace46def2 Mon Sep 17 00:00:00 2001 From: Tanmoy Date: Fri, 26 Aug 2022 00:27:48 +0530 Subject: [PATCH] Fix `LightningDataModule` hparams parsing (#12806) Co-authored-by: Akihiro Nitta Co-authored-by: Jirka Co-authored-by: Rohit Gupta --- src/pytorch_lightning/CHANGELOG.md | 3 + src/pytorch_lightning/utilities/parsing.py | 13 ++-- .../tuner/test_scale_batch_size.py | 69 ++++++++++++------- tests/tests_pytorch/utilities/test_parsing.py | 22 ++++-- 4 files changed, 71 insertions(+), 36 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 07c34bbc0e5793..642cb28d4db4ce 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296)) +- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806)) + + ## [1.7.2] - 2022-08-17 ### Added diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index 073423ab60773d..22dfb538828ab0 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -321,14 +321,17 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) holders.append(model) # Check if attribute in model.hparams, either namespace or dict - if hasattr(model, "hparams"): - if attribute in model.hparams: - holders.append(model.hparams) + if hasattr(model, "hparams") and attribute in model.hparams: + holders.append(model.hparams) trainer = model._trainer # Check if the attribute in datamodule (datamodule gets registered in Trainer) - if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): - holders.append(trainer.datamodule) + if trainer is not None and trainer.datamodule is not None: + if hasattr(trainer.datamodule, attribute): + holders.append(trainer.datamodule) + + if hasattr(trainer.datamodule, "hparams") and attribute in trainer.datamodule.hparams: + holders.append(trainer.datamodule.hparams) return holders diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index d2fc8a61e01078..ce7c3613f50125 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -29,8 +29,8 @@ class BatchSizeDataModule(BoringDataModule): - def __init__(self, batch_size): - super().__init__() + def __init__(self, data_dir, batch_size): + super().__init__(data_dir) if batch_size is not None: self.batch_size = batch_size @@ -58,7 +58,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b tuner = Tuner(trainer) model = BatchSizeModel(model_bs) - datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None + datamodule = BatchSizeDataModule(tmpdir, dm_bs) if dm_bs != -1 else None new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule) assert new_batch_size == 16 @@ -140,47 +140,64 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg): assert not os.path.exists(tmpdir / "scale_batch_size_temp_model.ckpt") -@RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("use_hparams", [True, False]) def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): - """Test that new batch size gets written to the correct hyperparameter attribute.""" + """Test that new batch size gets written to the correct hyperparameter attribute for model.""" tutils.reset_seed() hparams = {"batch_size": 2} - before_batch_size = hparams.get("batch_size") + before_batch_size = hparams["batch_size"] - class HparamsBatchSizeModel(BatchSizeModel): + class HparamsBatchSizeModel(BoringModel): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__() self.save_hyperparameters() - def dataloader(self, *args, **kwargs): - # artificially set batch_size so we can get a dataloader - # remove it immediately after, because we want only self.hparams.batch_size - setattr(self, "batch_size", before_batch_size) - dataloader = super().dataloader(*args, **kwargs) - del self.batch_size - return dataloader + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size) + + model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel + model = model_class(**hparams) + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) + trainer.tune(model, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4}) + after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size + assert before_batch_size != after_batch_size + assert after_batch_size <= len(trainer.train_dataloader.dataset) + + +@pytest.mark.parametrize("use_hparams", [True, False]) +def test_auto_scale_batch_size_set_datamodule_attribute(tmpdir, use_hparams): + """Test that new batch size gets written to the correct hyperparameter attribute for datamodule.""" + tutils.reset_seed() + + hparams = {"batch_size": 2} + before_batch_size = hparams["batch_size"] class HparamsBatchSizeDataModule(BoringDataModule): def __init__(self, data_dir, batch_size): super().__init__(data_dir) - self.batch_size = batch_size + self.save_hyperparameters() def train_dataloader(self): - return DataLoader(self.random_train, batch_size=self.batch_size) + return DataLoader(self.random_train, batch_size=self.hparams.batch_size) - datamodule_fit = HparamsBatchSizeDataModule(data_dir=tmpdir, batch_size=before_batch_size) - model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel - model = model_class(**hparams) + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True, accelerator="gpu", devices=1) - trainer.tune(model, datamodule_fit) - after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size - assert trainer.datamodule == datamodule_fit - assert before_batch_size != after_batch_size + datamodule_class = HparamsBatchSizeDataModule if use_hparams else BatchSizeDataModule + datamodule = datamodule_class(data_dir=tmpdir, batch_size=before_batch_size) + model = BatchSizeModel(**hparams) + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) + trainer.tune(model, datamodule=datamodule, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4}) + after_batch_size = datamodule.hparams.batch_size if use_hparams else datamodule.batch_size + assert trainer.datamodule == datamodule + assert before_batch_size < after_batch_size assert after_batch_size <= len(trainer.train_dataloader.dataset) - assert datamodule_fit.batch_size == after_batch_size def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index e918c9df2ac323..98b00a374d778d 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -64,8 +64,8 @@ class TestModel4(LightningModule): # fail case batch_size = 1 model4 = TestModel4() - trainer = Trainer() + model4.trainer = trainer datamodule = LightningDataModule() datamodule.batch_size = 8 trainer.datamodule = datamodule @@ -87,12 +87,21 @@ class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribut model7 = TestModel7() model7.trainer = trainer - return model1, model2, model3, model4, model5, model6, model7 + class TestDataModule8(LightningDataModule): # test for hparams dict + hparams = TestHparamsDict2 + + model8 = TestModel1() + trainer = Trainer() + model8.trainer = trainer + datamodule = TestDataModule8() + trainer.datamodule = datamodule + + return model1, model2, model3, model4, model5, model6, model7, model8 def test_lightning_hasattr(): """Test that the lightning_hasattr works in all cases.""" - model1, model2, model3, model4, model5, model6, model7 = models = model_cases() + model1, model2, model3, model4, model5, model6, model7, model8 = models = model_cases() assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable" assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable" assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable" @@ -104,6 +113,7 @@ def test_lightning_hasattr(): assert lightning_hasattr( model7, "batch_size" ), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + assert lightning_hasattr(model8, "batch_size") for m in models: assert not lightning_hasattr(m, "this_attr_not_exist") @@ -116,10 +126,11 @@ def test_lightning_getattr(): value = lightning_getattr(m, "learning_rate") assert value == i, "attribute not correctly extracted" - model5, model6, model7 = models[4:] + model5, model6, model7, model8 = models[4:] assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted" assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted" assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted" + assert lightning_getattr(model8, "batch_size") == 2, "batch_size not correctly extracted" for m in models: with pytest.raises( @@ -136,13 +147,14 @@ def test_lightning_setattr(tmpdir): lightning_setattr(m, "learning_rate", 10) assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set" - model5, model6, model7 = models[4:] + model5, model6, model7, model8 = models[4:] lightning_setattr(model5, "batch_size", 128) lightning_setattr(model6, "batch_size", 128) lightning_setattr(model7, "batch_size", 128) assert lightning_getattr(model5, "batch_size") == 128, "batch_size not correctly set" assert lightning_getattr(model6, "batch_size") == 128, "batch_size not correctly set" assert lightning_getattr(model7, "batch_size") == 128, "batch_size not correctly set" + assert lightning_getattr(model8, "batch_size") == 128, "batch_size not correctly set" for m in models: with pytest.raises(