diff --git a/test/test_data_pipeline.py b/test/test_data_pipeline.py index 1368fdec..bf414457 100644 --- a/test/test_data_pipeline.py +++ b/test/test_data_pipeline.py @@ -2,17 +2,20 @@ from pathlib import Path import unittest -import torch +import torch.utils.data +from torch import Tensor +from yolort.data import DetectionDataModule from yolort.data.coco import CocoDetection from yolort.data.transforms import collate_fn, default_train_transforms from yolort.utils import prepare_coco128 +from .dataset_utils import DummyCOCODetectionDataset + from typing import Dict class DataPipelineTester(unittest.TestCase): - def test_vanilla_dataloader(self): # Acquire the images and labels from the coco128 dataset data_path = Path('data-bin') @@ -27,19 +30,44 @@ def test_vanilla_dataloader(self): dataset = CocoDetection(image_root, annotation_file, default_train_transforms()) # Test the datasets image, target = next(iter(dataset)) - self.assertIsInstance(image, torch.Tensor) + self.assertIsInstance(image, Tensor) self.assertIsInstance(target, Dict) batch_size = 4 sampler = torch.utils.data.RandomSampler(dataset) batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) - loader = torch.utils.data.DataLoader( + data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0) # Test the dataloader - images, targets = next(iter(loader)) + images, targets = next(iter(data_loader)) + + self.assertEqual(len(images), batch_size) + self.assertIsInstance(images[0], Tensor) + self.assertEqual(len(images[0]), 3) + self.assertEqual(len(targets), batch_size) + self.assertIsInstance(targets[0], Dict) + self.assertIsInstance(targets[0]["image_id"], Tensor) + self.assertIsInstance(targets[0]["boxes"], Tensor) + self.assertIsInstance(targets[0]["labels"], Tensor) + self.assertIsInstance(targets[0]["orig_size"], Tensor) + + def test_detection_data_module(self): + # Setup the DataModule + batch_size = 4 + train_dataset = DummyCOCODetectionDataset(num_samples=128) + data_module = DetectionDataModule(train_dataset, batch_size=batch_size) + self.assertEqual(data_module.batch_size, batch_size) + data_loader = data_module.train_dataloader(batch_size=batch_size) + images, targets = next(iter(data_loader)) self.assertEqual(len(images), batch_size) + self.assertIsInstance(images[0], Tensor) + self.assertEqual(len(images[0]), 3) self.assertEqual(len(targets), batch_size) + self.assertIsInstance(targets[0], Dict) + self.assertIsInstance(targets[0]["image_id"], Tensor) + self.assertIsInstance(targets[0]["boxes"], Tensor) + self.assertIsInstance(targets[0]["labels"], Tensor) def test_prepare_coco128(self): data_path = Path('data-bin')