Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Dec 14, 2021
1 parent 0f49513 commit 0731848
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
24 changes: 12 additions & 12 deletions flash/image/keypoint_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform

if _ICEVISION_AVAILABLE:
from icevision.parsers import COCOKeyPointsParser, Parser
Expand All @@ -31,7 +31,7 @@

class KeypointDetectionData(DataModule):

input_transform_cls = IceVisionInputTransform
input_transform_cls = KeypointDetectionInputTransform

@classmethod
def from_icedata(
Expand All @@ -43,10 +43,10 @@ def from_icedata(
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
parser: Optional[Union[Callable, Type[Parser]]] = None,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
Expand All @@ -73,10 +73,10 @@ def from_coco(
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
parser: Optional[Type[Parser]] = COCOKeyPointsParser,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -119,7 +119,7 @@ def from_coco(
def from_folders(
cls,
predict_folder: Optional[str] = None,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
Expand Down Expand Up @@ -149,7 +149,7 @@ def from_folders(
def from_files(
cls,
predict_files: Optional[List[str]] = None,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
Expand Down
24 changes: 24 additions & 0 deletions flash/image/keypoint_detection/input_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.integrations.icevision.transforms import IceVisionInputTransform, IceVisionTransformAdapter
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires

if _ICEVISION_AVAILABLE:
from icevision.tfms import A


class KeypointDetectionInputTransform(IceVisionInputTransform):
@requires(["image", "icevision"])
def train_per_sample_transform(self):
return IceVisionTransformAdapter([*A.aug_tfms(size=self.image_size, crop_fn=None), A.Normalize()])

0 comments on commit 0731848

Please sign in to comment.