diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index d32c09b9cc..2c4e675989 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -67,10 +67,10 @@ def train_dataloader(self): dm = DummyDataModule() - model = SemSegment(datamodule=dm, num_classes=19) + model = SemSegment(num_classes=19) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1) - trainer.fit(model) + trainer.fit(model, dm) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0