diff --git a/tests/torch/test_utils.py b/tests/torch/test_utils.py index 326ca6f..4c10424 100644 --- a/tests/torch/test_utils.py +++ b/tests/torch/test_utils.py @@ -42,7 +42,9 @@ def test_train_fn(): model = CNNClassifier((224, 224), 5, 3, [13, 14], [15, 16], n_classes=3).to(device) loss_fn = torch.nn.CrossEntropyLoss() - iter_losses, epoch_losses = train_fn(model, loss_fn, inputs.to(device), outputs.to(device), lr, n_epochs) + iter_losses, epoch_losses = train_fn( + model, loss_fn, input=inputs.to(device), output=outputs.to(device), lr=lr, epochs=n_epochs + ) assert epoch_losses[-1] < epoch_losses[0], "Loss should decrease" assert len(iter_losses) == n_epochs