Skip to content

Commit

Permalink
fix import bug
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Dec 1, 2023
1 parent 966aa9b commit 1f10a92
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 39 deletions.
69 changes: 32 additions & 37 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
numpy_helper = LazyImport("onnx.numpy_helper")
onnx_proto = LazyImport("onnx.onnx_pb")
torch = LazyImport("torch")
onnxruntime = LazyImport("onnxruntime")
symbolic_shape_infer = LazyImport("onnxruntime.tools.symbolic_shape_infer")
onnx = LazyImport("onnx")

Expand Down Expand Up @@ -599,40 +598,36 @@ def to_numpy(data):
else:
return data

def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""):
"""Symbolic shape inference."""

class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference):
def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""):
super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix)
self.base_dir = base_dir

def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return (
self.sympy_data_[name]
if name in self.sympy_data_
else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir)
)

class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference):
"""Shape inference for ONNX model."""

def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""):
"""Initialize Shape inference class."""
super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix)
self.base_dir = base_dir

def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return (
self.sympy_data_[name]
if name in self.sympy_data_
else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir)
)

@staticmethod
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""):
"""Symbolic shape inference."""
onnx_opset = symbolic_shape_infer.get_opset(in_mp)
if (not onnx_opset) or onnx_opset < 7:
logger.warning("Only support models of onnx opset 7 and above.")
return None
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir
)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(in_mp)
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_
onnx_opset = symbolic_shape_infer.get_opset(in_mp)
if (not onnx_opset) or onnx_opset < 7:
logger.warning("Only support models of onnx opset 7 and above.")
return None
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir
)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(in_mp)
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_
4 changes: 2 additions & 2 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,9 +1045,9 @@ def split_model_with_node(
if shape_infer:
try:
# need ort.GraphOptimizationLevel <= ORT_ENABLE_BASIC
from neural_compressor.adaptor.ox_utils.util import SymbolicShapeInference
from neural_compressor.adaptor.ox_utils.util import infer_shapes

self._model = SymbolicShapeInference.infer_shapes(
self._model = infer_shapes(
self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path)
)
except Exception as e: # pragma: no cover
Expand Down

0 comments on commit 1f10a92

Please sign in to comment.