From d66443d3d83b8f7277bdf01762bbf4dfe54d252d Mon Sep 17 00:00:00 2001 From: "O'Donnell, Garry (DLSLtd,RAL,LSCI)" Date: Mon, 10 May 2021 13:41:15 +0100 Subject: [PATCH 1/5] Added base_encoder argument to BYOL constructor --- pl_bolts/models/self_supervised/byol/byol_module.py | 10 +++++++--- pl_bolts/models/self_supervised/byol/models.py | 10 +++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 3107a1956f..f01fc13180 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser from copy import deepcopy -from typing import Any +from typing import Any, Union import pytorch_lightning as pl import torch @@ -72,6 +72,10 @@ def __init__( num_workers: int = 0, warmup_epochs: int = 10, max_epochs: int = 1000, + base_encoder: Union[str, torch.nn.Module] = 'resnet50', + emb_dim: int = 2048, + hidden_size: int = 4096, + proj_dim: int = 256, **kwargs ): """ @@ -86,9 +90,9 @@ def __init__( max_epochs: max epochs for scheduler """ super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore='base_encoder') - self.online_network = SiameseArm() + self.online_network = SiameseArm(base_encoder, emb_dim, hidden_size, proj_dim) self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py index 53b90bf6ef..765fb2eb56 100644 --- a/pl_bolts/models/self_supervised/byol/models.py +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -23,17 +23,17 @@ def forward(self, x): class SiameseArm(nn.Module): - def __init__(self, encoder=None): + def __init__(self, encoder='resnet50', input_dim=2048, hidden_size=4096, output_dim=256): super().__init__() - if encoder is None: - encoder = torchvision_ssl_encoder('resnet50') + if isinstance(encoder, str): + encoder = torchvision_ssl_encoder(encoder) # Encoder self.encoder = encoder # Projector - self.projector = MLP() + self.projector = MLP(input_dim, hidden_size, output_dim) # Predictor - self.predictor = MLP(input_dim=256) + self.predictor = MLP(output_dim, hidden_size, output_dim) def forward(self, x): y = self.encoder(x)[0] From a18143f8d5a647b11a6174ee49765556cb3b55b4 Mon Sep 17 00:00:00 2001 From: "O'Donnell, Garry (DLSLtd,RAL,LSCI)" Date: Mon, 10 May 2021 14:13:38 +0100 Subject: [PATCH 2/5] Added encoder and MLP dimension args to BYOL --- pl_bolts/models/self_supervised/byol/byol_module.py | 8 ++++---- pl_bolts/models/self_supervised/byol/models.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index f01fc13180..627dbd2b27 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -73,9 +73,9 @@ def __init__( warmup_epochs: int = 10, max_epochs: int = 1000, base_encoder: Union[str, torch.nn.Module] = 'resnet50', - emb_dim: int = 2048, - hidden_size: int = 4096, - proj_dim: int = 256, + encoder_out_dim: int = 2048, + projector_hidden_size: int = 4096, + projector_out_dim: int = 256, **kwargs ): """ @@ -92,7 +92,7 @@ def __init__( super().__init__() self.save_hyperparameters(ignore='base_encoder') - self.online_network = SiameseArm(base_encoder, emb_dim, hidden_size, proj_dim) + self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim) self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py index 765fb2eb56..d7e5e87a29 100644 --- a/pl_bolts/models/self_supervised/byol/models.py +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -23,7 +23,7 @@ def forward(self, x): class SiameseArm(nn.Module): - def __init__(self, encoder='resnet50', input_dim=2048, hidden_size=4096, output_dim=256): + def __init__(self, encoder='resnet50', encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): super().__init__() if isinstance(encoder, str): @@ -31,9 +31,9 @@ def __init__(self, encoder='resnet50', input_dim=2048, hidden_size=4096, output_ # Encoder self.encoder = encoder # Projector - self.projector = MLP(input_dim, hidden_size, output_dim) + self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim) # Predictor - self.predictor = MLP(output_dim, hidden_size, output_dim) + self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim) def forward(self, x): y = self.encoder(x)[0] From de878e93f86109422ae2ea9aba6051c52cefd0ed Mon Sep 17 00:00:00 2001 From: "O'Donnell, Garry (DLSLtd,RAL,LSCI)" Date: Mon, 10 May 2021 14:21:03 +0100 Subject: [PATCH 3/5] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5064c5c528..aaf5157684 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637)) + ### Deprecated From c909458cb4263d83015b9208b433398d37f7b70b Mon Sep 17 00:00:00 2001 From: "O'Donnell, Garry (DLSLtd,RAL,LSCI)" Date: Mon, 10 May 2021 14:34:56 +0100 Subject: [PATCH 4/5] Updated docstring --- pl_bolts/models/self_supervised/byol/byol_module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 627dbd2b27..514e6e2964 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -88,6 +88,10 @@ def __init__( num_workers: number of workers warmup_epochs: num of epochs for scheduler warm up max_epochs: max epochs for scheduler + base_encoder: the base encoder module or resnet name + encoder_out_dim: output dimension of base_encoder + projector_hidden_size: hidden layer size of projector MLP + projector_out_dim: output size of projector MLP """ super().__init__() self.save_hyperparameters(ignore='base_encoder') From 234797995ce79ef368298a2b54f44987c98b326d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 16 Jun 2021 00:31:26 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4165524169..93db4fe211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629)) + + - Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637))