Skip to content

Commit

Permalink
Set default models to training mode in the train_step. (#103)
Browse files Browse the repository at this point in the history
* It turns out that we have to set the model to `training` mode to actually train it.

* Looks like moving the model.train() and model.eval() here should work.
  • Loading branch information
drewoldag authored Oct 23, 2024
1 parent dad7062 commit cefda87
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def create_evaluator(model: torch.nn.Module, save_function: Callable[[torch.Tens
Engine object which when run will evaluate the model.
"""
device = idist.device()
model.eval()
model = idist.auto_model(model)
evaluator = create_engine("forward", device, model)

Expand Down Expand Up @@ -178,6 +179,7 @@ def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory
Engine object that will be used to train the model.
"""
device = idist.device()
model.train()
model = idist.auto_model(model)
trainer = create_engine("train_step", device, model)

Expand Down

0 comments on commit cefda87

Please sign in to comment.