diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index c30478a1352..2f0bc2caadc 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -29,7 +29,6 @@ numpy_helper = LazyImport("onnx.numpy_helper") onnx_proto = LazyImport("onnx.onnx_pb") torch = LazyImport("torch") -onnxruntime = LazyImport("onnxruntime") symbolic_shape_infer = LazyImport("onnxruntime.tools.symbolic_shape_infer") onnx = LazyImport("onnx") @@ -599,40 +598,36 @@ def to_numpy(data): else: return data +def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""): + """Symbolic shape inference.""" + + class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference): + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""): + super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix) + self.base_dir = base_dir + + def _get_value(self, node, idx): + name = node.input[idx] + assert name in self.sympy_data_ or name in self.initializers_ + return ( + self.sympy_data_[name] + if name in self.sympy_data_ + else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir) + ) -class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference): - """Shape inference for ONNX model.""" - - def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""): - """Initialize Shape inference class.""" - super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix) - self.base_dir = base_dir - - def _get_value(self, node, idx): - name = node.input[idx] - assert name in self.sympy_data_ or name in self.initializers_ - return ( - self.sympy_data_[name] - if name in self.sympy_data_ - else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir) - ) - - @staticmethod - def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""): - """Symbolic shape inference.""" - onnx_opset = symbolic_shape_infer.get_opset(in_mp) - if (not onnx_opset) or onnx_opset < 7: - logger.warning("Only support models of onnx opset 7 and above.") - return None - symbolic_shape_inference = SymbolicShapeInference( - int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir - ) - all_shapes_inferred = False - symbolic_shape_inference._preprocess(in_mp) - while symbolic_shape_inference.run_: - all_shapes_inferred = symbolic_shape_inference._infer_impl() - symbolic_shape_inference._update_output_from_vi() - if not all_shapes_inferred: - onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) - raise Exception("Incomplete symbolic shape inference") - return symbolic_shape_inference.out_mp_ + onnx_opset = symbolic_shape_infer.get_opset(in_mp) + if (not onnx_opset) or onnx_opset < 7: + logger.warning("Only support models of onnx opset 7 and above.") + return None + symbolic_shape_inference = SymbolicShapeInference( + int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir + ) + all_shapes_inferred = False + symbolic_shape_inference._preprocess(in_mp) + while symbolic_shape_inference.run_: + all_shapes_inferred = symbolic_shape_inference._infer_impl() + symbolic_shape_inference._update_output_from_vi() + if not all_shapes_inferred: + onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) + raise Exception("Incomplete symbolic shape inference") + return symbolic_shape_inference.out_mp_ diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index d301964a2eb..ba9f35abf40 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -1045,9 +1045,9 @@ def split_model_with_node( if shape_infer: try: # need ort.GraphOptimizationLevel <= ORT_ENABLE_BASIC - from neural_compressor.adaptor.ox_utils.util import SymbolicShapeInference + from neural_compressor.adaptor.ox_utils.util import infer_shapes - self._model = SymbolicShapeInference.infer_shapes( + self._model = infer_shapes( self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path) ) except Exception as e: # pragma: no cover