Skip to content

Commit

Permalink
simplify/speedup catalyst tests (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Oct 17, 2022
1 parent 9a8549e commit 2d61715
Showing 1 changed file with 32 additions and 34 deletions.
66 changes: 32 additions & 34 deletions tests/test_catalyst.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os

import catalyst
import pytest
import torch
from catalyst import dl
from catalyst.contrib.datasets import MNIST
from catalyst.utils.torch import get_available_engine
from torch import nn, optim
from torch.utils.data import DataLoader

from dvclive import Live
from dvclive.catalyst import DvcLiveCallback
Expand All @@ -14,38 +12,46 @@
# pylint: disable=redefined-outer-name, unused-argument


@pytest.fixture(scope="session")
def loaders(tmp_path_factory):
path = tmp_path_factory.mktemp("catalyst_mnist")
train_data = MNIST(path, train=True, download=True)
valid_data = MNIST(path, train=False, download=True)
return {
"train": DataLoader(train_data, batch_size=32),
"valid": DataLoader(valid_data, batch_size=32),
}


@pytest.fixture
def runner():
return dl.SupervisedRunner(
engine=get_available_engine(),
engine=catalyst.utils.torch.get_available_engine(cpu=True),
input_key="features",
output_key="logits",
target_key="targets",
loss_key="loss",
)


def test_catalyst_callback(tmp_dir, runner, loaders):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)
# see:
# https://github.com/catalyst-team/catalyst/blob/e99f9/tests/catalyst/callbacks/test_batch_overfit.py
@pytest.fixture
def runner_params():
from torch.utils.data import DataLoader, TensorDataset

catalyst.utils.set_global_seed(42)
num_samples, num_features = int(32e1), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=0)
loaders = {"train": loader, "valid": loader}

model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"scheduler": scheduler,
"loaders": loaders,
}


def test_catalyst_callback(tmp_dir, runner, runner_params):
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
**runner_params,
num_epochs=2,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets"),
Expand All @@ -70,17 +76,9 @@ def test_catalyst_callback(tmp_dir, runner, loaders):
assert any("accuracy" in x.name for x in valid_path.iterdir())


def test_catalyst_model_file(tmp_dir, runner, loaders):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

def test_catalyst_model_file(tmp_dir, runner, runner_params):
runner.train(
model=model,
engine=runner.engine,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
**runner_params,
num_epochs=2,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets"),
Expand Down

0 comments on commit 2d61715

Please sign in to comment.