From 0a904fbeb19bc2d8b1cef6b813cea070bb2a60cc Mon Sep 17 00:00:00 2001 From: binliunls <107988372+binliunls@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:47:09 +0800 Subject: [PATCH] 6124-add-training-attribute-check (#6132) Fixes #6124 . ### Description When running the inference with torchscript wrapped TensorRT models, the evaluator would give an error. This is caused by the `with engine.mode()` code run the `training` method of `engine.network` without checking. In this PR, an attribute check has been added to cover this issue. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: binliu --- monai/engines/evaluator.py | 1 - monai/networks/utils.py | 18 +++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 43964ee8bc..7c6ddd5bdd 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -295,7 +295,6 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - # execute forward computation with engine.mode(engine.network): if engine.amp: diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d5c0629c05..b79ae8e9bd 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -375,17 +375,19 @@ def eval_mode(*nets: nn.Module): print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated """ - # Get original state of network(s) - training = [n for n in nets if n.training] + # Get original state of network(s). + # Check the training attribute in case it's TensorRT based models which don't have this attribute. + training = [n for n in nets if hasattr(n, "training") and n.training] try: # set to eval mode with torch.no_grad(): - yield [n.eval() for n in nets] + yield [n.eval() if hasattr(n, "eval") else n for n in nets] finally: # Return required networks to training for n in training: - n.train() + if hasattr(n, "train"): + n.train() @contextmanager @@ -410,16 +412,18 @@ def train_mode(*nets: nn.Module): """ # Get original state of network(s) - eval_list = [n for n in nets if not n.training] + # Check the training attribute in case it's TensorRT based models which don't have this attribute. + eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)] try: # set to train mode with torch.set_grad_enabled(True): - yield [n.train() for n in nets] + yield [n.train() if hasattr(n, "train") else n for n in nets] finally: # Return required networks to eval_list for n in eval_list: - n.eval() + if hasattr(n, "eval"): + n.eval() def get_state_dict(obj: torch.nn.Module | Mapping):