diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index 7c1fb6e6..dbb83634 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -1,11 +1,9 @@ import os +import catalyst import pytest +import torch from catalyst import dl -from catalyst.contrib.datasets import MNIST -from catalyst.utils.torch import get_available_engine -from torch import nn, optim -from torch.utils.data import DataLoader from dvclive import Live from dvclive.catalyst import DvcLiveCallback @@ -14,21 +12,10 @@ # pylint: disable=redefined-outer-name, unused-argument -@pytest.fixture(scope="session") -def loaders(tmp_path_factory): - path = tmp_path_factory.mktemp("catalyst_mnist") - train_data = MNIST(path, train=True, download=True) - valid_data = MNIST(path, train=False, download=True) - return { - "train": DataLoader(train_data, batch_size=32), - "valid": DataLoader(valid_data, batch_size=32), - } - - @pytest.fixture def runner(): return dl.SupervisedRunner( - engine=get_available_engine(), + engine=catalyst.utils.torch.get_available_engine(cpu=True), input_key="features", output_key="logits", target_key="targets", @@ -36,16 +23,35 @@ def runner(): ) -def test_catalyst_callback(tmp_dir, runner, loaders): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.02) +# see: +# https://github.com/catalyst-team/catalyst/blob/e99f9/tests/catalyst/callbacks/test_batch_overfit.py +@pytest.fixture +def runner_params(): + from torch.utils.data import DataLoader, TensorDataset + + catalyst.utils.set_global_seed(42) + num_samples, num_features = int(32e1), int(1e1) + X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) + dataset = TensorDataset(X, y) + loader = DataLoader(dataset, batch_size=32, num_workers=0) + loaders = {"train": loader, "valid": loader} + + model = torch.nn.Linear(num_features, 1) + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters()) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) + return { + "model": model, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "loaders": loaders, + } + +def test_catalyst_callback(tmp_dir, runner, runner_params): runner.train( - model=model, - criterion=criterion, - optimizer=optimizer, - loaders=loaders, + **runner_params, num_epochs=2, callbacks=[ dl.AccuracyCallback(input_key="logits", target_key="targets"), @@ -70,17 +76,9 @@ def test_catalyst_callback(tmp_dir, runner, loaders): assert any("accuracy" in x.name for x in valid_path.iterdir()) -def test_catalyst_model_file(tmp_dir, runner, loaders): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.02) - +def test_catalyst_model_file(tmp_dir, runner, runner_params): runner.train( - model=model, - engine=runner.engine, - criterion=criterion, - optimizer=optimizer, - loaders=loaders, + **runner_params, num_epochs=2, callbacks=[ dl.AccuracyCallback(input_key="logits", target_key="targets"),