Skip to content

Commit

Permalink
Unify import and unit-test format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 10, 2021
1 parent 73c6220 commit 31261ce
Showing 1 changed file with 43 additions and 40 deletions.
83 changes: 43 additions & 40 deletions test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest

import torch
from torch.utils import data
from torch import Tensor
import torch.utils.data
from torchvision.io import read_image

import pytorch_lightning as pl

from yolort.data import DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms

from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models import yolov5s

from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms
from yolort.data import DetectionDataModule

from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset
Expand Down Expand Up @@ -52,9 +54,9 @@ def test_train_with_vanilla_model(self):
model.train()
out = model(images, targets)
self.assertIsInstance(out, Dict)
self.assertIsInstance(out["cls_logits"], torch.Tensor)
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)
self.assertIsInstance(out["cls_logits"], Tensor)
self.assertIsInstance(out["bbox_regression"], Tensor)
self.assertIsInstance(out["objectness"], Tensor)

def test_train_with_vanilla_module(self):
"""
Expand All @@ -77,11 +79,12 @@ def test_train_with_vanilla_module(self):
batch_size = 4

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
sampler = data.RandomSampler(dataset)
batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True)
loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Sample a pair of images/targets
images, targets = next(iter(loader))
images, targets = next(iter(data_loader))
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

Expand All @@ -91,22 +94,22 @@ def test_train_with_vanilla_module(self):

out = model(images, targets)
self.assertIsInstance(out, Dict)
self.assertIsInstance(out["cls_logits"], torch.Tensor)
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)
self.assertIsInstance(out["cls_logits"], Tensor)
self.assertIsInstance(out["bbox_regression"], Tensor)
self.assertIsInstance(out["objectness"], Tensor)

def test_train_one_step(self):
def test_train_one_epoch(self):
# Setup the DataModule
train_dataset = DummyCOCODetectionDataset(num_samples=128)
data_module = DetectionDataModule(train_dataset, batch_size=16)
# Load model
model = yolov5s()
model.train()
# Setup the DataModule
train_dataset = DummyCOCODetectionDataset(num_samples=128)
datamodule = DetectionDataModule(train_dataset, batch_size=16)
# Trainer
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, datamodule)
trainer.fit(model, data_module)

def test_inference(self):
def test_predict_with_vanilla_model(self):
# Set image inputs
img_name = "test/assets/zidane.jpg"
img_input = default_loader(img_name)
Expand All @@ -119,11 +122,11 @@ def test_inference(self):
self.assertIsInstance(out, list)
self.assertEqual(len(out), 1)
self.assertIsInstance(out[0], Dict)
self.assertIsInstance(out[0]["boxes"], torch.Tensor)
self.assertIsInstance(out[0]["labels"], torch.Tensor)
self.assertIsInstance(out[0]["scores"], torch.Tensor)
self.assertIsInstance(out[0]["boxes"], Tensor)
self.assertIsInstance(out[0]["labels"], Tensor)
self.assertIsInstance(out[0]["scores"], Tensor)

def test_predict_tensor(self):
def test_predict_with_tensor(self):
# Set image inputs
img_name = "test/assets/zidane.jpg"
img_tensor = default_loader(img_name)
Expand All @@ -136,11 +139,11 @@ def test_predict_tensor(self):
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 1)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)
self.assertIsInstance(predictions[0]["boxes"], Tensor)
self.assertIsInstance(predictions[0]["labels"], Tensor)
self.assertIsInstance(predictions[0]["scores"], Tensor)

def test_predict_tensors(self):
def test_predict_with_tensors(self):
# Set image inputs
img_tensor1 = default_loader("test/assets/zidane.jpg")
self.assertEqual(img_tensor1.ndim, 3)
Expand All @@ -155,11 +158,11 @@ def test_predict_tensors(self):
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 2)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)
self.assertIsInstance(predictions[0]["boxes"], Tensor)
self.assertIsInstance(predictions[0]["labels"], Tensor)
self.assertIsInstance(predictions[0]["scores"], Tensor)

def test_predict_image_file(self):
def test_predict_with_image_file(self):
# Set image inputs
img_name = "test/assets/zidane.jpg"
# Load model
Expand All @@ -170,11 +173,11 @@ def test_predict_image_file(self):
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 1)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)
self.assertIsInstance(predictions[0]["boxes"], Tensor)
self.assertIsInstance(predictions[0]["labels"], Tensor)
self.assertIsInstance(predictions[0]["scores"], Tensor)

def test_predict_image_files(self):
def test_predict_with_image_files(self):
# Set image inputs
img_name1 = "test/assets/zidane.jpg"
img_name2 = "test/assets/bus.jpg"
Expand All @@ -187,6 +190,6 @@ def test_predict_image_files(self):
self.assertIsInstance(predictions, list)
self.assertEqual(len(predictions), 2)
self.assertIsInstance(predictions[0], Dict)
self.assertIsInstance(predictions[0]["boxes"], torch.Tensor)
self.assertIsInstance(predictions[0]["labels"], torch.Tensor)
self.assertIsInstance(predictions[0]["scores"], torch.Tensor)
self.assertIsInstance(predictions[0]["boxes"], Tensor)
self.assertIsInstance(predictions[0]["labels"], Tensor)
self.assertIsInstance(predictions[0]["scores"], Tensor)

0 comments on commit 31261ce

Please sign in to comment.