Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Make dependency of vissl optional #1276

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0ecf240
Make dependency of VISSl optional
ar90n Apr 4, 2022
8cec10d
Add tests for embedding
ar90n Apr 4, 2022
b386776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2022
f6b953c
No need for the optional typing
ar90n Apr 4, 2022
ebdb45d
Raise exception when train ImageEmbedder with DefaultAdapter
ar90n Apr 4, 2022
f834f3a
Remove head and loss_fn selection
ar90n Apr 5, 2022
7008093
Add copyright header
ar90n Apr 5, 2022
04b0235
Remove unused head argument
ar90n Apr 6, 2022
019aa03
Update CHANGELOG.md
ar90n Apr 6, 2022
5ce471b
Update CHANGELOG.md
ar90n Apr 6, 2022
098b324
Remove detect_anomaly
ar90n Apr 6, 2022
1a776f0
Improve exception message
ar90n Apr 6, 2022
df60ec4
Revert "Remove unused head argument"
ar90n Apr 6, 2022
45d9f2f
Fix parameter mismatch in overridden method
ar90n Apr 6, 2022
5dd499b
Use exception instead of assert
ar90n Apr 6, 2022
5488088
Raise NotImplementedError when calling validation_step
ar90n Apr 6, 2022
3eba6e4
Remove redundant required extras
ar90n Apr 6, 2022
ffb296b
Add backslash to escape correctly
ar90n Apr 7, 2022
3ea9037
Raise NotImplementedError when calling test_step
ar90n Apr 7, 2022
d7ad7f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2022
7ab0eec
Fix E501 issue
ar90n Apr 7, 2022
e444d16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2022
dd844ef
Make task as positional argument
ar90n Apr 7, 2022
5f65ebc
Remove unused head argument again
ar90n Apr 7, 2022
a70aebb
Remove *args
ar90n Apr 7, 2022
bf31ef2
Use CPU to run image embedding tests
ar90n Apr 7, 2022
4da9eff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2022
97eb2b5
Add condition not to run tests without torchvision
ar90n Apr 7, 2022
7d08042
Add new tests for ImageEmbedder
ar90n Apr 7, 2022
0330c88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2022
705e8f2
Change test condition
ar90n Apr 7, 2022
a66c197
Add docstring
ar90n Apr 7, 2022
16698ba
Improve docstring
ar90n Apr 8, 2022
30d628a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2022
d19508a
Remove duplicated test
ar90n Apr 8, 2022
023a25c
Remove unused forward
ar90n Apr 8, 2022
1dffc9d
Update
ethanwharris Apr 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))

### Deprecated

### Removed
Expand Down
4 changes: 3 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def decorator(func):

if not available:
modules = [f"'{module}'" for module in modules]
modules.append(f"'lightning-flash[{','.join(extras)}]'")

if extras:
modules.append(f"'lightning-flash[{','.join(extras)}]'")

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down
24 changes: 14 additions & 10 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ class ImageEmbedder(AdapterTask):
backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS

required_extras: List[str] = ["image", "vissl", "fairscale"]
required_extras: str = "image"

def __init__(
self,
training_strategy: str,
head: str,
pretraining_transform: str,
training_strategy: str = "default",
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
head: Optional[str] = None,
pretraining_transform: Optional[str] = None,
backbone: str = "resnet18",
pretrained: bool = False,
optimizer: OPTIMIZER_TYPE = "Adam",
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
loss_fn, head, hooks = metadata["fn"](head=head, num_features=num_features, **training_strategy_kwargs)

adapter = metadata["metadata"]["adapter"].from_task(
self,
task=self,
loss_fn=loss_fn,
backbone=model,
head=head,
Expand All @@ -128,12 +128,16 @@ def __init__(
learning_rate=learning_rate,
)

self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)
if pretraining_transform is not None:
warnings.warn(
"Overriding any transforms from the `DataModule` with the pretraining transform: "
f"{pretraining_transform}."
)
self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)

warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
" with pre-defined transforms for the training strategy."
)
if "providers" in metadata["metadata"] and metadata["metadata"]["providers"].name == "Facebook Research/vissl":
if pretraining_transform is None:
raise ValueError("Correct pretraining_transform must be set to use VISSL")

def forward(self, x: torch.Tensor) -> Any:
return self.model(x)
Expand Down
2 changes: 2 additions & 0 deletions flash/image/embedding/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.strategies.default import register_default_strategy
from flash.image.embedding.strategies.vissl_strategies import register_vissl_strategies # noqa: F401

IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies")
register_vissl_strategies(IMAGE_EMBEDDER_STRATEGIES)
register_default_strategy(IMAGE_EMBEDDER_STRATEGIES)
89 changes: 89 additions & 0 deletions flash/image/embedding/strategies/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Optional

import torch

from flash.core.adapter import Adapter, AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.url_error import catch_url_error


class DefaultAdapter(Adapter):
"""The ``DefaultAdapter`` is an :class:`~flash.core.adapter.Adapter`."""

required_extras: str = "image"

