From 2adee19c1ef935110ed0741e4031f52447ffb05f Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Fri, 9 Apr 2021 23:49:53 -0400 Subject: [PATCH] Unify import format --- test/test_data_pipeline.py | 22 +++++++++++----------- yolort/data/data_module.py | 5 ++--- yolort/data/detection_pipeline.py | 3 ++- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/test/test_data_pipeline.py b/test/test_data_pipeline.py index 8b678e32..1368fdec 100644 --- a/test/test_data_pipeline.py +++ b/test/test_data_pipeline.py @@ -3,7 +3,6 @@ import unittest import torch -from torch.utils import data from yolort.data.coco import CocoDetection from yolort.data.transforms import collate_fn, default_train_transforms @@ -14,13 +13,6 @@ class DataPipelineTester(unittest.TestCase): - def test_prepare_coco128(self): - data_path = Path('data-bin') - coco128_dirname = 'coco128' - prepare_coco128(data_path, dirname=coco128_dirname) - annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json' - self.assertTrue(annotation_file.is_file()) - def test_vanilla_dataloader(self): # Acquire the images and labels from the coco128 dataset data_path = Path('data-bin') @@ -39,11 +31,19 @@ def test_vanilla_dataloader(self): self.assertIsInstance(target, Dict) batch_size = 4 - sampler = data.RandomSampler(dataset) - batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True) - loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0) + sampler = torch.utils.data.RandomSampler(dataset) + batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) + loader = torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0) # Test the dataloader images, targets = next(iter(loader)) self.assertEqual(len(images), batch_size) self.assertEqual(len(targets), batch_size) + + def test_prepare_coco128(self): + data_path = Path('data-bin') + coco128_dirname = 'coco128' + prepare_coco128(data_path, dirname=coco128_dirname) + annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json' + self.assertTrue(annotation_file.is_file()) diff --git a/yolort/data/data_module.py b/yolort/data/data_module.py index eeae1bc8..f025aea5 100644 --- a/yolort/data/data_module.py +++ b/yolort/data/data_module.py @@ -2,7 +2,6 @@ from pathlib import Path import torch.utils.data -from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset from pytorch_lightning import LightningDataModule @@ -50,7 +49,7 @@ def train_dataloader(self, batch_size: int = 16) -> None: sampler = torch.utils.data.RandomSampler(self._train_dataset) batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) - loader = DataLoader( + loader = torch.utils.data.DataLoader( self._train_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, @@ -69,7 +68,7 @@ def val_dataloader(self, batch_size: int = 16) -> None: # Creating data loaders sampler = torch.utils.data.SequentialSampler(self._val_dataset) - loader = DataLoader( + loader = torch.utils.data.DataLoader( self._val_dataset, batch_size, sampler=sampler, diff --git a/yolort/data/detection_pipeline.py b/yolort/data/detection_pipeline.py index 0c882919..0b04528d 100644 --- a/yolort/data/detection_pipeline.py +++ b/yolort/data/detection_pipeline.py @@ -1,5 +1,4 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -from typing import Callable, Any, Optional, Type from collections.abc import Sequence from torch import Tensor @@ -8,6 +7,8 @@ from .transforms import collate_fn from .data_pipeline import DataPipeline +from typing import Callable, Any, Optional, Type + class ObjectDetectionDataPipeline(DataPipeline): """