diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 7a25a7742a..78c1d77ef7 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -46,7 +46,7 @@ Here's an outline: ... Once we've downloaded the data using :func:`~flash.core.data.download_data`, we can create the :class:`~flash.image.detection.data.ObjectDetectionData`. -We select a pre-trained RetinaNet to use for our :class:`~flash.image.detection.model.ObjectDetector` and fine-tune on the COCO 128 data. +We select a pre-trained EfficientDet to use for our :class:`~flash.image.detection.model.ObjectDetector` and fine-tune on the COCO 128 data. We then use the trained :class:`~flash.image.detection.model.ObjectDetector` for inference. Finally, we save the model. Here's the full example: @@ -82,26 +82,34 @@ Custom Transformations Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. The base :class:`~flash.core.data.io.input_transform.InputTransform` defines 7 hooks for different stages in the data loading pipeline. -For object-detection tasks, you can leverage the transformations from `Albumentations `__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`. +For object-detection tasks, you can leverage the transformations from `Albumentations `__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`, +creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransform` .. code-block:: python + from dataclasses import dataclass import albumentations as alb from icevision.tfms import A + from flash import InputTransform from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter from flash.image import ObjectDetectionData - train_transform = { - "per_sample_transform": transforms.IceVisionTransformAdapter( - [*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()] - ) - } + + @dataclass + class BrightnessContrastTransform(InputTransform): + image_size: int = 128 + + def per_sample_transform(self): + return IceVisionTransformAdapter( + [*A.aug_tfms(size=self.image_size), A.Normalize(), alb.RandomBrightnessContrast()] + ) + datamodule = ObjectDetectionData.from_coco( train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, - image_size=128, - train_transform=train_transform, + train_transform=BrightnessContrastTransform, + batch_size=4, )