Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support training with vanilla module #87

Merged
merged 4 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest

import torch
from torch.utils import data

from yolort.datasets.coco import CocoDetection
from yolort.datasets.transforms import collate_fn, default_train_transforms
from yolort.utils import prepare_coco128

from typing import Dict


class DataPipelineTester(unittest.TestCase):

def test_prepare_coco128(self):
data_path = Path('data-bin')
coco128_dirname = 'coco128'
prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())

def test_vanilla_dataloader(self):
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
# Test the datasets
image, target = next(iter(dataset))
self.assertIsInstance(image, torch.Tensor)
self.assertIsInstance(target, Dict)

batch_size = 4
sampler = data.RandomSampler(dataset)
batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True)
loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Test the dataloader
images, targets = next(iter(loader))

self.assertEqual(len(images), batch_size)
self.assertEqual(len(targets), batch_size)
56 changes: 49 additions & 7 deletions test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest
import torch
from torch.utils import data
from torchvision.io import read_image

import pytorch_lightning as pl

from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models import yolov5s

from yolort.datasets.coco import CocoDetection
from yolort.datasets.transforms import collate_fn, default_train_transforms
from yolort.datasets import DetectionDataModule

from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset

from typing import Dict

from torchvision.io import read_image


def default_loader(img_name, is_half=False):
"""
Expand All @@ -27,7 +34,7 @@ def default_loader(img_name, is_half=False):


class EngineTester(unittest.TestCase):
def test_train(self):
def test_train_with_vanilla_model(self):
# Do forward over image
img_name = "test/assets/zidane.jpg"
img_tensor = default_loader(img_name)
Expand All @@ -49,6 +56,45 @@ def test_train(self):
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)

def test_train_with_vanilla_module(self):
"""
For issue #86: <https://github.com/zhiqwang/yolov5-rt-stack/issues/86>
"""
# Define the device
device = torch.device('cpu')

# Prepare the datasets for training
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

batch_size = 4

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
sampler = data.RandomSampler(dataset)
batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True)
loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Sample a pair of images/targets
images, targets = next(iter(loader))
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

# Define the model
model = yolov5s(num_classes=80)
model.train()

out = model(images, targets)
self.assertIsInstance(out, Dict)
self.assertIsInstance(out["cls_logits"], torch.Tensor)
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)

def test_train_one_step(self):
# Load model
model = yolov5s()
Expand Down Expand Up @@ -144,7 +190,3 @@ def test_predict_image_files(self):
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)


if __name__ == '__main__':
unittest.main()
Empty file removed test/test_models_utils.py
Empty file.
4 changes: 1 addition & 3 deletions test/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
import unittest

import torch
Expand Down Expand Up @@ -65,6 +66,3 @@ def test_yolotr_script(self):
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))

if __name__ == "__main__":
unittest.main()
5 changes: 2 additions & 3 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,8 @@ def compute_loss(
loss_box += (1.0 - ciou).mean() # iou loss

# Objectness head
# iou ratio
ciou_vals = torch.tensor(ciou.detach().clamp(0), dtype=obj_logits.dtype)
obj_logits[b, a, gj, gi] = (1.0 - self.iou_ratio) + (self.iou_ratio * ciou_vals)
# Compute the iou ratio
obj_logits[b, a, gj, gi] = (1.0 - self.iou_ratio) + self.iou_ratio * ciou.detach().clamp(0)

# Classification head
if num_classes > 1: # cls loss (only if multiple classes)
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(

if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("YOLO always returns a (Losses, Detections) tuple in scripting")
warnings.warn("YOLO always returns a (Losses, Detections) tuple in scripting.")
self._has_warned = True
return losses, detections
else:
Expand Down
39 changes: 35 additions & 4 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import warnings
import argparse

import torch
Expand All @@ -8,7 +9,6 @@

from . import yolo
from .transform import GeneralizedYOLOTransform

from ..datasets import DetectionDataModule, DataPipeline

from typing import Any, List, Dict, Tuple, Optional
Expand Down Expand Up @@ -50,6 +50,9 @@ def __init__(

self._data_pipeline = None

# used only on torchscript mode
self._has_warned = False

def forward(
self,
inputs: List[Tensor],
Expand Down Expand Up @@ -77,9 +80,37 @@ def forward(
# Transform the input
samples, targets = self.transform(inputs, targets)
# Compute the detections
detections = self.model(samples.tensors, targets=targets)
# Rescale coordinate
detections = self.transform.postprocess(detections, samples.image_sizes, original_image_sizes)
outputs = self.model(samples.tensors, targets=targets)

losses = {}
detections: List[Dict[str, Tensor]] = []

if self.training:
# compute the losses
if torch.jit.is_scripting():
losses = outputs[0]
else:
losses = outputs
else:
# Rescale coordinate
detections = self.transform.postprocess(outputs, samples.image_sizes, original_image_sizes)

if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("YOLOModule always returns Detections in scripting.")
self._has_warned = True
return detections
else:
return self.eager_outputs(losses, detections)

@torch.jit.unused
def eager_outputs(
self,
losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]],
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training:
return losses

return detections

Expand Down
1 change: 1 addition & 0 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .flash_utils import get_callable_dict
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import update_module_state_from_ultralytics
from .file_utils import prepare_coco128
33 changes: 33 additions & 0 deletions yolort/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import PosixPath
from zipfile import ZipFile

import torch


def prepare_coco128(
data_path: PosixPath,
dirname: str = 'coco128',
) -> None:
"""
Prepare coco128 dataset to test.

Args:
data_path (PosixPath): root path of coco128 dataset.
dirname (str): the directory name of coco128 dataset. Default: 'coco128'.
"""
if not data_path.is_dir():
print(f'Create a new directory: {data_path}')
data_path.mkdir(parents=True, exist_ok=True)

zip_path = data_path / 'coco128.zip'
coco128_url = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip'
if not zip_path.is_file():
print(f'Downloading coco128 datasets form {coco128_url}')
torch.hub.download_url_to_file(coco128_url, zip_path, hash_prefix='a67d2887')

coco128_path = data_path / dirname
if not coco128_path.is_dir():
print(f'Unzipping dataset to {coco128_path}')
with ZipFile(zip_path, 'r') as zip_obj:
zip_obj.extractall(data_path)