Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove generated dataset folders after download tests #3376

Merged
merged 2 commits into from
Feb 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import time
import unittest.mock
from datetime import datetime
from distutils import dir_util
from os import path
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse
from urllib.request import urlopen, Request
import tempfile
import warnings

import pytest
Expand Down Expand Up @@ -194,6 +196,17 @@ def collect_download_configs(dataset_loader, name=None, **kwargs):
return make_download_configs(urls_and_md5s, name)


# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that you are using pytest specific functionalities here but could you elaborate on when this is triggered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The autouse=True flag tells pytest to trigger this automatically even if it is not explicitly used. The scope="module" tells pytest to run this once per module. In general, everything before the yield is the setup whereas everything after it is the teardown.

Here, pytest triggers this fixture the first time any test in this module is run and finishes it when the last test is run.

yield ROOT
dir_util.remove_tree(ROOT)


def places365():
with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
Expand All @@ -206,26 +219,26 @@ def places365():


def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")


def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")


def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")


def cifar100():
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")


def voc():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True),
lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
name=f"VOC, {year}",
file="voc",
)
Expand All @@ -235,27 +248,27 @@ def voc():


def mnist():
return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")


def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")
return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")


def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")
return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")


def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")
return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")


def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(".", what=what, download=True),
lambda: datasets.QMNIST(ROOT, what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
Expand All @@ -268,7 +281,7 @@ def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(".", background=background, download=True),
lambda: datasets.Omniglot(ROOT, background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
Expand All @@ -280,7 +293,7 @@ def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(".", name=name, download=True),
lambda: datasets.PhotoTour(ROOT, name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
Expand All @@ -293,31 +306,31 @@ def phototour():

def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(".", download=True),
lambda: datasets.SBDataset(ROOT, download=True),
name="SBDataset",
file="voc",
)


def sbu():
return collect_download_configs(
lambda: datasets.SBU(".", download=True),
lambda: datasets.SBU(ROOT, download=True),
name="SBU",
file="sbu",
)


def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(".", download=True),
lambda: datasets.SEMEION(ROOT, download=True),
name="SEMEION",
file="semeion",
)


def stl10():
return collect_download_configs(
lambda: datasets.STL10(".", download=True),
lambda: datasets.STL10(ROOT, download=True),
name="STL10",
)

Expand All @@ -326,7 +339,7 @@ def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(".", split=split, download=True),
lambda: datasets.SVHN(ROOT, split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
Expand All @@ -339,7 +352,7 @@ def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(".", train=train, download=True),
lambda: datasets.USPS(ROOT, train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
Expand All @@ -350,15 +363,15 @@ def usps():

def celeba():
return collect_download_configs(
lambda: datasets.CelebA(".", download=True),
lambda: datasets.CelebA(ROOT, download=True),
name="CelebA",
file="celeba",
)


def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(".", download=True),
lambda: datasets.WIDERFace(ROOT, download=True),
name="WIDERFace",
file="widerface",
)
Expand Down