Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Document custom image transformations #620

Merged
merged 31 commits into from
Aug 11, 2021
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
568ad4a
add weights path
aniketmaurya Jul 13, 2021
0efba03
Merge branch 'master' of https://github.com/PyTorchLightning/lightnin…
aniketmaurya Jul 14, 2021
729f607
add available weights
aniketmaurya Jul 14, 2021
974049c
remove weight path
aniketmaurya Jul 14, 2021
16e91fc
add tests :white_check_mark:
aniketmaurya Jul 14, 2021
f018b4a
fix
aniketmaurya Jul 14, 2021
9428c29
update
aniketmaurya Jul 14, 2021
2070ba7
add str pretrained
aniketmaurya Jul 14, 2021
9aad7e6
add test :white_check_mark:
aniketmaurya Jul 14, 2021
55e1102
fix
aniketmaurya Jul 14, 2021
3782923
Merge branch 'master' into task_a_thon-weight_path
ethanwharris Jul 14, 2021
c283bb6
Update flash/image/segmentation/heads.py
ethanwharris Jul 14, 2021
d996f32
Update CHANGELOG.md
ethanwharris Jul 14, 2021
5492d31
Merge branch 'master' into task_a_thon-weight_path
ethanwharris Jul 14, 2021
f864aa1
merge master
aniketmaurya Jul 15, 2021
bd0f047
Merge branch 'master' of https://github.com/PyTorchLightning/lightnin…
aniketmaurya Jul 29, 2021
e400692
add transformation documentation
aniketmaurya Jul 29, 2021
79d460d
fix
aniketmaurya Jul 29, 2021
813eb84
fix
aniketmaurya Jul 29, 2021
6f099b5
fix
aniketmaurya Jul 29, 2021
4d80295
Merge branch 'master' into documentation/img_transform
ethanwharris Jul 30, 2021
94c1161
apply suggestions
aniketmaurya Jul 30, 2021
d5d0c8b
update to testcode
aniketmaurya Aug 5, 2021
df7ca3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
604f016
Merge branch 'master' into documentation/img_transform
aniketmaurya Aug 5, 2021
cf98528
Merge branch 'master' into documentation/img_transform
ethanwharris Aug 11, 2021
1e367e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2021
56c685a
Updates
ethanwharris Aug 11, 2021
d49b856
Updates
ethanwharris Aug 11, 2021
ed3f500
Merge branch 'documentation/img_transform' of https://github.com/anik…
ethanwharris Aug 11, 2021
6c4d2ff
Try fix
ethanwharris Aug 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,57 @@ Here's the full example:

ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
------

**********************
Custom Transformations
**********************

Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case.
The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline.
To apply image augmentations you can directly import the ``default_transforms`` from ``flash.image.classification.transforms`` and then merge your custom image transformations with them using the :func:`~flash.core.data.transforms.merge_transforms` helper function.
Here's an example where we load the default transforms and merge with custom `torchvision` transformations.
We use the `post_tensor_transform` hook to apply the transformations after the image has been converted to a `torch.Tensor`.


.. testsetup:: transformations

from flash.core.data.utils import download_data

download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

.. testcode:: transformations

from torchvision import transforms as T

import flash
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.transforms import default_transforms

post_tensor_transform = ApplyToKeys(
DefaultDataKeys.INPUT,
T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]),
)

new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform})

datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms
)

model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")


.. testoutput:: transformations
:hide:

...

------

**********
Flash Zero
**********
Expand Down