Skip to content

Commit

Permalink
Add DetectionDataModule unit-test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 10, 2021
1 parent 2adee19 commit 73c6220
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
from pathlib import Path
import unittest

import torch
import torch.utils.data
from torch import Tensor

from yolort.data import DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms
from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset

from typing import Dict


class DataPipelineTester(unittest.TestCase):

def test_vanilla_dataloader(self):
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
Expand All @@ -27,19 +30,44 @@ def test_vanilla_dataloader(self):
dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
# Test the datasets
image, target = next(iter(dataset))
self.assertIsInstance(image, torch.Tensor)
self.assertIsInstance(image, Tensor)
self.assertIsInstance(target, Dict)

batch_size = 4
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
loader = torch.utils.data.DataLoader(
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Test the dataloader
images, targets = next(iter(loader))
images, targets = next(iter(data_loader))

self.assertEqual(len(images), batch_size)
self.assertIsInstance(images[0], Tensor)
self.assertEqual(len(images[0]), 3)
self.assertEqual(len(targets), batch_size)
self.assertIsInstance(targets[0], Dict)
self.assertIsInstance(targets[0]["image_id"], Tensor)
self.assertIsInstance(targets[0]["boxes"], Tensor)
self.assertIsInstance(targets[0]["labels"], Tensor)
self.assertIsInstance(targets[0]["orig_size"], Tensor)

def test_detection_data_module(self):
# Setup the DataModule
batch_size = 4
train_dataset = DummyCOCODetectionDataset(num_samples=128)
data_module = DetectionDataModule(train_dataset, batch_size=batch_size)
self.assertEqual(data_module.batch_size, batch_size)

data_loader = data_module.train_dataloader(batch_size=batch_size)
images, targets = next(iter(data_loader))
self.assertEqual(len(images), batch_size)
self.assertIsInstance(images[0], Tensor)
self.assertEqual(len(images[0]), 3)
self.assertEqual(len(targets), batch_size)
self.assertIsInstance(targets[0], Dict)
self.assertIsInstance(targets[0]["image_id"], Tensor)
self.assertIsInstance(targets[0]["boxes"], Tensor)
self.assertIsInstance(targets[0]["labels"], Tensor)

def test_prepare_coco128(self):
data_path = Path('data-bin')
Expand Down

0 comments on commit 73c6220

Please sign in to comment.