From 5802dcfb6ebdaf322950fe0a57190c7fc6ab8d4b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 18:56:38 +0100 Subject: [PATCH 01/46] Initial commit --- .gitignore | 1 + flash/core/data/data_module.py | 5 +- flash/core/data/data_pipeline.py | 5 + flash/core/registry.py | 4 +- flash/core/utilities/imports.py | 2 + flash/image/detection/data.py | 186 ++++++++++-------- flash/image/detection/heads.py | 118 ++++++++++++ flash/image/detection/model.py | 283 +++++++++++++++++++--------- flash/image/detection/transforms.py | 47 ++--- flash_examples/object_detection.py | 12 +- requirements/datatype_image.txt | 1 + 11 files changed, 467 insertions(+), 197 deletions(-) create mode 100644 flash/image/detection/heads.py diff --git a/.gitignore b/.gitignore index 48be6f46a7..4636d8cef2 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ jigsaw_toxic_comments flash_examples/serve/tabular_classification/data logs/cache/* flash_examples/data +flash_examples/checkpoints diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 47f309b856..d2db6cd22a 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -363,12 +363,13 @@ def _predict_dataloader(self) -> DataLoader: pin_memory = True if isinstance(getattr(self, "trainer", None), pl.Trainer): - return self.trainer.lightning_module.process_test_dataset( + return self.trainer.lightning_module.process_predict_dataset( predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn + collate_fn=collate_fn, + convert_to_dataloader=True, ) return DataLoader( diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 2d4a2bf1d7..50552e78ad 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -223,6 +223,7 @@ def _create_collate_preprocessors( prefix: str = _STAGES_PREFIX[stage] if collate_fn is not None: + preprocess._original_default_collate = preprocess._default_collate preprocess._default_collate = collate_fn func_names: Dict[str, str] = { @@ -486,6 +487,10 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin elif isinstance(stage, RunningStage): stages = [stage] + self._preprocess_pipeline._default_collate = getattr( + self._preprocess_pipeline, "_original_default_collate", self._preprocess_pipeline._default_collate + ) + for stage in stages: device_collate = None diff --git a/flash/core/registry.py b/flash/core/registry.py index ff3c99c336..7f41099120 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Union from pytorch_lightning.utilities import rank_zero_info @@ -76,7 +74,7 @@ def _register_function( override: bool = False, metadata: Optional[Dict[str, Any]] = None ): - if not isinstance(fn, FunctionType) and not isinstance(fn, partial): + if not callable(fn): raise MisconfigurationException(f"You can only register a function, found: {fn}") name = name or fn.__name__ diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 9922f49eba..eda24ef98c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -89,6 +89,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") +_ICEVISION_AVAILABLE = _module_available("icevision") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") @@ -103,6 +104,7 @@ def _compare_version(package: str, op, version) -> bool: _KORNIA_AVAILABLE, _PYSTICHE_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, + _ICEVISION_AVAILABLE, ]) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index bc378567b6..0e72026a39 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,25 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TYPE_CHECKING + +import numpy as np from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, FiftyOneDataSource +from flash.core.data.data_source import DefaultDataKeys, FiftyOneDataSource from flash.core.data.process import Preprocess from flash.core.utilities.imports import ( _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, + _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import, - requires, ) from flash.image.data import ImagePathsDataSource from flash.image.detection.transforms import default_transforms if _COCO_AVAILABLE: - from pycocotools.coco import COCO + pass SampleCollection = None if _FIFTYONE_AVAILABLE: @@ -42,75 +43,102 @@ if _TORCHVISION_AVAILABLE: from torchvision.datasets.folder import default_loader +if _ICEVISION_AVAILABLE: + from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks + from icevision.data import SingleSplitSplitter + from icevision.parsers import Parser -class COCODataSource(DataSource[Tuple[str, str]]): - - @requires("pycocotools") - def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - root, ann_file = data - - coco = COCO(ann_file) - - categories = coco.loadCats(coco.getCatIds()) - if categories: - dataset.num_classes = categories[-1]["id"] + 1 - - img_ids = list(sorted(coco.imgs.keys())) - paths = coco.loadImgs(img_ids) - - data = [] - - for img_id, path in zip(img_ids, paths): - path = path["file_name"] - ann_ids = coco.getAnnIds(imgIds=img_id) - annotations = coco.loadAnns(ann_ids) +class IceVisionPathsDataSource(ImagePathsDataSource): - boxes, labels, areas, iscrowd = [], [], [], [] + def __init__(self, parser: Type[Parser]): + self.parser = parser - # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py - if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): - continue - - for obj in annotations: - xmin = obj["bbox"][0] - ymin = obj["bbox"][1] - xmax = xmin + obj["bbox"][2] - ymax = ymin + obj["bbox"][3] + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root, ann_file = data - bbox = [xmin, ymin, xmax, ymax] - keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) - if keep: - boxes.append(bbox) - labels.append(obj["category_id"]) - areas.append(obj["area"]) - iscrowd.append(obj["iscrowd"]) - - data.append( - dict( - input=os.path.join(root, path), - target=dict( - boxes=boxes, - labels=labels, - image_id=img_id, - area=areas, - iscrowd=iscrowd, - ) - ) - ) - return data + parser = self.parser(ann_file, root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + + def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return super().predict_load_data(data, dataset) + + # coco = COCO(ann_file) + # + # categories = coco.loadCats(coco.getCatIds()) + # if categories: + # dataset.num_classes = categories[-1]["id"] + 1 + # + # img_ids = list(sorted(coco.imgs.keys())) + # paths = coco.loadImgs(img_ids) + # + # data = [] + # + # for img_id, path in zip(img_ids, paths): + # path = path["file_name"] + # + # ann_ids = coco.getAnnIds(imgIds=img_id) + # annotations = coco.loadAnns(ann_ids) + # + # boxes, labels, areas, iscrowd = [], [], [], [] + # + # # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py + # if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): + # continue + # + # for obj in annotations: + # xmin = obj["bbox"][0] + # ymin = obj["bbox"][1] + # xmax = xmin + obj["bbox"][2] + # ymax = ymin + obj["bbox"][3] + # + # bbox = [xmin, ymin, xmax, ymax] + # keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) + # if keep: + # boxes.append(bbox) + # labels.append(obj["category_id"]) + # areas.append(obj["area"]) + # iscrowd.append(obj["iscrowd"]) + # + # data.append( + # dict( + # input=os.path.join(root, path), + # target=dict( + # boxes=boxes, + # labels=labels, + # image_id=img_id, + # area=areas, + # iscrowd=iscrowd, + # ) + # ) + # ) + # return data def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - return sample - return sample + # TODO: get image size for metadata + # sample[DefaultDataKeys.INPUT] = sample[DefaultDataKeys.INPUT].load() + return sample[DefaultDataKeys.INPUT].load() + # filepath = sample[DefaultDataKeys.INPUT] + # img = default_loader(filepath) + # sample[DefaultDataKeys.INPUT] = img + # w, h = img.size # WxH + # sample[DefaultDataKeys.METADATA] = { + # "filepath": filepath, + # "size": (h, w), + # } + # return sample + # return sample + + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + image = np.array(sample[DefaultDataKeys.INPUT]) + record = BaseRecord([ImageRecordComponent()]) + # record.set_record_id(i) + record.set_img(image) + record.add_component(ClassMapRecordComponent(task=tasks.detection)) + return record class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): @@ -205,22 +233,27 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), **data_source_kwargs: Any, ): + self.image_size = image_size + super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), - DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource(), - "coco": COCODataSource(), + # DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), + # DefaultDataSources.FILES: ObjectDetectionPathsDataSource(), + # DefaultDataSources.FOLDERS: ObjectDetectionPathsDataSource(), + # "coco": COCODataSource(), }, - default_data_source=DefaultDataSources.FILES, + default_data_source="coco", ) + self._default_collate = self._identity + def get_state_dict(self) -> Dict[str, Any]: return {**self.transforms} @@ -229,7 +262,10 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) def default_transforms(self) -> Optional[Dict[str, Callable]]: - return default_transforms() + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) class ObjectDetectionData(DataModule): @@ -237,7 +273,6 @@ class ObjectDetectionData(DataModule): preprocess_cls = ObjectDetectionPreprocess @classmethod - @requires("pycocotools") def from_coco( cls, train_folder: Optional[str] = None, @@ -246,9 +281,11 @@ def from_coco( val_ann_file: Optional[str] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -298,6 +335,7 @@ def from_coco( (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, (test_folder, test_ann_file) if test_folder else None, + predict_folder, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, diff --git a/flash/image/detection/heads.py b/flash/image/detection/heads.py new file mode 100644 index 0000000000..57cd9a3ef0 --- /dev/null +++ b/flash/image/detection/heads.py @@ -0,0 +1,118 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from inspect import getmembers + +import torch +from torch import nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.backbones import BackboneConfig + +OBJECT_DETECTION_HEADS = FlashRegistry("heads") + +if _ICEVISION_AVAILABLE: + + def _icevision_model_adapter(model_type): + + class IceVisionModelAdapter(model_type.lightning.ModelAdapter): + + def log(self, name, value, **kwargs): + if "prog_bar" not in kwargs: + kwargs["prog_bar"] = True + return super().log(name, value, **kwargs) + + return IceVisionModelAdapter + + def _load_icevision(adapter, model_type, backbone, num_classes, **kwargs): + model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) + + backbone = nn.Module() + params = model.param_groups()[0] + for i, param in enumerate(params): + backbone.register_parameter(f"backbone_{i}", param) + + return model_type, model, adapter(model_type), backbone + + def _load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + return _load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + def _load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + kwargs["img_size"] = image_size + return _load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + def _get_backbones(model_type): + _BACKBONES = FlashRegistry("backbones") + + for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)): + _BACKBONES( + backbone_config, + name=backbone_name, + ) + return _BACKBONES + + if _TORCHVISION_AVAILABLE: + for model_type in [icevision_models.torchvision.retinanet, icevision_models.torchvision.faster_rcnn]: + OBJECT_DETECTION_HEADS( + partial(_load_icevision_ignore_image_size, _icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=_get_backbones(model_type), + ) + + if _module_available("yolov5"): + model_type = icevision_models.ultralytics.yolov5 + OBJECT_DETECTION_HEADS( + partial(_load_icevision_with_image_size, _icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=_get_backbones(model_type), + ) + + if _module_available("mmdet"): + for model_type in [ + icevision_models.mmdet.faster_rcnn, + icevision_models.mmdet.retinanet, + icevision_models.mmdet.fcos, + icevision_models.mmdet.sparse_rcnn, + ]: + OBJECT_DETECTION_HEADS( + partial(_load_icevision_ignore_image_size, _icevision_model_adapter, model_type), + f"mmdet_{model_type.__name__.split('.')[-1]}", + backbones=_get_backbones(model_type), + ) + + if _module_available("effdet"): + + def _icevision_effdet_model_adapter(model_type): + + class IceVisionEffdetModelAdapter(_icevision_model_adapter(model_type)): + + def validation_step(self, batch, batch_idx): + images = batch[0][0] + batch[0][1]["img_scale"] = torch.ones_like(images[:, 0, 0, 0]).unsqueeze(1) + batch[0][1]["img_size"] = (torch.ones_like(images[:, 0, 0, 0]) * + images[0].shape[-1]).unsqueeze(1).repeat(1, 2) + return super().validation_step(batch, batch_idx) + + return IceVisionEffdetModelAdapter + + model_type = icevision_models.ross.efficientdet + OBJECT_DETECTION_HEADS( + partial(_load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=_get_backbones(model_type), + ) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 41edea48ee..7dbb615d76 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -16,20 +16,18 @@ import torch from torch import nn, tensor from torch.optim import Optimizer +from torch.utils.data import DataLoader, Sampler -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.process import Serializer from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.image.backbones import OBJ_DETECTION_BACKBONES -from flash.image.detection.finetuning import ObjectDetectionFineTuning +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.image.detection.heads import OBJECT_DETECTION_HEADS from flash.image.detection.serialization import DetectionLabels if _TORCHVISION_AVAILABLE: import torchvision - from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor - from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead from torchvision.models.detection.rpn import AnchorGenerator from torchvision.ops import box_iou @@ -41,6 +39,11 @@ else: AnchorGenerator = None +if _ICEVISION_AVAILABLE: + from icevision.core import BaseRecord + from icevision.data import Dataset + from icevision.metrics import COCOMetric, COCOMetricType + def _evaluate_iou(target, pred): """ @@ -76,16 +79,17 @@ class ObjectDetector(Task): """ - backbones: FlashRegistry = OBJ_DETECTION_BACKBONES + # backbones: FlashRegistry = OBJ_DETECTION_BACKBONES + + heads: FlashRegistry = OBJECT_DETECTION_HEADS required_extras: str = "image" def __init__( self, num_classes: int, - model: str = "fasterrcnn", - backbone: Optional[str] = None, - fpn: bool = True, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "retinanet", pretrained: bool = True, pretrained_backbone: bool = True, trainable_backbone_layers: int = 3, @@ -99,108 +103,201 @@ def __init__( ): self.save_hyperparameters() - if model in _models: - model = ObjectDetector.get_model( - model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, - anchor_generator, **kwargs - ) - else: - ValueError(f"{model} is not supported yet.") + # if model in _models: + # model = ObjectDetector.get_model( + # model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, + # anchor_generator, **kwargs + # ) + # else: + # ValueError(f"{model} is not supported yet.") super().__init__( - model=model, + model=None, loss_fn=loss, - metrics=metrics, + metrics=None, learning_rate=learning_rate, optimizer=optimizer, serializer=serializer or DetectionLabels(), ) - @staticmethod - def get_model( - model_name, - num_classes, - backbone, - fpn, - pretrained, - pretrained_backbone, - trainable_backbone_layers, - anchor_generator, - **kwargs, - ): - if backbone is None: - # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. - if model_name == "fasterrcnn": - model = _models[model_name]( - pretrained=pretrained, - pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, - ) - in_features = model.roi_heads.box_predictor.cls_score.in_features - head = FastRCNNPredictor(in_features, num_classes) - model.roi_heads.box_predictor = head - else: - model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone) - model.head = RetinaNetHead( - in_channels=model.backbone.out_channels, - num_anchors=model.head.classification_head.num_anchors, - num_classes=num_classes, - **kwargs - ) - else: - backbone_model, num_features = ObjectDetector.backbones.get(backbone)( - pretrained=pretrained_backbone, - trainable_layers=trainable_backbone_layers, - **kwargs, - ) - backbone_model.out_channels = num_features - if anchor_generator is None: - anchor_generator = AnchorGenerator( - sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), ) - ) if not hasattr(backbone_model, "fpn") else None + metadata = self.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + backbone_config = backbones.get(backbone)(pretrained) + self.model_type, self.model, adapter, self.backbone = metadata["fn"](backbone_config, num_classes, **kwargs) + self.adapter = adapter(model=self.model, metrics=metrics or [COCOMetric(metric_type=COCOMetricType.bbox)]) - if model_name == "fasterrcnn": - model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) - else: - model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) - return model + @classmethod + def available_backbones(cls, head: str) -> List[str]: + metadata = cls.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + return backbones.available_keys() - def forward(self, x: List[torch.Tensor]) -> Any: - return self.model(x) + # @staticmethod + # def get_model( + # model_name, + # num_classes, + # backbone, + # fpn, + # pretrained, + # pretrained_backbone, + # trainable_backbone_layers, + # anchor_generator, + # **kwargs, + # ): + # if backbone is None: + # # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. + # if model_name == "fasterrcnn": + # model = _models[model_name]( + # pretrained=pretrained, + # pretrained_backbone=pretrained_backbone, + # trainable_backbone_layers=trainable_backbone_layers, + # ) + # in_features = model.roi_heads.box_predictor.cls_score.in_features + # head = FastRCNNPredictor(in_features, num_classes) + # model.roi_heads.box_predictor = head + # else: + # model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone) + # model.head = RetinaNetHead( + # in_channels=model.backbone.out_channels, + # num_anchors=model.head.classification_head.num_anchors, + # num_classes=num_classes, + # **kwargs + # ) + # else: + # backbone_model, num_features = ObjectDetector.backbones.get(backbone)( + # pretrained=pretrained_backbone, + # trainable_layers=trainable_backbone_layers, + # **kwargs, + # ) + # backbone_model.out_channels = num_features + # if anchor_generator is None: + # anchor_generator = AnchorGenerator( + # sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), ) + # ) if not hasattr(backbone_model, "fpn") else None + # + # if model_name == "fasterrcnn": + # model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) + # else: + # model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) + # return model - def training_step(self, batch, batch_idx) -> Any: - """The training step. Overrides ``Task.training_step`` - """ - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - targets = [dict(t.items()) for t in targets] + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: + checkpoint['_data_pipeline_state'] = self._data_pipeline_state + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.train_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) - # fasterrcnn takes both images and targets for training, returns loss_dict - loss_dict = self.model(images, targets) - loss = sum(loss_dict.values()) - self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, prog_bar=True) - return loss + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True + ) -> Union[DataLoader, BaseAutoDataset]: + if convert_to_dataloader: + return self.model_type.infer_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + return dataset + + def training_step(self, batch, batch_idx) -> Any: + return self.adapter.training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - # fasterrcnn takes only images for eval() mode - outs = self(images) - iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() - self.log("val_iou", iou) + return self.adapter.validation_step(batch, batch_idx) def test_step(self, batch, batch_idx): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - # fasterrcnn takes only images for eval() mode - outs = self(images) - iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() - self.log("test_iou", iou) + return self.adapter.validation_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - images = batch[DefaultDataKeys.INPUT] - batch[DefaultDataKeys.PREDS] = self(images) - return batch + if isinstance(batch, list) and isinstance(batch[0], BaseRecord): + data = Dataset(batch) + return self.model_type.predict(self.model, data) + return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) + + def training_epoch_end(self, outputs) -> None: + return self.adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) - def configure_finetune_callback(self): - return [ObjectDetectionFineTuning(train_bn=True)] + # def configure_finetune_callback(self): + # return [ObjectDetectionFineTuning(train_bn=True)] def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """ diff --git a/flash/image/detection/transforms.py b/flash/image/detection/transforms.py index 1f54854376..e70dc9118d 100644 --- a/flash/image/detection/transforms.py +++ b/flash/image/detection/transforms.py @@ -11,37 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Sequence +from typing import Any, Callable, Dict, Sequence, Tuple -import torch -from torch import nn - -from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE, requires_extras if _TORCHVISION_AVAILABLE: - import torchvision + pass + +if _ICEVISION_AVAILABLE: + from icevision.tfms import A def collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: return {key: [sample[key] for sample in samples] for key in samples[0]} -def default_transforms() -> Dict[str, Callable]: +@requires_extras("image") +def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms for object detection: convert the image and targets to a tensor, collate the batch.""" + return { + # "pre_tensor_transform": ApplyToKeys( + # DefaultDataKeys.INPUT, + # tfms.A.Adapter([*tfms.A.resize_and_pad(image_size), tfms.A.Normalize()]), + # ) + "pre_tensor_transform": A.Adapter([*A.resize_and_pad(image_size), A.Normalize()]), + } + + +@requires_extras("image") +def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms for object detection: convert the image and targets to a tensor, collate the batch.""" return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys( - 'target', - nn.Sequential( - ApplyToKeys('boxes', torch.as_tensor), - ApplyToKeys('labels', torch.as_tensor), - ApplyToKeys('image_id', torch.as_tensor), - ApplyToKeys('area', torch.as_tensor), - ApplyToKeys('iscrowd', torch.as_tensor), - ) - ), - ), - "collate": collate, + # "pre_tensor_transform": ApplyToKeys( + # DefaultDataKeys.INPUT, + # tfms.A.Adapter([*tfms.A.resize_and_pad(image_size), tfms.A.Normalize()]), + # ) + "pre_tensor_transform": A.Adapter([*A.aug_tfms(size=image_size), A.Normalize()]), } diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 118bdc5c67..c8d58ef133 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -23,16 +23,22 @@ train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, + predict_folder="data/coco128/images/train2017/", + image_size=(256, 256), ) # 2. Build the task -model = ObjectDetector(model="retinanet", num_classes=datamodule.num_classes) +print(ObjectDetector.available_heads()) +print(ObjectDetector.available_backbones("efficientdet")) + +model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=(256, 256)) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) -trainer.finetune(model, datamodule=datamodule) +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! +print(trainer.predict(model, datamodule=datamodule)) predictions = model.predict([ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index d39ad59395..f289b3c3a7 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -5,3 +5,4 @@ Pillow>=7.2 kornia>=0.5.1,<0.5.4 pystiche>=0.7.2 segmentation-models-pytorch +icevision>=0.8 From 35c04659acfac486092348a7816bf3acca745ed3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 14:35:58 +0100 Subject: [PATCH 02/46] Add instance segmentation and keypoint detection tasks --- flash/core/integrations/icevision/__init__.py | 0 .../core/integrations/icevision/backbones.py | 67 ++++ flash/core/integrations/icevision/data.py | 79 +++++ flash/core/integrations/icevision/model.py | 222 +++++++++++++ .../integrations/icevision}/transforms.py | 0 flash/core/utilities/imports.py | 4 + flash/image/__init__.py | 4 +- flash/image/backbones.py | 18 -- .../detection/{heads.py => backbones.py} | 66 +--- flash/image/detection/data.py | 299 +++++++++++------- flash/image/detection/finetuning.py | 29 -- flash/image/detection/model.py | 249 +-------------- flash/image/instance_segmentation/__init__.py | 2 + .../image/instance_segmentation/backbones.py | 44 +++ flash/image/instance_segmentation/data.py | 235 ++++++++++++++ flash/image/instance_segmentation/model.py | 91 ++++++ flash/image/keypoint_detection/__init__.py | 2 + flash/image/keypoint_detection/backbones.py | 36 +++ flash/image/keypoint_detection/data.py | 155 +++++++++ flash/image/keypoint_detection/model.py | 92 ++++++ flash_examples/graph_classification.py | 9 +- flash_examples/instance_segmentation.py | 54 ++++ flash_examples/keypoint_detection.py | 53 ++++ flash_examples/object_detection.py | 11 +- requirements/datatype_image.txt | 2 + 25 files changed, 1361 insertions(+), 462 deletions(-) create mode 100644 flash/core/integrations/icevision/__init__.py create mode 100644 flash/core/integrations/icevision/backbones.py create mode 100644 flash/core/integrations/icevision/data.py create mode 100644 flash/core/integrations/icevision/model.py rename flash/{image/detection => core/integrations/icevision}/transforms.py (100%) rename flash/image/detection/{heads.py => backbones.py} (51%) delete mode 100644 flash/image/detection/finetuning.py create mode 100644 flash/image/instance_segmentation/__init__.py create mode 100644 flash/image/instance_segmentation/backbones.py create mode 100644 flash/image/instance_segmentation/data.py create mode 100644 flash/image/instance_segmentation/model.py create mode 100644 flash/image/keypoint_detection/__init__.py create mode 100644 flash/image/keypoint_detection/backbones.py create mode 100644 flash/image/keypoint_detection/data.py create mode 100644 flash/image/keypoint_detection/model.py create mode 100644 flash_examples/instance_segmentation.py create mode 100644 flash_examples/keypoint_detection.py diff --git a/flash/core/integrations/icevision/__init__.py b/flash/core/integrations/icevision/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/core/integrations/icevision/backbones.py b/flash/core/integrations/icevision/backbones.py new file mode 100644 index 0000000000..82225d8eb9 --- /dev/null +++ b/flash/core/integrations/icevision/backbones.py @@ -0,0 +1,67 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from inspect import getmembers + +from torch import nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.backbones import BackboneConfig + +OBJECT_DETECTION_HEADS = FlashRegistry("heads") + + +def icevision_model_adapter(model_type): + + class IceVisionModelAdapter(model_type.lightning.ModelAdapter): + + def log(self, name, value, **kwargs): + if "prog_bar" not in kwargs: + kwargs["prog_bar"] = True + return super().log(name.split("/")[-1], value, **kwargs) + + return IceVisionModelAdapter + + +def load_icevision(adapter, model_type, backbone, num_classes, **kwargs): + model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) + + backbone = nn.Module() + params = model.param_groups()[0] + for i, param in enumerate(params): + backbone.register_parameter(f"backbone_{i}", param) + + return model_type, model, adapter(model_type), backbone + + +def load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + +def load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + kwargs["img_size"] = image_size + return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + +def get_backbones(model_type): + _BACKBONES = FlashRegistry("backbones") + + for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)): + _BACKBONES( + backbone_config, + name=backbone_name, + ) + return _BACKBONES diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py new file mode 100644 index 0000000000..6375ddc360 --- /dev/null +++ b/flash/core/integrations/icevision/data.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type + +import numpy as np + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.image.data import ImagePathsDataSource + +if _ICEVISION_AVAILABLE: + from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks + from icevision.data import SingleSplitSplitter + from icevision.parsers import Parser + + +class IceVisionPathsDataSource(ImagePathsDataSource): + + def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return super().predict_load_data(data, dataset) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + return sample[DefaultDataKeys.INPUT].load() + + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + image = np.array(sample[DefaultDataKeys.INPUT]) + record = BaseRecord([ImageRecordComponent()]) + + record.set_img(image) + record.add_component(ClassMapRecordComponent(task=tasks.detection)) + return record + + +class IceVisionParserDataSource(IceVisionPathsDataSource): + + def __init__(self, parser: Optional[Type[Parser]] = None): + super().__init__() + self.parser = parser + + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root, ann_file = data + + if self.parser is not None: + parser = self.parser(ann_file, root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + else: + raise ValueError("The parser type must be provided") + + +class IceDataParserDataSource(IceVisionPathsDataSource): + + def __init__(self, parser: Optional[Callable] = None): + super().__init__() + self.parser = parser + + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root = data + + if self.parser is not None: + parser = self.parser(root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + else: + raise ValueError("The parser must be provided") diff --git a/flash/core/integrations/icevision/model.py b/flash/core/integrations/icevision/model.py new file mode 100644 index 0000000000..ad8a80ab0e --- /dev/null +++ b/flash/core/integrations/icevision/model.py @@ -0,0 +1,222 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, Sampler + +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.process import Deserializer, Postprocess, Preprocess, Serializer +from flash.core.model import Task +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.core import BaseRecord + from icevision.data import Dataset + from icevision.metrics import COCOMetric + from icevision.metrics import Metric as IceVisionMetric + + +class SimpleCOCOMetric(COCOMetric): + + def finalize(self) -> Dict[str, float]: + logs = super().finalize() + return { + "Precision (IoU=0.50:0.95,area=all)": logs["AP (IoU=0.50:0.95) area=all"], + "Recall (IoU=0.50:0.95,area=all,maxDets=100)": logs["AR (IoU=0.50:0.95) area=all maxDets=100"], + } + + +class IceVisionTask(Task): + """The ``IceVisionTask`` is a base :class:`~flash.Task` for integrating with IceVision. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretrained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + pretrained: if true, returns a model pre-trained on COCO train2017. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + image_size + """ + + required_extras: str = "image" + + def __init__( + self, + num_classes: int, + backbone: str, + head: str, + pretrained: bool = True, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Optional[IceVisionMetric] = None, + learning_rate: float = 5e-4, + deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, + preprocess: Optional[Preprocess] = None, + postprocess: Optional[Postprocess] = None, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + image_size: Optional = None, + **kwargs, + ): + self.save_hyperparameters() + + super().__init__( + model=None, + metrics=None, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + learning_rate=learning_rate, + deserializer=deserializer, + preprocess=preprocess, + postprocess=postprocess, + serializer=serializer, + ) + + metadata = self.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + backbone_config = backbones.get(backbone)(pretrained) + self.model_type, self.model, adapter, self.backbone = metadata["fn"]( + backbone_config, + num_classes, + image_size=image_size, + **kwargs, + ) + self.adapter = adapter(model=self.model, metrics=metrics) + + @classmethod + def available_backbones(cls, head: str) -> List[str]: + metadata = cls.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + return backbones.available_keys() + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: + checkpoint['_data_pipeline_state'] = self._data_pipeline_state + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.train_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True + ) -> Union[DataLoader, BaseAutoDataset]: + if convert_to_dataloader: + return self.model_type.infer_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + return dataset + + def training_step(self, batch, batch_idx) -> Any: + return self.adapter.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.adapter.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.adapter.validation_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + if isinstance(batch, list) and isinstance(batch[0], BaseRecord): + data = Dataset(batch) + return self.model_type.predict(self.model, data) + return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) + + def training_epoch_end(self, outputs) -> None: + return self.adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) diff --git a/flash/image/detection/transforms.py b/flash/core/integrations/icevision/transforms.py similarity index 100% rename from flash/image/detection/transforms.py rename to flash/core/integrations/icevision/transforms.py diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 654ae3a165..f2c50c123c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -163,6 +163,10 @@ def requires_extras(extras: Union[str, List]): ) +def example_requires(extras: Union[str, List[str]]): + return requires_extras(extras)(lambda: None)() + + def lazy_import(module_name, callback=None): """Returns a proxy module object that will lazily import the given module the first time it is used. diff --git a/flash/image/__init__.py b/flash/image/__init__.py index c099e1c086..6e500cde4c 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -1,4 +1,4 @@ -from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES # noqa: F401 +from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401 from flash.image.classification import ( # noqa: F401 ImageClassificationData, ImageClassificationPreprocess, @@ -6,6 +6,8 @@ ) from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401 from flash.image.embedding import ImageEmbedder # noqa: F401 +from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401 +from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401 from flash.image.segmentation import ( # noqa: F401 SemanticSegmentation, SemanticSegmentationData, diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 267f4f8018..fc32dfa09c 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -29,7 +29,6 @@ if _TORCHVISION_AVAILABLE: import torchvision - from torchvision.models.detection.backbone_utils import resnet_fpn_backbone MOBILENET_MODELS = ["mobilenet_v2"] VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"] @@ -38,7 +37,6 @@ TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") -OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") def catch_url_error(fn): @@ -127,15 +125,6 @@ def _fn_resnet(model_name: str, return backbone, num_features - def _fn_resnet_fpn( - model_name: str, - pretrained: bool = True, - trainable_layers: bool = True, - **kwargs, - ) -> Tuple[nn.Module, int]: - backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs) - return backbone, 256 - for model_name in RESNET_MODELS: clf_kwargs = dict( fn=catch_url_error(partial(_fn_resnet, model_name=model_name)), @@ -158,13 +147,6 @@ def _fn_resnet_fpn( ) IMAGE_CLASSIFIER_BACKBONES(**clf_kwargs) - OBJ_DETECTION_BACKBONES( - fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), - name=model_name, - package="torchvision", - type="resnet-fpn" - ) - def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) diff --git a/flash/image/detection/heads.py b/flash/image/detection/backbones.py similarity index 51% rename from flash/image/detection/heads.py rename to flash/image/detection/backbones.py index 57cd9a3ef0..3d8e64dad1 100644 --- a/flash/image/detection/heads.py +++ b/flash/image/detection/backbones.py @@ -12,74 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from inspect import getmembers import torch -from torch import nn +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, + load_icevision_with_image_size, +) from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE if _ICEVISION_AVAILABLE: from icevision import models as icevision_models - from icevision.backbones import BackboneConfig OBJECT_DETECTION_HEADS = FlashRegistry("heads") if _ICEVISION_AVAILABLE: - - def _icevision_model_adapter(model_type): - - class IceVisionModelAdapter(model_type.lightning.ModelAdapter): - - def log(self, name, value, **kwargs): - if "prog_bar" not in kwargs: - kwargs["prog_bar"] = True - return super().log(name, value, **kwargs) - - return IceVisionModelAdapter - - def _load_icevision(adapter, model_type, backbone, num_classes, **kwargs): - model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) - - backbone = nn.Module() - params = model.param_groups()[0] - for i, param in enumerate(params): - backbone.register_parameter(f"backbone_{i}", param) - - return model_type, model, adapter(model_type), backbone - - def _load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): - return _load_icevision(adapter, model_type, backbone, num_classes, **kwargs) - - def _load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): - kwargs["img_size"] = image_size - return _load_icevision(adapter, model_type, backbone, num_classes, **kwargs) - - def _get_backbones(model_type): - _BACKBONES = FlashRegistry("backbones") - - for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)): - _BACKBONES( - backbone_config, - name=backbone_name, - ) - return _BACKBONES - if _TORCHVISION_AVAILABLE: for model_type in [icevision_models.torchvision.retinanet, icevision_models.torchvision.faster_rcnn]: OBJECT_DETECTION_HEADS( - partial(_load_icevision_ignore_image_size, _icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], - backbones=_get_backbones(model_type), + backbones=get_backbones(model_type), ) if _module_available("yolov5"): model_type = icevision_models.ultralytics.yolov5 OBJECT_DETECTION_HEADS( - partial(_load_icevision_with_image_size, _icevision_model_adapter, model_type), + partial(load_icevision_with_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], - backbones=_get_backbones(model_type), + backbones=get_backbones(model_type), ) if _module_available("mmdet"): @@ -90,16 +54,16 @@ def _get_backbones(model_type): icevision_models.mmdet.sparse_rcnn, ]: OBJECT_DETECTION_HEADS( - partial(_load_icevision_ignore_image_size, _icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", - backbones=_get_backbones(model_type), + backbones=get_backbones(model_type), ) if _module_available("effdet"): def _icevision_effdet_model_adapter(model_type): - class IceVisionEffdetModelAdapter(_icevision_model_adapter(model_type)): + class IceVisionEffdetModelAdapter(icevision_model_adapter(model_type)): def validation_step(self, batch, batch_idx): images = batch[0][0] @@ -112,7 +76,7 @@ def validation_step(self, batch, batch_idx): model_type = icevision_models.ross.efficientdet OBJECT_DETECTION_HEADS( - partial(_load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), + partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), model_type.__name__.split(".")[-1], - backbones=_get_backbones(model_type), + backbones=get_backbones(model_type), ) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 0e72026a39..71a0df7613 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,26 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TYPE_CHECKING - -import numpy as np +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, FiftyOneDataSource +from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource from flash.core.data.process import Preprocess -from flash.core.utilities.imports import ( - _COCO_AVAILABLE, - _FIFTYONE_AVAILABLE, - _ICEVISION_AVAILABLE, - _TORCHVISION_AVAILABLE, - lazy_import, +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, ) -from flash.image.data import ImagePathsDataSource -from flash.image.detection.transforms import default_transforms - -if _COCO_AVAILABLE: - pass +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import SampleCollection = None if _FIFTYONE_AVAILABLE: @@ -44,101 +37,7 @@ from torchvision.datasets.folder import default_loader if _ICEVISION_AVAILABLE: - from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks - from icevision.data import SingleSplitSplitter - from icevision.parsers import Parser - - -class IceVisionPathsDataSource(ImagePathsDataSource): - - def __init__(self, parser: Type[Parser]): - self.parser = parser - - def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - root, ann_file = data - - parser = self.parser(ann_file, root) - dataset.num_classes = len(parser.class_map) - records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] - - def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - return super().predict_load_data(data, dataset) - - # coco = COCO(ann_file) - # - # categories = coco.loadCats(coco.getCatIds()) - # if categories: - # dataset.num_classes = categories[-1]["id"] + 1 - # - # img_ids = list(sorted(coco.imgs.keys())) - # paths = coco.loadImgs(img_ids) - # - # data = [] - # - # for img_id, path in zip(img_ids, paths): - # path = path["file_name"] - # - # ann_ids = coco.getAnnIds(imgIds=img_id) - # annotations = coco.loadAnns(ann_ids) - # - # boxes, labels, areas, iscrowd = [], [], [], [] - # - # # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py - # if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): - # continue - # - # for obj in annotations: - # xmin = obj["bbox"][0] - # ymin = obj["bbox"][1] - # xmax = xmin + obj["bbox"][2] - # ymax = ymin + obj["bbox"][3] - # - # bbox = [xmin, ymin, xmax, ymax] - # keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) - # if keep: - # boxes.append(bbox) - # labels.append(obj["category_id"]) - # areas.append(obj["area"]) - # iscrowd.append(obj["iscrowd"]) - # - # data.append( - # dict( - # input=os.path.join(root, path), - # target=dict( - # boxes=boxes, - # labels=labels, - # image_id=img_id, - # area=areas, - # iscrowd=iscrowd, - # ) - # ) - # ) - # return data - - def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - # TODO: get image size for metadata - # sample[DefaultDataKeys.INPUT] = sample[DefaultDataKeys.INPUT].load() - return sample[DefaultDataKeys.INPUT].load() - # filepath = sample[DefaultDataKeys.INPUT] - # img = default_loader(filepath) - # sample[DefaultDataKeys.INPUT] = img - # w, h = img.size # WxH - # sample[DefaultDataKeys.METADATA] = { - # "filepath": filepath, - # "size": (h, w), - # } - # return sample - # return sample - - def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - sample = super().load_sample(sample) - image = np.array(sample[DefaultDataKeys.INPUT]) - record = BaseRecord([ImageRecordComponent()]) - # record.set_record_id(i) - record.set_img(image) - record.add_component(ClassMapRecordComponent(task=tasks.detection)) - return record + from icevision.parsers import COCOBBoxParser, VIABBoxParser, VOCBBoxParser class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): @@ -234,6 +133,7 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, **data_source_kwargs: Any, ): self.image_size = image_size @@ -244,12 +144,13 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - # DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), - # DefaultDataSources.FILES: ObjectDetectionPathsDataSource(), - # DefaultDataSources.FOLDERS: ObjectDetectionPathsDataSource(), - # "coco": COCODataSource(), + "coco": IceVisionParserDataSource(parser=COCOBBoxParser), + "via": IceVisionParserDataSource(parser=VIABBoxParser), + "voc": IceVisionParserDataSource(parser=VOCBBoxParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), }, - default_data_source="coco", + default_data_source=DefaultDataSources.FILES, ) self._default_collate = self._identity @@ -293,8 +194,8 @@ def from_coco( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data - folders and corresponding target folders. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and + annotation files in the COCO format. Args: train_folder: The folder containing the train data. @@ -303,12 +204,15 @@ def from_coco( val_ann_file: The COCO format annotation file. test_folder: The folder containing the test data. test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the @@ -325,7 +229,7 @@ def from_coco( Examples:: - data_module = SemanticSegmentationData.from_coco( + data_module = ObjectDetectionData.from_coco( train_folder="train_folder", train_ann_file="annotations.json", ) @@ -339,6 +243,165 @@ def from_coco( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_voc( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and + annotation files in the VOC format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ObjectDetectionData.from_voc( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "voc", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_via( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and + annotation files in the VIA format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ObjectDetectionData.from_via( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "via", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, diff --git a/flash/image/detection/finetuning.py b/flash/image/detection/finetuning.py deleted file mode 100644 index c1ca20072d..0000000000 --- a/flash/image/detection/finetuning.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytorch_lightning as pl - -from flash.core.finetuning import FlashBaseFinetuning - - -class ObjectDetectionFineTuning(FlashBaseFinetuning): - """ - Freezes the backbone during Detector training. - """ - - def __init__(self, train_bn: bool = True) -> None: - super().__init__(train_bn=train_bn) - - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - model = pl_module.model - self.freeze(modules=model.backbone, train_bn=self.train_bn) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 7dbb615d76..979a2d71c7 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,51 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Mapping, Optional, Type, Union import torch -from torch import nn, tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, Sampler -from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.process import Serializer -from flash.core.model import Task +from flash.core.integrations.icevision.model import IceVisionTask, SimpleCOCOMetric from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE -from flash.image.detection.heads import OBJECT_DETECTION_HEADS +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.image.detection.backbones import OBJECT_DETECTION_HEADS from flash.image.detection.serialization import DetectionLabels -if _TORCHVISION_AVAILABLE: - import torchvision - from torchvision.models.detection.rpn import AnchorGenerator - from torchvision.ops import box_iou - - _models = { - "fasterrcnn": torchvision.models.detection.fasterrcnn_resnet50_fpn, - "retinanet": torchvision.models.detection.retinanet_resnet50_fpn, - } - -else: - AnchorGenerator = None - if _ICEVISION_AVAILABLE: - from icevision.core import BaseRecord - from icevision.data import Dataset - from icevision.metrics import COCOMetric, COCOMetricType - + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric -def _evaluate_iou(target, pred): - """ - Evaluate intersection over union (IOU) for target from dataset and output prediction from model - """ - if pred["boxes"].shape[0] == 0: - # no box detected, 0 IOU - return tensor(0.0, device=pred["boxes"].device) - return box_iou(target["boxes"], pred["boxes"]).diag().mean() - -class ObjectDetector(Task): +class ObjectDetector(IceVisionTask): """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see :ref:`object_detection`. @@ -79,8 +52,6 @@ class ObjectDetector(Task): """ - # backbones: FlashRegistry = OBJ_DETECTION_BACKBONES - heads: FlashRegistry = OBJECT_DETECTION_HEADS required_extras: str = "image" @@ -91,214 +62,28 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, - pretrained_backbone: bool = True, - trainable_backbone_layers: int = 3, - anchor_generator: Optional[Type['AnchorGenerator']] = None, - loss=None, - metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, + metrics: Optional[IceVisionMetric] = None, optimizer: Type[Optimizer] = torch.optim.AdamW, - learning_rate: float = 1e-3, + learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + image_size: Optional[int] = None, **kwargs: Any, ): self.save_hyperparameters() - # if model in _models: - # model = ObjectDetector.get_model( - # model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, - # anchor_generator, **kwargs - # ) - # else: - # ValueError(f"{model} is not supported yet.") - super().__init__( - model=None, - loss_fn=loss, - metrics=None, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.bbox)], + image_size=image_size, learning_rate=learning_rate, optimizer=optimizer, serializer=serializer or DetectionLabels(), + **kwargs, ) - metadata = self.heads.get(head, with_metadata=True) - backbones = metadata["metadata"]["backbones"] - backbone_config = backbones.get(backbone)(pretrained) - self.model_type, self.model, adapter, self.backbone = metadata["fn"](backbone_config, num_classes, **kwargs) - self.adapter = adapter(model=self.model, metrics=metrics or [COCOMetric(metric_type=COCOMetricType.bbox)]) - - @classmethod - def available_backbones(cls, head: str) -> List[str]: - metadata = cls.heads.get(head, with_metadata=True) - backbones = metadata["metadata"]["backbones"] - return backbones.available_keys() - - # @staticmethod - # def get_model( - # model_name, - # num_classes, - # backbone, - # fpn, - # pretrained, - # pretrained_backbone, - # trainable_backbone_layers, - # anchor_generator, - # **kwargs, - # ): - # if backbone is None: - # # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. - # if model_name == "fasterrcnn": - # model = _models[model_name]( - # pretrained=pretrained, - # pretrained_backbone=pretrained_backbone, - # trainable_backbone_layers=trainable_backbone_layers, - # ) - # in_features = model.roi_heads.box_predictor.cls_score.in_features - # head = FastRCNNPredictor(in_features, num_classes) - # model.roi_heads.box_predictor = head - # else: - # model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone) - # model.head = RetinaNetHead( - # in_channels=model.backbone.out_channels, - # num_anchors=model.head.classification_head.num_anchors, - # num_classes=num_classes, - # **kwargs - # ) - # else: - # backbone_model, num_features = ObjectDetector.backbones.get(backbone)( - # pretrained=pretrained_backbone, - # trainable_layers=trainable_backbone_layers, - # **kwargs, - # ) - # backbone_model.out_channels = num_features - # if anchor_generator is None: - # anchor_generator = AnchorGenerator( - # sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), ) - # ) if not hasattr(backbone_model, "fpn") else None - # - # if model_name == "fasterrcnn": - # model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) - # else: - # model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) - # return model - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: - checkpoint['_data_pipeline_state'] = self._data_pipeline_state - - def process_train_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.model_type.train_dl( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - def process_val_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.model_type.valid_dl( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - def process_test_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.model_type.valid_dl( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - def process_predict_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int = 1, - num_workers: int = 0, - pin_memory: bool = False, - collate_fn: Callable = lambda x: x, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True - ) -> Union[DataLoader, BaseAutoDataset]: - if convert_to_dataloader: - return self.model_type.infer_dl( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - return dataset - - def training_step(self, batch, batch_idx) -> Any: - return self.adapter.training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - return self.adapter.validation_step(batch, batch_idx) - - def test_step(self, batch, batch_idx): - return self.adapter.validation_step(batch, batch_idx) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - if isinstance(batch, list) and isinstance(batch[0], BaseRecord): - data = Dataset(batch) - return self.model_type.predict(self.model, data) - return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) - - def training_epoch_end(self, outputs) -> None: - return self.adapter.training_epoch_end(outputs) - - def validation_epoch_end(self, outputs) -> None: - return self.adapter.validation_epoch_end(outputs) - - def test_epoch_end(self, outputs) -> None: - return self.adapter.validation_epoch_end(outputs) - - # def configure_finetune_callback(self): - # return [ObjectDetectionFineTuning(train_bn=True)] - def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """ This function is used only for debugging usage with CI diff --git a/flash/image/instance_segmentation/__init__.py b/flash/image/instance_segmentation/__init__.py new file mode 100644 index 0000000000..c5659822c8 --- /dev/null +++ b/flash/image/instance_segmentation/__init__.py @@ -0,0 +1,2 @@ +from flash.image.instance_segmentation.data import InstanceSegmentationData # noqa: F401 +from flash.image.instance_segmentation.model import InstanceSegmentation # noqa: F401 diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py new file mode 100644 index 0000000000..ff4f4efabe --- /dev/null +++ b/flash/image/instance_segmentation/backbones.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, +) +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + +INSTANCE_SEGMENTATION_HEADS = FlashRegistry("heads") + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.mask_rcnn + INSTANCE_SEGMENTATION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + ) + + if _module_available("mmdet"): + model_type = icevision_models.mmdet.mask_rcnn + INSTANCE_SEGMENTATION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + f"mmdet_{model_type.__name__.split('.')[-1]}", + backbones=get_backbones(model_type), + ) diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py new file mode 100644 index 0000000000..ca21552d25 --- /dev/null +++ b/flash/image/instance_segmentation/data.py @@ -0,0 +1,235 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOMaskParser, VOCMaskParser + + +class InstanceSegmentationPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": IceVisionParserDataSource(parser=COCOMaskParser), + "voc": IceVisionParserDataSource(parser=VOCMaskParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + }, + default_data_source=DefaultDataSources.FILES, + ) + + self._default_collate = self._identity + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + +class InstanceSegmentationData(DataModule): + + preprocess_cls = InstanceSegmentationPreprocess + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given + data folders and annotation files in the COCO format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = InstanceSegmentationData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_voc( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given + data folders and annotation files in the VOC format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = InstanceSegmentationData.from_voc( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "voc", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py new file mode 100644 index 0000000000..6cd75488d3 --- /dev/null +++ b/flash/image/instance_segmentation/model.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim import Optimizer + +from flash.core.data.process import Serializer +from flash.core.integrations.icevision.model import IceVisionTask, SimpleCOCOMetric +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS + +if _ICEVISION_AVAILABLE: + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric + + +class InstanceSegmentation(IceVisionTask): + """The ``InstanceSegmentation`` is a :class:`~flash.Task` for detecting objects in images. For more details, see + :ref:`object_detection`. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. + loss: the function(s) to update the model with. Has no effect for torchvision detection models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. + Has no effect for custom models. + learning_rate: The learning rate to use for training + + """ + + heads: FlashRegistry = INSTANCE_SEGMENTATION_HEADS + + required_extras: str = "image" + + def __init__( + self, + num_classes: int, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "mask_rcnn", + pretrained: bool = True, + metrics: Optional[IceVisionMetric] = None, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + image_size: Optional[int] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + super().__init__( + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.mask)], + image_size=image_size, + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer, + **kwargs, + ) + + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: + """ + This function is used only for debugging usage with CI + """ + # todo (tchaton) Improve convergence + # history[-1]["val_iou"] diff --git a/flash/image/keypoint_detection/__init__.py b/flash/image/keypoint_detection/__init__.py new file mode 100644 index 0000000000..d397086e24 --- /dev/null +++ b/flash/image/keypoint_detection/__init__.py @@ -0,0 +1,2 @@ +from flash.image.keypoint_detection.data import KeypointDetectionData # noqa: F401 +from flash.image.keypoint_detection.model import KeypointDetector # noqa: F401 diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py new file mode 100644 index 0000000000..f8f65c6e3a --- /dev/null +++ b/flash/image/keypoint_detection/backbones.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, +) +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + +KEYPOINT_DETECTION_HEADS = FlashRegistry("heads") + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.keypoint_rcnn + KEYPOINT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + ) diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py new file mode 100644 index 0000000000..60b6b33fdb --- /dev/null +++ b/flash/image/keypoint_detection/data.py @@ -0,0 +1,155 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOKeyPointsParser + + +class KeypointDetectionPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": IceVisionParserDataSource(parser=COCOKeyPointsParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + }, + default_data_source=DefaultDataSources.FILES, + ) + + self._default_collate = self._identity + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + +class KeypointDetectionData(DataModule): + + preprocess_cls = KeypointDetectionPreprocess + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given + data folders and annotation files in the COCO format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = KeypointDetectionData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py new file mode 100644 index 0000000000..0a029803a6 --- /dev/null +++ b/flash/image/keypoint_detection/model.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim import Optimizer + +from flash.core.data.process import Serializer +from flash.core.integrations.icevision.model import IceVisionTask +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS + +if _ICEVISION_AVAILABLE: + from icevision.metrics import Metric as IceVisionMetric + + +class KeypointDetector(IceVisionTask): + """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see + :ref:`object_detection`. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. + loss: the function(s) to update the model with. Has no effect for torchvision detection models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. + Has no effect for custom models. + learning_rate: The learning rate to use for training + + """ + + heads: FlashRegistry = KEYPOINT_DETECTION_HEADS + + required_extras: str = "image" + + def __init__( + self, + num_keypoints: int, + num_classes: int = 2, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "keypoint_rcnn", + pretrained: bool = True, + metrics: Optional[IceVisionMetric] = None, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + image_size: Optional[int] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + super().__init__( + num_classes=num_classes, + num_keypoints=num_keypoints, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics, + image_size=image_size, + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer, + **kwargs, + ) + + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: + """ + This function is used only for debugging usage with CI + """ + # todo (tchaton) Improve convergence + # history[-1]["val_iou"] diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py index 2737e7126a..8a085b6225 100644 --- a/flash_examples/graph_classification.py +++ b/flash_examples/graph_classification.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import flash -from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE +from flash.core.utilities.imports import example_requires from flash.graph.classification.data import GraphClassificationData from flash.graph.classification.model import GraphClassifier -if _TORCH_GEOMETRIC_AVAILABLE: - from torch_geometric.datasets import TUDataset -else: - raise ModuleNotFoundError("Please, pip install -e '.[graph]'") +example_requires("graph") + +from torch_geometric.datasets import TUDataset # noqa: E402 # 1. Create the DataModule dataset = TUDataset(root="data", name="KKI") diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py new file mode 100644 index 0000000000..5e451dcb9b --- /dev/null +++ b/flash_examples/instance_segmentation.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import flash +from flash.core.utilities.imports import example_requires +from flash.image import InstanceSegmentation, InstanceSegmentationData + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.pets.load_data() + +datamodule = InstanceSegmentationData.from_folders( + train_folder=data_dir, + val_split=0.1, + image_size=128, + parser=partial(icedata.pets.parser, mask=True), +) + +# 2. Build the task +model = InstanceSegmentation( + head="mask_rcnn", + backbone="resnet18_fpn", + num_classes=datamodule.num_classes, +) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect objects in a few images! +predictions = model.predict([ + str(data_dir / "images/yorkshire_terrier_9.jpg"), + str(data_dir / "images/english_cocker_spaniel_1.jpg"), + str(data_dir / "images/scottish_terrier_1.jpg"), +]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("instance_segmentation_model.pt") diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py new file mode 100644 index 0000000000..ea53dfde78 --- /dev/null +++ b/flash_examples/keypoint_detection.py @@ -0,0 +1,53 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.utilities.imports import example_requires +from flash.image import KeypointDetectionData, KeypointDetector + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.biwi.load_data() + +datamodule = KeypointDetectionData.from_folders( + train_folder=data_dir, + val_split=0.1, + image_size=128, + parser=icedata.biwi.parser, +) + +# 2. Build the task +model = KeypointDetector( + head="keypoint_rcnn", + backbone="resnet18_fpn", + num_keypoints=1, + num_classes=datamodule.num_classes, +) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect objects in a few images! +predictions = model.predict([ + str(data_dir / "biwi_sample/images/0.jpg"), + str(data_dir / "biwi_sample/images/1.jpg"), + str(data_dir / "biwi_sample/images/10.jpg"), +]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("object_detection_model.pt") diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index c8d58ef133..8a0b438c85 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -23,22 +23,17 @@ train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, - predict_folder="data/coco128/images/train2017/", - image_size=(256, 256), + image_size=128, ) # 2. Build the task -print(ObjectDetector.available_heads()) -print(ObjectDetector.available_backbones("efficientdet")) - -model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=(256, 256)) +model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) +trainer = flash.Trainer(max_epochs=1) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -print(trainer.predict(model, datamodule=datamodule)) predictions = model.predict([ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index f289b3c3a7..51010fe099 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -6,3 +6,5 @@ kornia>=0.5.1,<0.5.4 pystiche>=0.7.2 segmentation-models-pytorch icevision>=0.8 +icedata +effdet From 21a236d660a91c47652c342d74a6ad5b2ca83007 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 20:45:02 +0100 Subject: [PATCH 03/46] Updates --- flash/image/detection/data.py | 147 +++++++++++++++-------------- tests/image/detection/test_data.py | 81 +++++++++------- 2 files changed, 126 insertions(+), 102 deletions(-) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 71a0df7613..5a535024a8 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, TYPE_CHECKING from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule @@ -23,7 +23,7 @@ IceVisionPathsDataSource, ) from flash.core.integrations.icevision.transforms import default_transforms -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires SampleCollection = None if _FIFTYONE_AVAILABLE: @@ -33,84 +33,62 @@ else: foc, fol = None, None -if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader - if _ICEVISION_AVAILABLE: - from icevision.parsers import COCOBBoxParser, VIABBoxParser, VOCBBoxParser - - -class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): - - def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): - super().__init__(label_field=label_field) - self.iscrowd = iscrowd + from icevision.core import BBox, ClassMap, IsCrowdsRecordComponent, ObjectDetectionRecord + from icevision.data import SingleSplitSplitter + from icevision.parsers import COCOBBoxParser, Parser, VIABBoxParser, VOCBBoxParser + from icevision.utils import ImgSize +else: + Parser = object - @property - def label_cls(self): - return fol.Detections - def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - self._validate(data) +class FiftyOneParser(Parser): - data.compute_metadata() + def __init__(self, data, class_map, label_field, iscrowd): + template_record = ObjectDetectionRecord() + template_record.add_component(IsCrowdsRecordComponent()) + super().__init__(template_record=template_record) - filepaths = data.values("filepath") - widths = data.values("metadata.width") - heights = data.values("metadata.height") - labels = data.values(self.label_field + ".detections.label") - bboxes = data.values(self.label_field + ".detections.bounding_box") - iscrowds = data.values(self.label_field + ".detections." + self.iscrowd) + data = data + label_field = label_field + iscrowd = iscrowd - classes = self._get_classes(data) - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - if dataset is not None: - dataset.num_classes = len(classes) + self.data = [] + self.class_map = class_map - output_data = [] - img_id = 1 for fp, w, h, sample_labs, sample_boxes, sample_iscrowd in zip( - filepaths, widths, heights, labels, bboxes, iscrowds + data.values("filepath"), data.values("metadata.width"), data.values("metadata.height"), + data.values(label_field + ".detections.label"), data.values(label_field + ".detections.bounding_box"), + data.values(label_field + ".detections." + iscrowd) ): - output_boxes = [] - output_labs = [] - output_iscrowd = [] - output_areas = [] for lab, box, iscrowd in zip(sample_labs, sample_boxes, sample_iscrowd): - output_box, output_area = self._reformat_bbox(box[0], box[1], box[2], box[3], w, h) - output_areas.append(output_area) - output_labs.append(class_to_idx[lab]) - output_boxes.append(output_box) - if iscrowd is None: - iscrowd = 0 - output_iscrowd.append(iscrowd) - output_data.append( - dict( - input=fp, - target=dict( - boxes=output_boxes, - labels=output_labs, - image_id=img_id, - area=output_areas, - iscrowd=output_iscrowd, - ) - ) - ) - img_id += 1 + self.data.append((fp, w, h, lab, box, iscrowd)) - return output_data + def __iter__(self) -> Any: + return iter(self.data) - @staticmethod - def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - return sample + def __len__(self) -> int: + return len(self.data) + + def record_id(self, o) -> Hashable: + return o[0] + + def parse_fields(self, o, record, is_new): + fp, w, h, lab, box, iscrowd = o + + if iscrowd is None: + iscrowd = 0 + + if is_new: + record.set_filepath(fp) + record.set_img_size(ImgSize(width=w, height=h)) + record.detection.set_class_map(self.class_map) + + box = self._reformat_bbox(*box, w, h) + + record.detection.add_bboxes([BBox.from_xyxy(*box)]) + record.detection.add_labels([lab]) + record.detection.add_iscrowds([iscrowd]) @staticmethod def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): @@ -121,7 +99,37 @@ def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): xmax = xmin + box_w ymax = ymin + box_h output_bbox = [xmin, ymin, xmax, ymax] - return output_bbox, box_w * box_h + return output_bbox + + +class ObjectDetectionFiftyOneDataSource(IceVisionPathsDataSource, FiftyOneDataSource): + + def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): + super().__init__() + self.label_field = label_field + self.iscrowd = iscrowd + + @property + @requires("fiftyone") + def label_cls(self): + return fol.Detections + + @requires("fiftyone") + def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + self._validate(data) + + data.compute_metadata() + classes = self._get_classes(data) + class_map = ClassMap(classes) + + parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + + @staticmethod + @requires("fiftyone") + def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] class ObjectDetectionPreprocess(Preprocess): @@ -149,6 +157,7 @@ def __init__( "voc": IceVisionParserDataSource(parser=VOCBBoxParser), DefaultDataSources.FILES: IceVisionPathsDataSource(), DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), }, default_data_source=DefaultDataSources.FILES, ) diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index d0ef137a24..5fdd36204c 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -4,7 +4,6 @@ import pytest -from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image.detection.data import ObjectDetectionData @@ -126,15 +125,19 @@ def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) - datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) + datamodule = ObjectDetectionData.from_coco( + train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1, image_size=128 + ) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + record = data[0] + + assert record.detection.img.shape == (128, 128, 3) + assert record.detection.iscrowds[0] in (0, 1) + + assert record.img_size.height == 128 + assert record.img_size.width == 128 assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -148,23 +151,28 @@ def test_image_detector_data_from_coco(tmpdir): test_ann_file=coco_ann_path, batch_size=1, num_workers=0, + image_size=128, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + record = data[0] + + assert record.detection.img.shape == (128, 128, 3) + assert record.detection.iscrowds[0] in (0, 1) + + assert record.img_size.height == 128 + assert record.img_size.width == 128 data = next(iter(datamodule.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + record = data[0] + + assert record.detection.img.shape == (128, 128, 3) + assert record.detection.iscrowds[0] in (0, 1) + + assert record.img_size.height == 128 + assert record.img_size.width == 128 @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -173,15 +181,17 @@ def test_image_detector_data_from_fiftyone(tmpdir): train_dataset = _create_synth_fiftyone_dataset(tmpdir) - datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) + datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1, image_size=128) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + record = data[0] + + assert record.detection.img.shape == (128, 128, 3) + assert record.detection.iscrowds[0] in (0, 1) + + assert record.img_size.height == 128 + assert record.img_size.width == 128 assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -192,20 +202,25 @@ def test_image_detector_data_from_fiftyone(tmpdir): test_dataset=train_dataset, batch_size=1, num_workers=0, + image_size=128, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + record = data[0] + + assert record.detection.img.shape == (128, 128, 3) + assert record.detection.iscrowds[0] in (0, 1) + + assert record.img_size.height == 128 + assert record.img_size.width == 128 data = next(iter(datamodule.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + record = data[0] + + assert record.detection.img.shape == (128, 128, 3) + assert record.detection.iscrowds[0] in (0, 1) + + assert record.img_size.height == 128 + assert record.img_size.width == 128 From b9dfc48709ec40df8bc75ae75d9f9b2a44cc362f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 20:51:12 +0100 Subject: [PATCH 04/46] Updates --- requirements/datatype_image_extras.txt | 1 - .../detection/test_data_model_integration.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 7e7370035f..f61e3f9c25 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -1,3 +1,2 @@ matplotlib -pycocotools>=2.0.2 ; python_version >= "3.7" fiftyone diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index cba7034319..ca0e1eac78 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -19,6 +19,7 @@ from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData +from tests.helpers.utils import _IMAGE_TESTING if _PIL_AVAILABLE: from PIL import Image @@ -32,19 +33,18 @@ from tests.image.detection.test_data import _create_synth_fiftyone_dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") -@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) -def test_detection(tmpdir, model, backbone): +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +def test_detection(tmpdir, head, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) - model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) trainer = flash.Trainer(fast_dev_run=True) - trainer.finetune(model, data) + trainer.finetune(model, data, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") @@ -58,17 +58,17 @@ def test_detection(tmpdir, model, backbone): @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") -@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) -def test_detection_fiftyone(tmpdir, model, backbone): +@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +def test_detection_fiftyone(tmpdir, head, backbone): train_dataset = _create_synth_fiftyone_dataset(tmpdir) data = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) - model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) trainer = flash.Trainer(fast_dev_run=True) - trainer.finetune(model, data) + trainer.finetune(model, data, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") From 89385bd5f1c271d5918b1ad7ef822170d1521f15 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 20:52:21 +0100 Subject: [PATCH 05/46] Updates --- tests/image/detection/test_data.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 5fdd36204c..a5a7142c0d 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os from pathlib import Path From addfe96e6310fce35858ebbc7162cf296f799aac Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 21:24:53 +0100 Subject: [PATCH 06/46] Add docs --- docs/source/index.rst | 2 ++ .../reference/instance_segmentation.rst | 31 +++++++++++++++++++ docs/source/reference/keypoint_detection.rst | 31 +++++++++++++++++++ docs/source/reference/object_detection.rst | 2 ++ 4 files changed, 66 insertions(+) create mode 100644 docs/source/reference/instance_segmentation.rst create mode 100644 docs/source/reference/keypoint_detection.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 8f56b56214..fc16384087 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -36,6 +36,8 @@ Lightning Flash reference/image_classification_multi_label reference/image_embedder reference/object_detection + reference/keypoint_detection + reference/instance_segmentation reference/semantic_segmentation reference/style_transfer reference/video_classification diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst new file mode 100644 index 0000000000..75408dc3fa --- /dev/null +++ b/docs/source/reference/instance_segmentation.rst @@ -0,0 +1,31 @@ + +.. _instance_segmentation: + +##################### +Instance Segmentation +##################### + +******** +The Task +******** + +Instance segmentation is the task of segmenting objects images and determining their associated classes. + +The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision `_. + +------ + +******* +Example +******* + +Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset `_ from `IceData `_. +Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`. +We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data. +We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/instance_segmentation.py + :language: python + :lines: 14- diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst new file mode 100644 index 0000000000..76fd0dcdf5 --- /dev/null +++ b/docs/source/reference/keypoint_detection.rst @@ -0,0 +1,31 @@ + +.. _keypoint_detection: + +################## +Keypoint Detection +################## + +******** +The Task +******** + +Keypoint detection is the task of identifying keypoints in images and their associated classes. + +The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision `_. + +------ + +******* +Example +******* + +Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) `_ from `IceData `_. +Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`. +We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data. +We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/keypoint_detection.py + :language: python + :lines: 14- diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index bf82bec153..6210970e3e 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -11,6 +11,8 @@ The Task Object detection is the task of identifying objects in images and their associated classes and bounding boxes. +The :class:`~flash.image.detection.model.ObjectDetector` and :class:`~flash.image.detection.data.ObjectDetectionData` classes internally rely on `IceVision `_. + ------ ******* From 22b41526097bcd6aef536cedeedc70092efa3624 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 21:39:31 +0100 Subject: [PATCH 07/46] Update API reference --- docs/source/api/image.rst | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index 067b4ef404..d1f64e605c 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -31,8 +31,8 @@ ______________ classification.transforms.default_transforms classification.transforms.train_default_transforms -Detection -_________ +Object Detection +________________ .. autosummary:: :toctree: generated/ @@ -42,21 +42,37 @@ _________ ~detection.model.ObjectDetector ~detection.data.ObjectDetectionData - detection.data.COCODataSource + detection.data.FiftyOneParser detection.data.ObjectDetectionFiftyOneDataSource detection.data.ObjectDetectionPreprocess - detection.finetuning.ObjectDetectionFineTuning - detection.model.ObjectDetector detection.serialization.DetectionLabels detection.serialization.FiftyOneDetectionLabels +Keypoint Detection +__________________ + .. autosummary:: :toctree: generated/ :nosignatures: - :template: + :template: classtemplate.rst + + ~keypoint_detection.model.KeypointDetector + ~keypoint_detection.data.KeypointDetectionData + + keypoint_detection.data.KeypointDetectionPreprocess + +Instance Segmentation +_____________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~instance_segmentation.model.InstanceSegmentation + ~instance_segmentation.data.InstanceSegmentationData - detection.transforms.collate - detection.transforms.default_transforms + instance_segmentation.data.InstanceSegmentationPreprocess Embedding _________ From 14dd36f2952cdf7fba241bcae64a346d06fcfaaa Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 20 Jul 2021 22:21:25 +0100 Subject: [PATCH 08/46] Fix some tests --- flash/core/integrations/icevision/model.py | 9 ++-- tests/image/detection/test_model.py | 48 ++++++++++++++-------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/flash/core/integrations/icevision/model.py b/flash/core/integrations/icevision/model.py index ad8a80ab0e..c423e04547 100644 --- a/flash/core/integrations/icevision/model.py +++ b/flash/core/integrations/icevision/model.py @@ -116,7 +116,7 @@ def process_train_dataset( batch_size: int, num_workers: int, pin_memory: bool, - collate_fn: Callable, + collate_fn: Optional[Callable] = None, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None @@ -137,7 +137,7 @@ def process_val_dataset( batch_size: int, num_workers: int, pin_memory: bool, - collate_fn: Callable, + collate_fn: Optional[Callable] = None, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None @@ -158,7 +158,7 @@ def process_test_dataset( batch_size: int, num_workers: int, pin_memory: bool, - collate_fn: Callable, + collate_fn: Optional[Callable] = None, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None @@ -207,6 +207,9 @@ def test_step(self, batch, batch_idx): return self.adapter.validation_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) + + def forward(self, batch: Any) -> Any: if isinstance(batch, list) and isinstance(batch[0], BaseRecord): data = Dataset(batch) return self.model_type.predict(self.model, data) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index a610122783..d61f82e4b9 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -12,18 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import random import re +import numpy as np import pytest import torch from pytorch_lightning import Trainer -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset -from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector from tests.helpers.utils import _IMAGE_TESTING +if _ICEVISION_AVAILABLE: + from icevision.core import BBox, ClassMap, ObjectDetectionRecord + from icevision.data import Prediction + from icevision.utils import ImgSize + def collate_fn(samples): return {key: [sample[key] for sample in samples] for key in samples[0]} @@ -48,10 +54,19 @@ def _random_bbox(self): return [min(xs), min(ys), max(xs) + 1, max(ys) + 1] def __getitem__(self, idx): - img = torch.rand(self.img_shape) - boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) - labels = torch.randint(self.num_classes, (self.num_boxes, )) - return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} + record = ObjectDetectionRecord() + + img = np.random.rand(*self.img_shape).astype(np.float32) + + record.set_img(img) + record.set_img_size(ImgSize(width=self.img_shape[0], height=self.img_shape[1])) + record.detection.set_class_map(ClassMap([f"test_{i}" for i in range(self.num_classes)], background=None)) + + for i in range(self.num_boxes): + record.detection.add_bboxes([BBox.from_xyxy(*self._random_bbox())]) + record.detection.add_labels([f"test_{random.randint(0, self.num_classes - 1)}"]) + + return record @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -60,23 +75,22 @@ def test_init(): model.eval() batch_size = 2 - ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) - dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + dl = model.process_predict_dataset(ds, batch_size=batch_size) data = next(iter(dl)) - img = data[DefaultDataKeys.INPUT] - out = model(img) + out = model(data) assert len(out) == batch_size - assert {"boxes", "labels", "scores"} <= out[0].keys() + assert all(isinstance(res, Prediction) for res in out) -@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"]) +@pytest.mark.parametrize("head", ["faster_rcnn", "retinanet"]) @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_training(tmpdir, model): - model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False) - ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) - dl = DataLoader(ds, collate_fn=collate_fn) +def test_training(tmpdir, head): + model = ObjectDetector(num_classes=2, head=head, pretrained=False) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + dl = model.process_train_dataset(ds, 2, 0, False) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl) From 1b0642ef27d7ee380c54030a4ee379e00f452454 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 21 Jul 2021 09:28:18 +0100 Subject: [PATCH 09/46] Small fix --- flash/core/integrations/icevision/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 6375ddc360..f3d726e2fd 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -45,7 +45,7 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class IceVisionParserDataSource(IceVisionPathsDataSource): - def __init__(self, parser: Optional[Type[Parser]] = None): + def __init__(self, parser: Optional[Type['Parser']] = None): super().__init__() self.parser = parser From 4a6c3995c099dbe000c4dc312917520c40b71a88 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 21 Jul 2021 09:33:47 +0100 Subject: [PATCH 10/46] Drop failing JIT test --- tests/image/detection/test_model.py | 37 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index d61f82e4b9..8cafd4f390 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import random import re @@ -95,24 +94,24 @@ def test_training(tmpdir, head): trainer.fit(model, dl) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_jit(tmpdir): - path = os.path.join(tmpdir, "test.pt") - - model = ObjectDetector(2) - model.eval() - - model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN - - torch.jit.save(model, path) - model = torch.jit.load(path) - - out = model([torch.rand(3, 32, 32)]) - - # torchvision RCNN always returns a (Losses, Detections) tuple in scripting - out = out[1] - - assert {"boxes", "labels", "scores"} <= out[0].keys() +# @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +# def test_jit(tmpdir): +# path = os.path.join(tmpdir, "test.pt") +# +# model = ObjectDetector(2) +# model.eval() +# +# model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN +# +# torch.jit.save(model, path) +# model = torch.jit.load(path) +# +# out = model([torch.rand(3, 32, 32)]) +# +# # torchvision RCNN always returns a (Losses, Detections) tuple in scripting +# out = out[1] +# +# assert {"boxes", "labels", "scores"} <= out[0].keys() @pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.") From 9e3003450be4744387b26dde9caba9e5261c7172 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 21 Jul 2021 09:34:48 +0100 Subject: [PATCH 11/46] Updates --- flash/core/integrations/icevision/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash/core/integrations/icevision/model.py b/flash/core/integrations/icevision/model.py index c423e04547..fa413e6740 100644 --- a/flash/core/integrations/icevision/model.py +++ b/flash/core/integrations/icevision/model.py @@ -27,6 +27,8 @@ from icevision.data import Dataset from icevision.metrics import COCOMetric from icevision.metrics import Metric as IceVisionMetric +else: + COCOMetric = object class SimpleCOCOMetric(COCOMetric): From 00f391e4344630d337c35f31d288e37b70f8272d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 21 Jul 2021 09:41:35 +0100 Subject: [PATCH 12/46] Updates --- flash/core/integrations/icevision/model.py | 2 +- flash/image/detection/model.py | 2 +- flash/image/instance_segmentation/model.py | 2 +- flash/image/keypoint_detection/model.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash/core/integrations/icevision/model.py b/flash/core/integrations/icevision/model.py index fa413e6740..34175dbb68 100644 --- a/flash/core/integrations/icevision/model.py +++ b/flash/core/integrations/icevision/model.py @@ -66,7 +66,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Optional[IceVisionMetric] = None, + metrics: Optional['IceVisionMetric'] = None, learning_rate: float = 5e-4, deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, preprocess: Optional[Preprocess] = None, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 979a2d71c7..816690f6e9 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -62,7 +62,7 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, - metrics: Optional[IceVisionMetric] = None, + metrics: Optional['IceVisionMetric'] = None, optimizer: Type[Optimizer] = torch.optim.AdamW, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 6cd75488d3..3b339a54ed 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -61,7 +61,7 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "mask_rcnn", pretrained: bool = True, - metrics: Optional[IceVisionMetric] = None, + metrics: Optional['IceVisionMetric'] = None, optimizer: Type[Optimizer] = torch.optim.Adam, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 0a029803a6..a5f9802735 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -61,7 +61,7 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "keypoint_rcnn", pretrained: bool = True, - metrics: Optional[IceVisionMetric] = None, + metrics: Optional['IceVisionMetric'] = None, optimizer: Type[Optimizer] = torch.optim.Adam, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, From e6ee9943f580901217bd16b229a7d863c8732301 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 21 Jul 2021 09:50:37 +0100 Subject: [PATCH 13/46] Fix a test --- flash/core/registry.py | 2 +- tests/core/test_registry.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/core/registry.py b/flash/core/registry.py index f07eaf6b39..20f7984305 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -75,7 +75,7 @@ def _register_function( metadata: Optional[Dict[str, Any]] = None ): if not callable(fn): - raise MisconfigurationException(f"You can only register a function, found: {fn}") + raise MisconfigurationException(f"You can only register a callable, found: {fn}") name = name or fn.__name__ diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py index 3af891aa3a..8431e3f766 100644 --- a/tests/core/test_registry.py +++ b/tests/core/test_registry.py @@ -27,8 +27,8 @@ def test_registry_raises(): def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output - with pytest.raises(MisconfigurationException, match="You can only register a function, found: Linear"): - backbones(nn.Linear(1, 1), name="foo") + with pytest.raises(MisconfigurationException, match="You can only register a callable, found: 3"): + backbones(3, name="foo") backbones(my_model, name="foo", override=True) From d5486072c2711b91ce55e10f8f933f6363b2b737 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 26 Jul 2021 18:24:08 +0100 Subject: [PATCH 14/46] Initial credits support --- flash/core/registry.py | 21 +++++++++++++++++++++ flash/image/detection/backbones.py | 7 ++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/flash/core/registry.py b/flash/core/registry.py index 20f7984305..155bc131a6 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union from pytorch_lightning.utilities import rank_zero_info @@ -19,6 +20,16 @@ _REGISTERED_FUNCTION = Dict[str, Any] +@dataclass +class Credit: + + name: str + url: str + + def __str__(self): + return f"{self.name} ({self.url})" + + class FlashRegistry: """This class is used to register function or :class:`functools.partial` class to a registry.""" @@ -61,6 +72,16 @@ def get( if not matches: raise KeyError("Found no matches that fit your metadata criteria. Try removing some metadata") + for match in matches: + if "credits" in match["metadata"]: + credits = match["metadata"]["credits"] + if not isinstance(credits, List): + credits = [credits] + if len(credits) > 1: + credits[-2] = f"{str(credits[-2])} and {str(credits[-1])}" + credits = credits[:-1] + rank_zero_info(f"Using '{key}' provided by {', '.join(str(credit) for credit in credits)}.") + matches = [e if with_metadata else e["fn"] for e in matches] return matches[0] if strict else matches diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py index 3d8e64dad1..c6af4caf25 100644 --- a/flash/image/detection/backbones.py +++ b/flash/image/detection/backbones.py @@ -21,7 +21,7 @@ load_icevision_ignore_image_size, load_icevision_with_image_size, ) -from flash.core.registry import FlashRegistry +from flash.core.registry import Credit, FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE if _ICEVISION_AVAILABLE: @@ -36,6 +36,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + credits=Credit("IceVision", "www.icevision.com"), ) if _module_available("yolov5"): @@ -44,6 +45,7 @@ partial(load_icevision_with_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + credits=Credit("IceVision", "www.icevision.com"), ) if _module_available("mmdet"): @@ -57,6 +59,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), + credits=Credit("IceVision", "www.icevision.com"), ) if _module_available("effdet"): @@ -79,4 +82,6 @@ def validation_step(self, batch, batch_idx): partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + credits=[Credit("IceVision", "www.icevision.com"), + Credit("effdet", "github")], ) From 7d9838b7f6bc51bdc73b1d32c7f69350d52cd015 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 27 Jul 2021 18:15:33 +0100 Subject: [PATCH 15/46] Credit -> provider --- flash/core/registry.py | 22 +++++++++++-------- flash/core/utilities/providers.py | 20 +++++++++++++++++ flash/image/detection/backbones.py | 12 +++++----- .../image/instance_segmentation/backbones.py | 3 +++ flash/image/keypoint_detection/backbones.py | 2 ++ 5 files changed, 44 insertions(+), 15 deletions(-) create mode 100644 flash/core/utilities/providers.py diff --git a/flash/core/registry.py b/flash/core/registry.py index 155bc131a6..7d62d10f38 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -21,7 +21,7 @@ @dataclass -class Credit: +class Provider: name: str url: str @@ -73,14 +73,14 @@ def get( raise KeyError("Found no matches that fit your metadata criteria. Try removing some metadata") for match in matches: - if "credits" in match["metadata"]: - credits = match["metadata"]["credits"] - if not isinstance(credits, List): - credits = [credits] - if len(credits) > 1: - credits[-2] = f"{str(credits[-2])} and {str(credits[-1])}" - credits = credits[:-1] - rank_zero_info(f"Using '{key}' provided by {', '.join(str(credit) for credit in credits)}.") + if "providers" in match["metadata"]: + providers = match["metadata"]["providers"] + if not isinstance(providers, List): + providers = [providers] + if len(providers) > 1: + providers[-2] = f"{str(providers[-2])} and {str(providers[-1])}" + providers = providers[:-1] + rank_zero_info(f"Using '{key}' provided by {', '.join(str(provider) for provider in providers)}.") matches = [e if with_metadata else e["fn"] for e in matches] return matches[0] if strict else matches @@ -126,6 +126,7 @@ def __call__( fn: Optional[Callable[..., Any]] = None, name: Optional[str] = None, override: bool = False, + providers: Optional[Union[Provider, List[Provider]]] = None, **metadata ) -> Callable: """ @@ -134,6 +135,9 @@ def __call__( Functions can be filtered using metadata using the ``get`` function. """ + if providers is not None: + metadata["providers"] = providers + if fn is not None: self._register_function(fn=fn, name=name, override=override, metadata=metadata) return fn diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py new file mode 100644 index 0000000000..ff464e690c --- /dev/null +++ b/flash/core/utilities/providers.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import Provider + +_ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision") +_TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision") +_ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") +_MMDET = Provider("OpenMMLab/MMDetection", "https://github.com/open-mmlab/mmdetection") +_EFFDET = Provider("rwightman/efficientdet-pytorch", "https://github.com/rwightman/efficientdet-pytorch") diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py index c6af4caf25..929c1aa6d9 100644 --- a/flash/image/detection/backbones.py +++ b/flash/image/detection/backbones.py @@ -21,8 +21,9 @@ load_icevision_ignore_image_size, load_icevision_with_image_size, ) -from flash.core.registry import Credit, FlashRegistry +from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS if _ICEVISION_AVAILABLE: from icevision import models as icevision_models @@ -36,7 +37,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), - credits=Credit("IceVision", "www.icevision.com"), + providers=[_ICEVISION, _TORCHVISION], ) if _module_available("yolov5"): @@ -45,7 +46,7 @@ partial(load_icevision_with_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), - credits=Credit("IceVision", "www.icevision.com"), + providers=[_ICEVISION, _ULTRALYTICS], ) if _module_available("mmdet"): @@ -59,7 +60,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), - credits=Credit("IceVision", "www.icevision.com"), + providers=[_ICEVISION, _MMDET], ) if _module_available("effdet"): @@ -82,6 +83,5 @@ def validation_step(self, batch, batch_idx): partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), - credits=[Credit("IceVision", "www.icevision.com"), - Credit("effdet", "github")], + providers=[_ICEVISION, _EFFDET], ) diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py index ff4f4efabe..5a95242b82 100644 --- a/flash/image/instance_segmentation/backbones.py +++ b/flash/image/instance_segmentation/backbones.py @@ -20,6 +20,7 @@ ) from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _ICEVISION, _MMDET, _TORCHVISION if _ICEVISION_AVAILABLE: from icevision import models as icevision_models @@ -33,6 +34,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + providers=[_ICEVISION, _TORCHVISION] ) if _module_available("mmdet"): @@ -41,4 +43,5 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), + providers=[_ICEVISION, _MMDET] ) diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py index f8f65c6e3a..1fe4174282 100644 --- a/flash/image/keypoint_detection/backbones.py +++ b/flash/image/keypoint_detection/backbones.py @@ -20,6 +20,7 @@ ) from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _ICEVISION, _TORCHVISION if _ICEVISION_AVAILABLE: from icevision import models as icevision_models @@ -33,4 +34,5 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + providers=[_ICEVISION, _TORCHVISION] ) From 2e8a777cfeb064267b9f032a1753715cb29938da Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 27 Jul 2021 18:45:16 +0100 Subject: [PATCH 16/46] Update available backbones --- flash/core/integrations/icevision/model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/flash/core/integrations/icevision/model.py b/flash/core/integrations/icevision/model.py index 34175dbb68..b5624960ee 100644 --- a/flash/core/integrations/icevision/model.py +++ b/flash/core/integrations/icevision/model.py @@ -103,10 +103,21 @@ def __init__( self.adapter = adapter(model=self.model, metrics=metrics) @classmethod - def available_backbones(cls, head: str) -> List[str]: - metadata = cls.heads.get(head, with_metadata=True) - backbones = metadata["metadata"]["backbones"] - return backbones.available_keys() + def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List[str]], List[str]]: + if head is None: + heads = cls.available_heads() + else: + heads = [head] + + result = {} + for head in heads: + metadata = cls.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + result[head] = backbones.available_keys() + + if len(result) == 1: + result = next(iter(result.values()[0])) + return result def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: From a102d3118387d6d3df18a55e646b716fa71f1a5b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 29 Jul 2021 12:32:40 +0100 Subject: [PATCH 17/46] Add adapter --- flash/core/adapter.py | 184 ++++++++++++++++++ .../icevision/{model.py => adapter.py} | 101 +++------- .../core/integrations/icevision/backbones.py | 2 - flash/core/model.py | 154 +++++++-------- flash/core/registry.py | 32 ++- flash/image/detection/backbones.py | 37 ++++ flash/image/detection/model.py | 31 ++- .../image/instance_segmentation/backbones.py | 35 ++++ flash/image/instance_segmentation/model.py | 26 ++- flash/image/keypoint_detection/backbones.py | 35 ++++ flash/image/keypoint_detection/model.py | 27 ++- tests/core/test_model.py | 29 ++- 12 files changed, 478 insertions(+), 215 deletions(-) create mode 100644 flash/core/adapter.py rename flash/core/integrations/icevision/{model.py => adapter.py} (60%) diff --git a/flash/core/adapter.py b/flash/core/adapter.py new file mode 100644 index 0000000000..642c449802 --- /dev/null +++ b/flash/core/adapter.py @@ -0,0 +1,184 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +from pytorch_lightning import LightningModule, Trainer +from torch import nn +from torch.utils.data import DataLoader, Sampler + +import flash +from flash.core.data.auto_dataset import BaseAutoDataset + + +class Wrapper: + + def __init__(self): + super().__init__() + + self._children = [] + + def __setattr__(self, key, value): + if isinstance(value, (LightningModule, Adapter)): + self._children.append(key) + patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] + if isinstance(value, Trainer) or key in patched_attributes: + if hasattr(self, "_children"): + for child in self._children: + setattr(getattr(self, child), key, value) + super().__setattr__(key, value) + + +class Adapter(Wrapper, nn.Module): + + @classmethod + def from_task(cls, task: 'flash.Task', **kwargs) -> 'Adapter': + pass + + def forward(self, x: Any) -> Any: + pass + + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def validation_step(self, batch: Any, batch_idx: int) -> None: + pass + + def test_step(self, batch: Any, batch_idx: int) -> None: + pass + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + pass + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True, + ) -> DataLoader: + if convert_to_dataloader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn + ) + return dataset + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True + ) -> Union[DataLoader, BaseAutoDataset]: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + convert_to_dataloader=convert_to_dataloader + ) diff --git a/flash/core/integrations/icevision/model.py b/flash/core/integrations/icevision/adapter.py similarity index 60% rename from flash/core/integrations/icevision/model.py rename to flash/core/integrations/icevision/adapter.py index b5624960ee..d9e9b47279 100644 --- a/flash/core/integrations/icevision/model.py +++ b/flash/core/integrations/icevision/adapter.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Union -import torch -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Sampler +from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.process import Deserializer, Postprocess, Preprocess, Serializer from flash.core.model import Task from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -41,87 +39,42 @@ def finalize(self) -> Dict[str, float]: } -class IceVisionTask(Task): - """The ``IceVisionTask`` is a base :class:`~flash.Task` for integrating with IceVision. - - Args: - num_classes: the number of classes for detection, including background - model: a string of :attr`_models`. Defaults to 'fasterrcnn'. - backbone: Pretrained backbone CNN architecture. Constructs a model with a - ResNet-50-FPN backbone when no backbone is specified. - pretrained: if true, returns a model pre-trained on COCO train2017. - metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. - image_size - """ +class IceVisionAdapter(Adapter): + """The ``IceVisionAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with IceVision.""" required_extras: str = "image" - def __init__( - self, + def __init__(self, model_type, model, icevision_adapter, backbone): + super().__init__() + + self.model_type = model_type + self.model = model + self.icevision_adapter = icevision_adapter + self.backbone = backbone + + @classmethod + def from_task( + cls, + task: Task, num_classes: int, backbone: str, head: str, pretrained: bool = True, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, metrics: Optional['IceVisionMetric'] = None, - learning_rate: float = 5e-4, - deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, image_size: Optional = None, **kwargs, - ): - self.save_hyperparameters() - - super().__init__( - model=None, - metrics=None, - optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, - learning_rate=learning_rate, - deserializer=deserializer, - preprocess=preprocess, - postprocess=postprocess, - serializer=serializer, - ) - - metadata = self.heads.get(head, with_metadata=True) + ) -> Adapter: + metadata = task.heads.get(head, with_metadata=True) backbones = metadata["metadata"]["backbones"] backbone_config = backbones.get(backbone)(pretrained) - self.model_type, self.model, adapter, self.backbone = metadata["fn"]( + model_type, model, icevision_adapter, backbone = metadata["fn"]( backbone_config, num_classes, image_size=image_size, **kwargs, ) - self.adapter = adapter(model=self.model, metrics=metrics) - - @classmethod - def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List[str]], List[str]]: - if head is None: - heads = cls.available_heads() - else: - heads = [head] - - result = {} - for head in heads: - metadata = cls.heads.get(head, with_metadata=True) - backbones = metadata["metadata"]["backbones"] - result[head] = backbones.available_keys() - - if len(result) == 1: - result = next(iter(result.values()[0])) - return result - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: - checkpoint['_data_pipeline_state'] = self._data_pipeline_state + icevision_adapter = icevision_adapter(model=model, metrics=metrics) + return cls(model_type, model, icevision_adapter, backbone) def process_train_dataset( self, @@ -211,13 +164,13 @@ def process_predict_dataset( return dataset def training_step(self, batch, batch_idx) -> Any: - return self.adapter.training_step(batch, batch_idx) + return self.icevision_adapter.training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - return self.adapter.validation_step(batch, batch_idx) + return self.icevision_adapter.validation_step(batch, batch_idx) def test_step(self, batch, batch_idx): - return self.adapter.validation_step(batch, batch_idx) + return self.icevision_adapter.validation_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: return self(batch) @@ -229,10 +182,10 @@ def forward(self, batch: Any) -> Any: return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) def training_epoch_end(self, outputs) -> None: - return self.adapter.training_epoch_end(outputs) + return self.icevision_adapter.training_epoch_end(outputs) def validation_epoch_end(self, outputs) -> None: - return self.adapter.validation_epoch_end(outputs) + return self.icevision_adapter.validation_epoch_end(outputs) def test_epoch_end(self, outputs) -> None: - return self.adapter.validation_epoch_end(outputs) + return self.icevision_adapter.validation_epoch_end(outputs) diff --git a/flash/core/integrations/icevision/backbones.py b/flash/core/integrations/icevision/backbones.py index 82225d8eb9..831c84337a 100644 --- a/flash/core/integrations/icevision/backbones.py +++ b/flash/core/integrations/icevision/backbones.py @@ -21,8 +21,6 @@ if _ICEVISION_AVAILABLE: from icevision.backbones import BackboneConfig -OBJECT_DETECTION_HEADS = FlashRegistry("heads") - def icevision_model_adapter(model_type): diff --git a/flash/core/model.py b/flash/core/model.py index 21fa1a40f3..a4654cbd86 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import inspect +import pickle from abc import ABCMeta from copy import deepcopy from importlib import import_module @@ -24,6 +25,7 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -31,6 +33,7 @@ from torch.utils.data import DataLoader, Sampler import flash +from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource @@ -102,7 +105,7 @@ def __new__(mcs, *args, **kwargs): return result -class Task(LightningModule, metaclass=CheckDependenciesMeta): +class Task(Adapter, LightningModule, metaclass=CheckDependenciesMeta): """A general Task. Args: @@ -164,18 +167,6 @@ def __init__( self.deserializer = deserializer self.serializer = serializer - self._children = [] - - def __setattr__(self, key, value): - if isinstance(value, LightningModule): - self._children.append(key) - patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] - if isinstance(value, pl.Trainer) or key in patched_attributes: - if hasattr(self, "_children"): - for child in self._children: - setattr(getattr(self, child), key, value) - super().__setattr__(key, value) - def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """ The training/validation/test step. Override for custom behavior. @@ -535,7 +526,11 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: - checkpoint['data_pipeline'] = self.data_pipeline + try: + pickle.dumps(self.data_pipeline) # TODO: DataPipeline not always pickleable + checkpoint['data_pipeline'] = self.data_pipeline + except AttributeError: + rank_zero_warn("DataPipeline couldn't be added to the checkpoint.") if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: checkpoint['_data_pipeline_state'] = self._data_pipeline_state super().on_save_checkpoint(checkpoint) @@ -548,11 +543,27 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self._data_pipeline_state = checkpoint['_data_pipeline_state'] @classmethod - def available_backbones(cls) -> List[str]: - registry: Optional[FlashRegistry] = getattr(cls, "backbones", None) - if registry is None: - return [] - return registry.available_keys() + def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List[str]], List[str]]: + if head is None: + registry: Optional[FlashRegistry] = getattr(cls, "backbones", None) + if registry is not None: + return registry.available_keys() + heads = cls.available_heads() + else: + heads = [head] + + result = {} + for head in heads: + metadata = cls.heads.get(head, with_metadata=True)["metadata"] + if "backbones" in metadata: + backbones = metadata["backbones"].available_keys() + else: + backbones = cls.available_backbones() + result[head] = backbones + + if len(result) == 1: + result = next(iter(result.values())) + return result @classmethod def available_heads(cls) -> List[str]: @@ -711,29 +722,41 @@ def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): for state in self._state.values(): data_pipeline_state.set_state(state) - def _process_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> DataLoader: - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn - ) - return dataset + +class AdapterTask(Task): + + def __init__(self, adapter: Adapter, **kwargs): + super().__init__(**kwargs) + + self.adapter = adapter + + @property + def backbone(self) -> nn.Module: + return self.adapter.backbone + + def forward(self, x: Any) -> Any: + return self.adapter.forward(x) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + return self.adapter.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def training_epoch_end(self, outputs) -> None: + return self.adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.adapter.test_epoch_end(outputs) def process_train_dataset( self, @@ -746,15 +769,8 @@ def process_train_dataset( drop_last: bool = True, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_train_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) def process_val_dataset( @@ -768,15 +784,8 @@ def process_val_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_val_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) def process_test_dataset( @@ -787,18 +796,11 @@ def process_test_dataset( pin_memory: bool, collate_fn: Callable, shuffle: bool = False, - drop_last: bool = True, + drop_last: bool = False, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_test_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) def process_predict_dataset( @@ -813,14 +815,6 @@ def process_predict_dataset( sampler: Optional[Sampler] = None, convert_to_dataloader: bool = True ) -> Union[DataLoader, BaseAutoDataset]: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - convert_to_dataloader=convert_to_dataloader + return self.adapter.process_predict_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler, convert_to_dataloader ) diff --git a/flash/core/registry.py b/flash/core/registry.py index 7d62d10f38..8e0c83366f 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union @@ -30,6 +31,23 @@ def __str__(self): return f"{self.name} ({self.url})" +def print_provider_info(name, providers, func): + if not isinstance(providers, List): + providers = [providers] + providers = list(providers) + if len(providers) > 1: + providers[-2] = f"{str(providers[-2])} and {str(providers[-1])}" + providers = providers[:-1] + message = f"Using '{name}' provided by {', '.join(str(provider) for provider in providers)}." + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank_zero_info(message) + return func(*args, **kwargs) + + return wrapper + + class FlashRegistry: """This class is used to register function or :class:`functools.partial` class to a registry.""" @@ -72,16 +90,6 @@ def get( if not matches: raise KeyError("Found no matches that fit your metadata criteria. Try removing some metadata") - for match in matches: - if "providers" in match["metadata"]: - providers = match["metadata"]["providers"] - if not isinstance(providers, List): - providers = [providers] - if len(providers) > 1: - providers[-2] = f"{str(providers[-2])} and {str(providers[-1])}" - providers = providers[:-1] - rank_zero_info(f"Using '{key}' provided by {', '.join(str(provider) for provider in providers)}.") - matches = [e if with_metadata else e["fn"] for e in matches] return matches[0] if strict else matches @@ -103,6 +111,10 @@ def _register_function( if self._verbose: rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}") + if "providers" in metadata: + providers = metadata["providers"] + fn = print_provider_info(name, providers, fn) + item = {"fn": fn, "name": name, "metadata": metadata or {}} matching_index = self._find_matching_index(item) diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py index 929c1aa6d9..0a9a8bbe06 100644 --- a/flash/image/detection/backbones.py +++ b/flash/image/detection/backbones.py @@ -12,24 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Optional import torch +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric from flash.core.integrations.icevision.backbones import ( get_backbones, icevision_model_adapter, load_icevision_ignore_image_size, load_icevision_with_image_size, ) +from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS if _ICEVISION_AVAILABLE: from icevision import models as icevision_models + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric OBJECT_DETECTION_HEADS = FlashRegistry("heads") + +class IceVisionObjectDetectionAdapter(IceVisionAdapter): + + @classmethod + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str = "resnet18_fpn", + head: str = "retinanet", + pretrained: bool = True, + metrics: Optional['IceVisionMetric'] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.bbox)], + image_size=image_size, + **kwargs + ) + + if _ICEVISION_AVAILABLE: if _TORCHVISION_AVAILABLE: for model_type in [icevision_models.torchvision.retinanet, icevision_models.torchvision.faster_rcnn]: @@ -37,6 +70,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, providers=[_ICEVISION, _TORCHVISION], ) @@ -46,6 +80,7 @@ partial(load_icevision_with_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, providers=[_ICEVISION, _ULTRALYTICS], ) @@ -60,6 +95,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, providers=[_ICEVISION, _MMDET], ) @@ -83,5 +119,6 @@ def validation_step(self, batch, batch_idx): partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, providers=[_ICEVISION, _EFFDET], ) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 816690f6e9..aa1cdc0180 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -17,18 +17,12 @@ from torch.optim import Optimizer from flash.core.data.process import Serializer -from flash.core.integrations.icevision.model import IceVisionTask, SimpleCOCOMetric +from flash.core.model import AdapterTask from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.detection.backbones import OBJECT_DETECTION_HEADS -from flash.image.detection.serialization import DetectionLabels -if _ICEVISION_AVAILABLE: - from icevision.metrics import COCOMetricType - from icevision.metrics import Metric as IceVisionMetric - -class ObjectDetector(IceVisionTask): +class ObjectDetector(AdapterTask): """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see :ref:`object_detection`. @@ -62,31 +56,32 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, - optimizer: Type[Optimizer] = torch.optim.AdamW, + optimizer: Type[Optimizer] = torch.optim.Adam, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, - image_size: Optional[int] = None, **kwargs: Any, ): self.save_hyperparameters() - super().__init__( + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, num_classes=num_classes, backbone=backbone, head=head, pretrained=pretrained, - metrics=metrics or [SimpleCOCOMetric(COCOMetricType.bbox)], - image_size=image_size, + **kwargs, + ) + + super().__init__( + adapter, learning_rate=learning_rate, optimizer=optimizer, - serializer=serializer or DetectionLabels(), - **kwargs, + serializer=serializer, ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """ This function is used only for debugging usage with CI """ - # todo (tchaton) Improve convergence - # history[-1]["val_iou"] + # todo diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py index 5a95242b82..aca88e8b6a 100644 --- a/flash/image/instance_segmentation/backbones.py +++ b/flash/image/instance_segmentation/backbones.py @@ -12,21 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Optional +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric from flash.core.integrations.icevision.backbones import ( get_backbones, icevision_model_adapter, load_icevision_ignore_image_size, ) +from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE from flash.core.utilities.providers import _ICEVISION, _MMDET, _TORCHVISION if _ICEVISION_AVAILABLE: from icevision import models as icevision_models + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric INSTANCE_SEGMENTATION_HEADS = FlashRegistry("heads") + +class IceVisionInstanceSegmentationAdapter(IceVisionAdapter): + + @classmethod + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str = "resnet18_fpn", + head: str = "mask_rcnn", + pretrained: bool = True, + metrics: Optional['IceVisionMetric'] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.mask)], + image_size=image_size, + **kwargs + ) + + if _ICEVISION_AVAILABLE: if _TORCHVISION_AVAILABLE: model_type = icevision_models.torchvision.mask_rcnn @@ -34,6 +67,7 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + adapter=IceVisionInstanceSegmentationAdapter, providers=[_ICEVISION, _TORCHVISION] ) @@ -43,5 +77,6 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), + adapter=IceVisionInstanceSegmentationAdapter, providers=[_ICEVISION, _MMDET] ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 3b339a54ed..221bf0955c 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -17,17 +17,12 @@ from torch.optim import Optimizer from flash.core.data.process import Serializer -from flash.core.integrations.icevision.model import IceVisionTask, SimpleCOCOMetric +from flash.core.model import AdapterTask from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS -if _ICEVISION_AVAILABLE: - from icevision.metrics import COCOMetricType - from icevision.metrics import Metric as IceVisionMetric - -class InstanceSegmentation(IceVisionTask): +class InstanceSegmentation(AdapterTask): """The ``InstanceSegmentation`` is a :class:`~flash.Task` for detecting objects in images. For more details, see :ref:`object_detection`. @@ -61,31 +56,32 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "mask_rcnn", pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, optimizer: Type[Optimizer] = torch.optim.Adam, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, - image_size: Optional[int] = None, **kwargs: Any, ): self.save_hyperparameters() - super().__init__( + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, num_classes=num_classes, backbone=backbone, head=head, pretrained=pretrained, - metrics=metrics or [SimpleCOCOMetric(COCOMetricType.mask)], - image_size=image_size, + **kwargs, + ) + + super().__init__( + adapter, learning_rate=learning_rate, optimizer=optimizer, serializer=serializer, - **kwargs, ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """ This function is used only for debugging usage with CI """ - # todo (tchaton) Improve convergence - # history[-1]["val_iou"] + # todo diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py index 1fe4174282..5da4bd7d3c 100644 --- a/flash/image/keypoint_detection/backbones.py +++ b/flash/image/keypoint_detection/backbones.py @@ -12,21 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Optional +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter from flash.core.integrations.icevision.backbones import ( get_backbones, icevision_model_adapter, load_icevision_ignore_image_size, ) +from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.providers import _ICEVISION, _TORCHVISION if _ICEVISION_AVAILABLE: from icevision import models as icevision_models + from icevision.metrics import Metric as IceVisionMetric KEYPOINT_DETECTION_HEADS = FlashRegistry("heads") + +class IceVisionKeypointDetectionAdapter(IceVisionAdapter): + + @classmethod + def from_task( + cls, + task: Task, + num_keypoints: int, + num_classes: int = 2, + backbone: str = "resnet18_fpn", + head: str = "keypoint_rcnn", + pretrained: bool = True, + metrics: Optional['IceVisionMetric'] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_keypoints=num_keypoints, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics, + image_size=image_size, + **kwargs + ) + + if _ICEVISION_AVAILABLE: if _TORCHVISION_AVAILABLE: model_type = icevision_models.torchvision.keypoint_rcnn @@ -34,5 +68,6 @@ partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), + adapter=IceVisionKeypointDetectionAdapter, providers=[_ICEVISION, _TORCHVISION] ) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index a5f9802735..d648e28680 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -17,16 +17,12 @@ from torch.optim import Optimizer from flash.core.data.process import Serializer -from flash.core.integrations.icevision.model import IceVisionTask +from flash.core.model import AdapterTask from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS -if _ICEVISION_AVAILABLE: - from icevision.metrics import Metric as IceVisionMetric - -class KeypointDetector(IceVisionTask): +class KeypointDetector(AdapterTask): """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see :ref:`object_detection`. @@ -61,32 +57,33 @@ def __init__( backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "keypoint_rcnn", pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, optimizer: Type[Optimizer] = torch.optim.Adam, learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, - image_size: Optional[int] = None, **kwargs: Any, ): self.save_hyperparameters() - super().__init__( - num_classes=num_classes, + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, num_keypoints=num_keypoints, + num_classes=num_classes, backbone=backbone, head=head, pretrained=pretrained, - metrics=metrics, - image_size=image_size, + **kwargs, + ) + + super().__init__( + adapter, learning_rate=learning_rate, optimizer=optimizer, serializer=serializer, - **kwargs, ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """ This function is used only for debugging usage with CI """ - # todo (tchaton) Improve convergence - # history[-1]["val_iou"] + # todo diff --git a/tests/core/test_model.py b/tests/core/test_model.py index eb04ecdb68..17115aa03e 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -28,6 +28,7 @@ from torch.utils.data import DataLoader import flash +from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE @@ -124,6 +125,32 @@ def __init__(self, child): super().__init__(Parent(child)) +class BasicAdapter(Adapter): + + def __init__(self, child): + super().__init__() + + self.child = child + + def training_step(self, batch, batch_idx): + return self.child.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.child.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.child.test_step(batch, batch_idx) + + def forward(self, x): + return self.child(x) + + +class AdapterParent(Parent): + + def __init__(self, child): + super().__init__(BasicAdapter(child)) + + # ================================ @@ -139,7 +166,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] -@pytest.mark.parametrize("task", [Parent, GrandParent]) +@pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent]) def test_nested_tasks(tmpdir, task): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) From ad7722e33a79f1309a058c7843c3cf43e1b2a6e0 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 29 Jul 2021 13:00:50 +0100 Subject: [PATCH 18/46] Fix a test --- tests/core/data/test_callback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index e11591f33a..4b1ee25863 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -23,8 +23,9 @@ from flash.core.trainer import Trainer +@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_flash_callback(_, tmpdir): +def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" callback_mock = MagicMock() From 22afaae78ba25e19646b657614dfac14b9c64e40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jul 2021 12:19:33 +0000 Subject: [PATCH 19/46] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/core/integrations/icevision/transforms.py | 6 ++++-- flash/image/detection/data.py | 12 ++++++------ flash/image/detection/model.py | 4 +--- flash/image/instance_segmentation/data.py | 8 ++++---- flash/image/instance_segmentation/model.py | 4 +--- flash/image/keypoint_detection/data.py | 4 ++-- flash/image/keypoint_detection/model.py | 4 +--- 7 files changed, 19 insertions(+), 23 deletions(-) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index e70dc9118d..22692be89a 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -28,7 +28,8 @@ def collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: @requires_extras("image") def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: - """The default transforms for object detection: convert the image and targets to a tensor, collate the batch.""" + """The default transforms for object detection: convert the image and targets to a tensor, collate the + batch.""" return { # "pre_tensor_transform": ApplyToKeys( # DefaultDataKeys.INPUT, @@ -40,7 +41,8 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: @requires_extras("image") def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: - """The default transforms for object detection: convert the image and targets to a tensor, collate the batch.""" + """The default transforms for object detection: convert the image and targets to a tensor, collate the + batch.""" return { # "pre_tensor_transform": ApplyToKeys( # DefaultDataKeys.INPUT, diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 5a535024a8..1738132e2f 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -203,8 +203,8 @@ def from_coco( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and - annotation files in the COCO format. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the COCO format. Args: train_folder: The folder containing the train data. @@ -282,8 +282,8 @@ def from_voc( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and - annotation files in the VOC format. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the VOC format. Args: train_folder: The folder containing the train data. @@ -361,8 +361,8 @@ def from_via( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and - annotation files in the VIA format. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the VIA format. Args: train_folder: The folder containing the train data. diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index aa1cdc0180..3d3974d8cc 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -81,7 +81,5 @@ def __init__( ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" # todo diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index ca21552d25..32e9d91693 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -97,8 +97,8 @@ def from_coco( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given - data folders and annotation files in the COCO format. + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the + given data folders and annotation files in the COCO format. Args: train_folder: The folder containing the train data. @@ -176,8 +176,8 @@ def from_voc( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given - data folders and annotation files in the VOC format. + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the + given data folders and annotation files in the VOC format. Args: train_folder: The folder containing the train data. diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 221bf0955c..8c9c355362 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -81,7 +81,5 @@ def __init__( ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" # todo diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 60b6b33fdb..3bd7f2c6f2 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -96,8 +96,8 @@ def from_coco( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given - data folders and annotation files in the COCO format. + """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data + folders and annotation files in the COCO format. Args: train_folder: The folder containing the train data. diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index d648e28680..641652189f 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -83,7 +83,5 @@ def __init__( ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" # todo From e19b4c27aa19026363f9bf21e6398d4ef8e47b21 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 4 Aug 2021 20:15:05 +0100 Subject: [PATCH 20/46] Updates --- flash/core/data/data_pipeline.py | 14 +++++++------- flash/core/model.py | 5 +++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 6598a8d923..0329471e26 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -163,8 +163,13 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]: def deserialize_processor(self) -> _DeserializeProcessor: return self._create_collate_preprocessors(RunningStage.PREDICTING)[0] - def worker_preprocessor(self, running_stage: RunningStage, is_serving: bool = False) -> _Preprocessor: - return self._create_collate_preprocessors(running_stage, is_serving=is_serving)[1] + def worker_preprocessor( + self, + running_stage: RunningStage, + collate_fn: Optional[Callable] = None, + is_serving: bool = False + ) -> _Preprocessor: + return self._create_collate_preprocessors(running_stage, collate_fn=collate_fn, is_serving=is_serving)[1] def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor: return self._create_collate_preprocessors(running_stage)[2] @@ -219,7 +224,6 @@ def _create_collate_preprocessors( prefix: str = _STAGES_PREFIX[stage] if collate_fn is not None: - preprocess._original_default_collate = preprocess._default_collate preprocess._default_collate = collate_fn func_names: Dict[str, str] = { @@ -481,10 +485,6 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin elif isinstance(stage, RunningStage): stages = [stage] - self._preprocess_pipeline._default_collate = getattr( - self._preprocess_pipeline, "_original_default_collate", self._preprocess_pipeline._default_collate - ) - for stage in stages: device_collate = None diff --git a/flash/core/model.py b/flash/core/model.py index 9506aa1a4c..2495078423 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -255,8 +255,9 @@ def predict( data_pipeline = self.build_data_pipeline(data_source or "default", deserializer, data_pipeline) dataset = data_pipeline.data_source.generate_dataset(x, running_stage) - x = list(self.process_predict_dataset(dataset, convert_to_dataloader=False)) - x = data_pipeline.worker_preprocessor(running_stage)(x) + dataloader = self.process_predict_dataset(dataset, convert_to_dataloader=True) + x = list(dataloader.dataset) + x = data_pipeline.worker_preprocessor(running_stage, collate_fn=dataloader.collate_fn)(x) # todo (tchaton): Remove this when sync with Lightning master. if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3: x = self.transfer_batch_to_device(x, self.device, 0) From 4cf6332914f5667a61d9da2bd6aa9493df6ae0f8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 4 Aug 2021 21:40:44 +0100 Subject: [PATCH 21/46] Fixes --- flash/core/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 642c449802..5650821b5e 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -165,7 +165,7 @@ def process_predict_dataset( batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, - collate_fn: Callable = lambda x: x, + collate_fn: Callable = None, shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, From 858acdb98a3391ced45a8e168ea564cc00c069f1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 4 Aug 2021 22:05:25 +0100 Subject: [PATCH 22/46] Refactor --- flash/core/adapter.py | 33 +++++++++----------- flash/core/data/data_module.py | 3 +- flash/core/integrations/icevision/adapter.py | 27 +++++++--------- flash/core/model.py | 9 +++--- flash/pointcloud/detection/model.py | 29 +++++++---------- flash/pointcloud/segmentation/model.py | 29 +++++++---------- 6 files changed, 55 insertions(+), 75 deletions(-) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 5650821b5e..90cbbc4f86 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional from pytorch_lightning import LightningModule, Trainer from torch import nn @@ -78,20 +78,17 @@ def _process_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, + sampler: Optional[Sampler] = None ) -> DataLoader: - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn - ) - return dataset + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn + ) def process_train_dataset( self, @@ -168,9 +165,8 @@ def process_predict_dataset( collate_fn: Callable = None, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True - ) -> Union[DataLoader, BaseAutoDataset]: + sampler: Optional[Sampler] = None + ) -> DataLoader: return self._process_dataset( dataset, batch_size=batch_size, @@ -179,6 +175,5 @@ def process_predict_dataset( collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, - sampler=sampler, - convert_to_dataloader=convert_to_dataloader + sampler=sampler ) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 32c0f0fd82..403eb64cfa 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -366,8 +366,7 @@ def _predict_dataloader(self) -> DataLoader: batch_size=batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn, - convert_to_dataloader=True, + collate_fn=collate_fn ) return DataLoader( diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index d9e9b47279..d2b4150a4d 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional from torch.utils.data import DataLoader, Sampler @@ -148,20 +148,17 @@ def process_predict_dataset( collate_fn: Callable = lambda x: x, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True - ) -> Union[DataLoader, BaseAutoDataset]: - if convert_to_dataloader: - return self.model_type.infer_dl( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - return dataset + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self.model_type.infer_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) def training_step(self, batch, batch_idx) -> Any: return self.icevision_adapter.training_step(batch, batch_idx) diff --git a/flash/core/model.py b/flash/core/model.py index 2495078423..e02ad065ef 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -255,7 +255,7 @@ def predict( data_pipeline = self.build_data_pipeline(data_source or "default", deserializer, data_pipeline) dataset = data_pipeline.data_source.generate_dataset(x, running_stage) - dataloader = self.process_predict_dataset(dataset, convert_to_dataloader=True) + dataloader = self.process_predict_dataset(dataset) x = list(dataloader.dataset) x = data_pipeline.worker_preprocessor(running_stage, collate_fn=dataloader.collate_fn)(x) # todo (tchaton): Remove this when sync with Lightning master. @@ -817,9 +817,8 @@ def process_predict_dataset( collate_fn: Callable = lambda x: x, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True - ) -> Union[DataLoader, BaseAutoDataset]: + sampler: Optional[Sampler] = None + ) -> DataLoader: return self.adapter.process_predict_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler, convert_to_dataloader + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index d1abee600a..a123cf5c3e 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -161,9 +161,8 @@ def _process_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> Union[DataLoader, BaseAutoDataset]: + sampler: Optional[Sampler] = None + ) -> DataLoader: if not _POINTCLOUD_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") @@ -171,17 +170,13 @@ def _process_dataset( dataset.preprocess_fn = self.model.preprocess dataset.transform_fn = self.model.transform - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - else: - return dataset + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index b6de290b25..099d3a39ca 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -191,9 +191,8 @@ def _process_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True, - ) -> Union[DataLoader, BaseAutoDataset]: + sampler: Optional[Sampler] = None + ) -> DataLoader: if not _POINTCLOUD_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") @@ -207,20 +206,16 @@ def _process_dataset( use_cache=False, ) - if convert_to_dataloader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler, - ) - - else: - return dataset + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) def configure_finetune_callback(self) -> List[Callback]: return [PointCloudSegmentationFinetuning()] From 7c6fb2f6fd5bf3ea67e2c2cecf10fef1556e317a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 4 Aug 2021 22:17:11 +0100 Subject: [PATCH 23/46] Refactor --- flash/core/adapter.py | 61 +++++++++++++++++++++++-------------------- flash/core/model.py | 4 +-- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 90cbbc4f86..ca7efc6e91 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -39,35 +39,7 @@ def __setattr__(self, key, value): super().__setattr__(key, value) -class Adapter(Wrapper, nn.Module): - - @classmethod - def from_task(cls, task: 'flash.Task', **kwargs) -> 'Adapter': - pass - - def forward(self, x: Any) -> Any: - pass - - def training_step(self, batch: Any, batch_idx: int) -> Any: - pass - - def validation_step(self, batch: Any, batch_idx: int) -> None: - pass - - def test_step(self, batch: Any, batch_idx: int) -> None: - pass - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - pass - - def training_epoch_end(self, outputs) -> None: - pass - - def validation_epoch_end(self, outputs) -> None: - pass - - def test_epoch_end(self, outputs) -> None: - pass +class DatasetProcessor: def _process_dataset( self, @@ -177,3 +149,34 @@ def process_predict_dataset( drop_last=drop_last, sampler=sampler ) + + +class Adapter(DatasetProcessor, Wrapper, nn.Module): + + @classmethod + def from_task(cls, task: 'flash.Task', **kwargs) -> 'Adapter': + pass + + def forward(self, x: Any) -> Any: + pass + + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def validation_step(self, batch: Any, batch_idx: int) -> None: + pass + + def test_step(self, batch: Any, batch_idx: int) -> None: + pass + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + pass + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass diff --git a/flash/core/model.py b/flash/core/model.py index e02ad065ef..9f67a20f56 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -33,7 +33,7 @@ from torch.utils.data import DataLoader, Sampler import flash -from flash.core.adapter import Adapter +from flash.core.adapter import Adapter, DatasetProcessor, Wrapper from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource @@ -103,7 +103,7 @@ def __new__(mcs, *args, **kwargs): return result -class Task(Adapter, LightningModule, metaclass=CheckDependenciesMeta): +class Task(DatasetProcessor, Wrapper, LightningModule, metaclass=CheckDependenciesMeta): """A general Task. Args: From 89f697835eaf434932a7333acbd801d21340957e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 4 Aug 2021 22:46:07 +0100 Subject: [PATCH 24/46] Refactor --- flash/core/adapter.py | 176 +++++++--------- flash/core/model.py | 229 ++++++++++++--------- flash/image/detection/model.py | 2 +- flash/image/instance_segmentation/model.py | 2 +- flash/image/keypoint_detection/model.py | 2 +- 5 files changed, 206 insertions(+), 205 deletions(-) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index ca7efc6e91..f63579dc66 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -11,56 +11,83 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import abstractmethod from typing import Any, Callable, Optional -from pytorch_lightning import LightningModule, Trainer from torch import nn from torch.utils.data import DataLoader, Sampler import flash from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.model import DatasetProcessor, Task, Wrapper -class Wrapper: +class Adapter(DatasetProcessor, Wrapper, nn.Module): - def __init__(self): - super().__init__() + @classmethod + @abstractmethod + def from_task(cls, task: 'flash.Task', **kwargs) -> 'Adapter': + pass - self._children = [] + def forward(self, x: Any) -> Any: + pass - def __setattr__(self, key, value): - if isinstance(value, (LightningModule, Adapter)): - self._children.append(key) - patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] - if isinstance(value, Trainer) or key in patched_attributes: - if hasattr(self, "_children"): - for child in self._children: - setattr(getattr(self, child), key, value) - super().__setattr__(key, value) + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + def validation_step(self, batch: Any, batch_idx: int) -> None: + pass -class DatasetProcessor: + def test_step(self, batch: Any, batch_idx: int) -> None: + pass - def _process_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn - ) + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + pass + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass + + +class AdapterTask(Task): + + def __init__(self, adapter: Adapter, **kwargs): + super().__init__(**kwargs) + + self.adapter = adapter + + @property + def backbone(self) -> nn.Module: + return self.adapter.backbone + + def forward(self, x: Any) -> Any: + return self.adapter.forward(x) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + return self.adapter.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def training_epoch_end(self, outputs) -> None: + return self.adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.adapter.test_epoch_end(outputs) def process_train_dataset( self, @@ -73,15 +100,8 @@ def process_train_dataset( drop_last: bool = True, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_train_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) def process_val_dataset( @@ -95,15 +115,8 @@ def process_val_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_val_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) def process_test_dataset( @@ -114,18 +127,11 @@ def process_test_dataset( pin_memory: bool, collate_fn: Callable, shuffle: bool = False, - drop_last: bool = True, + drop_last: bool = False, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_test_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) def process_predict_dataset( @@ -134,49 +140,11 @@ def process_predict_dataset( batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, - collate_fn: Callable = None, + collate_fn: Callable = lambda x: x, shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None ) -> DataLoader: - return self._process_dataset( - dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - collate_fn=collate_fn, - shuffle=shuffle, - drop_last=drop_last, - sampler=sampler + return self.adapter.process_predict_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler ) - - -class Adapter(DatasetProcessor, Wrapper, nn.Module): - - @classmethod - def from_task(cls, task: 'flash.Task', **kwargs) -> 'Adapter': - pass - - def forward(self, x: Any) -> Any: - pass - - def training_step(self, batch: Any, batch_idx: int) -> Any: - pass - - def validation_step(self, batch: Any, batch_idx: int) -> None: - pass - - def test_step(self, batch: Any, batch_idx: int) -> None: - pass - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - pass - - def training_epoch_end(self, outputs) -> None: - pass - - def validation_epoch_end(self, outputs) -> None: - pass - - def test_epoch_end(self, outputs) -> None: - pass diff --git a/flash/core/model.py b/flash/core/model.py index 9f67a20f56..2505e99465 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -22,7 +22,7 @@ import pytorch_lightning as pl import torch import torchmetrics -from pytorch_lightning import LightningModule +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn @@ -33,7 +33,6 @@ from torch.utils.data import DataLoader, Sampler import flash -from flash.core.adapter import Adapter, DatasetProcessor, Wrapper from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource @@ -53,6 +52,136 @@ from flash.core.utilities.imports import requires_extras +class Wrapper: + + def __init__(self): + super().__init__() + + self._children = [] + + def __setattr__(self, key, value): + if isinstance(value, (LightningModule, Wrapper)): + self._children.append(key) + patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] + if isinstance(value, Trainer) or key in patched_attributes: + if hasattr(self, "_children"): + for child in self._children: + setattr(getattr(self, child), key, value) + super().__setattr__(key, value) + + +class DatasetProcessor: + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn + ) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = None, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + class BenchmarkConvergenceCI(Callback): def __init__(self): @@ -726,99 +855,3 @@ def set_state(self, state: ProcessState): def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): for state in self._state.values(): data_pipeline_state.set_state(state) - - -class AdapterTask(Task): - - def __init__(self, adapter: Adapter, **kwargs): - super().__init__(**kwargs) - - self.adapter = adapter - - @property - def backbone(self) -> nn.Module: - return self.adapter.backbone - - def forward(self, x: Any) -> Any: - return self.adapter.forward(x) - - def training_step(self, batch: Any, batch_idx: int) -> Any: - return self.adapter.training_step(batch, batch_idx) - - def validation_step(self, batch: Any, batch_idx: int) -> None: - return self.adapter.validation_step(batch, batch_idx) - - def test_step(self, batch: Any, batch_idx: int) -> None: - return self.adapter.test_step(batch, batch_idx) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) - - def training_epoch_end(self, outputs) -> None: - return self.adapter.training_epoch_end(outputs) - - def validation_epoch_end(self, outputs) -> None: - return self.adapter.validation_epoch_end(outputs) - - def test_epoch_end(self, outputs) -> None: - return self.adapter.test_epoch_end(outputs) - - def process_train_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.adapter.process_train_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler - ) - - def process_val_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.adapter.process_val_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler - ) - - def process_test_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int, - num_workers: int, - pin_memory: bool, - collate_fn: Callable, - shuffle: bool = False, - drop_last: bool = False, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.adapter.process_test_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler - ) - - def process_predict_dataset( - self, - dataset: BaseAutoDataset, - batch_size: int = 1, - num_workers: int = 0, - pin_memory: bool = False, - collate_fn: Callable = lambda x: x, - shuffle: bool = False, - drop_last: bool = True, - sampler: Optional[Sampler] = None - ) -> DataLoader: - return self.adapter.process_predict_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler - ) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 3d3974d8cc..cc5e2bb545 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -16,8 +16,8 @@ import torch from torch.optim import Optimizer +from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer -from flash.core.model import AdapterTask from flash.core.registry import FlashRegistry from flash.image.detection.backbones import OBJECT_DETECTION_HEADS diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 8c9c355362..52f2706554 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -16,8 +16,8 @@ import torch from torch.optim import Optimizer +from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer -from flash.core.model import AdapterTask from flash.core.registry import FlashRegistry from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 641652189f..b85177d083 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -16,8 +16,8 @@ import torch from torch.optim import Optimizer +from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer -from flash.core.model import AdapterTask from flash.core.registry import FlashRegistry from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS From 53e171e49b5fbc9aeb3ffb2774b8a644c47bd1fe Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 6 Aug 2021 12:01:52 +0100 Subject: [PATCH 25/46] minor changes --- flash/image/detection/model.py | 2 +- flash/pointcloud/detection/data.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index cc5e2bb545..c2bcd606f6 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -29,7 +29,7 @@ class ObjectDetector(AdapterTask): Args: num_classes: the number of classes for detection, including background model: a string of :attr`_models`. Defaults to 'fasterrcnn'. - backbone: Pretained backbone CNN architecture. Constructs a model with a + backbone: Pretrained backbone CNN architecture. Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. pretrained: if true, returns a model pre-trained on COCO train2017 diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 4527eba22b..fcb8ad304e 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -6,7 +6,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras if _POINTCLOUD_AVAILABLE: from flash.pointcloud.detection.open3d_ml.data_sources import ( @@ -14,7 +14,7 @@ PointCloudObjectDetectorFoldersDataSource, ) else: - PointCloudObjectDetectorFoldersDataSource = object() + PointCloudObjectDetectorFoldersDataSource = object class PointCloudObjectDetectionDataFormat: KITTI = None @@ -46,6 +46,7 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: class PointCloudObjectDetectorPreprocess(Preprocess): + @requires_extras("pointcloud") def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, From cb3a2f05840d1b73a3389b015232664267f33f1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Aug 2021 11:39:28 +0000 Subject: [PATCH 26/46] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/core/adapter.py | 12 +++++------- flash/core/integrations/icevision/adapter.py | 11 +++++------ flash/core/integrations/icevision/backbones.py | 2 -- flash/core/integrations/icevision/data.py | 5 +---- flash/image/detection/backbones.py | 12 +++++------- flash/image/instance_segmentation/backbones.py | 9 ++++----- flash/image/instance_segmentation/data.py | 1 - flash/image/keypoint_detection/backbones.py | 7 +++---- flash/image/keypoint_detection/data.py | 1 - flash_examples/instance_segmentation.py | 12 +++++++----- flash_examples/keypoint_detection.py | 12 +++++++----- 11 files changed, 37 insertions(+), 47 deletions(-) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index f63579dc66..d55051a148 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -23,10 +23,9 @@ class Adapter(DatasetProcessor, Wrapper, nn.Module): - @classmethod @abstractmethod - def from_task(cls, task: 'flash.Task', **kwargs) -> 'Adapter': + def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter": pass def forward(self, x: Any) -> Any: @@ -55,7 +54,6 @@ def test_epoch_end(self, outputs) -> None: class AdapterTask(Task): - def __init__(self, adapter: Adapter, **kwargs): super().__init__(**kwargs) @@ -98,7 +96,7 @@ def process_train_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_train_dataset( dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler @@ -113,7 +111,7 @@ def process_val_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = False, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_val_dataset( dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler @@ -128,7 +126,7 @@ def process_test_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = False, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_test_dataset( dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler @@ -143,7 +141,7 @@ def process_predict_dataset( collate_fn: Callable = lambda x: x, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_predict_dataset( dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index d2b4150a4d..5d3201116a 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -30,7 +30,6 @@ class SimpleCOCOMetric(COCOMetric): - def finalize(self) -> Dict[str, float]: logs = super().finalize() return { @@ -60,7 +59,7 @@ def from_task( backbone: str, head: str, pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, + metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, **kwargs, ) -> Adapter: @@ -85,7 +84,7 @@ def process_train_dataset( collate_fn: Optional[Callable] = None, shuffle: bool = False, drop_last: bool = False, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.model_type.train_dl( dataset, @@ -106,7 +105,7 @@ def process_val_dataset( collate_fn: Optional[Callable] = None, shuffle: bool = False, drop_last: bool = False, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.model_type.valid_dl( dataset, @@ -127,7 +126,7 @@ def process_test_dataset( collate_fn: Optional[Callable] = None, shuffle: bool = False, drop_last: bool = False, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.model_type.valid_dl( dataset, @@ -148,7 +147,7 @@ def process_predict_dataset( collate_fn: Callable = lambda x: x, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self.model_type.infer_dl( dataset, diff --git a/flash/core/integrations/icevision/backbones.py b/flash/core/integrations/icevision/backbones.py index 831c84337a..dd30d3be56 100644 --- a/flash/core/integrations/icevision/backbones.py +++ b/flash/core/integrations/icevision/backbones.py @@ -23,9 +23,7 @@ def icevision_model_adapter(model_type): - class IceVisionModelAdapter(model_type.lightning.ModelAdapter): - def log(self, name, value, **kwargs): if "prog_bar" not in kwargs: kwargs["prog_bar"] = True diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index f3d726e2fd..fb9813658a 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -26,7 +26,6 @@ class IceVisionPathsDataSource(ImagePathsDataSource): - def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: return super().predict_load_data(data, dataset) @@ -44,8 +43,7 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class IceVisionParserDataSource(IceVisionPathsDataSource): - - def __init__(self, parser: Optional[Type['Parser']] = None): + def __init__(self, parser: Optional[Type["Parser"]] = None): super().__init__() self.parser = parser @@ -62,7 +60,6 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq class IceDataParserDataSource(IceVisionPathsDataSource): - def __init__(self, parser: Optional[Callable] = None): super().__init__() self.parser = parser diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py index 0a9a8bbe06..c3e9d5cfad 100644 --- a/flash/image/detection/backbones.py +++ b/flash/image/detection/backbones.py @@ -38,7 +38,6 @@ class IceVisionObjectDetectionAdapter(IceVisionAdapter): - @classmethod def from_task( cls, @@ -47,7 +46,7 @@ def from_task( backbone: str = "resnet18_fpn", head: str = "retinanet", pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, + metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, **kwargs, ) -> Adapter: @@ -59,7 +58,7 @@ def from_task( pretrained=pretrained, metrics=metrics or [SimpleCOCOMetric(COCOMetricType.bbox)], image_size=image_size, - **kwargs + **kwargs, ) @@ -102,14 +101,13 @@ def from_task( if _module_available("effdet"): def _icevision_effdet_model_adapter(model_type): - class IceVisionEffdetModelAdapter(icevision_model_adapter(model_type)): - def validation_step(self, batch, batch_idx): images = batch[0][0] batch[0][1]["img_scale"] = torch.ones_like(images[:, 0, 0, 0]).unsqueeze(1) - batch[0][1]["img_size"] = (torch.ones_like(images[:, 0, 0, 0]) * - images[0].shape[-1]).unsqueeze(1).repeat(1, 2) + batch[0][1]["img_size"] = ( + (torch.ones_like(images[:, 0, 0, 0]) * images[0].shape[-1]).unsqueeze(1).repeat(1, 2) + ) return super().validation_step(batch, batch_idx) return IceVisionEffdetModelAdapter diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py index aca88e8b6a..9811d6fa78 100644 --- a/flash/image/instance_segmentation/backbones.py +++ b/flash/image/instance_segmentation/backbones.py @@ -35,7 +35,6 @@ class IceVisionInstanceSegmentationAdapter(IceVisionAdapter): - @classmethod def from_task( cls, @@ -44,7 +43,7 @@ def from_task( backbone: str = "resnet18_fpn", head: str = "mask_rcnn", pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, + metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, **kwargs, ) -> Adapter: @@ -56,7 +55,7 @@ def from_task( pretrained=pretrained, metrics=metrics or [SimpleCOCOMetric(COCOMetricType.mask)], image_size=image_size, - **kwargs + **kwargs, ) @@ -68,7 +67,7 @@ def from_task( model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionInstanceSegmentationAdapter, - providers=[_ICEVISION, _TORCHVISION] + providers=[_ICEVISION, _TORCHVISION], ) if _module_available("mmdet"): @@ -78,5 +77,5 @@ def from_task( f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), adapter=IceVisionInstanceSegmentationAdapter, - providers=[_ICEVISION, _MMDET] + providers=[_ICEVISION, _MMDET], ) diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 32e9d91693..b67e606683 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -30,7 +30,6 @@ class InstanceSegmentationPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py index 5da4bd7d3c..72334761f2 100644 --- a/flash/image/keypoint_detection/backbones.py +++ b/flash/image/keypoint_detection/backbones.py @@ -34,7 +34,6 @@ class IceVisionKeypointDetectionAdapter(IceVisionAdapter): - @classmethod def from_task( cls, @@ -44,7 +43,7 @@ def from_task( backbone: str = "resnet18_fpn", head: str = "keypoint_rcnn", pretrained: bool = True, - metrics: Optional['IceVisionMetric'] = None, + metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, **kwargs, ) -> Adapter: @@ -57,7 +56,7 @@ def from_task( pretrained=pretrained, metrics=metrics, image_size=image_size, - **kwargs + **kwargs, ) @@ -69,5 +68,5 @@ def from_task( model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionKeypointDetectionAdapter, - providers=[_ICEVISION, _TORCHVISION] + providers=[_ICEVISION, _TORCHVISION], ) diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 3bd7f2c6f2..48e4b06a44 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -30,7 +30,6 @@ class KeypointDetectionPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py index 5e451dcb9b..16e5699d14 100644 --- a/flash_examples/instance_segmentation.py +++ b/flash_examples/instance_segmentation.py @@ -43,11 +43,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -predictions = model.predict([ - str(data_dir / "images/yorkshire_terrier_9.jpg"), - str(data_dir / "images/english_cocker_spaniel_1.jpg"), - str(data_dir / "images/scottish_terrier_1.jpg"), -]) +predictions = model.predict( + [ + str(data_dir / "images/yorkshire_terrier_9.jpg"), + str(data_dir / "images/english_cocker_spaniel_1.jpg"), + str(data_dir / "images/scottish_terrier_1.jpg"), + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py index ea53dfde78..731f0a8125 100644 --- a/flash_examples/keypoint_detection.py +++ b/flash_examples/keypoint_detection.py @@ -42,11 +42,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -predictions = model.predict([ - str(data_dir / "biwi_sample/images/0.jpg"), - str(data_dir / "biwi_sample/images/1.jpg"), - str(data_dir / "biwi_sample/images/10.jpg"), -]) +predictions = model.predict( + [ + str(data_dir / "biwi_sample/images/0.jpg"), + str(data_dir / "biwi_sample/images/1.jpg"), + str(data_dir / "biwi_sample/images/10.jpg"), + ] +) print(predictions) # 5. Save the model! From 8725028e0c8762256d920761be6cdef17ea7980c Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 6 Aug 2021 16:37:12 +0200 Subject: [PATCH 27/46] 0.5.0dev --- flash/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/__about__.py b/flash/__about__.py index e57715c058..eab8629bc9 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.1dev" +__version__ = "0.5.0dev" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" From 335073a7375c4e2478621a24e42c4d257beeca10 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 9 Aug 2021 19:52:56 +0200 Subject: [PATCH 28/46] pl --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0693689f06..2e0fcbab49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch torchmetrics -pytorch-lightning>=1.4.0rc0 +pytorch-lightning>=1.4.0 pyDeprecate PyYAML>=5.1 numpy From 19143db6d1a8e1ac3f191e02d9b929b61ce7105f Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 9 Aug 2021 22:18:00 +0200 Subject: [PATCH 29/46] imports --- flash/core/integrations/icevision/data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index fb9813658a..99022ecc19 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -20,9 +20,10 @@ from flash.image.data import ImagePathsDataSource if _ICEVISION_AVAILABLE: - from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks - from icevision.data import SingleSplitSplitter - from icevision.parsers import Parser + from icevision.core.record import BaseRecord + from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent, tasks + from icevision.data.data_splitter import SingleSplitSplitter + from icevision.parsers.parser import Parser class IceVisionPathsDataSource(ImagePathsDataSource): From b72375eff796b9430a0d46c0b4fde447401e36cb Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 9 Aug 2021 21:21:15 +0100 Subject: [PATCH 30/46] Update adapter.py --- flash/core/integrations/icevision/adapter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 5d3201116a..384d7ecdc9 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -18,6 +18,7 @@ from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.model import Task +from flash.core.utilities.url_error import catch_url_error from flash.core.utilities.imports import _ICEVISION_AVAILABLE if _ICEVISION_AVAILABLE: @@ -52,6 +53,7 @@ def __init__(self, model_type, model, icevision_adapter, backbone): self.backbone = backbone @classmethod + @catch_url_error def from_task( cls, task: Task, From 5a1cb64a1fd1c319b9faff901c1194367fb285bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Aug 2021 20:21:45 +0000 Subject: [PATCH 31/46] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/core/integrations/icevision/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 384d7ecdc9..5a65cfd43c 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -18,8 +18,8 @@ from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.model import Task -from flash.core.utilities.url_error import catch_url_error from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.core.utilities.url_error import catch_url_error if _ICEVISION_AVAILABLE: from icevision.core import BaseRecord From 55377f1a90f59ae49fd6345dc58376e570c7fdf2 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 9 Aug 2021 22:21:55 +0100 Subject: [PATCH 32/46] Update adapter.py --- flash/core/integrations/icevision/adapter.py | 27 +++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 5a65cfd43c..c6fd79c104 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import urllib.error from typing import Any, Callable, Dict, Optional from torch.utils.data import DataLoader, Sampler @@ -67,13 +68,25 @@ def from_task( ) -> Adapter: metadata = task.heads.get(head, with_metadata=True) backbones = metadata["metadata"]["backbones"] - backbone_config = backbones.get(backbone)(pretrained) - model_type, model, icevision_adapter, backbone = metadata["fn"]( - backbone_config, - num_classes, - image_size=image_size, - **kwargs, - ) + try: + backbone_config = backbones.get(backbone)(pretrained) + model_type, model, icevision_adapter, backbone = metadata["fn"]( + backbone_config, + num_classes, + image_size=image_size, + **kwargs, + ) + except urllib.error.URLError: + pretrained = False + if "efficientdet" in head: + kwargs["pretrained_backbone"] = False + backbone_config = backbones.get(backbone)(pretrained) + model_type, model, icevision_adapter, backbone = metadata["fn"]( + backbone_config, + num_classes, + image_size=image_size, + **kwargs, + ) icevision_adapter = icevision_adapter(model=model, metrics=metrics) return cls(model_type, model, icevision_adapter, backbone) From 68648ab4820eeec2b649c7b6669cda71113c3518 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 18:23:26 +0100 Subject: [PATCH 33/46] Updates --- flash/core/integrations/icevision/adapter.py | 27 +++++--------------- flash/core/utilities/url_error.py | 3 +++ 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index c6fd79c104..5a65cfd43c 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import urllib.error from typing import Any, Callable, Dict, Optional from torch.utils.data import DataLoader, Sampler @@ -68,25 +67,13 @@ def from_task( ) -> Adapter: metadata = task.heads.get(head, with_metadata=True) backbones = metadata["metadata"]["backbones"] - try: - backbone_config = backbones.get(backbone)(pretrained) - model_type, model, icevision_adapter, backbone = metadata["fn"]( - backbone_config, - num_classes, - image_size=image_size, - **kwargs, - ) - except urllib.error.URLError: - pretrained = False - if "efficientdet" in head: - kwargs["pretrained_backbone"] = False - backbone_config = backbones.get(backbone)(pretrained) - model_type, model, icevision_adapter, backbone = metadata["fn"]( - backbone_config, - num_classes, - image_size=image_size, - **kwargs, - ) + backbone_config = backbones.get(backbone)(pretrained) + model_type, model, icevision_adapter, backbone = metadata["fn"]( + backbone_config, + num_classes, + image_size=image_size, + **kwargs, + ) icevision_adapter = icevision_adapter(model=model, metrics=metrics) return cls(model_type, model, icevision_adapter, backbone) diff --git a/flash/core/utilities/url_error.py b/flash/core/utilities/url_error.py index 83559131c9..6f0d28676a 100644 --- a/flash/core/utilities/url_error.py +++ b/flash/core/utilities/url_error.py @@ -23,6 +23,9 @@ def wrapper(*args, pretrained=False, **kwargs): try: return fn(*args, pretrained=pretrained, **kwargs) except urllib.error.URLError: + # Hack for icevision/efficientdet to work without internet access + if "efficientdet" in kwargs.get("head", ""): + kwargs["pretrained_backbone"] = False result = fn(*args, pretrained=False, **kwargs) rank_zero_warn( "Failed to download pretrained weights for the selected backbone. The backbone has been created with" From 12a89ddc86472d3688b6845dfe28048671fe6796 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 11 Aug 2021 15:04:52 +0100 Subject: [PATCH 34/46] Add transforms to and from icevision records --- flash/core/integrations/icevision/adapter.py | 32 ++- flash/core/integrations/icevision/data.py | 6 +- .../core/integrations/icevision/transforms.py | 183 ++++++++++++++++-- flash/core/model.py | 50 ++--- 4 files changed, 216 insertions(+), 55 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 5a65cfd43c..a550d43619 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -11,19 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional +import functools +from typing import Any, Callable, Dict, List, Optional from torch.utils.data import DataLoader, Sampler from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.integrations.icevision.transforms import to_icevision_record from flash.core.model import Task from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.core.utilities.url_error import catch_url_error if _ICEVISION_AVAILABLE: - from icevision.core import BaseRecord - from icevision.data import Dataset from icevision.metrics import COCOMetric from icevision.metrics import Metric as IceVisionMetric else: @@ -77,6 +78,12 @@ def from_task( icevision_adapter = icevision_adapter(model=model, metrics=metrics) return cls(model_type, model, icevision_adapter, backbone) + @staticmethod + def _collate_fn(collate_fn, samples, metadata: List[Dict[str, Any]]): + return collate_fn( + [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] + ) + def process_train_dataset( self, dataset: BaseAutoDataset, @@ -88,7 +95,7 @@ def process_train_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None, ) -> DataLoader: - return self.model_type.train_dl( + result = self.model_type.train_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -97,6 +104,8 @@ def process_train_dataset( drop_last=drop_last, sampler=sampler, ) + result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) + return result def process_val_dataset( self, @@ -109,7 +118,7 @@ def process_val_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None, ) -> DataLoader: - return self.model_type.valid_dl( + result = self.model_type.valid_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -118,6 +127,8 @@ def process_val_dataset( drop_last=drop_last, sampler=sampler, ) + result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) + return result def process_test_dataset( self, @@ -130,7 +141,7 @@ def process_test_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None, ) -> DataLoader: - return self.model_type.valid_dl( + result = self.model_type.valid_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -139,6 +150,8 @@ def process_test_dataset( drop_last=drop_last, sampler=sampler, ) + result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) + return result def process_predict_dataset( self, @@ -151,7 +164,7 @@ def process_predict_dataset( drop_last: bool = True, sampler: Optional[Sampler] = None, ) -> DataLoader: - return self.model_type.infer_dl( + result = self.model_type.infer_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -160,6 +173,8 @@ def process_predict_dataset( drop_last=drop_last, sampler=sampler, ) + result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) + return result def training_step(self, batch, batch_idx) -> Any: return self.icevision_adapter.training_step(batch, batch_idx) @@ -174,9 +189,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return self(batch) def forward(self, batch: Any) -> Any: - if isinstance(batch, list) and isinstance(batch[0], BaseRecord): - data = Dataset(batch) - return self.model_type.predict(self.model, data) return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) def training_epoch_end(self, outputs) -> None: diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 99022ecc19..80ce622616 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -16,6 +16,7 @@ import numpy as np from flash.core.data.data_source import DefaultDataKeys +from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource @@ -31,7 +32,8 @@ def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None return super().predict_load_data(data, dataset) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - return sample[DefaultDataKeys.INPUT].load() + record = sample[DefaultDataKeys.INPUT].load() + return from_icevision_record(record) def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample = super().load_sample(sample) @@ -40,7 +42,7 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: record.set_img(image) record.add_component(ClassMapRecordComponent(task=tasks.detection)) - return record + return from_icevision_record(record) class IceVisionParserDataSource(IceVisionPathsDataSource): diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 22692be89a..cde7483c3f 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -11,42 +11,185 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import Any, Callable, Dict, Tuple -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE, requires_extras +from torch import nn -if _TORCHVISION_AVAILABLE: - pass +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires_extras if _ICEVISION_AVAILABLE: + from icevision.core import tasks + from icevision.core.bbox import BBox + from icevision.core.keypoints import KeyPoints + from icevision.core.mask import EncodedRLEs, MaskArray + from icevision.core.record import BaseRecord + from icevision.core.record_components import ( + BBoxesRecordComponent, + ClassMapRecordComponent, + FilepathRecordComponent, + ImageRecordComponent, + InstancesLabelsRecordComponent, + KeyPointsRecordComponent, + MasksRecordComponent, + RecordIDRecordComponent, + ) from icevision.tfms import A -def collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: - return {key: [sample[key] for sample in samples] for key in samples[0]} +def to_icevision_record(sample: Dict[str, Any]): + record = BaseRecord([]) + + record_id_component = RecordIDRecordComponent() + record_id_component.set_record_id(sample[DefaultDataKeys.METADATA]["image_id"]) + + component = ClassMapRecordComponent(tasks.detection) + component.set_class_map(sample[DefaultDataKeys.METADATA].get("class_map", None)) + record.add_component(component) + + if "labels" in sample[DefaultDataKeys.TARGET]: + labels_component = InstancesLabelsRecordComponent() + labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"]) + record.add_component(labels_component) + + if "bboxes" in sample[DefaultDataKeys.TARGET]: + bboxes = [ + BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"]) + for bbox in sample[DefaultDataKeys.TARGET]["bboxes"] + ] + component = BBoxesRecordComponent() + component.set_bboxes(bboxes) + record.add_component(component) + + if "masks" in sample[DefaultDataKeys.TARGET]: + mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"]) + component = MasksRecordComponent() + component.set_masks(mask_array) + record.add_component(component) + + if "keypoints" in sample[DefaultDataKeys.TARGET]: + keypoints = [] + + for keypoints_list, keypoints_metadata in zip( + sample[DefaultDataKeys.TARGET]["keypoints"], sample[DefaultDataKeys.TARGET]["keypoints_metadata"] + ): + xyv = [] + for keypoint in keypoints_list: + xyv.extend((keypoint["x"], keypoint["y"], keypoint["visible"])) + + keypoints.append(KeyPoints.from_xyv(xyv, keypoints_metadata)) + component = KeyPointsRecordComponent() + component.set_keypoints(keypoints) + record.add_component(component) + + if isinstance(sample[DefaultDataKeys.INPUT], str): + input_component = FilepathRecordComponent() + input_component.set_filepath(sample[DefaultDataKeys.INPUT]) + else: + if "filepath" in sample[DefaultDataKeys.METADATA]: + input_component = FilepathRecordComponent() + input_component.filepath = sample[DefaultDataKeys.METADATA]["filepath"] + else: + input_component = ImageRecordComponent() + input_component.composite = record + input_component.set_img(sample[DefaultDataKeys.INPUT]) + record.add_component(input_component) + + return record + + +def from_icevision_record(record: "BaseRecord"): + sample = { + DefaultDataKeys.METADATA: { + "image_id": record.record_id, + } + } + + if record.img is not None: + sample[DefaultDataKeys.INPUT] = record.img + filepath = getattr(record, "filepath", None) + if filepath is not None: + sample[DefaultDataKeys.METADATA]["filepath"] = filepath + elif record.filepath is not None: + sample[DefaultDataKeys.INPUT] = record.filepath + + sample[DefaultDataKeys.TARGET] = {} + + if hasattr(record.detection, "bboxes"): + sample[DefaultDataKeys.TARGET]["bboxes"] = [] + for bbox in record.detection.bboxes: + bbox_list = list(bbox.xywh) + bbox_dict = { + "xmin": bbox_list[0], + "ymin": bbox_list[1], + "width": bbox_list[2], + "height": bbox_list[3], + } + sample[DefaultDataKeys.TARGET]["bboxes"].append(bbox_dict) + + if hasattr(record.detection, "masks"): + masks = record.detection.masks + + if isinstance(masks, EncodedRLEs): + masks = masks.to_mask(record.height, record.width) + + if isinstance(masks, MaskArray): + sample[DefaultDataKeys.TARGET]["masks"] = masks.data + else: + raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.") + + if hasattr(record.detection, "keypoints"): + keypoints = record.detection.keypoints + + sample[DefaultDataKeys.TARGET]["keypoints"] = [] + sample[DefaultDataKeys.TARGET]["keypoints_metadata"] = [] + + for keypoint in keypoints: + keypoints_list = [] + for x, y, v in keypoint.xyv: + keypoints_list.append( + { + "x": x, + "y": y, + "visible": v, + } + ) + sample[DefaultDataKeys.TARGET]["keypoints"].append(keypoints_list) + + # TODO: Unpack keypoints_metadata + sample[DefaultDataKeys.TARGET]["keypoints_metadata"].append(keypoint.metadata) + + if getattr(record.detection, "label_ids", None) is not None: + sample[DefaultDataKeys.TARGET]["labels"] = list(record.detection.label_ids) + + if getattr(record.detection, "class_map", None) is not None: + sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map + + return sample + + +class IceVisionTransformAdapter(nn.Module): + def __init__(self, transform): + super().__init__() + self.transform = transform + + def forward(self, x): + record = to_icevision_record(x) + record = self.transform(record) + return from_icevision_record(record) @requires_extras("image") def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: - """The default transforms for object detection: convert the image and targets to a tensor, collate the - batch.""" + """The default transforms from IceVision.""" return { - # "pre_tensor_transform": ApplyToKeys( - # DefaultDataKeys.INPUT, - # tfms.A.Adapter([*tfms.A.resize_and_pad(image_size), tfms.A.Normalize()]), - # ) - "pre_tensor_transform": A.Adapter([*A.resize_and_pad(image_size), A.Normalize()]), + "pre_tensor_transform": IceVisionTransformAdapter(A.Adapter([*A.resize_and_pad(image_size), A.Normalize()])), } @requires_extras("image") def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: - """The default transforms for object detection: convert the image and targets to a tensor, collate the - batch.""" + """The default augmentations from IceVision.""" return { - # "pre_tensor_transform": ApplyToKeys( - # DefaultDataKeys.INPUT, - # tfms.A.Adapter([*tfms.A.resize_and_pad(image_size), tfms.A.Normalize()]), - # ) - "pre_tensor_transform": A.Adapter([*A.aug_tfms(size=image_size), A.Normalize()]), + "pre_tensor_transform": IceVisionTransformAdapter(A.Adapter([*A.aug_tfms(size=image_size), A.Normalize()])), } diff --git a/flash/core/model.py b/flash/core/model.py index 69702051b4..686a104058 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -58,16 +58,42 @@ def __init__(self): self._children = [] + # TODO: create enum values to define what are the exact states + self._data_pipeline_state: Optional[DataPipelineState] = None + + # model own internal state shared with the data pipeline. + self._state: Dict[Type[ProcessState], ProcessState] = {} + def __setattr__(self, key, value): if isinstance(value, (LightningModule, Wrapper)): self._children.append(key) - patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] + patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results", "_data_pipeline_state"] if isinstance(value, Trainer) or key in patched_attributes: if hasattr(self, "_children"): for child in self._children: setattr(getattr(self, child), key, value) super().__setattr__(key, value) + def get_state(self, state_type): + if state_type in self._state: + return self._state[state_type] + if self._data_pipeline_state is not None: + return self._data_pipeline_state.get_state(state_type) + return None + + def set_state(self, state: ProcessState): + self._state[type(state)] = state + if self._data_pipeline_state is not None: + self._data_pipeline_state.set_state(state) + + def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): + for state in self._state.values(): + data_pipeline_state.set_state(state) + for child in self._children: + child = getattr(self, child) + if hasattr(child, "attach_data_pipeline_state"): + child.attach_data_pipeline_state(data_pipeline_state) + class DatasetProcessor: def _process_dataset( @@ -280,12 +306,6 @@ def __init__( self._postprocess: Optional[Postprocess] = postprocess self._serializer: Optional[Serializer] = None - # TODO: create enum values to define what are the exact states - self._data_pipeline_state: Optional[DataPipelineState] = None - - # model own internal state shared with the data pipeline. - self._state: Dict[Type[ProcessState], ProcessState] = {} - # Explicitly set the serializer to call the setter self.deserializer = deserializer self.serializer = serializer @@ -836,19 +856,3 @@ def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition - - def get_state(self, state_type): - if state_type in self._state: - return self._state[state_type] - if self._data_pipeline_state is not None: - return self._data_pipeline_state.get_state(state_type) - return None - - def set_state(self, state: ProcessState): - self._state[type(state)] = state - if self._data_pipeline_state is not None: - self._data_pipeline_state.set_state(state) - - def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): - for state in self._state.values(): - data_pipeline_state.set_state(state) From cee3edf73557f1993f47b7fa796d2f504120a928 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 11 Aug 2021 15:50:38 +0100 Subject: [PATCH 35/46] Fix tests --- flash/core/integrations/icevision/adapter.py | 3 +- .../core/integrations/icevision/transforms.py | 13 ++-- flash/image/detection/data.py | 1 + tests/image/detection/test_data.py | 60 ++++--------------- tests/image/detection/test_model.py | 24 ++++---- 5 files changed, 37 insertions(+), 64 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index a550d43619..f6bfaa3663 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -79,7 +79,8 @@ def from_task( return cls(model_type, model, icevision_adapter, backbone) @staticmethod - def _collate_fn(collate_fn, samples, metadata: List[Dict[str, Any]]): + def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): + metadata = metadata or [None] * len(samples) return collate_fn( [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] ) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index cde7483c3f..c5a5968160 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -40,11 +40,14 @@ def to_icevision_record(sample: Dict[str, Any]): record = BaseRecord([]) - record_id_component = RecordIDRecordComponent() - record_id_component.set_record_id(sample[DefaultDataKeys.METADATA]["image_id"]) + metadata = sample.get(DefaultDataKeys.METADATA, None) or {} + + if "image_id" in metadata: + record_id_component = RecordIDRecordComponent() + record_id_component.set_record_id(metadata["image_id"]) component = ClassMapRecordComponent(tasks.detection) - component.set_class_map(sample[DefaultDataKeys.METADATA].get("class_map", None)) + component.set_class_map(metadata.get("class_map", None)) record.add_component(component) if "labels" in sample[DefaultDataKeys.TARGET]: @@ -86,9 +89,9 @@ def to_icevision_record(sample: Dict[str, Any]): input_component = FilepathRecordComponent() input_component.set_filepath(sample[DefaultDataKeys.INPUT]) else: - if "filepath" in sample[DefaultDataKeys.METADATA]: + if "filepath" in metadata: input_component = FilepathRecordComponent() - input_component.filepath = sample[DefaultDataKeys.METADATA]["filepath"] + input_component.filepath = metadata["filepath"] else: input_component = ImageRecordComponent() input_component.composite = record diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index a7f788c886..d75ff23430 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -122,6 +122,7 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se data.compute_metadata() classes = self._get_classes(data) class_map = ClassMap(classes) + dataset.num_classes = len(class_map) parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) records = parser.parse(data_splitter=SingleSplitSplitter()) diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 34a9bbb832..50ce9fb196 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -17,6 +17,7 @@ import pytest +from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image.detection.data import ObjectDetectionData @@ -152,14 +153,8 @@ def test_image_detector_data_from_coco(tmpdir): ) data = next(iter(datamodule.train_dataloader())) - - record = data[0] - - assert record.detection.img.shape == (128, 128, 3) - assert record.detection.iscrowds[0] in (0, 1) - - assert record.img_size.height == 128 - assert record.img_size.width == 128 + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -178,23 +173,12 @@ def test_image_detector_data_from_coco(tmpdir): data = next(iter(datamodule.val_dataloader())) - record = data[0] - - assert record.detection.img.shape == (128, 128, 3) - assert record.detection.iscrowds[0] in (0, 1) - - assert record.img_size.height == 128 - assert record.img_size.width == 128 + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) - - record = data[0] - - assert record.detection.img.shape == (128, 128, 3) - assert record.detection.iscrowds[0] in (0, 1) - - assert record.img_size.height == 128 - assert record.img_size.width == 128 + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -206,14 +190,8 @@ def test_image_detector_data_from_fiftyone(tmpdir): datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1, image_size=128) data = next(iter(datamodule.train_dataloader())) - - record = data[0] - - assert record.detection.img.shape == (128, 128, 3) - assert record.detection.iscrowds[0] in (0, 1) - - assert record.img_size.height == 128 - assert record.img_size.width == 128 + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -228,21 +206,9 @@ def test_image_detector_data_from_fiftyone(tmpdir): ) data = next(iter(datamodule.val_dataloader())) - - record = data[0] - - assert record.detection.img.shape == (128, 128, 3) - assert record.detection.iscrowds[0] in (0, 1) - - assert record.img_size.height == 128 - assert record.img_size.width == 128 + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) - - record = data[0] - - assert record.detection.img.shape == (128, 128, 3) - assert record.detection.iscrowds[0] in (0, 1) - - assert record.img_size.height == 128 - assert record.img_size.width == 128 + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 77a97f28ae..8b946fc9df 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -22,14 +22,13 @@ from torch.utils.data import Dataset from flash.__main__ import main +from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector from tests.helpers.utils import _IMAGE_TESTING if _ICEVISION_AVAILABLE: - from icevision.core import BBox, ClassMap, ObjectDetectionRecord from icevision.data import Prediction - from icevision.utils import ImgSize def collate_fn(samples): @@ -51,22 +50,25 @@ def _random_bbox(self): c, h, w = self.img_shape xs = torch.randint(w - 1, (2,)) ys = torch.randint(h - 1, (2,)) - return [min(xs), min(ys), max(xs) + 1, max(ys) + 1] + return {"xmin": min(xs), "ymin": min(ys), "width": max(xs) - min(xs) + 1, "height": max(ys) - min(ys) + 1} def __getitem__(self, idx): - record = ObjectDetectionRecord() + sample = {} img = np.random.rand(*self.img_shape).astype(np.float32) - record.set_img(img) - record.set_img_size(ImgSize(width=self.img_shape[0], height=self.img_shape[1])) - record.detection.set_class_map(ClassMap([f"test_{i}" for i in range(self.num_classes)], background=None)) + sample[DefaultDataKeys.INPUT] = img + + sample[DefaultDataKeys.TARGET] = { + "bboxes": [], + "labels": [], + } for i in range(self.num_boxes): - record.detection.add_bboxes([BBox.from_xyxy(*self._random_bbox())]) - record.detection.add_labels([f"test_{random.randint(0, self.num_classes - 1)}"]) + sample[DefaultDataKeys.TARGET]["bboxes"].append(self._random_bbox()) + sample[DefaultDataKeys.TARGET]["labels"].append(random.randint(0, self.num_classes - 1)) - return record + return sample @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -90,7 +92,7 @@ def test_init(): def test_training(tmpdir, head): model = ObjectDetector(num_classes=2, head=head, pretrained=False) ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) - dl = model.process_train_dataset(ds, 2, 0, False) + dl = model.process_train_dataset(ds, 2, 0, False, None) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl) From 0b02c558fb8228750b632369db140ffd3c73a010 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 11 Aug 2021 15:56:38 +0100 Subject: [PATCH 36/46] Try fix --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 21ac8fbd45..46a1f79e96 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -137,7 +137,7 @@ jobs: run: | sudo apt-get install libsndfile1 pip install matplotlib - pip install '.[image]' --pre --upgrade + pip install '.[image]' --pre --upgrade torchaudio - name: Cache datasets uses: actions/cache@v2 From 1824e5ec11d973931c12cb1dbc04d5e83e4dd813 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 11 Aug 2021 16:04:59 +0100 Subject: [PATCH 37/46] Update CHANGELOG.md --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4461ceff74..84093d54bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,12 +34,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Flash Zero, a zero code command line ML platform built with flash ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611)) +- Added integration with IceVision for the `ObjectDetector` ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added keypoint detection task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) - Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) +- Changed arguments to `ObjectDetector`, use `head` instead of `model` and append `_fpn` to the backbone name instead of the `fpn` argument ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + ### Fixed - Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493)) From 6fb7ee39e383581ded2e31f624b7b8ce1575a922 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 11 Aug 2021 16:06:35 +0100 Subject: [PATCH 38/46] Fix tests --- tests/image/test_backbones.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index cc9f80c629..706a2dc68e 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -14,21 +14,18 @@ import urllib.error import pytest -from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE -from flash.core.utilities.imports import _TIMM_AVAILABLE from flash.core.utilities.url_error import catch_url_error from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES +from tests.helpers.utils import _IMAGE_TESTING @pytest.mark.parametrize( ["backbone", "expected_num_features"], [ - pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")), - pytest.param( - "mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") - ), + pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), + pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No timm")), + pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), ], ) def test_image_classifier_backbones_registry(backbone, expected_num_features): @@ -45,11 +42,9 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): "resnet50", "supervised", 2048, - marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision"), - ), - pytest.param( - "resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision"), ), + pytest.param("resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), ], ) def test_pretrained_weights_registry(backbone, pretrained, expected_num_features): From 221b01c15cf8dc67b0728b251a9d8e8d299e025d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 11 Aug 2021 16:07:37 +0100 Subject: [PATCH 39/46] Fix a test --- tests/core/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 124534acab..2e85bf3da4 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -284,7 +284,7 @@ def test_available_backbones(): class Foo(ImageClassifier): backbones = None - assert Foo.available_backbones() == [] + assert Foo.available_backbones() == {} def test_optimization(tmpdir): From 1ca9b6b280da971107a22d8fcc5381d9ae01a03c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 12 Aug 2021 11:02:26 +0100 Subject: [PATCH 40/46] Try fix --- .github/workflows/ci-testing.yml | 2 +- flash/core/utilities/imports.py | 3 +-- requirements/datatype_audio.txt | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 46a1f79e96..21ac8fbd45 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -137,7 +137,7 @@ jobs: run: | sudo apt-get install libsndfile1 pip install matplotlib - pip install '.[image]' --pre --upgrade torchaudio + pip install '.[image]' --pre --upgrade - name: Cache datasets uses: actions/cache@v2 diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index e15d266c3b..1a4837c68b 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -86,7 +86,6 @@ def _compare_version(package: str, op, version) -> bool: _UVICORN_AVAILABLE = _module_available("uvicorn") _PIL_AVAILABLE = _module_available("PIL") _OPEN3D_AVAILABLE = _module_available("open3d") -_ASTEROID_AVAILABLE = _module_available("asteroid") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") _SOUNDFILE_AVAILABLE = _module_available("soundfile") _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") @@ -124,7 +123,7 @@ def _compare_version(package: str, op, version) -> bool: ) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE -_AUDIO_AVAILABLE = all([_ASTEROID_AVAILABLE, _TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) +_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE _EXTRAS_AVAILABLE = { diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 570e7c89b8..4c198da250 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,4 +1,3 @@ -asteroid>=0.5.1 torchaudio soundfile>=0.10.2 transformers>=4.5 From d97dbdface2a05f8d271f8877f8624538583260d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 12 Aug 2021 11:07:38 +0100 Subject: [PATCH 41/46] Try fix --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 21ac8fbd45..254234c8fd 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -137,7 +137,7 @@ jobs: run: | sudo apt-get install libsndfile1 pip install matplotlib - pip install '.[image]' --pre --upgrade + pip install '.[audio,image]' --pre --upgrade - name: Cache datasets uses: actions/cache@v2 From 3b387f7996e7612985ecb679285d24621e28ec10 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 12 Aug 2021 13:19:37 +0100 Subject: [PATCH 42/46] Add some docs --- flash/core/adapter.py | 20 +++++++++++++++++--- flash/core/model.py | 18 +++++++++++++++--- flash/core/serve/core.py | 2 +- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index d55051a148..c7557b1977 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -19,14 +19,20 @@ import flash from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.model import DatasetProcessor, Task, Wrapper +from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task -class Adapter(DatasetProcessor, Wrapper, nn.Module): +class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module): + """The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular + provider within a :class:`~flash.core.model.Task`.""" + @classmethod @abstractmethod def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter": - pass + """Instantiate the adapter from the given :class:`~flash.core.model.Task`. + + This includes resolution / creation of backbones / heads and any other provider specific options. + """ def forward(self, x: Any) -> Any: pass @@ -54,6 +60,14 @@ def test_epoch_end(self, outputs) -> None: class AdapterTask(Task): + """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` + and forwards all of the hooks. + + Args: + adapter: The :class:`~flash.core.adapter.Adapter` to wrap. + kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`. + """ + def __init__(self, adapter: Adapter, **kwargs): super().__init__(**kwargs) diff --git a/flash/core/model.py b/flash/core/model.py index 686a104058..3edeef238d 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -52,7 +52,16 @@ from flash.core.utilities.imports import requires_extras -class Wrapper: +class ModuleWrapperBase: + """The ``ModuleWrapperBase`` is a base for classes which wrap a ``LightningModule`` or an instance of + ``ModuleWrapperBase``. + + This class ensures that trainer attributes are forwarded to any wrapped or nested + ``LightningModule`` instances so that nested calls to ``.log`` are handled correctly. The ``ModuleWrapperBase`` is + also stateful, meaning that a :class:`~flash.core.data.data_pipeline.DataPipelineState` can be attached. Attached + state will be forwarded to any nested ``ModuleWrapperBase`` instances. + """ + def __init__(self): super().__init__() @@ -65,7 +74,7 @@ def __init__(self): self._state: Dict[Type[ProcessState], ProcessState] = {} def __setattr__(self, key, value): - if isinstance(value, (LightningModule, Wrapper)): + if isinstance(value, (LightningModule, ModuleWrapperBase)): self._children.append(key) patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results", "_data_pipeline_state"] if isinstance(value, Trainer) or key in patched_attributes: @@ -96,6 +105,9 @@ def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): class DatasetProcessor: + """The ``DatasetProcessor`` mixin provides hooks for classes which need custom logic for producing the data + loaders for each running stage given the corresponding dataset.""" + def _process_dataset( self, dataset: BaseAutoDataset, @@ -254,7 +266,7 @@ def __new__(mcs, *args, **kwargs): return result -class Task(DatasetProcessor, Wrapper, LightningModule, metaclass=CheckDependenciesMeta): +class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=CheckDependenciesMeta): """A general Task. Args: diff --git a/flash/core/serve/core.py b/flash/core/serve/core.py index e05717212a..563c0d580e 100644 --- a/flash/core/serve/core.py +++ b/flash/core/serve/core.py @@ -83,7 +83,7 @@ def __call__(self, *args, **kwargs): class Servable: - """Wrapper around a model object to enable serving at scale. + """ModuleWrapperBase around a model object to enable serving at scale. Create a ``Servable`` from either (LM, LOCATION) or (LOCATION,) From 16ed49c465b05efac1b192c0bbe815388c25daa3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 12 Aug 2021 13:28:23 +0100 Subject: [PATCH 43/46] Add API reference --- docs/source/api/core.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst index 5b8674c37a..1b80d0e2c1 100644 --- a/docs/source/api/core.rst +++ b/docs/source/api/core.rst @@ -7,6 +7,17 @@ flash.core :local: :backlinks: top +flash.core.adapter +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.adapter.Adapter + ~flash.core.adapter.AdapterTask + flash.core.classification _________________________ @@ -56,6 +67,8 @@ ________________ ~flash.core.model.BenchmarkConvergenceCI ~flash.core.model.CheckDependenciesMeta + ~flash.core.model.ModuleWrapperBase + ~flash.core.model.DatasetProcessor ~flash.core.model.Task flash.core.registry From 40b7c9bce35c6bda1bc55aa533e075dd72f67c91 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 12 Aug 2021 13:41:17 +0100 Subject: [PATCH 44/46] Small updates --- flash/core/integrations/icevision/adapter.py | 24 ++++++++++---------- tests/image/detection/test_model.py | 1 + 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index f6bfaa3663..af95da9a52 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -96,7 +96,7 @@ def process_train_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None, ) -> DataLoader: - result = self.model_type.train_dl( + data_loader = self.model_type.train_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -105,8 +105,8 @@ def process_train_dataset( drop_last=drop_last, sampler=sampler, ) - result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) - return result + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader def process_val_dataset( self, @@ -119,7 +119,7 @@ def process_val_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None, ) -> DataLoader: - result = self.model_type.valid_dl( + data_loader = self.model_type.valid_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -128,8 +128,8 @@ def process_val_dataset( drop_last=drop_last, sampler=sampler, ) - result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) - return result + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader def process_test_dataset( self, @@ -142,7 +142,7 @@ def process_test_dataset( drop_last: bool = False, sampler: Optional[Sampler] = None, ) -> DataLoader: - result = self.model_type.valid_dl( + data_loader = self.model_type.valid_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -151,8 +151,8 @@ def process_test_dataset( drop_last=drop_last, sampler=sampler, ) - result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) - return result + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader def process_predict_dataset( self, @@ -165,7 +165,7 @@ def process_predict_dataset( drop_last: bool = True, sampler: Optional[Sampler] = None, ) -> DataLoader: - result = self.model_type.infer_dl( + data_loader = self.model_type.infer_dl( dataset, batch_size=batch_size, num_workers=num_workers, @@ -174,8 +174,8 @@ def process_predict_dataset( drop_last=drop_last, sampler=sampler, ) - result.collate_fn = functools.partial(self._collate_fn, result.collate_fn) - return result + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader def training_step(self, batch, batch_idx) -> Any: return self.icevision_adapter.training_step(batch, batch_idx) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 8b946fc9df..c0220526d6 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -97,6 +97,7 @@ def test_training(tmpdir, head): trainer.fit(model, dl) +# TODO: resolve JIT issues # @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") # def test_jit(tmpdir): # path = os.path.join(tmpdir, "test.pt") From ac7743b4b15593b543a00b403030fa9c4cb24f25 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 16 Aug 2021 12:56:12 -0400 Subject: [PATCH 45/46] pep fix --- flash_examples/object_detection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 14b10806e5..1a5dddbce9 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - import flash from flash.core.data.utils import download_data from flash.image import ObjectDetectionData, ObjectDetector From 6c74155d21aedb8a9f550c19426d7dd207e34e0c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 16 Aug 2021 17:58:39 +0100 Subject: [PATCH 46/46] Fixes --- flash_examples/object_detection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 14b10806e5..1a5dddbce9 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - import flash from flash.core.data.utils import download_data from flash.image import ObjectDetectionData, ObjectDetector