diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bb6e71f1c..1750e2a336 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `QuestionAnswering` task for extractive question answering ([#607](https://github.com/PyTorchLightning/lightning-flash/pull/607)) +- Added automatic unwrapping of IceVision prediction objects ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727)) + +- Added support for the `ObjectDetector` with FiftyOne ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index 34d44164a8..ded8bccd33 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -45,7 +45,6 @@ ________________ detection.data.FiftyOneParser detection.data.ObjectDetectionFiftyOneDataSource detection.data.ObjectDetectionPreprocess - detection.serialization.DetectionLabels detection.serialization.FiftyOneDetectionLabels Keypoint Detection diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index dd0ed1e9dd..b781c17058 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -344,13 +344,13 @@ def default_uncollate(batch: Any): return batch return list(torch.unbind(batch, 0)) - if isinstance(batch, Mapping): + if isinstance(batch, dict): return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())] if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple - return [batch_type(*default_uncollate(sample)) for sample in zip(*batch)] + return [batch_type(*sample) for sample in zip(*batch)] if isinstance(batch, Sequence) and not isinstance(batch, str): - return [default_uncollate(sample) for sample in batch] + return [sample for sample in batch] return batch diff --git a/flash/core/data/serialization.py b/flash/core/data/serialization.py new file mode 100644 index 0000000000..190bbffe5b --- /dev/null +++ b/flash/core/data/serialization.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Union + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer + + +class Preds(Serializer): + """A :class:`~flash.core.data.process.Serializer` which returns the "preds" from the model outputs.""" + + def serialize(self, sample: Any) -> Union[int, List[int]]: + return sample.get(DefaultDataKeys.PREDS, sample) if isinstance(sample, dict) else sample diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index af95da9a52..83be7c3848 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -19,7 +19,7 @@ from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_source import DefaultDataKeys -from flash.core.integrations.icevision.transforms import to_icevision_record +from flash.core.integrations.icevision.transforms import from_icevision_predictions, to_icevision_record from flash.core.model import Task from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.core.utilities.url_error import catch_url_error @@ -81,9 +81,12 @@ def from_task( @staticmethod def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): metadata = metadata or [None] * len(samples) - return collate_fn( - [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] - ) + return { + DefaultDataKeys.INPUT: collate_fn( + [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] + ), + DefaultDataKeys.METADATA: metadata, + } def process_train_dataset( self, @@ -178,19 +181,20 @@ def process_predict_dataset( return data_loader def training_step(self, batch, batch_idx) -> Any: - return self.icevision_adapter.training_step(batch, batch_idx) + return self.icevision_adapter.training_step(batch[DefaultDataKeys.INPUT], batch_idx) def validation_step(self, batch, batch_idx): - return self.icevision_adapter.validation_step(batch, batch_idx) + return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx) def test_step(self, batch, batch_idx): - return self.icevision_adapter.validation_step(batch, batch_idx) + return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return self(batch) + batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT]) + return batch def forward(self, batch: Any) -> Any: - return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) + return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False)) def training_epoch_end(self, outputs) -> None: return self.icevision_adapter.training_epoch_end(outputs) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 80ce622616..ee1dfe1ed5 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -11,18 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type import numpy as np -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.data_source import DefaultDataKeys, LabelsState from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource if _ICEVISION_AVAILABLE: from icevision.core.record import BaseRecord - from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent, tasks + from icevision.core.record_components import ClassMapRecordComponent, FilepathRecordComponent, tasks from icevision.data.data_splitter import SingleSplitSplitter from icevision.parsers.parser import Parser @@ -36,10 +37,14 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return from_icevision_record(record) def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord): + return self.load_sample(sample) + filepath = sample[DefaultDataKeys.INPUT] sample = super().load_sample(sample) image = np.array(sample[DefaultDataKeys.INPUT]) - record = BaseRecord([ImageRecordComponent()]) + record = BaseRecord([FilepathRecordComponent()]) + record.filepath = filepath record.set_img(image) record.add_component(ClassMapRecordComponent(task=tasks.detection)) return from_icevision_record(record) @@ -51,29 +56,23 @@ def __init__(self, parser: Optional[Type["Parser"]] = None): self.parser = parser def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - root, ann_file = data - if self.parser is not None: - parser = self.parser(ann_file, root) - dataset.num_classes = len(parser.class_map) + if inspect.isclass(self.parser) and issubclass(self.parser, Parser): + root, ann_file = data + parser = self.parser(ann_file, root) + elif isinstance(self.parser, Callable): + parser = self.parser(data) + else: + raise ValueError("The parser must be a callable or an IceVision Parser type.") + dataset.num_classes = parser.class_map.num_classes + self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) records = parser.parse(data_splitter=SingleSplitSplitter()) return [{DefaultDataKeys.INPUT: record} for record in records[0]] else: - raise ValueError("The parser type must be provided") - - -class IceDataParserDataSource(IceVisionPathsDataSource): - def __init__(self, parser: Optional[Callable] = None): - super().__init__() - self.parser = parser + raise ValueError("The parser argument must be provided.") - def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - root = data - - if self.parser is not None: - parser = self.parser(root) - dataset.num_classes = len(parser.class_map) - records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] - else: - raise ValueError("The parser must be provided") + def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + result = super().predict_load_data(data, dataset) + if len(result) == 0: + result = self.load_data(data, dataset) + return result diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 3d347c730c..3458ebf3d9 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, List, Tuple from torch import nn @@ -34,6 +34,7 @@ MasksRecordComponent, RecordIDRecordComponent, ) + from icevision.data.prediction import Prediction from icevision.tfms import A @@ -101,51 +102,38 @@ def to_icevision_record(sample: Dict[str, Any]): return record -def from_icevision_record(record: "BaseRecord"): - sample = { - DefaultDataKeys.METADATA: { - "image_id": record.record_id, - } - } +def from_icevision_detection(record: "BaseRecord"): + detection = record.detection - if record.img is not None: - sample[DefaultDataKeys.INPUT] = record.img - filepath = getattr(record, "filepath", None) - if filepath is not None: - sample[DefaultDataKeys.METADATA]["filepath"] = filepath - elif record.filepath is not None: - sample[DefaultDataKeys.INPUT] = record.filepath + result = {} - sample[DefaultDataKeys.TARGET] = {} - - if hasattr(record.detection, "bboxes"): - sample[DefaultDataKeys.TARGET]["bboxes"] = [] - for bbox in record.detection.bboxes: - bbox_list = list(bbox.xywh) - bbox_dict = { - "xmin": bbox_list[0], - "ymin": bbox_list[1], - "width": bbox_list[2], - "height": bbox_list[3], + if hasattr(detection, "bboxes"): + result["bboxes"] = [ + { + "xmin": bbox.xmin, + "ymin": bbox.ymin, + "width": bbox.width, + "height": bbox.height, } - sample[DefaultDataKeys.TARGET]["bboxes"].append(bbox_dict) + for bbox in detection.bboxes + ] - if hasattr(record.detection, "masks"): - masks = record.detection.masks + if hasattr(detection, "masks"): + masks = detection.masks if isinstance(masks, EncodedRLEs): masks = masks.to_mask(record.height, record.width) if isinstance(masks, MaskArray): - sample[DefaultDataKeys.TARGET]["masks"] = masks.data + result["masks"] = masks.data else: raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.") - if hasattr(record.detection, "keypoints"): - keypoints = record.detection.keypoints + if hasattr(detection, "keypoints"): + keypoints = detection.keypoints - sample[DefaultDataKeys.TARGET]["keypoints"] = [] - sample[DefaultDataKeys.TARGET]["keypoints_metadata"] = [] + result["keypoints"] = [] + result["keypoints_metadata"] = [] for keypoint in keypoints: keypoints_list = [] @@ -157,13 +145,42 @@ def from_icevision_record(record: "BaseRecord"): "visible": v, } ) - sample[DefaultDataKeys.TARGET]["keypoints"].append(keypoints_list) + result["keypoints"].append(keypoints_list) # TODO: Unpack keypoints_metadata - sample[DefaultDataKeys.TARGET]["keypoints_metadata"].append(keypoint.metadata) + result["keypoints_metadata"].append(keypoint.metadata) + + if getattr(detection, "label_ids", None) is not None: + result["labels"] = list(detection.label_ids) - if getattr(record.detection, "label_ids", None) is not None: - sample[DefaultDataKeys.TARGET]["labels"] = list(record.detection.label_ids) + if getattr(detection, "scores", None) is not None: + result["scores"] = list(detection.scores) + + return result + + +def from_icevision_record(record: "BaseRecord"): + sample = { + DefaultDataKeys.METADATA: { + "size": (record.height, record.width), + } + } + + if getattr(record, "record_id", None) is not None: + sample[DefaultDataKeys.METADATA]["image_id"] = record.record_id + + if getattr(record, "filepath", None) is not None: + sample[DefaultDataKeys.METADATA]["filepath"] = record.filepath + + if record.img is not None: + sample[DefaultDataKeys.INPUT] = record.img + filepath = getattr(record, "filepath", None) + if filepath is not None: + sample[DefaultDataKeys.METADATA]["filepath"] = filepath + elif record.filepath is not None: + sample[DefaultDataKeys.INPUT] = record.filepath + + sample[DefaultDataKeys.TARGET] = from_icevision_detection(record) if getattr(record.detection, "class_map", None) is not None: sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map @@ -171,6 +188,13 @@ def from_icevision_record(record: "BaseRecord"): return sample +def from_icevision_predictions(predictions: List["Prediction"]): + result = [] + for prediction in predictions: + result.append(from_icevision_detection(prediction.pred)) + return result + + class IceVisionTransformAdapter(nn.Module): def __init__(self, transform): super().__init__() diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9a1be7b114..9b00375d99 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -17,11 +17,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource from flash.core.data.process import Preprocess -from flash.core.integrations.icevision.data import ( - IceDataParserDataSource, - IceVisionParserDataSource, - IceVisionPathsDataSource, -) +from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires @@ -160,7 +156,7 @@ def __init__( "via": IceVisionParserDataSource(parser=VIABBoxParser), "voc": IceVisionParserDataSource(parser=VOCBBoxParser), DefaultDataSources.FILES: IceVisionPathsDataSource(), - DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser), DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), }, default_data_source=DefaultDataSources.FILES, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index c2bcd606f6..d431285336 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -18,6 +18,7 @@ from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer +from flash.core.data.serialization import Preds from flash.core.registry import FlashRegistry from flash.image.detection.backbones import OBJECT_DETECTION_HEADS @@ -57,7 +58,7 @@ def __init__( head: Optional[str] = "retinanet", pretrained: bool = True, optimizer: Type[Optimizer] = torch.optim.Adam, - learning_rate: float = 5e-4, + learning_rate: float = 5e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, ): @@ -77,7 +78,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - serializer=serializer, + serializer=serializer or Preds(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/detection/serialization.py b/flash/image/detection/serialization.py index e50614d0ef..115f7d3118 100644 --- a/flash/image/detection/serialization.py +++ b/flash/image/detection/serialization.py @@ -28,14 +28,6 @@ fo = None -class DetectionLabels(Serializer): - """A :class:`.Serializer` which extracts predictions from sample dict.""" - - def serialize(self, sample: Any) -> Dict[str, Any]: - sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample - return sample - - class FiftyOneDetectionLabels(Serializer): """A :class:`.Serializer` which converts model outputs to FiftyOne detection format. @@ -81,21 +73,23 @@ def serialize(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] detections = [] - for det in sample[DefaultDataKeys.PREDS]: - confidence = det["scores"].tolist() + preds = sample[DefaultDataKeys.PREDS] + + for bbox, label, score in zip(preds["bboxes"], preds["labels"], preds["scores"]): + confidence = score.tolist() if self.threshold is not None and confidence < self.threshold: continue - xmin, ymin, xmax, ymax = (c.tolist() for c in det["boxes"]) + xmin, ymin, box_width, box_height = bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"] box = [ - xmin / width, - ymin / height, - (xmax - xmin) / width, - (ymax - ymin) / height, + (xmin / width).item(), + (ymin / height).item(), + (box_width / width).item(), + (box_height / height).item(), ] - label = det["labels"].tolist() + label = label.item() if labels is not None: label = labels[label] else: diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index a475e2abb4..91a1e8eeb1 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -17,11 +17,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataSources from flash.core.data.process import Preprocess -from flash.core.integrations.icevision.data import ( - IceDataParserDataSource, - IceVisionParserDataSource, - IceVisionPathsDataSource, -) +from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -53,7 +49,7 @@ def __init__( "coco": IceVisionParserDataSource(parser=COCOMaskParser), "voc": IceVisionParserDataSource(parser=VOCMaskParser), DefaultDataSources.FILES: IceVisionPathsDataSource(), - DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser), }, default_data_source=DefaultDataSources.FILES, ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 52f2706554..dd16a389b2 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -18,6 +18,7 @@ from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer +from flash.core.data.serialization import Preds from flash.core.registry import FlashRegistry from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS @@ -77,7 +78,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - serializer=serializer, + serializer=serializer or Preds(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 485ddc0d56..0e7f700e4a 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -17,11 +17,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataSources from flash.core.data.process import Preprocess -from flash.core.integrations.icevision.data import ( - IceDataParserDataSource, - IceVisionParserDataSource, - IceVisionPathsDataSource, -) +from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -51,7 +47,7 @@ def __init__( data_sources={ "coco": IceVisionParserDataSource(parser=COCOKeyPointsParser), DefaultDataSources.FILES: IceVisionPathsDataSource(), - DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser), }, default_data_source=DefaultDataSources.FILES, ) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index b85177d083..d718c92587 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -18,6 +18,7 @@ from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer +from flash.core.data.serialization import Preds from flash.core.registry import FlashRegistry from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS @@ -79,7 +80,7 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - serializer=serializer, + serializer=serializer or Preds(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 771014bbb5..3162d1e3c8 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -35,8 +35,8 @@ class SemanticSegmentationPostprocess(Postprocess): def per_sample_transform(self, sample: Any) -> Any: resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear") - sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS])) - sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT])) + sample[DefaultDataKeys.PREDS] = resize(sample[DefaultDataKeys.PREDS]) + sample[DefaultDataKeys.INPUT] = resize(sample[DefaultDataKeys.INPUT]) return super().per_sample_transform(sample) diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py index bddcfe7e41..d4bd99e289 100644 --- a/flash/pointcloud/detection/open3d_ml/app.py +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -import torch from torch.utils.data.dataset import Dataset import flash @@ -157,7 +156,7 @@ def show_predictions(self, predictions): for pred in predictions: data = { - "points": torch.stack(pred[DefaultDataKeys.INPUT])[:, :3], + "points": pred[DefaultDataKeys.INPUT][:, :3], "name": pred[DefaultDataKeys.METADATA], } bounding_box = pred[DefaultDataKeys.PREDS] diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index b1145c53b5..45edb8bbe3 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -86,9 +86,9 @@ def show_predictions(self, predictions): for pred in predictions: predictions_visualizations.append( { - "points": torch.stack(pred[DefaultDataKeys.INPUT]), - "labels": torch.stack(pred[DefaultDataKeys.TARGET]), - "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, + "points": pred[DefaultDataKeys.INPUT], + "labels": pred[DefaultDataKeys.TARGET], + "predictions": torch.argmax(pred[DefaultDataKeys.PREDS], axis=-1) + 1, "name": pred[DefaultDataKeys.METADATA]["name"], } ) diff --git a/flash_examples/integrations/fiftyone/object_detection.py b/flash_examples/integrations/fiftyone/object_detection.py new file mode 100644 index 0000000000..efec712477 --- /dev/null +++ b/flash_examples/integrations/fiftyone/object_detection.py @@ -0,0 +1,51 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import chain + +import flash +from flash.core.integrations.fiftyone import visualize +from flash.core.utilities.imports import example_requires +from flash.image import ObjectDetectionData, ObjectDetector +from flash.image.detection.serialization import FiftyOneDetectionLabels + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.fridge.load_data() + +datamodule = ObjectDetectionData.from_folders( + train_folder=data_dir, + predict_folder=data_dir, + val_split=0.1, + image_size=128, + parser=icedata.fridge.parser, +) + +# 2. Build the task +model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Set the serializer and get some predictions +model.serializer = FiftyOneDetectionLabels(return_filepath=True) # output FiftyOne format +predictions = trainer.predict(model, datamodule=datamodule) +predictions = list(chain.from_iterable(predictions)) # flatten batches + +# 5. Visualize predictions in FiftyOne app +# Optional: pass `wait=True` to block execution until App is closed +session = visualize(predictions, wait=True) diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index a03457ed77..725a678907 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -113,9 +113,9 @@ def test_sequence(self): for sample in output: assert list(sample.keys()) == ["a", "b", "c"] - assert isinstance(sample["a"], list) + assert isinstance(sample["a"], torch.Tensor) assert len(sample["a"]) == 4 - assert isinstance(sample["b"], list) + assert isinstance(sample["b"], torch.Tensor) assert len(sample["b"]) == 2 assert isinstance(sample["c"], torch.Tensor) assert len(sample["c"].shape) == 0 @@ -130,7 +130,7 @@ def test_named_tuple(self): for sample in output: assert isinstance(sample, Batch) - assert isinstance(sample.x, list) + assert isinstance(sample.x, torch.Tensor) assert len(sample.x) == 4 assert isinstance(sample.y, torch.Tensor) assert len(sample.y.shape) == 0 diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index f5fd1fba85..ad56a302b8 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -23,13 +23,10 @@ from flash.__main__ import main from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ObjectDetector from tests.helpers.utils import _IMAGE_TESTING -if _ICEVISION_AVAILABLE: - from icevision.data import Prediction - def collate_fn(samples): return {key: [sample[key] for sample in samples] for key in samples[0]} @@ -81,10 +78,13 @@ def test_init(): dl = model.process_predict_dataset(ds, batch_size=batch_size) data = next(iter(dl)) - out = model(data) + out = model.forward(data[DefaultDataKeys.INPUT]) assert len(out) == batch_size - assert all(isinstance(res, Prediction) for res in out) + assert all(isinstance(res, dict) for res in out) + assert all("bboxes" in res for res in out) + assert all("labels" in res for res in out) + assert all("scores" in res for res in out) @pytest.mark.parametrize("head", ["faster_rcnn", "retinanet"]) diff --git a/tests/image/detection/test_serialization.py b/tests/image/detection/test_serialization.py index 8f707a229a..fcad6e5fe7 100644 --- a/tests/image/detection/test_serialization.py +++ b/tests/image/detection/test_serialization.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -23,13 +24,18 @@ def test_serialize_fiftyone(): labels_serial = FiftyOneDetectionLabels(labels=labels) sample = { - DefaultDataKeys.PREDS: [ - { - "boxes": [torch.tensor(20), torch.tensor(30), torch.tensor(40), torch.tensor(50)], - "labels": torch.tensor(0), - "scores": torch.tensor(0.5), - }, - ], + DefaultDataKeys.PREDS: { + "bboxes": [ + { + "xmin": torch.tensor(20), + "ymin": torch.tensor(30), + "width": torch.tensor(20), + "height": torch.tensor(20), + } + ], + "labels": [torch.tensor(0)], + "scores": [torch.tensor(0.5)], + }, DefaultDataKeys.METADATA: { "filepath": "something", "size": (100, 100), @@ -38,13 +44,13 @@ def test_serialize_fiftyone(): detections = serial.serialize(sample) assert len(detections.detections) == 1 - assert detections.detections[0].bounding_box == [0.2, 0.3, 0.2, 0.2] + np.testing.assert_array_almost_equal(detections.detections[0].bounding_box, [0.2, 0.3, 0.2, 0.2]) assert detections.detections[0].confidence == 0.5 assert detections.detections[0].label == "0" detections = filepath_serial.serialize(sample) assert len(detections["predictions"].detections) == 1 - assert detections["predictions"].detections[0].bounding_box == [0.2, 0.3, 0.2, 0.2] + np.testing.assert_array_almost_equal(detections["predictions"].detections[0].bounding_box, [0.2, 0.3, 0.2, 0.2]) assert detections["predictions"].detections[0].confidence == 0.5 assert detections["filepath"] == "something" @@ -53,6 +59,6 @@ def test_serialize_fiftyone(): detections = labels_serial.serialize(sample) assert len(detections.detections) == 1 - assert detections.detections[0].bounding_box == [0.2, 0.3, 0.2, 0.2] + np.testing.assert_array_almost_equal(detections.detections[0].bounding_box, [0.2, 0.3, 0.2, 0.2]) assert detections.detections[0].confidence == 0.5 assert detections.detections[0].label == "class_1" diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py index b337fa28da..3fa6248107 100644 --- a/tests/pointcloud/detection/test_data.py +++ b/tests/pointcloud/detection/test_data.py @@ -54,5 +54,5 @@ def training_step(self, batch, batch_idx: int): model.eval() predictions = model.predict([join(predict_path, "scans/000000.bin")]) - assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape[1] == 4 + assert predictions[0][DefaultDataKeys.INPUT].shape[1] == 4 assert len(predictions[0][DefaultDataKeys.PREDS]) == 158 diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py index a4c808fff2..400da2c0c4 100644 --- a/tests/pointcloud/segmentation/test_data.py +++ b/tests/pointcloud/segmentation/test_data.py @@ -51,6 +51,6 @@ def training_step(self, batch, batch_idx: int): trainer.fit(model, dm) predictions = model.predict(join(tmpdir, "SemanticKittiMicro", "predict")) - assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape == torch.Size([45056, 3]) - assert torch.stack(predictions[0][DefaultDataKeys.PREDS]).shape == torch.Size([45056, 19]) - assert torch.stack(predictions[0][DefaultDataKeys.TARGET]).shape == torch.Size([45056]) + assert predictions[0][DefaultDataKeys.INPUT].shape == torch.Size([45056, 3]) + assert predictions[0][DefaultDataKeys.PREDS].shape == torch.Size([45056, 19]) + assert predictions[0][DefaultDataKeys.TARGET].shape == torch.Size([45056])