Skip to content

Commit

Permalink
swav-improvements (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
Atharva-Phatak authored Oct 18, 2022
1 parent 7f8ede8 commit d108329
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 150 deletions.
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/ssl_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from torchmetrics import Accuracy

from pl_bolts.models.self_supervised import SSLEvaluator
from pl_bolts.utils.stability import under_review


@under_review()
class SSLFineTuner(LightningModule):
"""Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
with 1024 units.
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
Expand All @@ -13,4 +14,5 @@
"SwAVEvalDataTransform",
"SwAVFinetuneTransform",
"SwAVTrainDataTransform",
"SWAVLoss",
]
131 changes: 131 additions & 0 deletions pl_bolts/models/self_supervised/swav/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist


class SWAVLoss(nn.Module):
def __init__(
self,
temperature: float,
crops_for_assign: tuple,
nmb_crops: tuple,
sinkhorn_iterations: int,
epsilon: float,
gpus: int,
num_nodes: int,
):
"""Implementation for SWAV loss function.
Args:
temperature: loss temperature
crops_for_assign: list of crop ids for computing assignment
nmb_crops: number of global and local crops, ex: [2, 6]
sinkhorn_iterations: iterations for sinkhorn normalization
epsilon: epsilon val for swav assignments
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
num_nodes: num_nodes: number of nodes to train on
"""
super().__init__()
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.softmax = nn.Softmax(dim=1)
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.nmb_crops = nmb_crops
self.gpus = gpus
self.num_nodes = num_nodes
if self.gpus * self.num_nodes > 1:
self.assignment_fn = self.distributed_sinkhorn
else:
self.assignment_fn = self.sinkhorn

def forward(
self,
output: torch.Tensor,
embedding: torch.Tensor,
prototype_weights: torch.Tensor,
batch_size: int,
queue: Optional[torch.Tensor] = None,
use_queue: bool = False,
) -> Tuple[int, Optional[torch.Tensor], bool]:
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[batch_size * crop_id : batch_size * (crop_id + 1)]

# Time to use the queue
if queue is not None:
if use_queue or not torch.all(queue[i, -1, :] == 0):
use_queue = True
out = torch.cat((torch.mm(queue[i], prototype_weights.t()), out))
# fill the queue
queue[i, batch_size:] = self.queue[i, :-batch_size].clone() # type: ignore
queue[i, :batch_size] = embedding[crop_id * batch_size : (crop_id + 1) * batch_size]
# get assignments
q = torch.exp(out / self.epsilon).t()
q = self.assignment_fn(q, self.sinkhorn_iterations)[-batch_size:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id):
p = self.softmax(output[batch_size * v : batch_size * (v + 1)] / self.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1)
loss /= len(self.crops_for_assign) # type: ignore
return loss, queue, use_queue

def sinkhorn(self, Q: torch.Tensor, nmb_iters: int) -> torch.Tensor:
"""Implementation of Sinkhorn clustering."""
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q

K, B = Q.shape

if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B

for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)

Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)

return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

def distributed_sinkhorn(self, Q: torch.Tensor, nmb_iters: int) -> torch.Tensor:
"""Implementation of Distributed Sinkhorn."""
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])

curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)

for _ in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization
from pl_bolts.utils.stability import under_review


@under_review()
def cli_main(): # pragma: no cover
from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule

Expand Down
112 changes: 22 additions & 90 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import os
from argparse import ArgumentParser

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import distributed as dist
from torch import nn

from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.optimizers.lars import LARS
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
Expand All @@ -17,10 +16,8 @@
imagenet_normalization,
stl10_normalization,
)
from pl_bolts.utils.stability import under_review


@under_review()
class SwAV(LightningModule):
def __init__(
self,
Expand Down Expand Up @@ -129,19 +126,21 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

if self.gpus * self.num_nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn

self.model = self.init_model()

self.criterion = SWAVLoss(
gpus=self.gpus,
num_nodes=self.num_nodes,
temperature=self.temperature,
crops_for_assign=self.crops_for_assign,
nmb_crops=self.nmb_crops,
sinkhorn_iterations=self.sinkhorn_iterations,
epsilon=self.epsilon,
)
self.use_the_queue = None
# compute iters per epoch
global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

self.queue = None
self.softmax = nn.Softmax(dim=1)

def setup(self, stage):
if self.queue_length > 0:
Expand Down Expand Up @@ -216,33 +215,17 @@ def shared_step(self, batch):
embedding = embedding.detach()
bs = inputs[0].size(0)

# 3. swav loss computation
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id : bs * (crop_id + 1)]

# 4. time to use the queue
if self.queue is not None:
if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0):
self.use_the_queue = True
out = torch.cat((torch.mm(self.queue[i], self.model.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs : (crop_id + 1) * bs]

# 5. get assignments
q = torch.exp(out / self.epsilon).t()
q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id):
p = self.softmax(output[bs * v : bs * (v + 1)] / self.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1)
loss /= len(self.crops_for_assign)

# SWAV loss computation
loss, queue, use_queue = self.criterion(
output=output,
embedding=embedding,
prototype_weights=self.model.prototypes.weight,
batch_size=bs,
queue=self.queue,
use_queue=self.use_the_queue,
)
self.queue = queue
self.use_the_queue = use_queue
return loss

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -302,56 +285,6 @@ def configure_optimizers(self):

return [optimizer], [scheduler]

def sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q

K, B = Q.shape

if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B

for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)

Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)

return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

def distributed_sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])

curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)

for it in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand Down Expand Up @@ -446,7 +379,6 @@ def add_model_specific_args(parent_parser):
return parser


@under_review()
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
Expand Down
Loading

0 comments on commit d108329

Please sign in to comment.