Skip to content

Commit

Permalink
Looks like moving the model.train() and model.eval() here should work.
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag committed Oct 23, 2024
1 parent c97284e commit 411de9d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 6 deletions.
3 changes: 0 additions & 3 deletions src/fibad/models/example_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@ def train_step(self, batch):
Current loss value
The loss value for the current batch.
"""
# Set the model to train mode
self.train()

# When we run on a supervised dataset like CIFAR10, drop the labels given by the data loader
x = batch[0] if isinstance(batch, tuple) else batch

Expand Down
3 changes: 0 additions & 3 deletions src/fibad/models/example_cnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def train_step(self, batch):
Current loss value
The loss value for the current batch.
"""
# Set the model to train mode
self.train()

inputs, labels = batch

self.optimizer.zero_grad()
Expand Down
2 changes: 2 additions & 0 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def create_evaluator(model: torch.nn.Module, save_function: Callable[[torch.Tens
Engine object which when run will evaluate the model.
"""
device = idist.device()
model.eval()

Check warning on line 141 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L141

Added line #L141 was not covered by tests
model = idist.auto_model(model)
evaluator = create_engine("forward", device, model)

Expand Down Expand Up @@ -178,6 +179,7 @@ def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory
Engine object that will be used to train the model.
"""
device = idist.device()
model.train()

Check warning on line 182 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L182

Added line #L182 was not covered by tests
model = idist.auto_model(model)
trainer = create_engine("train_step", device, model)

Expand Down

0 comments on commit 411de9d

Please sign in to comment.