Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[1/2] Use any task as an embedder for any layer (#1396)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
  • Loading branch information
ethanwharris and krshrimali authored Jul 18, 2022
1 parent 3e4c8bb commit 140c5f6
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for more formats when loading audio files ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387))

- Added support to use any task as an embedder by calling `as_embedder` ([#1396](https://github.com/PyTorchLightning/lightning-flash/pull/1396))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
22 changes: 22 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def __new__(mcs, *args, **kwargs):
patterns = ["load_from_checkpoint", "available_*"] # must match classmethods only
regex = "(" + ")|(".join(patterns) + ")"
for attribute_name, attribute_value in filter(lambda x: re.match(regex, x[0]), inspect.getmembers(result)):
# TODO: Find a better way to do this
if attribute_name in ["available_layers"]:
continue
setattr(
result, attribute_name, classmethod(requires(*result.required_extras)(attribute_value.__func__))
)
Expand Down Expand Up @@ -570,6 +573,25 @@ def configure_finetune_callback(

return [finetuning_strategy_fn(**finetuning_strategy_metadata)]

def as_embedder(self, layer: str):
"""Convert this task to an embedder. Note that the parameters are not copied so that any optimization of
the embedder will also apply to the converted ``Task``.
Args:
layer: The layer to embed to. This should be one of the :meth:`~flash.core.model.Task.available_layers`.
"""
from flash.core.utilities.embedder import Embedder # Avoid circular import

return Embedder(self, layer)

def available_layers(self):
"""Get the list of available layers for use with the :meth:`~flash.core.model.Task.as_embedder` method."""
available_layers = []
for name, _ in self.named_modules():
if name not in ["train_metrics", "val_metrics", "test_metrics"]:
available_layers.append(name)
return ["output"] + available_layers

@classmethod
def available_backbones(
cls, head: Optional[str] = None
Expand Down
85 changes: 85 additions & 0 deletions flash/core/utilities/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from pytorch_lightning import LightningModule

from flash.core.model import Task


class StopForward(Exception):
pass


class Embedder(Task):
def __init__(self, model: LightningModule, layer: str):
super().__init__()

self.model = model
self.layer = layer

self._module, self._hook = self._make_hook()
self._handle = None
self._out = None

def _make_hook(self):
def hook(_, __, output):
self._out = output
raise StopForward

available_layers = {"output", ""}

if self.layer in available_layers:
return None, None

for name, module in self.model.named_modules():
available_layers.add(name)
if name == self.layer:
return module, hook

raise ValueError(
"The requested layer is not available in `model.named_modules`. The available layers are: "
f"{', '.join(available_layers)}."
)

def _register_hook(self):
if self._module is not None:
self._handle = self._module.register_forward_hook(self._hook)

def _remove_hook(self):
if self._handle is not None:
self._handle.remove()
self._handle = None

def training_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Training an `Embedder` is not supported.")

def validation_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Validating an `Embedder` is not supported.")

def test_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Testing an `Embedder` is not supported.")

def forward(self, batch: Any) -> Any:
try:
self._register_hook()
return self.model.predict_step(batch, 0, dataloader_idx=0)
except StopForward:
return self._out
finally:
self._remove_hook()
self._out = None

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
30 changes: 30 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from flash.core.classification import ClassificationTask
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.output_transform import OutputTransform
from flash.core.utilities.embedder import Embedder
from flash.core.utilities.imports import (
_AUDIO_TESTING,
_CORE_TESTING,
Expand Down Expand Up @@ -290,6 +291,35 @@ def test_model_download(tmpdir, cls, filename):
assert isinstance(task, cls)


class DummyTask(Task):
def __init__(self):
super().__init__()

self.backbone = nn.Sequential(
nn.Linear(10, 20),
nn.Linear(20, 30),
nn.Linear(30, 40),
)

def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
return self.backbone(batch)


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_as_embedder():
layer_number = 1
embedder = DummyTask().as_embedder(f"backbone.{layer_number}")

assert isinstance(embedder, Embedder)
assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == embedder.model.backbone[layer_number].out_features


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_available_layers():
task = DummyTask()
assert task.available_layers() == ["output", "", "backbone", "backbone.0", "backbone.1", "backbone.2"]


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_available_backbones():
backbones = ImageClassifier.available_backbones()
Expand Down
108 changes: 108 additions & 0 deletions tests/core/utilities/test_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import pytest
import torch
from pytorch_lightning import LightningModule
from torch import nn

from flash.core.utilities.embedder import Embedder
from flash.core.utilities.imports import _CORE_TESTING


class EmbedderTestModel(LightningModule):
def __init__(self, backbone):
super().__init__()

self.backbone = backbone

def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
return self.backbone(batch)


class NLayerModel(EmbedderTestModel):
def __init__(self, n_layers):
super().__init__(nn.Sequential(*[nn.Linear(1000, 1000) for _ in range(n_layers)]))


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize("layer, size", [("backbone.1", 30), ("output", 40), ("", 40)])
def test_embedder(layer, size):
"""Tests that the embedder ``predict_step`` correctly returns the output from the requested layer."""
model = EmbedderTestModel(
nn.Sequential(
nn.Linear(10, 20),
nn.Linear(20, 30),
nn.Linear(30, 40),
)
)

embedder = Embedder(model, layer)

assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == size
assert embedder(torch.rand(10, 10)).size(1) == size


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_embedder_scaling_overhead():
"""Tests that embedding to the 3rd layer of a 200 layer model takes less than double the time of embedding to.
the same layer of a 3 layer model and therefore in the order of 10s - 100s of times faster than executing the full
200 layer model.
Note that this bound is intentionally high in an effort to reduce the flakiness of the test.
"""
shallow_embedder = Embedder(NLayerModel(3), "backbone.2")

start = time.perf_counter()
shallow_embedder.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

shallow_time = end - start

deep_embedder = Embedder(NLayerModel(200), "backbone.2")

start = time.perf_counter()
deep_embedder.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

deep_time = end - start

assert (abs(deep_time - shallow_time) / shallow_time) < 1


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_embedder_raising_overhead():
"""Tests that embedding to the output layer of a 3 layer model takes less than 10ms more than the time taken to
execute the model without the embedder.
Note that this bound is intentionally high in an effort to reduce the flakiness of the test.
"""
model = NLayerModel(10)
embedder = Embedder(model, "output")

start = time.perf_counter()
model.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

model_time = end - start

start = time.perf_counter()
embedder.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

embedder_time = end - start

assert abs(embedder_time - model_time) < 0.01

0 comments on commit 140c5f6

Please sign in to comment.