Skip to content

Commit

Permalink
6124-add-training-attribute-check (#6132)
Browse files Browse the repository at this point in the history
Fixes #6124 .

### Description

When running the inference with torchscript wrapped TensorRT models, the
evaluator would give an error. This is caused by the `with
engine.mode()` code run the `training` method of `engine.network`
without checking. In this PR, an attribute check has been added to cover
this issue.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: binliu <[email protected]>
  • Loading branch information
binliunls authored Mar 14, 2023
1 parent a8302ec commit 0a904fb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
1 change: 0 additions & 1 deletion monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

# execute forward computation
with engine.mode(engine.network):
if engine.amp:
Expand Down
18 changes: 11 additions & 7 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,17 +375,19 @@ def eval_mode(*nets: nn.Module):
print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated
"""

# Get original state of network(s)
training = [n for n in nets if n.training]
# Get original state of network(s).
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
training = [n for n in nets if hasattr(n, "training") and n.training]

try:
# set to eval mode
with torch.no_grad():
yield [n.eval() for n in nets]
yield [n.eval() if hasattr(n, "eval") else n for n in nets]
finally:
# Return required networks to training
for n in training:
n.train()
if hasattr(n, "train"):
n.train()


@contextmanager
Expand All @@ -410,16 +412,18 @@ def train_mode(*nets: nn.Module):
"""

# Get original state of network(s)
eval_list = [n for n in nets if not n.training]
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)]

try:
# set to train mode
with torch.set_grad_enabled(True):
yield [n.train() for n in nets]
yield [n.train() if hasattr(n, "train") else n for n in nets]
finally:
# Return required networks to eval_list
for n in eval_list:
n.eval()
if hasattr(n, "eval"):
n.eval()


def get_state_dict(obj: torch.nn.Module | Mapping):
Expand Down

0 comments on commit 0a904fb

Please sign in to comment.