Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update pystiche dependency (#642)
Browse files Browse the repository at this point in the history
* update pystiche dependency

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes

* Fixes

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Aug 7, 2021
1 parent 6483d1c commit 60d2733
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,9 @@ python flash_examples/finetuning/semantic_segmentation.py

</details>

### Example 7: Style Transfer with Pystiche
### Example 7: Style Transfer with pystiche

Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [Pystiche](https://github.com/pystiche/pystiche).
Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [pystiche](https://pystiche.org).

<details>
<summary>View example</summary>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/style_transfer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The Task
The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content.
The goal is that the output image looks like the content image, but “painted” in the style of the style reference image.

.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg
.. image:: https://raw.githubusercontent.com/pystiche/pystiche/main/docs/source/graphics/banner/banner.jpg
:alt: style_transfer_example

The :class:`~flash.image.style_transfer.model.StyleTransfer` and :class:`~flash.image.style_transfer.data.StyleTransferData` classes internally rely on `pystiche <https://pystiche.org>`_.
Expand Down
20 changes: 8 additions & 12 deletions flash/image/style_transfer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,18 @@

if _IMAGE_AVAILABLE:
import pystiche.demo
from pystiche import enc, loss, ops
from pystiche import enc, loss
from pystiche.image import read_image
else:

class enc:
Encoder = None
MultiLayerEncoder = None

class ops:
EncodingComparisonOperator = None
FeatureReconstructionOperator = None
MultiLayerEncodingOperator = None

class loss:
class GramLoss:
pass

class PerceptualLoss:
pass

Expand Down Expand Up @@ -128,11 +126,11 @@ def default_style_image() -> torch.Tensor:
return pystiche.demo.images()["paint"].read(size=256)

@staticmethod
def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator:
def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> loss.GramLoss:
# The official PyTorch examples as well as the reference implementation of the original author contain an
# oversight: they normalize the representation twice by the number of channels. To be compatible with them, we
# do the same here.
class GramOperator(ops.GramOperator):
class GramOperator(loss.GramLoss):
def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
repr = super().enc_to_repr(enc)
num_channels = repr.size()[1]
Expand All @@ -150,10 +148,8 @@ def _get_perceptual_loss(
style_weight: float,
) -> loss.PerceptualLoss:
mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)())
content_loss = ops.FeatureReconstructionOperator(
mle.extract_encoder(content_layer), score_weight=content_weight
)
style_loss = ops.MultiLayerEncodingOperator(
content_loss = loss.FeatureReconstructionLoss(mle.extract_encoder(content_layer), score_weight=content_weight)
style_loss = loss.MultiLayerEncodingLoss(
mle,
style_layers,
lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight),
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ timm>=0.4.5
lightning-bolts>=0.3.3
Pillow>=7.2
kornia>=0.5.1,<0.5.4
pystiche>=0.7.2
pystiche==1.*
segmentation-models-pytorch
3 changes: 3 additions & 0 deletions tests/image/style_transfer/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_jit(tmpdir):
model = StyleTransfer()
model.eval()

model.loss_fn = None
model.perceptual_loss = None # TODO: Document this

model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche

torch.jit.save(model, path)
Expand Down

0 comments on commit 60d2733

Please sign in to comment.