From 140c5f64472b51c3f830aedcd71f9f744c5b7ee6 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 18 Jul 2022 12:58:16 +0100 Subject: [PATCH] [1/2] Use any task as an embedder for any layer (#1396) Co-authored-by: Kushashwa Ravi Shrimali --- CHANGELOG.md | 2 + flash/core/model.py | 22 ++++++ flash/core/utilities/embedder.py | 85 ++++++++++++++++++++ tests/core/test_model.py | 30 +++++++ tests/core/utilities/test_embedder.py | 108 ++++++++++++++++++++++++++ 5 files changed, 247 insertions(+) create mode 100644 flash/core/utilities/embedder.py create mode 100644 tests/core/utilities/test_embedder.py diff --git a/CHANGELOG.md b/CHANGELOG.md index eb1e945ef2..84c7d8e890 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for more formats when loading audio files ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387)) +- Added support to use any task as an embedder by calling `as_embedder` ([#1396](https://github.com/PyTorchLightning/lightning-flash/pull/1396)) + ### Changed - Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) diff --git a/flash/core/model.py b/flash/core/model.py index 0237500ab9..3728c3f172 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -286,6 +286,9 @@ def __new__(mcs, *args, **kwargs): patterns = ["load_from_checkpoint", "available_*"] # must match classmethods only regex = "(" + ")|(".join(patterns) + ")" for attribute_name, attribute_value in filter(lambda x: re.match(regex, x[0]), inspect.getmembers(result)): + # TODO: Find a better way to do this + if attribute_name in ["available_layers"]: + continue setattr( result, attribute_name, classmethod(requires(*result.required_extras)(attribute_value.__func__)) ) @@ -570,6 +573,25 @@ def configure_finetune_callback( return [finetuning_strategy_fn(**finetuning_strategy_metadata)] + def as_embedder(self, layer: str): + """Convert this task to an embedder. Note that the parameters are not copied so that any optimization of + the embedder will also apply to the converted ``Task``. + + Args: + layer: The layer to embed to. This should be one of the :meth:`~flash.core.model.Task.available_layers`. + """ + from flash.core.utilities.embedder import Embedder # Avoid circular import + + return Embedder(self, layer) + + def available_layers(self): + """Get the list of available layers for use with the :meth:`~flash.core.model.Task.as_embedder` method.""" + available_layers = [] + for name, _ in self.named_modules(): + if name not in ["train_metrics", "val_metrics", "test_metrics"]: + available_layers.append(name) + return ["output"] + available_layers + @classmethod def available_backbones( cls, head: Optional[str] = None diff --git a/flash/core/utilities/embedder.py b/flash/core/utilities/embedder.py new file mode 100644 index 0000000000..dca570680e --- /dev/null +++ b/flash/core/utilities/embedder.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from pytorch_lightning import LightningModule + +from flash.core.model import Task + + +class StopForward(Exception): + pass + + +class Embedder(Task): + def __init__(self, model: LightningModule, layer: str): + super().__init__() + + self.model = model + self.layer = layer + + self._module, self._hook = self._make_hook() + self._handle = None + self._out = None + + def _make_hook(self): + def hook(_, __, output): + self._out = output + raise StopForward + + available_layers = {"output", ""} + + if self.layer in available_layers: + return None, None + + for name, module in self.model.named_modules(): + available_layers.add(name) + if name == self.layer: + return module, hook + + raise ValueError( + "The requested layer is not available in `model.named_modules`. The available layers are: " + f"{', '.join(available_layers)}." + ) + + def _register_hook(self): + if self._module is not None: + self._handle = self._module.register_forward_hook(self._hook) + + def _remove_hook(self): + if self._handle is not None: + self._handle.remove() + self._handle = None + + def training_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Training an `Embedder` is not supported.") + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Validating an `Embedder` is not supported.") + + def test_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Testing an `Embedder` is not supported.") + + def forward(self, batch: Any) -> Any: + try: + self._register_hook() + return self.model.predict_step(batch, 0, dataloader_idx=0) + except StopForward: + return self._out + finally: + self._remove_hook() + self._out = None + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a6ab8031e6..08c80d9a41 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -37,6 +37,7 @@ from flash.core.classification import ClassificationTask from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform +from flash.core.utilities.embedder import Embedder from flash.core.utilities.imports import ( _AUDIO_TESTING, _CORE_TESTING, @@ -290,6 +291,35 @@ def test_model_download(tmpdir, cls, filename): assert isinstance(task, cls) +class DummyTask(Task): + def __init__(self): + super().__init__() + + self.backbone = nn.Sequential( + nn.Linear(10, 20), + nn.Linear(20, 30), + nn.Linear(30, 40), + ) + + def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): + return self.backbone(batch) + + +@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +def test_as_embedder(): + layer_number = 1 + embedder = DummyTask().as_embedder(f"backbone.{layer_number}") + + assert isinstance(embedder, Embedder) + assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == embedder.model.backbone[layer_number].out_features + + +@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +def test_available_layers(): + task = DummyTask() + assert task.available_layers() == ["output", "", "backbone", "backbone.0", "backbone.1", "backbone.2"] + + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_available_backbones(): backbones = ImageClassifier.available_backbones() diff --git a/tests/core/utilities/test_embedder.py b/tests/core/utilities/test_embedder.py new file mode 100644 index 0000000000..f75d3d5d6f --- /dev/null +++ b/tests/core/utilities/test_embedder.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import pytest +import torch +from pytorch_lightning import LightningModule +from torch import nn + +from flash.core.utilities.embedder import Embedder +from flash.core.utilities.imports import _CORE_TESTING + + +class EmbedderTestModel(LightningModule): + def __init__(self, backbone): + super().__init__() + + self.backbone = backbone + + def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): + return self.backbone(batch) + + +class NLayerModel(EmbedderTestModel): + def __init__(self, n_layers): + super().__init__(nn.Sequential(*[nn.Linear(1000, 1000) for _ in range(n_layers)])) + + +@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.parametrize("layer, size", [("backbone.1", 30), ("output", 40), ("", 40)]) +def test_embedder(layer, size): + """Tests that the embedder ``predict_step`` correctly returns the output from the requested layer.""" + model = EmbedderTestModel( + nn.Sequential( + nn.Linear(10, 20), + nn.Linear(20, 30), + nn.Linear(30, 40), + ) + ) + + embedder = Embedder(model, layer) + + assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == size + assert embedder(torch.rand(10, 10)).size(1) == size + + +@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +def test_embedder_scaling_overhead(): + """Tests that embedding to the 3rd layer of a 200 layer model takes less than double the time of embedding to. + + the same layer of a 3 layer model and therefore in the order of 10s - 100s of times faster than executing the full + 200 layer model. + + Note that this bound is intentionally high in an effort to reduce the flakiness of the test. + """ + shallow_embedder = Embedder(NLayerModel(3), "backbone.2") + + start = time.perf_counter() + shallow_embedder.predict_step(torch.rand(10, 1000), 0, 0) + end = time.perf_counter() + + shallow_time = end - start + + deep_embedder = Embedder(NLayerModel(200), "backbone.2") + + start = time.perf_counter() + deep_embedder.predict_step(torch.rand(10, 1000), 0, 0) + end = time.perf_counter() + + deep_time = end - start + + assert (abs(deep_time - shallow_time) / shallow_time) < 1 + + +@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +def test_embedder_raising_overhead(): + """Tests that embedding to the output layer of a 3 layer model takes less than 10ms more than the time taken to + execute the model without the embedder. + + Note that this bound is intentionally high in an effort to reduce the flakiness of the test. + """ + model = NLayerModel(10) + embedder = Embedder(model, "output") + + start = time.perf_counter() + model.predict_step(torch.rand(10, 1000), 0, 0) + end = time.perf_counter() + + model_time = end - start + + start = time.perf_counter() + embedder.predict_step(torch.rand(10, 1000), 0, 0) + end = time.perf_counter() + + embedder_time = end - start + + assert abs(embedder_time - model_time) < 0.01