Skip to content

Commit

Permalink
Unify import format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 10, 2021
1 parent d69e311 commit 2adee19
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
22 changes: 11 additions & 11 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import unittest

import torch
from torch.utils import data

from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms
Expand All @@ -14,13 +13,6 @@

class DataPipelineTester(unittest.TestCase):

def test_prepare_coco128(self):
data_path = Path('data-bin')
coco128_dirname = 'coco128'
prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())

def test_vanilla_dataloader(self):
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
Expand All @@ -39,11 +31,19 @@ def test_vanilla_dataloader(self):
self.assertIsInstance(target, Dict)

batch_size = 4
sampler = data.RandomSampler(dataset)
batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=True)
loader = data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
# Test the dataloader
images, targets = next(iter(loader))

self.assertEqual(len(images), batch_size)
self.assertEqual(len(targets), batch_size)

def test_prepare_coco128(self):
data_path = Path('data-bin')
coco128_dirname = 'coco128'
prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())
5 changes: 2 additions & 3 deletions yolort/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path

import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from pytorch_lightning import LightningDataModule
Expand Down Expand Up @@ -50,7 +49,7 @@ def train_dataloader(self, batch_size: int = 16) -> None:
sampler = torch.utils.data.RandomSampler(self._train_dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)

loader = DataLoader(
loader = torch.utils.data.DataLoader(
self._train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
Expand All @@ -69,7 +68,7 @@ def val_dataloader(self, batch_size: int = 16) -> None:
# Creating data loaders
sampler = torch.utils.data.SequentialSampler(self._val_dataset)

loader = DataLoader(
loader = torch.utils.data.DataLoader(
self._val_dataset,
batch_size,
sampler=sampler,
Expand Down
3 changes: 2 additions & 1 deletion yolort/data/detection_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import Callable, Any, Optional, Type
from collections.abc import Sequence

from torch import Tensor
Expand All @@ -8,6 +7,8 @@
from .transforms import collate_fn
from .data_pipeline import DataPipeline

from typing import Callable, Any, Optional, Type


class ObjectDetectionDataPipeline(DataPipeline):
"""
Expand Down

0 comments on commit 2adee19

Please sign in to comment.