From cefda8797f73014de5865a49def0778a064cf02d Mon Sep 17 00:00:00 2001 From: Drew Oldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:03:39 -0700 Subject: [PATCH] Set default models to `training` mode in the `train_step`. (#103) * It turns out that we have to set the model to `training` mode to actually train it. * Looks like moving the model.train() and model.eval() here should work. --- src/fibad/pytorch_ignite.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fibad/pytorch_ignite.py b/src/fibad/pytorch_ignite.py index 9028352..cf6d25e 100644 --- a/src/fibad/pytorch_ignite.py +++ b/src/fibad/pytorch_ignite.py @@ -138,6 +138,7 @@ def create_evaluator(model: torch.nn.Module, save_function: Callable[[torch.Tens Engine object which when run will evaluate the model. """ device = idist.device() + model.eval() model = idist.auto_model(model) evaluator = create_engine("forward", device, model) @@ -178,6 +179,7 @@ def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory Engine object that will be used to train the model. """ device = idist.device() + model.train() model = idist.auto_model(model) trainer = create_engine("train_step", device, model)