def __init__(self, backbone: torch.nn.Module):
super().__init__()

self.backbone = backbone

@classmethod
@catch_url_error
def from_task(
cls,
task: AdapterTask,
backbone: torch.nn.Module,
**kwargs,
) -> Adapter:
adapter = cls(backbone)
adapter.__dict__["_task"] = task
return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError(
'Training an `ImageEmbedder` with `strategy="default"` is not supported. '
"Use a different strategy instead."
)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError(
'Validation an `ImageEmbedder` with `strategy="default"` is not supported. '
"Use a different strategy instead."
)

def test_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError(
'Testing an `ImageEmbedder` with `strategy="default"` is not supported. '
"Use a different strategy instead."
)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = Task.predict_step(
self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
)
return batch


def default(head: Optional[str] = None, loss_fn: Optional[str] = None, **kwargs):
"""Return `(None, None, [])` as loss function, head and hooks.

Because default strategy only support prediction.
"""
if head is not None:
warnings.warn(f"default strategy has no heads. So given head({head}) is ignored.")

if loss_fn is not None:
warnings.warn(f"default strategy has no loss functions. So given loss_fn({loss_fn}) is ignored.")

return None, None, []


def register_default_strategy(register: FlashRegistry):
"""Register default strategy to given ``FlashRegistry``."""
register(default, name="default", adapter=DefaultAdapter)
10 changes: 6 additions & 4 deletions flash/image/embedding/strategies/vissl_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.providers import _VISSL
from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS
from flash.image.embedding.losses import IMAGE_EMBEDDER_LOSS_FUNCTIONS
Expand All @@ -23,20 +23,23 @@
from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook


@requires(["vissl", "classy_vision"])
def swav(head: str = "swav_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head, [SwAVUpdateQueueScoresHook(), NormalizePrototypesHook(), TrainingSetupHook()]


@requires(["vissl", "classy_vision"])
def simclr(head: str = "simclr_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("simclr_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head, [SimCLRTrainingSetupHook()]


@requires(["vissl", "classy_vision"])
def barlow_twins(head: str = "barlow_twins_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("barlow_twins_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)
Expand All @@ -45,6 +48,5 @@ def barlow_twins(head: str = "barlow_twins_head", **kwargs):


def register_vissl_strategies(register: FlashRegistry):
if _VISSL_AVAILABLE:
for training_strategy in (swav, simclr, barlow_twins):
register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL)
for training_strategy in (swav, simclr, barlow_twins):
register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL)
14 changes: 11 additions & 3 deletions flash_examples/integrations/fiftyone/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import fiftyone as fo
import fiftyone.brain as fob
import numpy as np
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageEmbedder
from flash.image.classification.data import ImageClassificationData

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")
Expand All @@ -26,13 +29,18 @@
"data/hymenoptera_data/test/",
fo.types.ImageClassificationDirectoryTree,
)
datamodule = ImageClassificationData.from_files(
predict_files=dataset.values("filepath"),
batch_size=16,
)

# 3 Load model
embedder = ImageEmbedder(backbone="resnet101")
embedder = ImageEmbedder(backbone="resnet18")

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))
trainer = flash.Trainer(gpus=torch.cuda.device_count())
embedding_batches = trainer.predict(embedder, datamodule=datamodule)
embeddings = np.stack(sum(embedding_batches, []))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
Expand Down
56 changes: 56 additions & 0 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,59 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform
for prediction_batch in predictions:
for prediction in prediction_batch:
assert prediction.size(0) == embedding_size


@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.")
@pytest.mark.parametrize(
"backbone, training_strategy, head, pretraining_transform, expected_exception",
[
("resnet18", "simclr", "simclr_head", None, ValueError),
("resnet18", "simclr", None, "simclr_transform", KeyError),
],
)
def test_vissl_training_with_wrong_arguments(
backbone, training_strategy, head, pretraining_transform, expected_exception
):
with pytest.raises(expected_exception):
ImageEmbedder(
backbone=backbone,
training_strategy=training_strategy,
head=head,
pretraining_transform=pretraining_transform,
)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="torch vision not installed.")
@pytest.mark.parametrize(
"backbone, embedding_size",
[
("resnet18", 512),
("vit_small_patch16_224", 384),
],
)
def test_only_embedding(backbone, embedding_size):
datamodule = ImageClassificationData.from_datasets(
predict_dataset=FakeData(8),
batch_size=4,
transform_kwargs=dict(image_size=(224, 224)),
)

embedder = ImageEmbedder(backbone=backbone)
trainer = flash.Trainer()

predictions = trainer.predict(embedder, datamodule=datamodule)
for prediction_batch in predictions:
for prediction in prediction_batch:
assert prediction.size(0) == embedding_size


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="torch vision not installed.")
def test_not_implemented_steps():
embedder = ImageEmbedder(backbone="resnet18")

with pytest.raises(NotImplementedError):
embedder.training_step([], 0)
with pytest.raises(NotImplementedError):
embedder.validation_step([], 0)
with pytest.raises(NotImplementedError):
embedder.test_step([], 0)