Skip to content

Commit

Permalink
Reproduce the bug in #86
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 8, 2021
1 parent 40b4dba commit c72794f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 7 deletions.
49 changes: 49 additions & 0 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest

import torch
from torch.utils import data

from yolort.datasets.coco import CocoDetection
from yolort.datasets.transforms import collate_fn, default_train_transforms
from yolort.utils import prepare_coco128

from typing import Dict


class DataPipelineTester(unittest.TestCase):

def test_prepare_coco128(self):
data_path = Path('data-bin')
coco128_dirname = 'coco128'
prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())

def test_vanilla_dataloader(self):
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
# Test the datasets
image, target = next(iter(dataset))
self.assertIsInstance(image, torch.Tensor)
self.assertIsInstance(target, Dict)

batch_size = 4
sampler = data.RandomSampler(dataset)
batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True)
loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Test the dataloader
images, targets = next(iter(loader))

self.assertEqual(len(images), batch_size)
self.assertEqual(len(targets), batch_size)
57 changes: 50 additions & 7 deletions test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest
import torch
from torch.utils import data
from torchvision.io import read_image

import pytorch_lightning as pl

from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models import yolov5s

from yolort.datasets.coco import CocoDetection
from yolort.datasets.transforms import collate_fn, default_train_transforms
from yolort.datasets import DetectionDataModule

from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset

from typing import Dict

from torchvision.io import read_image


def default_loader(img_name, is_half=False):
"""
Expand All @@ -27,7 +34,7 @@ def default_loader(img_name, is_half=False):


class EngineTester(unittest.TestCase):
def test_train(self):
def test_train_with_vanilla_model(self):
# Do forward over image
img_name = "test/assets/zidane.jpg"
img_tensor = default_loader(img_name)
Expand All @@ -49,6 +56,46 @@ def test_train(self):
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)

def test_train_with_vanilla_module(self):
"""
For issue #86: <https://github.com/zhiqwang/yolov5-rt-stack/issues/86>
"""
# Define the device
device = torch.device('cpu')

# Prepare the datasets for training
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

batch_size = 4

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
sampler = data.RandomSampler(dataset)
batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True)
loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Sample a pair of images/targets
images, targets = next(iter(loader))
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

# Define the model
model = yolov5s(num_classes=80)
model.train()

out = model(images, targets)
self.assertIsInstance(out, Dict)
self.assertIsInstance(out["cls_logits"], torch.Tensor)
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)

@unittest.skip("Just ignore this.")
def test_train_one_step(self):
# Load model
model = yolov5s()
Expand Down Expand Up @@ -144,7 +191,3 @@ def test_predict_image_files(self):
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)


if __name__ == '__main__':
unittest.main()
Empty file removed test/test_models_utils.py
Empty file.
1 change: 1 addition & 0 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .flash_utils import get_callable_dict
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import update_module_state_from_ultralytics
from .file_utils import prepare_coco128
33 changes: 33 additions & 0 deletions yolort/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import PosixPath
from zipfile import ZipFile

import torch


def prepare_coco128(
data_path: PosixPath,
dirname: str = 'coco128',
) -> None:
"""
Prepare coco128 dataset to test.
Args:
data_path (PosixPath): root path of coco128 dataset.
dirname (str): the directory name of coco128 dataset. Default: 'coco128'.
"""
if not data_path.is_dir():
print(f'Create a new directory: {data_path}')
data_path.mkdir(parents=True, exist_ok=True)

zip_path = data_path / 'coco128.zip'
coco128_url = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip'
if not zip_path.is_file():
print(f'Downloading coco128 datasets form {coco128_url}')
torch.hub.download_url_to_file(coco128_url, zip_path, hash_prefix='a67d2887')

coco128_path = data_path / dirname
if not coco128_path.is_dir():
print(f'Unzipping dataset to {coco128_path}')
with ZipFile(zip_path, 'r') as zip_obj:
zip_obj.extractall(data_path)

0 comments on commit c72794f

Please sign in to comment.