Skip to content

Commit

Permalink
Decouple DataModules from Models - CPCV2 (#386)
Browse files Browse the repository at this point in the history
* Decouple dms from CPCV2

* Update tests
  • Loading branch information
akihironitta authored Nov 26, 2020
1 parent 040eec3 commit 2e903c3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 37 deletions.
51 changes: 16 additions & 35 deletions pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,65 +33,52 @@ class CPCV2(pl.LightningModule):

def __init__(
self,
datamodule: Optional[pl.LightningDataModule] = None,
encoder_name: str = 'cpc_encoder',
patch_size: int = 8,
patch_overlap: int = 4,
online_ft: int = True,
online_ft: bool = True,
task: str = 'cpc',
num_workers: int = 4,
learning_rate: int = 1e-4,
data_dir: str = '',
batch_size: int = 32,
num_classes: int = 10,
learning_rate: float = 1e-4,
pretrained: Optional[str] = None,
**kwargs,
):
"""
Args:
datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder,
or a custon nn.Module encoder
patch_size: How big to make the image patches
patch_overlap: How much overlap should each patch have.
online_ft: Enable a 1024-unit MLP to fine-tune online
patch_overlap: How much overlap each patch should have
online_ft: If True, enables a 1024-unit MLP to fine-tune online
task: Which self-supervised task to use ('cpc', 'amdim', etc...)
num_workers: num dataloader worksers
learning_rate: what learning rate to use
data_dir: where to store data
batch_size: batch size
num_workers: number of dataloader workers
num_classes: number of classes
learning_rate: learning rate
pretrained: If true, will use the weights pretrained (using CPC) on Imagenet
"""

super().__init__()
self.save_hyperparameters()

# HACK - datamodule not pickleable so we remove it from hparams.
# TODO - remove datamodule from init. data should be decoupled from models.
del self.hparams['datamodule']

self.online_evaluator = self.hparams.online_ft

if pretrained:
self.hparams.dataset = pretrained
self.online_evaluator = True

assert datamodule
self.datamodule = datamodule

self.encoder = self.init_encoder()

# info nce loss
c, h = self.__compute_final_nb_c(self.hparams.patch_size)
self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1)

self.z_dim = c * h * h
self.num_classes = self.datamodule.num_classes
self.num_classes = num_classes

if pretrained:
self.load_pretrained(self.hparams.encoder_name)

print(self.hparams)

def load_pretrained(self, encoder_name):
available_weights = {'resnet18'}

Expand Down Expand Up @@ -212,19 +199,9 @@ def add_model_specific_args(parent_parser):
'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
]
parser.add_argument('--encoder', default='cpc_encoder', type=str, choices=possible_resnets)

# training params
parser.add_argument('--batch_size', type=int, default=128)

# cifar10: 1e-5, stl10: 3e-5, imagenet: 4e-4
parser.add_argument('--learning_rate', type=float, default=1e-5)

# data
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--data_dir', default='.', type=str)
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet')
parser.add_argument('--num_workers', default=8, type=int)

return parser


Expand All @@ -237,9 +214,13 @@ def cli_main():
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = CPCV2.add_model_specific_args(parser)
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--data_dir', default='.', type=str)
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--batch_size', type=int, default=128)

args = parser.parse_args()
args.online_ft = True

datamodule = None

Expand Down Expand Up @@ -276,9 +257,9 @@ def to_device(batch, device):
datamodule.val_transforms = CPCEvalTransformsImageNet128()
args.patch_size = 32

model = CPCV2(**vars(args), datamodule=datamodule)
model = CPCV2(**vars(args))
trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
trainer.fit(model)
trainer.fit(model, datamodule)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def test_cpcv2(tmpdir, datadir):
datamodule.train_transforms = CPCTrainTransformsCIFAR10()
datamodule.val_transforms = CPCEvalTransformsCIFAR10()

model = CPCV2(encoder='resnet18', data_dir=datadir, batch_size=2, online_ft=True, datamodule=datamodule)
model = CPCV2(encoder='resnet18', online_ft=True, num_classes=datamodule.num_classes)
trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
trainer.fit(model, datamodule)
loss = trainer.progress_bar_dict['val_nce']

assert float(loss) > 0
Expand Down

0 comments on commit 2e903c3

Please sign in to comment.