diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index a0c7ca9084..e1ccba2294 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -23,7 +23,8 @@ class TestEvaluator(Evaluator): def _iteration(self, engine, batchdata): - pass + engine.state.output = "called" + return engine.state.output class TestHandlerValidation(unittest.TestCase): @@ -44,6 +45,7 @@ def _train_func(engine, batch): engine.run(data, max_epochs=1) self.assertEqual(evaluator.state.max_epochs, 1) self.assertEqual(evaluator.state.epoch_length, 8) + self.assertEqual(evaluator.state.output, "called") engine.run(data, max_epochs=5) self.assertEqual(evaluator.state.max_epochs, 4)