Skip to content

Commit

Permalink
Fix LightningDataModule hparams parsing (#12806)
Browse files Browse the repository at this point in the history
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
4 people authored and lexierule committed Aug 31, 2022
1 parent 7e3504c commit c762804
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 36 deletions.
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))



Expand Down
13 changes: 8 additions & 5 deletions src/pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,17 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str)
holders.append(model)

# Check if attribute in model.hparams, either namespace or dict
if hasattr(model, "hparams"):
if attribute in model.hparams:
holders.append(model.hparams)
if hasattr(model, "hparams") and attribute in model.hparams:
holders.append(model.hparams)

trainer = model._trainer
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)
if trainer is not None and trainer.datamodule is not None:
if hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)

if hasattr(trainer.datamodule, "hparams") and attribute in trainer.datamodule.hparams:
holders.append(trainer.datamodule.hparams)

return holders

Expand Down
69 changes: 43 additions & 26 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


class BatchSizeDataModule(BoringDataModule):
def __init__(self, batch_size):
super().__init__()
def __init__(self, data_dir, batch_size):
super().__init__(data_dir)
if batch_size is not None:
self.batch_size = batch_size

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b
tuner = Tuner(trainer)

model = BatchSizeModel(model_bs)
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
datamodule = BatchSizeDataModule(tmpdir, dm_bs) if dm_bs != -1 else None

new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
assert new_batch_size == 16
Expand Down Expand Up @@ -140,47 +140,64 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
assert not os.path.exists(tmpdir / "scale_batch_size_temp_model.ckpt")


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("use_hparams", [True, False])
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
"""Test that new batch size gets written to the correct hyperparameter attribute."""
"""Test that new batch size gets written to the correct hyperparameter attribute for model."""
tutils.reset_seed()

hparams = {"batch_size": 2}
before_batch_size = hparams.get("batch_size")
before_batch_size = hparams["batch_size"]

class HparamsBatchSizeModel(BatchSizeModel):
class HparamsBatchSizeModel(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__()
self.save_hyperparameters()

def dataloader(self, *args, **kwargs):
# artificially set batch_size so we can get a dataloader
# remove it immediately after, because we want only self.hparams.batch_size
setattr(self, "batch_size", before_batch_size)
dataloader = super().dataloader(*args, **kwargs)
del self.batch_size
return dataloader
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size)

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size)

model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel
model = model_class(**hparams)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4})
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)


@pytest.mark.parametrize("use_hparams", [True, False])
def test_auto_scale_batch_size_set_datamodule_attribute(tmpdir, use_hparams):
"""Test that new batch size gets written to the correct hyperparameter attribute for datamodule."""
tutils.reset_seed()

hparams = {"batch_size": 2}
before_batch_size = hparams["batch_size"]

class HparamsBatchSizeDataModule(BoringDataModule):
def __init__(self, data_dir, batch_size):
super().__init__(data_dir)
self.batch_size = batch_size
self.save_hyperparameters()

def train_dataloader(self):
return DataLoader(self.random_train, batch_size=self.batch_size)
return DataLoader(self.random_train, batch_size=self.hparams.batch_size)

datamodule_fit = HparamsBatchSizeDataModule(data_dir=tmpdir, batch_size=before_batch_size)
model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel
model = model_class(**hparams)
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True, accelerator="gpu", devices=1)
trainer.tune(model, datamodule_fit)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert trainer.datamodule == datamodule_fit
assert before_batch_size != after_batch_size
datamodule_class = HparamsBatchSizeDataModule if use_hparams else BatchSizeDataModule
datamodule = datamodule_class(data_dir=tmpdir, batch_size=before_batch_size)
model = BatchSizeModel(**hparams)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model, datamodule=datamodule, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4})
after_batch_size = datamodule.hparams.batch_size if use_hparams else datamodule.batch_size
assert trainer.datamodule == datamodule
assert before_batch_size < after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)
assert datamodule_fit.batch_size == after_batch_size


def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
Expand Down
22 changes: 17 additions & 5 deletions tests/tests_pytorch/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class TestModel4(LightningModule): # fail case
batch_size = 1

model4 = TestModel4()

trainer = Trainer()
model4.trainer = trainer
datamodule = LightningDataModule()
datamodule.batch_size = 8
trainer.datamodule = datamodule
Expand All @@ -87,12 +87,21 @@ class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribut
model7 = TestModel7()
model7.trainer = trainer

return model1, model2, model3, model4, model5, model6, model7
class TestDataModule8(LightningDataModule): # test for hparams dict
hparams = TestHparamsDict2

model8 = TestModel1()
trainer = Trainer()
model8.trainer = trainer
datamodule = TestDataModule8()
trainer.datamodule = datamodule

return model1, model2, model3, model4, model5, model6, model7, model8


def test_lightning_hasattr():
"""Test that the lightning_hasattr works in all cases."""
model1, model2, model3, model4, model5, model6, model7 = models = model_cases()
model1, model2, model3, model4, model5, model6, model7, model8 = models = model_cases()
assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable"
assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable"
assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable"
Expand All @@ -104,6 +113,7 @@ def test_lightning_hasattr():
assert lightning_hasattr(
model7, "batch_size"
), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present"
assert lightning_hasattr(model8, "batch_size")

for m in models:
assert not lightning_hasattr(m, "this_attr_not_exist")
Expand All @@ -116,10 +126,11 @@ def test_lightning_getattr():
value = lightning_getattr(m, "learning_rate")
assert value == i, "attribute not correctly extracted"

model5, model6, model7 = models[4:]
model5, model6, model7, model8 = models[4:]
assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model8, "batch_size") == 2, "batch_size not correctly extracted"

for m in models:
with pytest.raises(
Expand All @@ -136,13 +147,14 @@ def test_lightning_setattr(tmpdir):
lightning_setattr(m, "learning_rate", 10)
assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set"

model5, model6, model7 = models[4:]
model5, model6, model7, model8 = models[4:]
lightning_setattr(model5, "batch_size", 128)
lightning_setattr(model6, "batch_size", 128)
lightning_setattr(model7, "batch_size", 128)
assert lightning_getattr(model5, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model6, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model7, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model8, "batch_size") == 128, "batch_size not correctly set"

for m in models:
with pytest.raises(
Expand Down

0 comments on commit c762804

Please sign in to comment.