diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index e82a234c16..dd70e4bfed 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -73,3 +73,35 @@ To view configuration options and options for running the object detector with y .. code-block:: bash flash object_detection --help + +------ + +********************** +Custom Transformations +********************** + +Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. +The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline. +For object-detection tasks, you can leverage the transformations from `Albumentations `__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`. + +.. code-block:: python + + import albumentations as alb + from icevision.tfms import A + + from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter + from flash.image import ObjectDetectionData + + train_transform = { + "pre_tensor_transform": transforms.IceVisionTransformAdapter( + [*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()] + ) + } + + datamodule = ObjectDetectionData.from_coco( + train_folder="data/coco128/images/train2017/", + train_ann_file="data/coco128/annotations/instances_train2017.json", + val_split=0.1, + image_size=128, + train_transform=train_transform, + ) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 3859bfa2ff..5619dfd5af 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -196,7 +196,13 @@ def from_icevision_predictions(predictions: List["Prediction"]): class IceVisionTransformAdapter(nn.Module): - def __init__(self, transform): + """ + Args: + transform: list of transformation functions to apply + + """ + + def __init__(self, transform: List[Callable]): super().__init__() self.transform = A.Adapter(transform) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index abe7ba931f..1ec4d016e4 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -17,7 +17,6 @@ import types from importlib.util import find_spec from typing import List, Union -from warnings import warn from pkg_resources import DistributionNotFound @@ -107,19 +106,8 @@ def _compare_version(package: str, op, version) -> bool: from PIL import Image # noqa: F401 else: - class MetaImage(type): - def __init__(cls, name, bases, dct): - super().__init__(name, bases, dct) - - cls._Image = None - - @property - def Image(cls): - warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".") - return cls._Image - - class Image(metaclass=MetaImage): - pass + class Image: + Image = object if Version: