Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Nov 28, 2023
1 parent 344a9d2 commit ad7818c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/torch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_train_fn():

model = CNNClassifier((224, 224), 5, 3, [13, 14], [15, 16], n_classes=3).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
iter_losses, epoch_losses = train_fn(model, loss_fn, inputs.to(device), outputs.to(device), lr, n_epochs)
iter_losses, epoch_losses = train_fn(
model, loss_fn, input=inputs.to(device), output=outputs.to(device), lr=lr, epochs=n_epochs
)

assert epoch_losses[-1] < epoch_losses[0], "Loss should decrease"
assert len(iter_losses) == n_epochs
Expand Down

0 comments on commit ad7818c

Please sign in to comment.