Skip to content

Commit

Permalink
tests: Use cached datasets in LitMNIST and the doctests (Lightning-Un…
Browse files Browse the repository at this point in the history
…iverse#414)

* Use cached datasets

* Use cached datasets in doctests
  • Loading branch information
akihironitta authored and chris-clem committed Dec 16, 2020
1 parent f8335f9 commit 5b299ac
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pl_bolts/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CIFAR10(LightDataset):
>>> from torchvision import transforms
>>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
>>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()])
>>> dataset = CIFAR10(download=True, transform=cf10_transforms)
>>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets")
>>> len(dataset)
50000
>>> torch.bincount(dataset.targets)
Expand Down Expand Up @@ -167,7 +167,7 @@ class TrialCIFAR10(CIFAR10):
without the torchvision dependency.
Examples:
>>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8))
>>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8), data_dir="datasets")
>>> len(dataset)
450
>>> sorted(set([d.item() for d in dataset.targets]))
Expand Down
12 changes: 6 additions & 6 deletions tests/callbacks/test_data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
)
@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram")
def test_base_log_interval_override(
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir
):
""" Test logging interval set by log_every_n_steps argument. """
monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps)
model = LitMNIST(num_workers=0)
model = LitMNIST(data_dir=datadir, num_workers=0)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
Expand All @@ -43,11 +43,11 @@ def test_base_log_interval_override(
)
@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram")
def test_base_log_interval_fallback(
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls
log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir
):
""" Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """
monitor = TrainingDataMonitor()
model = LitMNIST(num_workers=0)
model = LitMNIST(data_dir=datadir, num_workers=0)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=log_every_n_steps,
Expand Down Expand Up @@ -81,10 +81,10 @@ def test_base_unsupported_logger_warning(tmpdir):


@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram")
def test_training_data_monitor(log_histogram, tmpdir):
def test_training_data_monitor(log_histogram, tmpdir, datadir):
""" Test that the TrainingDataMonitor logs histograms of data points going into training_step. """
monitor = TrainingDataMonitor()
model = LitMNIST()
model = LitMNIST(data_dir=datadir)
trainer = Trainer(
default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor],
)
Expand Down

0 comments on commit 5b299ac

Please sign in to comment.