Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

feat: Add Detection Task #56

Merged
merged 36 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ce852c8
add CustomCOCODataset for detection
kaushikb11 Feb 2, 2021
a8c6315
add steps for the detection task
kaushikb11 Feb 3, 2021
3db086d
update steps for detector task
kaushikb11 Feb 3, 2021
837417a
add data pipeline
kaushikb11 Feb 3, 2021
f7ccfe2
add import error for coco api
kaushikb11 Feb 3, 2021
135543c
add ref to bolts
kaushikb11 Feb 3, 2021
c518a2b
add base finetuning
kaushikb11 Feb 3, 2021
45543ec
add dataset
kaushikb11 Feb 3, 2021
16db73a
add fine tuning script
kaushikb11 Feb 3, 2021
c3deb78
update fine tuning script
kaushikb11 Feb 3, 2021
8d3954b
add imge detector data module
kaushikb11 Feb 4, 2021
3fa1a49
handle images with no annotations
kaushikb11 Feb 4, 2021
9fd752f
add test step
kaushikb11 Feb 4, 2021
696515b
fix crowd coco assign
kaushikb11 Feb 4, 2021
14939bf
add test for COCO dataloader
kaushikb11 Feb 5, 2021
932cd10
update example format
kaushikb11 Feb 5, 2021
d1254fb
update test for COCO dataloader
kaushikb11 Feb 5, 2021
c1891dd
add pycoco to requirements
kaushikb11 Feb 5, 2021
cfd90e1
add pycoco to requirements with python version
kaushikb11 Feb 5, 2021
4461121
skip test if coco not installed
kaushikb11 Feb 5, 2021
9a8cd35
add test for data model integration
kaushikb11 Feb 5, 2021
ebc2fcf
skip integration test if coco not installed
kaushikb11 Feb 5, 2021
010402e
add model predict in integration tests
kaushikb11 Feb 5, 2021
a80b569
update dummy image shapes
kaushikb11 Feb 5, 2021
9a15863
add save hyperparameters
kaushikb11 Feb 5, 2021
ff52ffc
fix labels assignment
kaushikb11 Feb 7, 2021
fa46955
add singular data pipeline
kaushikb11 Feb 8, 2021
6fb5935
add updates for predict
kaushikb11 Feb 8, 2021
cef3f45
add cython to requirements
kaushikb11 Feb 8, 2021
71ebf17
add cython to requirements
kaushikb11 Feb 8, 2021
720b1cc
fix failing tests
kaushikb11 Feb 8, 2021
84c5dec
use download_data with coco128 url
kaushikb11 Feb 8, 2021
068c3a4
update imports
kaushikb11 Feb 8, 2021
cbcc2b1
handle degenerated boxes
kaushikb11 Feb 9, 2021
9f97160
Apply suggestions from code review
Borda Feb 10, 2021
3330387
Merge branch 'master' into feat/detect
Borda Feb 10, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
Args:
attr_names: Name(s) of the module attributes of the model to be frozen.

train_bn: Wether to train Batch Norm layer
train_bn: Whether to train Batch Norm layer

"""

Expand Down
192 changes: 192 additions & 0 deletions flash/vision/detection/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Optional, Tuple

import torch
from PIL import Image
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torchvision import transforms as T

from flash.core.data import TaskDataPipeline
from flash.core.data.datamodule import DataModule
from flash.core.data.utils import _contains_any_tensor
from flash.vision.classification.data import _pil_loader

try:
from pycocotools.coco import COCO
except ImportError:
COCO = None
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved


class CustomCOCODataset(torch.utils.data.Dataset):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
root: str,
ann_file: str,
transforms: Optional[Callable] = None,
):
if COCO is None:
raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset")

self.root = root
self.transforms = transforms
self.coco = COCO(ann_file)
self.ids = list(sorted(self.coco.imgs.keys()))

@property
def num_classes(self):
categories = self.coco.loadCats(self.coco.getCatIds())
if not categories:
raise ValueError("No Categories found")
return categories[-1]["id"] + 1
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, index: int) -> Tuple[Any, Any]:
coco = self.coco
img_idx = self.ids[index]

ann_ids = coco.getAnnIds(imgIds=img_idx)
annotations = coco.loadAnns(ann_ids)

image_path = coco.loadImgs(img_idx)[0]["file_name"]
img = Image.open(os.path.join(self.root, image_path))

boxes = []
labels = []
areas = []
iscrowd = []

for obj in annotations:
xmin = obj["bbox"][0]
ymin = obj["bbox"][1]
xmax = xmin + obj["bbox"][2]
ymax = ymin + obj["bbox"][3]

boxes.append([xmin, ymin, xmax, ymax])
labels.append(obj["category_id"] - 1)
areas.append(obj["area"])
iscrowd.append(obj["iscrowd"])

target = {}
target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
target["image_id"] = torch.tensor([img_idx])
target["area"] = torch.as_tensor(areas, dtype=torch.float32)
target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64)

if self.transforms is not None:
img = self.transforms(img)

return img, target

def __len__(self):
return len(self.ids)


_default_transform = T.Compose([T.ToTensor()])
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved


def collate_fn(batch):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
return tuple(zip(*batch))


def _coco_remove_images_without_annotations(dataset):
# Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py

def _has_only_empty_bbox(anno):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
Borda marked this conversation as resolved.
Show resolved Hide resolved
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
return True

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


class ImageDetectorDataPipeline(TaskDataPipeline):

def __init__(self, valid_transform: Optional[Callable] = _default_transform, loader: Callable = _pil_loader):
self._valid_transform = valid_transform
self._loader = loader

def before_collate(self, samples: Any) -> Any:
if _contains_any_tensor(samples):
return samples

if isinstance(samples, str):
samples = [samples]

if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
outputs.append(self._valid_transform(output))
return outputs
raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.")


class ImageDetectionData(DataModule):

@classmethod
def from_coco(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
cls,
train_folder: Optional[str] = None,
train_ann_file: Optional[str] = None,
train_transform: Optional[Callable] = _default_transform,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
valid_folder: Optional[str] = None,
valid_ann_file: Optional[str] = None,
valid_transform: Optional[Callable] = _default_transform,
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
test_transform: Optional[Callable] = _default_transform,
batch_size: int = 4,
num_workers: Optional[int] = None,
**kwargs
):
train_ds = CustomCOCODataset(train_folder, train_ann_file, train_transform)
num_classes = train_ds.num_classes
train_ds = _coco_remove_images_without_annotations(train_ds)

valid_ds = (
CustomCOCODataset(valid_folder, valid_ann_file, valid_transform) if valid_folder is not None else None
)

test_ds = (CustomCOCODataset(test_folder, test_ann_file, test_transform) if test_folder is not None else None)

datamodule = cls(
train_ds=train_ds,
valid_ds=valid_ds,
test_ds=test_ds,
batch_size=batch_size,
num_workers=num_workers,
)

datamodule.num_classes = num_classes
datamodule.data_pipeline = ImageDetectorDataPipeline()
datamodule.data_pipeline.collate_fn = collate_fn
return datamodule
20 changes: 20 additions & 0 deletions flash/vision/detection/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.data import download_data


def coco128_data_download(path: str):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
URL = "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip"
download_data(URL, path)
29 changes: 29 additions & 0 deletions flash/vision/detection/finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl

from flash.core.finetuning import FlashBaseFinetuning


class ImageDetectorFineTuning(FlashBaseFinetuning):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
Freezes the backbone during Detector training.
"""

