diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst index 15362aa12b..9455691b39 100644 --- a/docs/source/api/core.rst +++ b/docs/source/api/core.rst @@ -48,8 +48,8 @@ _____________________ ~flash.core.finetuning.NoFreeze ~flash.core.finetuning.UnfreezeMilestones -flash.core.integration.fiftyone -_______________________________ +flash.core.integrations.fiftyone +________________________________ .. autosummary:: :toctree: generated/ @@ -57,6 +57,17 @@ _______________________________ ~flash.core.integrations.fiftyone.utils.visualize +flash.core.integrations.icevision +_________________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter + ~flash.core.integrations.icevision.transforms.default_transforms + ~flash.core.integrations.icevision.transforms.train_default_transforms + flash.core.model ________________ diff --git a/docs/source/index.rst b/docs/source/index.rst index 3d4b48be5c..91ea1a09e5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -84,6 +84,7 @@ Lightning Flash integrations/providers integrations/fiftyone + integrations/icevision .. toctree:: :maxdepth: 1 diff --git a/docs/source/integrations/icevision.rst b/docs/source/integrations/icevision.rst new file mode 100644 index 0000000000..ff21565a4e --- /dev/null +++ b/docs/source/integrations/icevision.rst @@ -0,0 +1,44 @@ +.. _ice_vision: + +######### +IceVision +######### + +IceVision from airctic is an awesome computer vision framework which offers a curated collection of hundreds of high-quality pre-trained models for: object detection, keypoint detection, and instance segmentation. +In Flash, we've integrated the IceVision framework to provide: data loading, augmentation, backbones, and heads. +We use IceVision components in our: :ref:`object detection `, :ref:`instance segmentation `, and :ref:`keypoint detection ` tasks. +Take a look at `their documentation `_ and star `IceVision on GitHub `_ to spread the open source love! + +IceData +_______ + +The `IceData library `_ is a community driven dataset hub for IceVision. +All of the datasets in IceData can be used out of the box with flash using our ``.from_folders`` methods and the ``parser`` argument. +Take a look at our :ref:`keypoint_detection` page for an example. + +Albumentations with IceVision and Flash +_______________________________________ + +IceVision provides two utilities for using the `albumentations library `_ with their models: +- the ``Adapter`` helper class for adapting an any albumentations transform to work with IceVision records, +- the ``aug_tfms`` utility function that returns a standard augmentation recipe to get the most out of your model. + +In Flash, we use the ``aug_tfms`` as default transforms for the: :ref:`object detection `, :ref:`instance segmentation `, and :ref:`keypoint detection ` tasks. +You can also provide custom transforms from albumentations using the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter` (which relies on the IceVision ``Adapter`` underneath). +Here's an example: + +.. code-block:: python + + import albumentations as A + + from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter + from flash.image import ObjectDetectionData + + train_transform = { + "pre_tensor_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]), + } + + datamodule = ObjectDetectionData.from_coco( + ..., + train_transform=train_transform, + ) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index c5a5968160..3d347c730c 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -174,7 +174,7 @@ def from_icevision_record(record: "BaseRecord"): class IceVisionTransformAdapter(nn.Module): def __init__(self, transform): super().__init__() - self.transform = transform + self.transform = A.Adapter(transform) def forward(self, x): record = to_icevision_record(x) @@ -186,7 +186,7 @@ def forward(self, x): def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms from IceVision.""" return { - "pre_tensor_transform": IceVisionTransformAdapter(A.Adapter([*A.resize_and_pad(image_size), A.Normalize()])), + "pre_tensor_transform": IceVisionTransformAdapter([*A.resize_and_pad(image_size), A.Normalize()]), } @@ -194,5 +194,5 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default augmentations from IceVision.""" return { - "pre_tensor_transform": IceVisionTransformAdapter(A.Adapter([*A.aug_tfms(size=image_size), A.Normalize()])), + "pre_tensor_transform": IceVisionTransformAdapter([*A.aug_tfms(size=image_size), A.Normalize()]), }