Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[1/2] Use any task as an embedder for any layer #1396

Merged
merged 11 commits into from
Jul 18, 2022
22 changes: 22 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def __new__(mcs, *args, **kwargs):
patterns = ["load_from_checkpoint", "available_*"] # must match classmethods only
regex = "(" + ")|(".join(patterns) + ")"
for attribute_name, attribute_value in filter(lambda x: re.match(regex, x[0]), inspect.getmembers(result)):
# TODO: Find a better way to do this
if attribute_name in ["available_layers"]:
continue
setattr(
result, attribute_name, classmethod(requires(*result.required_extras)(attribute_value.__func__))
)
Expand Down Expand Up @@ -570,6 +573,25 @@ def configure_finetune_callback(

return [finetuning_strategy_fn(**finetuning_strategy_metadata)]

def as_embedder(self, layer: str):
"""Convert this task to an embedder. Note that the parameters are not copied so that any optimization of
the embedder will also apply to the converted ``Task``.

Args:
layer: The layer to embed to. This should be one of the :meth:`~flash.core.model.Task.available_layers`.
"""
from flash.core.utilities.embedder import Embedder # Avoid circular import

return Embedder(self, layer)

def available_layers(self):
"""Get the list of available layers for use with the :meth:`~flash.core.model.Task.as_embedder` method."""
available_layers = []
for name, _ in self.named_modules():
if name not in ["train_metrics", "val_metrics", "test_metrics"]:
available_layers.append(name)
return ["output"] + available_layers

@classmethod
def available_backbones(
cls, head: Optional[str] = None
Expand Down
82 changes: 82 additions & 0 deletions flash/core/utilities/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from pytorch_lightning import LightningModule

from flash.core.model import Task


class StopForward(Exception):
pass


class Embedder(Task):
def __init__(self, model: LightningModule, layer: str):
super().__init__()

self.model = model
self.layer = layer

self._module, self._hook = self._make_hook()
self._handle = None
self._out = None

def _make_hook(self):
def hook(_, __, output):
self._out = output
raise StopForward

available_layers = {"output", ""}

if self.layer in available_layers:
return None, None

for name, module in self.model.named_modules():
available_layers.add(name)
if name == self.layer:
return module, hook

raise ValueError(
"The requested layer is not available in `model.named_modules`. The available layers are: "
f"{', '.join(available_layers)}."
)

def _register_hook(self):
if self._module is not None:
self._handle = self._module.register_forward_hook(self._hook)

def _remove_hook(self):
if self._handle is not None:
self._handle.remove()
self._handle = None

def training_step(self, batch: Any, batch_idx: int) -> Any:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("Training an `Embedder` is not supported.")

def validation_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Validating an `Embedder` is not supported.")

def test_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Testing an `Embedder` is not supported.")

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
try:
self._register_hook()
return self.model.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
except StopForward:
return self._out
finally:
self._remove_hook()
self._out = None
28 changes: 28 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from flash.core.classification import ClassificationTask
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.output_transform import OutputTransform
from flash.core.utilities.embedder import Embedder
from flash.core.utilities.imports import (
_AUDIO_TESTING,
_CORE_TESTING,
Expand Down Expand Up @@ -290,6 +291,33 @@ def test_model_download(tmpdir, cls, filename):
assert isinstance(task, cls)


class DummyTask(Task):
def __init__(self):
super().__init__()

self.model = nn.Sequential(
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
nn.Linear(10, 20),
nn.Linear(20, 30),
nn.Linear(30, 40),
)

def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
return self.model(batch)


def test_as_embedder():
layer_number = "1"
embedder = DummyTask().as_embedder(f"model.{layer_number}")
justusschock marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(embedder, Embedder)
assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == embedder.model.model[layer_number].out_features


def test_available_layers():
task = DummyTask()
assert task.available_layers() == ["output", "", "model", "model.0", "model.1", "model.2"]


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_available_backbones():
backbones = ImageClassifier.available_backbones()
Expand Down
102 changes: 102 additions & 0 deletions tests/core/utilities/test_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import pytest
import torch
from pytorch_lightning import LightningModule
from torch import nn

from flash.core.utilities.embedder import Embedder


class EmbedderTestModel(LightningModule):
def __init__(self, model):
super().__init__()

self.model = model

def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
return self.model(batch)


class NLayerModel(EmbedderTestModel):
def __init__(self, n_layers):
super().__init__(nn.Sequential(*[nn.Linear(1000, 1000) for _ in range(n_layers)]))


@pytest.mark.parametrize("layer, size", [("model.1", 30), ("output", 40), ("", 40)])
def test_embedder(layer, size):
"""Tests that the embedder ``predict_step`` correctly returns the output from the requested layer."""
model = EmbedderTestModel(
nn.Sequential(
nn.Linear(10, 20),
nn.Linear(20, 30),
nn.Linear(30, 40),
)
)

embedder = Embedder(model, layer)

assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == size


def test_embedder_scaling_overhead():
"""Tests that embedding to the 3rd layer of a 1000 layer model takes less than double the time of embedding to
the same layer of a 3 layer model and therefore in the order of 100s of times faster than executing the full
1000 layer model.

Note that this bound is intentionally high in an effort to reduce the flakiness of the test.
"""
shallow_embedder = Embedder(NLayerModel(3), "model.2")

start = time.perf_counter()
shallow_embedder.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

shallow_time = end - start

deep_embedder = Embedder(NLayerModel(1000), "model.2")

start = time.perf_counter()
deep_embedder.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

deep_time = end - start

assert (abs(deep_time - shallow_time) / shallow_time) < 1


def test_embedder_raising_overhead():
"""Tests that embedding to the output layer of a 3 layer model takes less than 5ms more than the time taken to
execute the model without the embedder.

Note that this bound is intentionally high in an effort to reduce the flakiness of the test.
"""
model = NLayerModel(10)
embedder = Embedder(model, "output")

start = time.perf_counter()
model.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

model_time = end - start

start = time.perf_counter()
embedder.predict_step(torch.rand(10, 1000), 0, 0)
end = time.perf_counter()

embedder_time = end - start

assert abs(embedder_time - model_time) < 0.005