Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Document object detector augmentations (#776)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
gianscarpe and ethanwharris authored Sep 22, 2021
1 parent 742fe2e commit ae525d9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
32 changes: 32 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,35 @@ To view configuration options and options for running the object detector with y
.. code-block:: bash
flash object_detection --help
------

**********************
Custom Transformations
**********************

Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case.
The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline.
For object-detection tasks, you can leverage the transformations from `Albumentations <https://github.com/albumentations-team/albumentations>`__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`.

.. code-block:: python
import albumentations as alb
from icevision.tfms import A
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
train_transform = {
"pre_tensor_transform": transforms.IceVisionTransformAdapter(
[*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()]
)
}
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
image_size=128,
train_transform=train_transform,
)
8 changes: 7 additions & 1 deletion flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,13 @@ def from_icevision_predictions(predictions: List["Prediction"]):


class IceVisionTransformAdapter(nn.Module):
def __init__(self, transform):
"""
Args:
transform: list of transformation functions to apply
"""

def __init__(self, transform: List[Callable]):
super().__init__()
self.transform = A.Adapter(transform)

Expand Down
16 changes: 2 additions & 14 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import types
from importlib.util import find_spec
from typing import List, Union
from warnings import warn

from pkg_resources import DistributionNotFound

Expand Down Expand Up @@ -107,19 +106,8 @@ def _compare_version(package: str, op, version) -> bool:
from PIL import Image # noqa: F401
else:

class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".")
return cls._Image

class Image(metaclass=MetaImage):
pass
class Image:
Image = object


if Version:
Expand Down

0 comments on commit ae525d9

Please sign in to comment.