From f671de1d3c75f84e3168bb3fe43c5fdb5f5af6b3 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 20 Nov 2020 11:46:55 +0900 Subject: [PATCH] Update tests --- tests/models/self_supervised/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 95175b7bb6..b89f178072 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -22,9 +22,9 @@ def test_cpcv2(tmpdir, datadir): datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() - model = CPCV2(encoder='resnet18', data_dir=datadir, batch_size=2, online_ft=True, datamodule=datamodule) + model = CPCV2(encoder='resnet18', online_ft=True, num_classes=datamodule.num_classes) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) - trainer.fit(model) + trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['val_nce'] assert float(loss) > 0