diff --git a/test/test_engine.py b/test/test_engine.py index 3f638c56..09b01a46 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -1,20 +1,22 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from pathlib import Path import unittest + import torch -from torch.utils import data +from torch import Tensor +import torch.utils.data from torchvision.io import read_image import pytorch_lightning as pl +from yolort.data import DetectionDataModule +from yolort.data.coco import CocoDetection +from yolort.data.transforms import collate_fn, default_train_transforms + from yolort.models.yolo import yolov5_darknet_pan_s_r31 from yolort.models.transform import nested_tensor_from_tensor_list from yolort.models import yolov5s -from yolort.data.coco import CocoDetection -from yolort.data.transforms import collate_fn, default_train_transforms -from yolort.data import DetectionDataModule - from yolort.utils import prepare_coco128 from .dataset_utils import DummyCOCODetectionDataset @@ -52,9 +54,9 @@ def test_train_with_vanilla_model(self): model.train() out = model(images, targets) self.assertIsInstance(out, Dict) - self.assertIsInstance(out["cls_logits"], torch.Tensor) - self.assertIsInstance(out["bbox_regression"], torch.Tensor) - self.assertIsInstance(out["objectness"], torch.Tensor) + self.assertIsInstance(out["cls_logits"], Tensor) + self.assertIsInstance(out["bbox_regression"], Tensor) + self.assertIsInstance(out["objectness"], Tensor) def test_train_with_vanilla_module(self): """ @@ -77,11 +79,12 @@ def test_train_with_vanilla_module(self): batch_size = 4 dataset = CocoDetection(image_root, annotation_file, default_train_transforms()) - sampler = data.RandomSampler(dataset) - batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True) - loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0) + sampler = torch.utils.data.RandomSampler(dataset) + batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) + data_loader = torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0) # Sample a pair of images/targets - images, targets = next(iter(loader)) + images, targets = next(iter(data_loader)) images = [img.to(device) for img in images] targets = [{k: v.to(device) for k, v in t.items()} for t in targets] @@ -91,22 +94,22 @@ def test_train_with_vanilla_module(self): out = model(images, targets) self.assertIsInstance(out, Dict) - self.assertIsInstance(out["cls_logits"], torch.Tensor) - self.assertIsInstance(out["bbox_regression"], torch.Tensor) - self.assertIsInstance(out["objectness"], torch.Tensor) + self.assertIsInstance(out["cls_logits"], Tensor) + self.assertIsInstance(out["bbox_regression"], Tensor) + self.assertIsInstance(out["objectness"], Tensor) - def test_train_one_step(self): + def test_train_one_epoch(self): + # Setup the DataModule + train_dataset = DummyCOCODetectionDataset(num_samples=128) + data_module = DetectionDataModule(train_dataset, batch_size=16) # Load model model = yolov5s() model.train() - # Setup the DataModule - train_dataset = DummyCOCODetectionDataset(num_samples=128) - datamodule = DetectionDataModule(train_dataset, batch_size=16) # Trainer trainer = pl.Trainer(max_epochs=1) - trainer.fit(model, datamodule) + trainer.fit(model, data_module) - def test_inference(self): + def test_predict_with_vanilla_model(self): # Set image inputs img_name = "test/assets/zidane.jpg" img_input = default_loader(img_name) @@ -119,11 +122,11 @@ def test_inference(self): self.assertIsInstance(out, list) self.assertEqual(len(out), 1) self.assertIsInstance(out[0], Dict) - self.assertIsInstance(out[0]["boxes"], torch.Tensor) - self.assertIsInstance(out[0]["labels"], torch.Tensor) - self.assertIsInstance(out[0]["scores"], torch.Tensor) + self.assertIsInstance(out[0]["boxes"], Tensor) + self.assertIsInstance(out[0]["labels"], Tensor) + self.assertIsInstance(out[0]["scores"], Tensor) - def test_predict_tensor(self): + def test_predict_with_tensor(self): # Set image inputs img_name = "test/assets/zidane.jpg" img_tensor = default_loader(img_name) @@ -136,11 +139,11 @@ def test_predict_tensor(self): self.assertIsInstance(predictions, list) self.assertEqual(len(predictions), 1) self.assertIsInstance(predictions[0], Dict) - self.assertIsInstance(predictions[0]["boxes"], torch.Tensor) - self.assertIsInstance(predictions[0]["labels"], torch.Tensor) - self.assertIsInstance(predictions[0]["scores"], torch.Tensor) + self.assertIsInstance(predictions[0]["boxes"], Tensor) + self.assertIsInstance(predictions[0]["labels"], Tensor) + self.assertIsInstance(predictions[0]["scores"], Tensor) - def test_predict_tensors(self): + def test_predict_with_tensors(self): # Set image inputs img_tensor1 = default_loader("test/assets/zidane.jpg") self.assertEqual(img_tensor1.ndim, 3) @@ -155,11 +158,11 @@ def test_predict_tensors(self): self.assertIsInstance(predictions, list) self.assertEqual(len(predictions), 2) self.assertIsInstance(predictions[0], Dict) - self.assertIsInstance(predictions[0]["boxes"], torch.Tensor) - self.assertIsInstance(predictions[0]["labels"], torch.Tensor) - self.assertIsInstance(predictions[0]["scores"], torch.Tensor) + self.assertIsInstance(predictions[0]["boxes"], Tensor) + self.assertIsInstance(predictions[0]["labels"], Tensor) + self.assertIsInstance(predictions[0]["scores"], Tensor) - def test_predict_image_file(self): + def test_predict_with_image_file(self): # Set image inputs img_name = "test/assets/zidane.jpg" # Load model @@ -170,11 +173,11 @@ def test_predict_image_file(self): self.assertIsInstance(predictions, list) self.assertEqual(len(predictions), 1) self.assertIsInstance(predictions[0], Dict) - self.assertIsInstance(predictions[0]["boxes"], torch.Tensor) - self.assertIsInstance(predictions[0]["labels"], torch.Tensor) - self.assertIsInstance(predictions[0]["scores"], torch.Tensor) + self.assertIsInstance(predictions[0]["boxes"], Tensor) + self.assertIsInstance(predictions[0]["labels"], Tensor) + self.assertIsInstance(predictions[0]["scores"], Tensor) - def test_predict_image_files(self): + def test_predict_with_image_files(self): # Set image inputs img_name1 = "test/assets/zidane.jpg" img_name2 = "test/assets/bus.jpg" @@ -187,6 +190,6 @@ def test_predict_image_files(self): self.assertIsInstance(predictions, list) self.assertEqual(len(predictions), 2) self.assertIsInstance(predictions[0], Dict) - self.assertIsInstance(predictions[0]["boxes"], torch.Tensor) - self.assertIsInstance(predictions[0]["labels"], torch.Tensor) - self.assertIsInstance(predictions[0]["scores"], torch.Tensor) + self.assertIsInstance(predictions[0]["boxes"], Tensor) + self.assertIsInstance(predictions[0]["labels"], Tensor) + self.assertIsInstance(predictions[0]["scores"], Tensor)