diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index 5c864f08df..23f8f50c34 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -33,51 +33,40 @@ class CPCV2(pl.LightningModule): def __init__( self, - datamodule: Optional[pl.LightningDataModule] = None, encoder_name: str = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, - online_ft: int = True, + online_ft: bool = True, task: str = 'cpc', num_workers: int = 4, - learning_rate: int = 1e-4, - data_dir: str = '', - batch_size: int = 32, + num_classes: int = 10, + learning_rate: float = 1e-4, pretrained: Optional[str] = None, **kwargs, ): """ Args: - datamodule: A Datamodule (optional). Otherwise set the dataloaders directly encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches - patch_overlap: How much overlap should each patch have. - online_ft: Enable a 1024-unit MLP to fine-tune online + patch_overlap: How much overlap each patch should have + online_ft: If True, enables a 1024-unit MLP to fine-tune online task: Which self-supervised task to use ('cpc', 'amdim', etc...) - num_workers: num dataloader worksers - learning_rate: what learning rate to use - data_dir: where to store data - batch_size: batch size + num_workers: number of dataloader workers + num_classes: number of classes + learning_rate: learning rate pretrained: If true, will use the weights pretrained (using CPC) on Imagenet """ super().__init__() self.save_hyperparameters() - # HACK - datamodule not pickleable so we remove it from hparams. - # TODO - remove datamodule from init. data should be decoupled from models. - del self.hparams['datamodule'] - self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True - assert datamodule - self.datamodule = datamodule - self.encoder = self.init_encoder() # info nce loss @@ -85,13 +74,11 @@ def __init__( self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1) self.z_dim = c * h * h - self.num_classes = self.datamodule.num_classes + self.num_classes = num_classes if pretrained: self.load_pretrained(self.hparams.encoder_name) - print(self.hparams) - def load_pretrained(self, encoder_name): available_weights = {'resnet18'} @@ -212,19 +199,9 @@ def add_model_specific_args(parent_parser): 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' ] parser.add_argument('--encoder', default='cpc_encoder', type=str, choices=possible_resnets) - - # training params - parser.add_argument('--batch_size', type=int, default=128) - # cifar10: 1e-5, stl10: 3e-5, imagenet: 4e-4 parser.add_argument('--learning_rate', type=float, default=1e-5) - # data - parser.add_argument('--dataset', default='cifar10', type=str) - parser.add_argument('--data_dir', default='.', type=str) - parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') - parser.add_argument('--num_workers', default=8, type=int) - return parser @@ -237,9 +214,13 @@ def cli_main(): parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = CPCV2.add_model_specific_args(parser) + parser.add_argument('--dataset', default='cifar10', type=str) + parser.add_argument('--data_dir', default='.', type=str) + parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--batch_size', type=int, default=128) args = parser.parse_args() - args.online_ft = True datamodule = None @@ -276,9 +257,9 @@ def to_device(batch, device): datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 - model = CPCV2(**vars(args), datamodule=datamodule) + model = CPCV2(**vars(args)) trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator]) - trainer.fit(model) + trainer.fit(model, datamodule) if __name__ == '__main__': diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 149239e8f7..6ef6d3f1b1 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -22,9 +22,9 @@ def test_cpcv2(tmpdir, datadir): datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() - model = CPCV2(encoder='resnet18', data_dir=datadir, batch_size=2, online_ft=True, datamodule=datamodule) + model = CPCV2(encoder='resnet18', online_ft=True, num_classes=datamodule.num_classes) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) - trainer.fit(model) + trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['val_nce'] assert float(loss) > 0