def __init__(self, train_bn: bool = True):
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
model = pl_module.model
self.freeze(module=model.backbone, train_bn=self.train_bn)
80 changes: 70 additions & 10 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torchvision
from torch import nn
from torch.optim import Optimizer
from torchvision.ops import box_iou

from flash.core.classification import ClassificationTask
from flash.core import Task
from flash.core.data import DataPipeline
from flash.core.model import predict_context
from flash.vision.detection.data import _default_transform, ImageDetectorDataPipeline
from flash.vision.detection.finetuning import ImageDetectorFineTuning

_models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn}


class ImageDetector(ClassificationTask):
def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class ImageDetector(Task):
"""Image detection task

Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts
Args:
num_classes: the number of classes for detection, including background
model: either a string of :attr`_models` or a custom nn.Module.
Expand All @@ -52,17 +69,16 @@ def __init__(
learning_rate=1e-3,
**kwargs,
):

self.save_hyperparameters()

if model in _models:
model = _models[model](pretrained=pretrained)
if isinstance(model, torchvision.models.detection.FasterRCNN):
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head

if loss is None:
# TODO: maybe better way of handling no loss,
loss = {}

super().__init__(
model=model,
loss_fn=loss,
Expand All @@ -81,7 +97,51 @@ def training_step(self, batch, batch_idx) -> Any:
# fasterrcnn takes both images and targets for training, returns loss_dict
loss_dict = self.model(images, targets)
loss = sum(loss_dict.values())
for k, v in loss_dict.items():
self.log("train_k", v)

self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
return {"val_iou": iou}

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
logs = {"val_iou": avg_iou}
return {"avg_val_iou": avg_iou, "log": logs}

def test_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
return {"test_iou": iou}

def test_epoch_end(self, outs):
avg_iou = torch.stack([o["test_iou"] for o in outs]).mean()
logs = {"test_iou": avg_iou}
return {"avg_test_iou": avg_iou, "log": logs}

@predict_context
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
def predict(
self,
x: Any,
batch_idx: Optional[int] = None,
skip_collate_fn: bool = False,
dataloader_idx: Optional[int] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:

data_pipeline = data_pipeline or self.default_pipeline()
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
predictions = self.forward(batch)
return data_pipeline.uncollate_fn(predictions)

@staticmethod
def default_pipeline() -> ImageDetectorDataPipeline:
return ImageDetectorDataPipeline()

def configure_finetune_callback(self):
return [ImageDetectorFineTuning(train_bn=True)]
26 changes: 26 additions & 0 deletions flash_examples/finetuning/image_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import flash
from flash.vision.detection.data import ImageDetectionData
from flash.vision.detection.dataset import coco128_data_download
from flash.vision.detection.model import ImageDetector

# 1. Download the data
coco128_data_download("data/")

# 2. Load the Data
datamodule = ImageDetectionData.from_coco(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
batch_size=2
)

# 3. Build the model
model = ImageDetector(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)

# 5. Finetune the model
trainer.finetune(model, datamodule)

# 6. Save it!
trainer.save_checkpoint("image_detection_model.pt")
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
Loading