Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docstrings for IceVision data (#1102)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
ethanwharris and Borda authored Jan 11, 2022
1 parent 54d641d commit 1739ada
Show file tree
Hide file tree
Showing 12 changed files with 988 additions and 240 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added Flash zero support for tabular regression ([#1098](https://github.com/PyTorchLightning/lightning-flash/pull/1098))

- Added support for COCO annotations with non-default keypoint labels to `KeypointDetectionData.from_coco` ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down Expand Up @@ -44,6 +46,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where loading data for prediction with `SemanticSegmentationData.from_folders` raised an error ([#1101](https://github.com/PyTorchLightning/lightning-flash/pull/1101))

- Fixed a bug when passing a `predict_folder` argument to `from_coco` / `from_voc` / `from_via` in IceVision tasks ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102))

- Fixed `ObjectDetectionData.from_voc` and `ObjectDetectionData.from_via` ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102))

- Fixed a bug where `InstanceSegmentationData.from_coco` would raise an error if not using file-based masks ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102))

- Fixed `InstanceSegmentationData.from_voc` ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
27 changes: 17 additions & 10 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from flash.core.data.utilities.paths import list_valid_files
from flash.core.integrations.icevision.transforms import from_icevision_record
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.image.data import image_loader, IMG_EXTENSIONS, NP_EXTENSIONS

if _ICEVISION_AVAILABLE:
from icevision.core.record import BaseRecord
Expand All @@ -36,23 +35,29 @@ def load_data(
root: str,
ann_file: Optional[str] = None,
parser: Optional[Type["Parser"]] = None,
parser_kwargs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
if inspect.isclass(parser) and issubclass(parser, Parser):
parser = parser(ann_file, root)
elif isinstance(parser, Callable):
parser = parser(root)
parser_kwargs = {} if parser_kwargs is None else parser_kwargs
unwrapped_parser = getattr(parser, "func", parser)
if inspect.isclass(unwrapped_parser) and issubclass(unwrapped_parser, Parser):
parser = parser(ann_file, root, **parser_kwargs)
elif isinstance(unwrapped_parser, Callable):
parser = parser(root, **parser_kwargs)
else:
raise ValueError("The parser must be a callable or an IceVision Parser type.")
self.num_classes = parser.class_map.num_classes
self.set_state(ClassificationState([parser.class_map.get_by_id(i) for i in range(self.num_classes)]))
class_map = getattr(parser, "class_map", None)
if class_map is not None:
self.num_classes = class_map.num_classes
self.labels = [class_map.get_by_id(i) for i in range(self.num_classes)]
self.set_state(ClassificationState(self.labels))
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DataKeys.INPUT: record} for record in records[0]]

def predict_load_data(
self, paths: Union[str, List[str]], ann_file: Optional[str] = None, parser: Optional[Type["Parser"]] = None
self, paths: Union[str, List[str]], parser: Optional[Type["Parser"]] = None
) -> List[Dict[str, Any]]:
if parser is not None and parser != Parser:
return self.load_data(paths, ann_file, parser)
from flash.image.data import IMG_EXTENSIONS, NP_EXTENSIONS # Import locally to prevent circular import

