From 60d27336282e32dd56dd1e232271c904364b2343 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sat, 7 Aug 2021 16:33:15 +0200 Subject: [PATCH] update pystiche dependency (#642) * update pystiche dependency * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes * Fixes Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris --- README.md | 4 ++-- docs/source/reference/style_transfer.rst | 2 +- flash/image/style_transfer/model.py | 20 ++++++++------------ requirements/datatype_image.txt | 2 +- tests/image/style_transfer/test_model.py | 3 +++ 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 7f62b2508d..97cb374b52 100644 --- a/README.md +++ b/README.md @@ -446,9 +446,9 @@ python flash_examples/finetuning/semantic_segmentation.py -### Example 7: Style Transfer with Pystiche +### Example 7: Style Transfer with pystiche -Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [Pystiche](https://github.com/pystiche/pystiche). +Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [pystiche](https://pystiche.org).
View example diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index 759cc988ad..1200e315b0 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -12,7 +12,7 @@ The Task The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. -.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg +.. image:: https://raw.githubusercontent.com/pystiche/pystiche/main/docs/source/graphics/banner/banner.jpg :alt: style_transfer_example The :class:`~flash.image.style_transfer.model.StyleTransfer` and :class:`~flash.image.style_transfer.data.StyleTransferData` classes internally rely on `pystiche `_. diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 2908df52e6..86a6b723e5 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -26,7 +26,7 @@ if _IMAGE_AVAILABLE: import pystiche.demo - from pystiche import enc, loss, ops + from pystiche import enc, loss from pystiche.image import read_image else: @@ -34,12 +34,10 @@ class enc: Encoder = None MultiLayerEncoder = None - class ops: - EncodingComparisonOperator = None - FeatureReconstructionOperator = None - MultiLayerEncodingOperator = None - class loss: + class GramLoss: + pass + class PerceptualLoss: pass @@ -128,11 +126,11 @@ def default_style_image() -> torch.Tensor: return pystiche.demo.images()["paint"].read(size=256) @staticmethod - def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator: + def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> loss.GramLoss: # The official PyTorch examples as well as the reference implementation of the original author contain an # oversight: they normalize the representation twice by the number of channels. To be compatible with them, we # do the same here. - class GramOperator(ops.GramOperator): + class GramOperator(loss.GramLoss): def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] @@ -150,10 +148,8 @@ def _get_perceptual_loss( style_weight: float, ) -> loss.PerceptualLoss: mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)()) - content_loss = ops.FeatureReconstructionOperator( - mle.extract_encoder(content_layer), score_weight=content_weight - ) - style_loss = ops.MultiLayerEncodingOperator( + content_loss = loss.FeatureReconstructionLoss(mle.extract_encoder(content_layer), score_weight=content_weight) + style_loss = loss.MultiLayerEncodingLoss( mle, style_layers, lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight), diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index d39ad59395..3be9ed638d 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -3,5 +3,5 @@ timm>=0.4.5 lightning-bolts>=0.3.3 Pillow>=7.2 kornia>=0.5.1,<0.5.4 -pystiche>=0.7.2 +pystiche==1.* segmentation-models-pytorch diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index f6458369f7..93ccb32ece 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -49,6 +49,9 @@ def test_jit(tmpdir): model = StyleTransfer() model.eval() + model.loss_fn = None + model.perceptual_loss = None # TODO: Document this + model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche torch.jit.save(model, path)