From 4771734eb30310b5f6179f668c2ee56ddd04e9ba Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 29 Nov 2020 21:55:22 +0900 Subject: [PATCH 1/2] Use cached datasets --- tests/callbacks/test_data_monitor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index ac1ec6835b..701b6641d6 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -16,11 +16,11 @@ ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_override( - log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls + log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir ): """ Test logging interval set by log_every_n_steps argument. """ monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) - model = LitMNIST(num_workers=0) + model = LitMNIST(data_dir=datadir, num_workers=0) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, @@ -43,11 +43,11 @@ def test_base_log_interval_override( ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_fallback( - log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls + log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir ): """ Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """ monitor = TrainingDataMonitor() - model = LitMNIST(num_workers=0) + model = LitMNIST(data_dir=datadir, num_workers=0) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=log_every_n_steps, @@ -81,10 +81,10 @@ def test_base_unsupported_logger_warning(tmpdir): @mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") -def test_training_data_monitor(log_histogram, tmpdir): +def test_training_data_monitor(log_histogram, tmpdir, datadir): """ Test that the TrainingDataMonitor logs histograms of data points going into training_step. """ monitor = TrainingDataMonitor() - model = LitMNIST() + model = LitMNIST(data_dir=datadir) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], ) From e5a2ffa07e8f02b649ff3d2f2e793753e2397e2c Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 29 Nov 2020 21:58:55 +0900 Subject: [PATCH 2/2] Use cached datasets in doctests --- pl_bolts/datasets/cifar10_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 1038c7aa87..fcc783a72a 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -41,7 +41,7 @@ class CIFAR10(LightDataset): >>> from torchvision import transforms >>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization >>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()]) - >>> dataset = CIFAR10(download=True, transform=cf10_transforms) + >>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets") >>> len(dataset) 50000 >>> torch.bincount(dataset.targets) @@ -167,7 +167,7 @@ class TrialCIFAR10(CIFAR10): without the torchvision dependency. Examples: - >>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8)) + >>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8), data_dir="datasets") >>> len(dataset) 450 >>> sorted(set([d.item() for d in dataset.targets]))