Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Propagate collate_fn to InputTransform in ImageEmbedder #1217

Merged
merged 13 commits into from
Mar 28, 2022
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
- Fixed a bug where collate functions were never called in the `ImageEmbedder` class. ([#1217](https://github.com/PyTorchLightning/lightning-flash/pull/1217))

- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))

- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))

- Fixed normalizing inputs to video classification ([#1213](https://github.com/PyTorchLightning/lightning-flash/pull/1213))

Expand Down
9 changes: 9 additions & 0 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]:
def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None:
self.adapter.input_transform = input_transform

@torch.jit.unused
@property
def collate_fn(self) -> Optional[Callable]:
return self.adapter.collate_fn

@collate_fn.setter
def collate_fn(self, collate_fn: Callable) -> None:
self.adapter.collate_fn = collate_fn

@torch.jit.unused
@property
def backbone(self) -> nn.Module:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def create_or_configure_input_transform(
)
return transform(**transform_kwargs)

if isinstance(transform, partial) and transform.func.__name__ == "LambdaInputTransform":
if isinstance(transform, partial):
return transform(**transform_kwargs)

if isinstance(transform, Callable):
Expand Down
19 changes: 13 additions & 6 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def simclr_head(
dims: List[int] = [2048, 2048, 256],
num_features: int = 2048,
embedding_dim: int = 128,
dims: List[int] = [2048],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"dims": [num_features] + dims + [embedding_dim],
"use_bn": use_bn,
}

Expand All @@ -108,7 +110,9 @@ def simclr_head(


def swav_head(
dims: List[int] = [2048, 2048, 128],
num_features: int = 2048,
embedding_dim: int = 128,
dims: List[int] = [2048],
use_bn: bool = True,
num_clusters: Union[int, List[int]] = [3000],
use_bias: bool = True,
Expand All @@ -121,7 +125,7 @@ def swav_head(
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"dims": [num_features] + dims + [embedding_dim],
"use_bn": use_bn,
"num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters,
"use_bias": use_bias,
Expand All @@ -140,8 +144,11 @@ def swav_head(
return head


def barlow_twins_head(**kwargs) -> nn.Module:
return simclr_head(**kwargs)
def barlow_twins_head(
latent_embedding_dim: int = 8192,
**kwargs,
) -> nn.Module:
return simclr_head(embedding_dim=latent_embedding_dim, **kwargs)


def moco_head(**kwargs) -> nn.Module:
Expand Down
13 changes: 13 additions & 0 deletions flash/image/embedding/losses/vissl_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
from typing import List, Union

import torch.cuda
from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE

Expand All @@ -26,11 +29,21 @@
ClassyLoss = object


def _recursive_register(module):
for key, value in module.__dict__.items():
if isinstance(value, torch.Tensor):
delattr(module, key)
module.register_buffer(key, value)
if isinstance(value, nn.Module):
_recursive_register(value)


def get_loss_fn(loss_name: str, cfg: AttrDict):
set_cpu_device()
loss_fn = LOSS_REGISTRY[loss_name](cfg)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
loss_fn.__dict__["loss_name"] = loss_name

_recursive_register(loss_fn)
return loss_fn


Expand Down
12 changes: 3 additions & 9 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import LambdaInputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
Expand Down Expand Up @@ -92,10 +88,10 @@ def __init__(
if pretraining_transform_kwargs is None:
pretraining_transform_kwargs = {}

backbone, _ = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)
backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

metadata = self.training_strategies.get(training_strategy, with_metadata=True)
loss_fn, head, hooks = metadata["fn"](head=head, **training_strategy_kwargs)
loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs)

adapter = metadata["metadata"]["adapter"].from_task(
self,
Expand All @@ -112,9 +108,7 @@ def __init__(
learning_rate=learning_rate,
)

input_transform, self.collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)
output = ApplyToKeys(DataKeys.INPUT, input_transform)
self.input_transform = partial(LambdaInputTransform, transform=output)
self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)

warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
Expand Down
21 changes: 11 additions & 10 deletions flash/image/embedding/transforms/vissl_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE
from flash.image.embedding.vissl.transforms import moco_collate_fn, multicrop_collate_fn, simclr_collate_fn

if _VISSL_AVAILABLE:
from classy_vision.dataset.transforms import TRANSFORM_REGISTRY
from flash.image.embedding.vissl.transforms.multicrop import StandardMultiCropSSLTransform


def simclr_transform(
Expand All @@ -33,19 +30,21 @@ def simclr_transform(
jitter_strength: float = 1.0,
normalize: Optional[nn.Module] = None,
collate_fn: Callable = simclr_collate_fn,
) -> nn.Module:
) -> partial:
"""For simclr, barlow twins and moco."""
transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"](
transform = partial(
StandardMultiCropSSLTransform,
total_num_crops=total_num_crops,
num_crops=num_crops,
size_crops=size_crops,
crop_scales=crop_scales,
gaussian_blur=gaussian_blur,
jitter_strength=jitter_strength,
normalize=normalize,
collate_fn=collate_fn,
)

return transform, collate_fn
return transform


def swav_transform(
Expand All @@ -57,19 +56,21 @@ def swav_transform(
jitter_strength: float = 1.0,
normalize: Optional[nn.Module] = None,
collate_fn: Callable = multicrop_collate_fn,
) -> nn.Module:
) -> partial:
"""For swav and dino."""
transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"](
transform = partial(
StandardMultiCropSSLTransform,
total_num_crops=total_num_crops,
num_crops=num_crops,
size_crops=size_crops,
crop_scales=crop_scales,
gaussian_blur=gaussian_blur,
jitter_strength=jitter_strength,
normalize=normalize,
collate_fn=collate_fn,
)

return transform, collate_fn
return transform


barlow_twins_transform = partial(simclr_transform, collate_fn=simclr_collate_fn)
Expand Down
7 changes: 0 additions & 7 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,6 @@ def from_task(

return result

def on_epoch_start(self) -> None:
use_gpu = self.adapter_task.device != torch.device("cpu") and self.adapter_task.device != "cpu"
if hasattr(self.loss_fn, "info_criterion"):
self.loss_fn.info_criterion.use_gpu = use_gpu
if hasattr(self.loss_fn, "swav_criterion"):
self.loss_fn.swav_criterion.use_gpu = use_gpu

ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def get_model_config_template():
cfg = AttrDict(
Expand Down
6 changes: 5 additions & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") ->

# get around vissl distributed training by setting MockTask flags
num_nodes = lightning_module.trainer.num_nodes
accelerators_ids = accelerator_connector(lightning_module.trainer).parallel_device_ids
accelerators_ids = getattr(
lightning_module.trainer,
"device_ids",
getattr(accelerator_connector(lightning_module.trainer), "parallel_device_ids", None),
)
accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1
task.world_size = num_nodes * accelerator_per_node

Expand Down
5 changes: 0 additions & 5 deletions flash/image/embedding/vissl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,3 @@
multicrop_collate_fn,
simclr_collate_fn,
)

if _VISSL_AVAILABLE:
from classy_vision.dataset.transforms import register_transform # noqa: F401

register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform)
Loading