From d1083298b057ea3775ab8fc275c7a7cc2edb6850 Mon Sep 17 00:00:00 2001 From: Atharva Phatak Date: Tue, 18 Oct 2022 03:46:28 -0400 Subject: [PATCH] swav-improvements (#903) --- .../models/self_supervised/ssl_finetuner.py | 2 - .../models/self_supervised/swav/__init__.py | 2 + pl_bolts/models/self_supervised/swav/loss.py | 131 ++++++++++++++++++ .../self_supervised/swav/swav_finetuner.py | 2 - .../self_supervised/swav/swav_module.py | 112 +++------------ .../self_supervised/swav/swav_resnet.py | 13 -- .../models/self_supervised/swav/transforms.py | 40 +----- tests/models/self_supervised/test_models.py | 29 +++- .../self_supervised/unit/test_transforms.py | 52 +++++++ 9 files changed, 233 insertions(+), 150 deletions(-) create mode 100644 pl_bolts/models/self_supervised/swav/loss.py diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index 306a7c7552..49919c54cb 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -6,10 +6,8 @@ from torchmetrics import Accuracy from pl_bolts.models.self_supervised import SSLEvaluator -from pl_bolts.utils.stability import under_review -@under_review() class SSLFineTuner(LightningModule): """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units. diff --git a/pl_bolts/models/self_supervised/swav/__init__.py b/pl_bolts/models/self_supervised/swav/__init__.py index 6a1ba8c546..ddddff1890 100644 --- a/pl_bolts/models/self_supervised/swav/__init__.py +++ b/pl_bolts/models/self_supervised/swav/__init__.py @@ -1,3 +1,4 @@ +from pl_bolts.models.self_supervised.swav.loss import SWAVLoss from pl_bolts.models.self_supervised.swav.swav_module import SwAV from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 from pl_bolts.models.self_supervised.swav.transforms import ( @@ -13,4 +14,5 @@ "SwAVEvalDataTransform", "SwAVFinetuneTransform", "SwAVTrainDataTransform", + "SWAVLoss", ] diff --git a/pl_bolts/models/self_supervised/swav/loss.py b/pl_bolts/models/self_supervised/swav/loss.py new file mode 100644 index 0000000000..d322b60a73 --- /dev/null +++ b/pl_bolts/models/self_supervised/swav/loss.py @@ -0,0 +1,131 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist + + +class SWAVLoss(nn.Module): + def __init__( + self, + temperature: float, + crops_for_assign: tuple, + nmb_crops: tuple, + sinkhorn_iterations: int, + epsilon: float, + gpus: int, + num_nodes: int, + ): + """Implementation for SWAV loss function. + + Args: + temperature: loss temperature + crops_for_assign: list of crop ids for computing assignment + nmb_crops: number of global and local crops, ex: [2, 6] + sinkhorn_iterations: iterations for sinkhorn normalization + epsilon: epsilon val for swav assignments + gpus: number of gpus per node used in training, passed to SwAV module + to manage the queue and select distributed sinkhorn + num_nodes: num_nodes: number of nodes to train on + """ + super().__init__() + self.temperature = temperature + self.crops_for_assign = crops_for_assign + self.softmax = nn.Softmax(dim=1) + self.sinkhorn_iterations = sinkhorn_iterations + self.epsilon = epsilon + self.nmb_crops = nmb_crops + self.gpus = gpus + self.num_nodes = num_nodes + if self.gpus * self.num_nodes > 1: + self.assignment_fn = self.distributed_sinkhorn + else: + self.assignment_fn = self.sinkhorn + + def forward( + self, + output: torch.Tensor, + embedding: torch.Tensor, + prototype_weights: torch.Tensor, + batch_size: int, + queue: Optional[torch.Tensor] = None, + use_queue: bool = False, + ) -> Tuple[int, Optional[torch.Tensor], bool]: + loss = 0 + for i, crop_id in enumerate(self.crops_for_assign): + with torch.no_grad(): + out = output[batch_size * crop_id : batch_size * (crop_id + 1)] + + # Time to use the queue + if queue is not None: + if use_queue or not torch.all(queue[i, -1, :] == 0): + use_queue = True + out = torch.cat((torch.mm(queue[i], prototype_weights.t()), out)) + # fill the queue + queue[i, batch_size:] = self.queue[i, :-batch_size].clone() # type: ignore + queue[i, :batch_size] = embedding[crop_id * batch_size : (crop_id + 1) * batch_size] + # get assignments + q = torch.exp(out / self.epsilon).t() + q = self.assignment_fn(q, self.sinkhorn_iterations)[-batch_size:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id): + p = self.softmax(output[batch_size * v : batch_size * (v + 1)] / self.temperature) + subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1)) + loss += subloss / (np.sum(self.nmb_crops) - 1) + loss /= len(self.crops_for_assign) # type: ignore + return loss, queue, use_queue + + def sinkhorn(self, Q: torch.Tensor, nmb_iters: int) -> torch.Tensor: + """Implementation of Sinkhorn clustering.""" + with torch.no_grad(): + sum_Q = torch.sum(Q) + Q /= sum_Q + + K, B = Q.shape + + if self.gpus > 0: + u = torch.zeros(K).cuda() + r = torch.ones(K).cuda() / K + c = torch.ones(B).cuda() / B + else: + u = torch.zeros(K) + r = torch.ones(K) / K + c = torch.ones(B) / B + + for _ in range(nmb_iters): + u = torch.sum(Q, dim=1) + + Q *= (r / u).unsqueeze(1) + Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) + + return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() + + def distributed_sinkhorn(self, Q: torch.Tensor, nmb_iters: int) -> torch.Tensor: + """Implementation of Distributed Sinkhorn.""" + with torch.no_grad(): + sum_Q = torch.sum(Q) + dist.all_reduce(sum_Q) + Q /= sum_Q + + if self.gpus > 0: + u = torch.zeros(Q.shape[0]).cuda(non_blocking=True) + r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0] + c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1]) + else: + u = torch.zeros(Q.shape[0]) + r = torch.ones(Q.shape[0]) / Q.shape[0] + c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1]) + + curr_sum = torch.sum(Q, dim=1) + dist.all_reduce(curr_sum) + + for _ in range(nmb_iters): + u = curr_sum + Q *= (r / u).unsqueeze(1) + Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) + curr_sum = torch.sum(Q, dim=1) + dist.all_reduce(curr_sum) + return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() diff --git a/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/pl_bolts/models/self_supervised/swav/swav_finetuner.py index 5c754ad838..4754846bff 100644 --- a/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -7,10 +7,8 @@ from pl_bolts.models.self_supervised.swav.swav_module import SwAV from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization -from pl_bolts.utils.stability import under_review -@under_review() def cli_main(): # pragma: no cover from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index 2763e99244..c253a08431 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -2,13 +2,12 @@ import os from argparse import ArgumentParser -import numpy as np import torch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from torch import distributed as dist from torch import nn +from pl_bolts.models.self_supervised.swav.loss import SWAVLoss from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay @@ -17,10 +16,8 @@ imagenet_normalization, stl10_normalization, ) -from pl_bolts.utils.stability import under_review -@under_review() class SwAV(LightningModule): def __init__( self, @@ -129,19 +126,21 @@ def __init__( self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs - if self.gpus * self.num_nodes > 1: - self.get_assignments = self.distributed_sinkhorn - else: - self.get_assignments = self.sinkhorn - self.model = self.init_model() - + self.criterion = SWAVLoss( + gpus=self.gpus, + num_nodes=self.num_nodes, + temperature=self.temperature, + crops_for_assign=self.crops_for_assign, + nmb_crops=self.nmb_crops, + sinkhorn_iterations=self.sinkhorn_iterations, + epsilon=self.epsilon, + ) + self.use_the_queue = None # compute iters per epoch global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size - self.queue = None - self.softmax = nn.Softmax(dim=1) def setup(self, stage): if self.queue_length > 0: @@ -216,33 +215,17 @@ def shared_step(self, batch): embedding = embedding.detach() bs = inputs[0].size(0) - # 3. swav loss computation - loss = 0 - for i, crop_id in enumerate(self.crops_for_assign): - with torch.no_grad(): - out = output[bs * crop_id : bs * (crop_id + 1)] - - # 4. time to use the queue - if self.queue is not None: - if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0): - self.use_the_queue = True - out = torch.cat((torch.mm(self.queue[i], self.model.prototypes.weight.t()), out)) - # fill the queue - self.queue[i, bs:] = self.queue[i, :-bs].clone() - self.queue[i, :bs] = embedding[crop_id * bs : (crop_id + 1) * bs] - - # 5. get assignments - q = torch.exp(out / self.epsilon).t() - q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:] - - # cluster assignment prediction - subloss = 0 - for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id): - p = self.softmax(output[bs * v : bs * (v + 1)] / self.temperature) - subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1)) - loss += subloss / (np.sum(self.nmb_crops) - 1) - loss /= len(self.crops_for_assign) - + # SWAV loss computation + loss, queue, use_queue = self.criterion( + output=output, + embedding=embedding, + prototype_weights=self.model.prototypes.weight, + batch_size=bs, + queue=self.queue, + use_queue=self.use_the_queue, + ) + self.queue = queue + self.use_the_queue = use_queue return loss def training_step(self, batch, batch_idx): @@ -302,56 +285,6 @@ def configure_optimizers(self): return [optimizer], [scheduler] - def sinkhorn(self, Q, nmb_iters): - with torch.no_grad(): - sum_Q = torch.sum(Q) - Q /= sum_Q - - K, B = Q.shape - - if self.gpus > 0: - u = torch.zeros(K).cuda() - r = torch.ones(K).cuda() / K - c = torch.ones(B).cuda() / B - else: - u = torch.zeros(K) - r = torch.ones(K) / K - c = torch.ones(B) / B - - for _ in range(nmb_iters): - u = torch.sum(Q, dim=1) - - Q *= (r / u).unsqueeze(1) - Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) - - return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() - - def distributed_sinkhorn(self, Q, nmb_iters): - with torch.no_grad(): - sum_Q = torch.sum(Q) - dist.all_reduce(sum_Q) - Q /= sum_Q - - if self.gpus > 0: - u = torch.zeros(Q.shape[0]).cuda(non_blocking=True) - r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0] - c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1]) - else: - u = torch.zeros(Q.shape[0]) - r = torch.ones(Q.shape[0]) / Q.shape[0] - c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1]) - - curr_sum = torch.sum(Q, dim=1) - dist.all_reduce(curr_sum) - - for it in range(nmb_iters): - u = curr_sum - Q *= (r / u).unsqueeze(1) - Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) - curr_sum = torch.sum(Q, dim=1) - dist.all_reduce(curr_sum) - return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() - @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -446,7 +379,6 @@ def add_model_specific_args(parent_parser): return parser -@under_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/swav/swav_resnet.py b/pl_bolts/models/self_supervised/swav/swav_resnet.py index 1a2e5de63e..fb24651a11 100644 --- a/pl_bolts/models/self_supervised/swav/swav_resnet.py +++ b/pl_bolts/models/self_supervised/swav/swav_resnet.py @@ -2,10 +2,7 @@ import torch from torch import nn -from pl_bolts.utils.stability import under_review - -@under_review() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -20,13 +17,11 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) -@under_review() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -@under_review() class BasicBlock(nn.Module): expansion = 1 __constants__ = ["downsample"] @@ -77,7 +72,6 @@ def forward(self, x): return out -@under_review() class Bottleneck(nn.Module): expansion = 4 __constants__ = ["downsample"] @@ -131,7 +125,6 @@ def forward(self, x): return out -@under_review() class ResNet(nn.Module): def __init__( self, @@ -343,7 +336,6 @@ def forward(self, inputs): return self.forward_head(output) -@under_review() class MultiPrototypes(nn.Module): def __init__(self, output_dim, nmb_prototypes): super().__init__() @@ -358,26 +350,21 @@ def forward(self, x): return out -@under_review() def resnet18(**kwargs): return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) -@under_review() def resnet50(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) -@under_review() def resnet50w2(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs) -@under_review() def resnet50w4(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs) -@under_review() def resnet50w5(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs) diff --git a/pl_bolts/models/self_supervised/swav/transforms.py b/pl_bolts/models/self_supervised/swav/transforms.py index 76c3ea871d..2047563a50 100644 --- a/pl_bolts/models/self_supervised/swav/transforms.py +++ b/pl_bolts/models/self_supervised/swav/transforms.py @@ -1,9 +1,6 @@ from typing import Tuple -import numpy as np - -from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -11,13 +8,7 @@ else: # pragma: no cover warn_missing_pkg("torchvision") -if _OPENCV_AVAILABLE: - import cv2 -else: # pragma: no cover - warn_missing_pkg("cv2", pypi_name="opencv-python") - -@under_review() class SwAVTrainDataTransform: def __init__( self, @@ -56,7 +47,8 @@ def __init__( if kernel_size % 2 == 0: kernel_size += 1 - color_transform.append(GaussianBlur(kernel_size=kernel_size, p=0.5)) + # Resort to torchvision gaussian blur instead of custom implementation + color_transform.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5)) self.color_transform = transforms.Compose(color_transform) @@ -100,7 +92,6 @@ def __call__(self, sample): return multi_crops -@under_review() class SwAVEvalDataTransform(SwAVTrainDataTransform): def __init__( self, @@ -135,7 +126,6 @@ def __init__( self.transform[-1] = test_transform -@under_review() class SwAVFinetuneTransform: def __init__( self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False @@ -175,27 +165,3 @@ def __init__( def __call__(self, sample): return self.transform(sample) - - -@under_review() -class GaussianBlur: - # Implements Gaussian blur as described in the SimCLR paper - def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): - self.min = min - self.max = max - - # kernel size is set to be 10% of the image height/width - self.kernel_size = kernel_size - self.p = p - - def __call__(self, sample): - sample = np.array(sample) - - # blur the image with a 50% chance - prob = np.random.random_sample() - - if prob < self.p: - sigma = (self.max - self.min) * np.random.random_sample() + self.min - sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) - - return sample diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index bd8957928d..9db7ad411c 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -86,8 +86,14 @@ def test_simclr(tmpdir, datadir): trainer.fit(model, datamodule=datamodule) -def test_swav(tmpdir, datadir, batch_size=2): - # inputs, y = batch (doesn't receive y for some reason) +def test_swav(tmpdir, datadir, catch_warnings): + """Test SWAV on CIFAR-10.""" + warnings.filterwarnings( + "ignore", + message=".+does not have many workers which may be a bottleneck.+", + category=PossibleUserWarning, + ) + batch_size = 2 datamodule = CIFAR10DataModule(data_dir=datadir, batch_size=batch_size, num_workers=0) datamodule.train_transforms = SwAVTrainDataTransform( @@ -96,12 +102,18 @@ def test_swav(tmpdir, datadir, batch_size=2): datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) + if torch.cuda.device_count() >= 1: + devices = torch.cuda.device_count() + accelerator = "gpu" + else: + devices = None + accelerator = "cpu" model = SwAV( arch="resnet18", hidden_mlp=512, - gpus=0, nodes=1, + gpus=0 if devices is None else devices, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], @@ -112,9 +124,14 @@ def test_swav(tmpdir, datadir, batch_size=2): first_conv=False, dataset="cifar10", ) - - trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) - + trainer = Trainer( + accelerator=accelerator, + devices=devices, + fast_dev_run=True, + default_root_dir=tmpdir, + log_every_n_steps=1, + max_epochs=1, + ) trainer.fit(model, datamodule=datamodule) diff --git a/tests/models/self_supervised/unit/test_transforms.py b/tests/models/self_supervised/unit/test_transforms.py index 737af74bcb..eaf97ab3ee 100644 --- a/tests/models/self_supervised/unit/test_transforms.py +++ b/tests/models/self_supervised/unit/test_transforms.py @@ -8,6 +8,58 @@ SimCLRFinetuneTransform, SimCLRTrainDataTransform, ) +from pl_bolts.models.self_supervised.swav.transforms import ( + SwAVEvalDataTransform, + SwAVFinetuneTransform, + SwAVTrainDataTransform, +) + + +@pytest.mark.parametrize( + "transform_cls", + [pytest.param(SwAVTrainDataTransform, id="train-data"), pytest.param(SwAVEvalDataTransform, id="eval-data")], +) +def test_swav_train_data_transform(catch_warnings, transform_cls): + # dummy image + img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) + img = Image.fromarray(img) + crop_sizes = (96, 36) + + # size of the generated views + transform = transform_cls(size_crops=crop_sizes) + views = transform(img) + + # the transform must output a list or a tuple of images + assert isinstance(views, (list, tuple)) + + # the transform must output three images + # (2 Global Crops, 4 Local Crops, online evaluation view) + assert len(views) == 7 + + # all views are tensors + assert all(torch.is_tensor(v) for v in views) + + # Global Views have equal size + assert all(v.size(1) == v.size(2) == crop_sizes[0] for v in views[:2]) + # Check local views have same size + assert all(v.size(1) == v.size(2) == crop_sizes[1] for v in views[2 : len(views) - 1]) # Ignore online transform + + +def test_swav_finetune_transform(catch_warnings): + # dummy image + img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) + img = Image.fromarray(img) + + # size of the generated views + input_height = 96 + transform = SwAVFinetuneTransform(input_height=input_height) + view = transform(img) + + # the view generator is a tensor + assert torch.is_tensor(view) + + # view has expected size + assert view.size(1) == view.size(2) == input_height @pytest.mark.parametrize(