diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 55ec7f38268..82fce713f14 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -3,10 +3,12 @@ import time import unittest.mock from datetime import datetime +from distutils import dir_util from os import path from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import urlopen, Request +import tempfile import warnings import pytest @@ -194,6 +196,17 @@ def collect_download_configs(dataset_loader, name=None, **kwargs): return make_download_configs(urls_and_md5s, name) +# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a +# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run. +ROOT = tempfile.mkdtemp() + + +@pytest.fixture(scope="module", autouse=True) +def root(): + yield ROOT + dir_util.remove_tree(ROOT) + + def places365(): with log_download_attempts(patch=False) as urls_and_md5s: for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): @@ -206,26 +219,26 @@ def places365(): def caltech101(): - return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101") + return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101") def caltech256(): - return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256") + return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256") def cifar10(): - return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10") + return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10") def cifar100(): - return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100") + return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100") def voc(): return itertools.chain( *[ collect_download_configs( - lambda: datasets.VOCSegmentation(".", year=year, download=True), + lambda: datasets.VOCSegmentation(ROOT, year=year, download=True), name=f"VOC, {year}", file="voc", ) @@ -235,27 +248,27 @@ def voc(): def mnist(): - return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST") + return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST") def fashion_mnist(): - return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST") + return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST") def kmnist(): - return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST") + return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST") def emnist(): # the 'split' argument can be any valid one, since everything is downloaded anyway - return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST") + return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST") def qmnist(): return itertools.chain( *[ collect_download_configs( - lambda: datasets.QMNIST(".", what=what, download=True), + lambda: datasets.QMNIST(ROOT, what=what, download=True), name=f"QMNIST, {what}", file="mnist", ) @@ -268,7 +281,7 @@ def omniglot(): return itertools.chain( *[ collect_download_configs( - lambda: datasets.Omniglot(".", background=background, download=True), + lambda: datasets.Omniglot(ROOT, background=background, download=True), name=f"Omniglot, {'background' if background else 'evaluation'}", ) for background in (True, False) @@ -280,7 +293,7 @@ def phototour(): return itertools.chain( *[ collect_download_configs( - lambda: datasets.PhotoTour(".", name=name, download=True), + lambda: datasets.PhotoTour(ROOT, name=name, download=True), name=f"PhotoTour, {name}", file="phototour", ) @@ -293,7 +306,7 @@ def phototour(): def sbdataset(): return collect_download_configs( - lambda: datasets.SBDataset(".", download=True), + lambda: datasets.SBDataset(ROOT, download=True), name="SBDataset", file="voc", ) @@ -301,7 +314,7 @@ def sbdataset(): def sbu(): return collect_download_configs( - lambda: datasets.SBU(".", download=True), + lambda: datasets.SBU(ROOT, download=True), name="SBU", file="sbu", ) @@ -309,7 +322,7 @@ def sbu(): def semeion(): return collect_download_configs( - lambda: datasets.SEMEION(".", download=True), + lambda: datasets.SEMEION(ROOT, download=True), name="SEMEION", file="semeion", ) @@ -317,7 +330,7 @@ def semeion(): def stl10(): return collect_download_configs( - lambda: datasets.STL10(".", download=True), + lambda: datasets.STL10(ROOT, download=True), name="STL10", ) @@ -326,7 +339,7 @@ def svhn(): return itertools.chain( *[ collect_download_configs( - lambda: datasets.SVHN(".", split=split, download=True), + lambda: datasets.SVHN(ROOT, split=split, download=True), name=f"SVHN, {split}", file="svhn", ) @@ -339,7 +352,7 @@ def usps(): return itertools.chain( *[ collect_download_configs( - lambda: datasets.USPS(".", train=train, download=True), + lambda: datasets.USPS(ROOT, train=train, download=True), name=f"USPS, {'train' if train else 'test'}", file="usps", ) @@ -350,7 +363,7 @@ def usps(): def celeba(): return collect_download_configs( - lambda: datasets.CelebA(".", download=True), + lambda: datasets.CelebA(ROOT, download=True), name="CelebA", file="celeba", ) @@ -358,7 +371,7 @@ def celeba(): def widerface(): return collect_download_configs( - lambda: datasets.WIDERFace(".", download=True), + lambda: datasets.WIDERFace(ROOT, download=True), name="WIDERFace", file="widerface", )