Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added encoder argument to BYOL constructor #637

Merged
merged 10 commits into from
Jun 16, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629))


- Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637))
Borda marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated


Expand Down
14 changes: 11 additions & 3 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser
from copy import deepcopy
from typing import Any
from typing import Any, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -72,6 +72,10 @@ def __init__(
num_workers: int = 0,
warmup_epochs: int = 10,
max_epochs: int = 1000,
base_encoder: Union[str, torch.nn.Module] = 'resnet50',
encoder_out_dim: int = 2048,
projector_hidden_size: int = 4096,
projector_out_dim: int = 256,
**kwargs
):
"""
Expand All @@ -84,11 +88,15 @@ def __init__(
num_workers: number of workers
warmup_epochs: num of epochs for scheduler warm up
max_epochs: max epochs for scheduler
base_encoder: the base encoder module or resnet name
encoder_out_dim: output dimension of base_encoder
projector_hidden_size: hidden layer size of projector MLP
projector_out_dim: output size of projector MLP
"""
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore='base_encoder')

self.online_network = SiameseArm()
self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim)
self.target_network = deepcopy(self.online_network)
self.weight_callback = BYOLMAWeightUpdate()

Expand Down
10 changes: 5 additions & 5 deletions pl_bolts/models/self_supervised/byol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def forward(self, x):

class SiameseArm(nn.Module):

def __init__(self, encoder=None):
def __init__(self, encoder='resnet50', encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256):
super().__init__()

if encoder is None:
encoder = torchvision_ssl_encoder('resnet50')
if isinstance(encoder, str):
encoder = torchvision_ssl_encoder(encoder)
# Encoder
self.encoder = encoder
# Projector
self.projector = MLP()
self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim)
# Predictor
self.predictor = MLP(input_dim=256)
self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim)

def forward(self, x):
y = self.encoder(x)[0]
Expand Down