diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 8e402dff56..774ef162c6 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -47,7 +47,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo Args: attr_names: Name(s) of the module attributes of the model to be frozen. - train_bn: Wether to train Batch Norm layer + train_bn: Whether to train Batch Norm layer """ diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 6d7c326281..d4d1430989 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -1,3 +1,3 @@ from flash.vision.classification import ImageClassificationData, ImageClassifier -from flash.vision.detection import ImageDetector +from flash.vision.detection import ImageDetectionData, ImageDetector from flash.vision.embedding import ImageEmbedder diff --git a/flash/vision/detection/__init__.py b/flash/vision/detection/__init__.py index e605d1e3db..4cfbc195d5 100644 --- a/flash/vision/detection/__init__.py +++ b/flash/vision/detection/__init__.py @@ -1 +1,2 @@ +from flash.vision.detection.data import ImageDetectionData from flash.vision.detection.model import ImageDetector diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py new file mode 100644 index 0000000000..d512d5a37f --- /dev/null +++ b/flash/vision/detection/data.py @@ -0,0 +1,201 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Callable, List, Optional, Tuple + +import torch +from PIL import Image +from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor +from torch._six import container_abcs +from torch.utils.data._utils.collate import default_collate +from torchvision import transforms as T + +from flash.core.data import TaskDataPipeline +from flash.core.data.datamodule import DataModule +from flash.core.data.utils import _contains_any_tensor +from flash.vision.classification.data import _pil_loader + +_COCO_AVAILABLE = _module_available("pycocotools") +if _COCO_AVAILABLE: + from pycocotools.coco import COCO + + +class CustomCOCODataset(torch.utils.data.Dataset): + + def __init__( + self, + root: str, + ann_file: str, + transforms: Optional[Callable] = None, + ): + if not _COCO_AVAILABLE: + raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset") + + self.root = root + self.transforms = transforms + self.coco = COCO(ann_file) + self.ids = list(sorted(self.coco.imgs.keys())) + + @property + def num_classes(self): + categories = self.coco.loadCats(self.coco.getCatIds()) + if not categories: + raise ValueError("No Categories found") + return categories[-1]["id"] + 1 + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + coco = self.coco + img_idx = self.ids[index] + + ann_ids = coco.getAnnIds(imgIds=img_idx) + annotations = coco.loadAnns(ann_ids) + + image_path = coco.loadImgs(img_idx)[0]["file_name"] + img = Image.open(os.path.join(self.root, image_path)) + + boxes = [] + labels = [] + areas = [] + iscrowd = [] + + for obj in annotations: + xmin = obj["bbox"][0] + ymin = obj["bbox"][1] + xmax = xmin + obj["bbox"][2] + ymax = ymin + obj["bbox"][3] + + bbox = [xmin, ymin, xmax, ymax] + keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) + if keep: + boxes.append(bbox) + labels.append(obj["category_id"]) + areas.append(obj["area"]) + iscrowd.append(obj["iscrowd"]) + + target = {} + target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32) + target["labels"] = torch.as_tensor(labels, dtype=torch.int64) + target["image_id"] = torch.tensor([img_idx]) + target["area"] = torch.as_tensor(areas, dtype=torch.float32) + target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64) + + if self.transforms is not None: + img = self.transforms(img) + + return img, target + + def __len__(self): + return len(self.ids) + + +def _coco_remove_images_without_annotations(dataset): + # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py + + def _has_only_empty_bbox(anno: List): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + def _has_valid_annotation(anno: List): + # if it's empty, there is no annotation + if not anno: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + return True + + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +_default_transform = T.ToTensor() + + +class ImageDetectorDataPipeline(TaskDataPipeline): + + def __init__(self, valid_transform: Optional[Callable] = _default_transform, loader: Callable = _pil_loader): + self._valid_transform = valid_transform + self._loader = loader + + def before_collate(self, samples: Any) -> Any: + if _contains_any_tensor(samples): + return samples + + if isinstance(samples, str): + samples = [samples] + + if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): + outputs = [] + for sample in samples: + output = self._loader(sample) + outputs.append(self._valid_transform(output)) + return outputs + raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") + + def collate(self, samples: Any) -> Any: + if not isinstance(samples, Tensor): + elem = samples[0] + if isinstance(elem, container_abcs.Sequence): + return tuple(zip(*samples)) + return default_collate(samples) + return samples.unsqueeze(dim=0) + + +class ImageDetectionData(DataModule): + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + train_transform: Optional[Callable] = _default_transform, + valid_folder: Optional[str] = None, + valid_ann_file: Optional[str] = None, + valid_transform: Optional[Callable] = _default_transform, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + test_transform: Optional[Callable] = _default_transform, + batch_size: int = 4, + num_workers: Optional[int] = None, + **kwargs + ): + train_ds = CustomCOCODataset(train_folder, train_ann_file, train_transform) + num_classes = train_ds.num_classes + train_ds = _coco_remove_images_without_annotations(train_ds) + + valid_ds = ( + CustomCOCODataset(valid_folder, valid_ann_file, valid_transform) if valid_folder is not None else None + ) + + test_ds = (CustomCOCODataset(test_folder, test_ann_file, test_transform) if test_folder is not None else None) + + datamodule = cls( + train_ds=train_ds, + valid_ds=valid_ds, + test_ds=test_ds, + batch_size=batch_size, + num_workers=num_workers, + ) + + datamodule.num_classes = num_classes + datamodule.data_pipeline = ImageDetectorDataPipeline() + return datamodule diff --git a/flash/vision/detection/finetuning.py b/flash/vision/detection/finetuning.py new file mode 100644 index 0000000000..2ac49ce821 --- /dev/null +++ b/flash/vision/detection/finetuning.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytorch_lightning as pl + +from flash.core.finetuning import FlashBaseFinetuning + + +class ImageDetectorFineTuning(FlashBaseFinetuning): + """ + Freezes the backbone during Detector training. + """ + + def __init__(self, train_bn: bool = True): + self.train_bn = train_bn + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + model = pl_module.model + self.freeze(module=model.backbone, train_bn=self.train_bn) diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index 9cd7fac5ce..318ac8016e 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -11,21 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Sequence, Type, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union import torch import torchvision from torch import nn from torch.optim import Optimizer +from torchvision.ops import box_iou -from flash.core.classification import ClassificationTask +from flash.core import Task +from flash.core.data import DataPipeline +from flash.vision.detection.data import ImageDetectorDataPipeline +from flash.vision.detection.finetuning import ImageDetectorFineTuning _models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn} -class ImageDetector(ClassificationTask): +def _evaluate_iou(target, pred): + """ + Evaluate intersection over union (IOU) for target from dataset and output prediction + from model + """ + if pred["boxes"].shape[0] == 0: + # no box detected, 0 IOU + return torch.tensor(0.0, device=pred["boxes"].device) + return box_iou(target["boxes"], pred["boxes"]).diag().mean() + + +class ImageDetector(Task): """Image detection task + Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts Args: num_classes: the number of classes for detection, including background model: either a string of :attr`_models` or a custom nn.Module. @@ -52,6 +68,9 @@ def __init__( learning_rate=1e-3, **kwargs, ): + + self.save_hyperparameters() + if model in _models: model = _models[model](pretrained=pretrained) if isinstance(model, torchvision.models.detection.FasterRCNN): @@ -59,10 +78,6 @@ def __init__( head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) model.roi_heads.box_predictor = head - if loss is None: - # TODO: maybe better way of handling no loss, - loss = {} - super().__init__( model=model, loss_fn=loss, @@ -81,7 +96,36 @@ def training_step(self, batch, batch_idx) -> Any: # fasterrcnn takes both images and targets for training, returns loss_dict loss_dict = self.model(images, targets) loss = sum(loss_dict.values()) - for k, v in loss_dict.items(): - self.log("train_k", v) - + self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, prog_bar=True) return loss + + def validation_step(self, batch, batch_idx): + images, targets = batch + # fasterrcnn takes only images for eval() mode + outs = self.model(images) + iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() + return {"val_iou": iou} + + def validation_epoch_end(self, outs): + avg_iou = torch.stack([o["val_iou"] for o in outs]).mean() + logs = {"val_iou": avg_iou} + return {"avg_val_iou": avg_iou, "log": logs} + + def test_step(self, batch, batch_idx): + images, targets = batch + # fasterrcnn takes only images for eval() mode + outs = self.model(images) + iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() + return {"test_iou": iou} + + def test_epoch_end(self, outs): + avg_iou = torch.stack([o["test_iou"] for o in outs]).mean() + logs = {"test_iou": avg_iou} + return {"avg_test_iou": avg_iou, "log": logs} + + @staticmethod + def default_pipeline() -> ImageDetectorDataPipeline: + return ImageDetectorDataPipeline() + + def configure_finetune_callback(self): + return [ImageDetectorFineTuning(train_bn=True)] diff --git a/flash_examples/finetuning/image_detection.py b/flash_examples/finetuning/image_detection.py new file mode 100644 index 0000000000..dd5153dda6 --- /dev/null +++ b/flash_examples/finetuning/image_detection.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.data import download_data +from flash.vision import ImageDetectionData, ImageDetector + +# 1. Download the data +# Dataset Credit: https://www.kaggle.com/ultralytics/coco128 +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + +# 2. Load the Data +datamodule = ImageDetectionData.from_coco( + train_folder="data/coco128/images/train2017/", + train_ann_file="data/coco128/annotations/instances_train2017.json", + batch_size=2 +) + +# 3. Build the model +model = ImageDetector(num_classes=datamodule.num_classes) + +# 4. Create the trainer. Run twice on data +trainer = flash.Trainer(max_epochs=2) + +# 5. Finetune the model +trainer.finetune(model, datamodule) + +# 6. Save it! +trainer.save_checkpoint("image_detection_model.pt") diff --git a/requirements.txt b/requirements.txt index 1a2fc3b8e0..58376cba59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ tqdm # comes with 3rd-party dependency rouge-score>=0.0.4 sentencepiece>=0.1.95 pytorch-lightning-bolts==0.3.0 +filelock # comes with 3rd-party dependency +pycocotools>=2.0.2 ; python_version >= "3.7" diff --git a/tests/vision/detection/test_data.py b/tests/vision/detection/test_data.py new file mode 100644 index 0000000000..9e7bd44f66 --- /dev/null +++ b/tests/vision/detection/test_data.py @@ -0,0 +1,126 @@ +import json +import os +from pathlib import Path + +import pytest +import torch +from PIL import Image +from pytorch_lightning.utilities import _module_available +from torchvision import transforms as T + +from flash.vision.detection.data import ImageDetectionData + +_COCO_AVAILABLE = _module_available("pycocotools") +if _COCO_AVAILABLE: + from pycocotools.coco import COCO + + +def _create_dummy_coco_json(dummy_json_path): + + dummy_json = { + "images": [{ + "id": 0, + 'width': 1920, + 'height': 1080, + 'file_name': 'sample_one.png', + }, { + "id": 1, + "width": 1920, + "height": 1080, + "file_name": "sample_two.png", + }], + "annotations": [{ + "id": 1, + "image_id": 0, + "category_id": 0, + "area": 150, + "bbox": [30, 40, 20, 20], + "iscrowd": 0, + }, { + "id": 2, + "image_id": 1, + "category_id": 0, + "area": 240, + "bbox": [50, 100, 280, 15], + "iscrowd": 0, + }, { + "id": 3, + "image_id": 1, + "category_id": 0, + "area": 170, + "bbox": [230, 130, 90, 180], + "iscrowd": 0, + }], + "categories": [{ + "id": 0, + "name": "person", + "supercategory": "person", + }] + } + + with open(dummy_json_path, "w") as fp: + json.dump(dummy_json, fp) + + +def _create_synth_coco_dataset(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "images").mkdir() + Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_one.png") + Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_two.png") + + (train_dir / "annotations").mkdir() + dummy_json = train_dir / "annotations" / "sample.json" + + train_folder = os.fspath(Path(train_dir / "images")) + coco_ann_path = os.fspath(dummy_json) + _create_dummy_coco_json(coco_ann_path) + + return train_folder, coco_ann_path + + +@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +def test_image_detector_data_from_coco(tmpdir): + + train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) + + datamodule = ImageDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) + + data = next(iter(datamodule.train_dataloader())) + imgs, labels = data + + assert len(imgs) == 1 + assert imgs[0].shape == (3, 1080, 1920) + assert len(labels) == 1 + assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + + assert datamodule.val_dataloader() is None + assert datamodule.test_dataloader() is None + + datamodule = ImageDetectionData.from_coco( + train_folder=train_folder, + train_ann_file=coco_ann_path, + valid_folder=train_folder, + valid_ann_file=coco_ann_path, + test_folder=train_folder, + test_ann_file=coco_ann_path, + batch_size=1, + num_workers=0 + ) + + data = next(iter(datamodule.val_dataloader())) + imgs, labels = data + + assert len(imgs) == 1 + assert imgs[0].shape == (3, 1080, 1920) + assert len(labels) == 1 + assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + + data = next(iter(datamodule.test_dataloader())) + imgs, labels = data + + assert len(imgs) == 1 + assert imgs[0].shape == (3, 1080, 1920) + assert len(labels) == 1 + assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/vision/detection/test_data_model_integration.py new file mode 100644 index 0000000000..00bbaa5ed3 --- /dev/null +++ b/tests/vision/detection/test_data_model_integration.py @@ -0,0 +1,50 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +from PIL import Image +from pytorch_lightning.utilities import _module_available + +import flash +from flash.vision import ImageDetector +from flash.vision.detection.data import ImageDetectionData +from tests.vision.detection.test_data import _create_synth_coco_dataset + +_COCO_AVAILABLE = _module_available("pycocotools") +if _COCO_AVAILABLE: + from pycocotools.coco import COCO + + +@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +def test_detection(tmpdir): + + train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) + + data = ImageDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) + model = ImageDetector(num_classes=data.num_classes) + + trainer = flash.Trainer(fast_dev_run=True) + + trainer.finetune(model, data) + + test_image_one = os.fspath(tmpdir / "test_one.png") + test_image_two = os.fspath(tmpdir / "test_two.png") + + Image.new('RGB', (1920, 1080)).save(test_image_one) + Image.new('RGB', (1920, 1080)).save(test_image_two) + + test_images = [test_image_one, test_image_two] + + model.predict(test_images)