From 94329dfc4d7f7fb28b8dcf761914782af0027683 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Fri, 1 Jul 2022 10:48:12 +0800 Subject: [PATCH] Nano : add length check in trace (#4972) * add length check * fix style * fix * modify as reviews * fix style --- .../src/bigdl/nano/deps/onnxruntime/core/onnxruntime_model.py | 3 +++ python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/nano/src/bigdl/nano/deps/onnxruntime/core/onnxruntime_model.py b/python/nano/src/bigdl/nano/deps/onnxruntime/core/onnxruntime_model.py index f9d4620532d..6a88cca9e61 100644 --- a/python/nano/src/bigdl/nano/deps/onnxruntime/core/onnxruntime_model.py +++ b/python/nano/src/bigdl/nano/deps/onnxruntime/core/onnxruntime_model.py @@ -30,6 +30,9 @@ def forward_step(self, *inputs): ''' This function run through the onnxruntime forwarding step ''' + invalidInputError(len(self._forward_args) >= len(inputs), "The length of inputs is " + "inconsistent with the length of ONNX Runtime session's inputs, " + "there may be some redundant inputs.") inputs = dict(zip(self._forward_args, inputs)) ort_outs = self.ortsess.run(None, inputs) return ort_outs diff --git a/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py b/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py index fcc397aca07..faec6f958a5 100644 --- a/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py +++ b/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py @@ -321,7 +321,7 @@ def quantize(model, # remove the type requirement for type checking elif accelerator == 'openvino': model_type = type(model).__name__ if not model_type == 'PytorchOpenVINOModel': - if not input_sample: + if input_sample is None: # input_sample can be a dataloader input_sample = calib_dataloader model = Trainer.trace(model,