From 07318489c3cb2a39f48e4eae601018a01479f904 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 21:05:11 +0000 Subject: [PATCH] Fixes --- flash/image/keypoint_detection/data.py | 24 +++++++++---------- .../keypoint_detection/input_transform.py | 24 +++++++++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) create mode 100644 flash/image/keypoint_detection/input_transform.py diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 3e0dfb366e..751a7513ec 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -17,10 +17,10 @@ from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.integrations.icevision.data import IceVisionInput -from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE +from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform if _ICEVISION_AVAILABLE: from icevision.parsers import COCOKeyPointsParser, Parser @@ -31,7 +31,7 @@ class KeypointDetectionData(DataModule): - input_transform_cls = IceVisionInputTransform + input_transform_cls = KeypointDetectionInputTransform @classmethod def from_icedata( @@ -43,10 +43,10 @@ def from_icedata( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, @@ -73,10 +73,10 @@ def from_coco( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, parser: Optional[Type[Parser]] = COCOKeyPointsParser, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, @@ -119,7 +119,7 @@ def from_coco( def from_folders( cls, predict_folder: Optional[str] = None, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -149,7 +149,7 @@ def from_folders( def from_files( cls, predict_files: Optional[List[str]] = None, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, diff --git a/flash/image/keypoint_detection/input_transform.py b/flash/image/keypoint_detection/input_transform.py new file mode 100644 index 0000000000..bd741bd822 --- /dev/null +++ b/flash/image/keypoint_detection/input_transform.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.integrations.icevision.transforms import IceVisionInputTransform, IceVisionTransformAdapter +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires + +if _ICEVISION_AVAILABLE: + from icevision.tfms import A + + +class KeypointDetectionInputTransform(IceVisionInputTransform): + @requires(["image", "icevision"]) + def train_per_sample_transform(self): + return IceVisionTransformAdapter([*A.aug_tfms(size=self.image_size, crop_fn=None), A.Normalize()])