paths = list_valid_files(paths, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return [{DataKeys.INPUT: path} for path in paths]

Expand All @@ -61,6 +66,8 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return from_icevision_record(record)

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
from flash.image.data import image_loader # Import locally to prevent circular import

if isinstance(sample[DataKeys.INPUT], BaseRecord):
return self.load_sample(sample)
filepath = sample[DataKeys.INPUT]
Expand Down
82 changes: 39 additions & 43 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
from torch import nn

from flash.core.data.io.input import DataKeys
Expand All @@ -24,7 +25,7 @@
from icevision.core import tasks
from icevision.core.bbox import BBox
from icevision.core.keypoints import KeyPoints
from icevision.core.mask import EncodedRLEs, MaskArray
from icevision.core.mask import Mask, MaskArray
from icevision.core.record import BaseRecord
from icevision.core.record_components import (
BBoxesRecordComponent,
Expand All @@ -37,14 +38,20 @@
)
from icevision.data.prediction import Prediction
from icevision.tfms import A
else:
MaskArray = object

if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0:
from icevision.core.mask import MaskFile
from icevision.core.record_components import InstanceMasksRecordComponent
elif _ICEVISION_AVAILABLE:
from icevision.core.record_components import MasksRecordComponent


def _split_mask_array(mask_array: MaskArray) -> List[MaskArray]:
"""Utility to split a single ``MaskArray`` object into a list of ``MaskArray`` objects (one per mask)."""
return [MaskArray(mask) for mask in mask_array.data]


def to_icevision_record(sample: Dict[str, Any]):
record = BaseRecord([])

Expand All @@ -58,6 +65,19 @@ def to_icevision_record(sample: Dict[str, Any]):
component.set_class_map(metadata.get("class_map", None))
record.add_component(component)

if isinstance(sample[DataKeys.INPUT], str):
input_component = FilepathRecordComponent()
input_component.set_filepath(sample[DataKeys.INPUT])
else:
if "filepath" in metadata:
input_component = FilepathRecordComponent()
input_component.filepath = metadata["filepath"]
else:
input_component = ImageRecordComponent()
input_component.composite = record
input_component.set_img(sample[DataKeys.INPUT])
record.add_component(input_component)

if "labels" in sample[DataKeys.TARGET]:
labels_component = InstancesLabelsRecordComponent()
labels_component.add_labels_by_id(sample[DataKeys.TARGET]["labels"])
Expand All @@ -73,19 +93,19 @@ def to_icevision_record(sample: Dict[str, Any]):
record.add_component(component)

if _ICEVISION_GREATER_EQUAL_0_11_0:
mask_array = sample[DataKeys.TARGET].get("mask_array", None)
# mask_array = sample[DataKeys.TARGET].get("mask_array", None)
masks = sample[DataKeys.TARGET].get("masks", None)

if mask_array is not None or masks is not None:
if masks is not None:
component = InstanceMasksRecordComponent()

if masks is not None:
masks = [MaskFile(mask) for mask in masks]
component.set_masks(masks)

if mask_array is not None:
mask_array = MaskArray(mask_array)
component.set_mask_array(mask_array)
if masks is not None and len(masks) > 0:
if isinstance(masks[0], Mask):
component.set_masks(masks)
else:
mask_array = MaskArray(np.stack(masks, axis=0))
component.set_mask_array(mask_array)
component.set_masks(_split_mask_array(mask_array))

record.add_component(component)
else:
Expand All @@ -110,19 +130,6 @@ def to_icevision_record(sample: Dict[str, Any]):
component.set_keypoints(keypoints)
record.add_component(component)

if isinstance(sample[DataKeys.INPUT], str):
input_component = FilepathRecordComponent()
input_component.set_filepath(sample[DataKeys.INPUT])
else:
if "filepath" in metadata:
input_component = FilepathRecordComponent()
input_component.filepath = metadata["filepath"]
else:
input_component = ImageRecordComponent()
input_component.composite = record
input_component.set_img(sample[DataKeys.INPUT])
record.add_component(input_component)

return record


Expand All @@ -142,26 +149,15 @@ def from_icevision_detection(record: "BaseRecord"):
for bbox in detection.bboxes
]

mask_array = (
getattr(detection, "mask_array", None) if _ICEVISION_GREATER_EQUAL_0_11_0 else getattr(detection, "masks", None)
)
if mask_array is not None:
if isinstance(mask_array, EncodedRLEs):
mask_array = mask_array.to_mask(record.height, record.width)

if isinstance(mask_array, MaskArray):
result["mask_array"] = mask_array.data
else:
raise RuntimeError("Mask arrays are expected to be a MaskArray or EncodedRLEs.")

masks = getattr(detection, "masks", None)
if masks is not None and _ICEVISION_GREATER_EQUAL_0_11_0:
result["masks"] = []
for mask in masks:
if isinstance(mask, MaskFile):
result["masks"].append(mask.filepath)
else:
raise RuntimeError("Masks are expected to be MaskFile objects.")
mask_array = getattr(detection, "mask_array", None)
if mask_array is not None or not _ICEVISION_GREATER_EQUAL_0_11_0:
if not isinstance(mask_array, MaskArray) or len(mask_array.data) == 0:
mask_array = MaskArray.from_masks(masks, record.height, record.width)

result["masks"] = [mask.data[0] for mask in _split_mask_array(mask_array)]
elif masks is not None:
result["masks"] = masks # Note - this doesn't unpack IceVision objects

if hasattr(detection, "keypoints"):
keypoints = detection.keypoints
Expand Down
3 changes: 2 additions & 1 deletion flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
ImageClassifier,
)
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401
from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401
from flash.image.detection.data import ObjectDetectionData # noqa: F401
from flash.image.detection.model import ObjectDetector # noqa: F401
from flash.image.embedding import ImageEmbedder # noqa: F401
from flash.image.face_detection import FaceDetectionData, FaceDetector # noqa: F401
from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401
Expand Down
Loading

0 comments on commit 1739ada

Please sign in to comment.