diff --git a/sne4onnx/__init__.py b/sne4onnx/__init__.py index 573cb3f..a44eb25 100644 --- a/sne4onnx/__init__.py +++ b/sne4onnx/__init__.py @@ -1,3 +1,3 @@ from sne4onnx.onnx_network_extraction import extraction, main -__version__ = '1.0.11' +__version__ = '1.0.12' diff --git a/sne4onnx/onnx_network_extraction.py b/sne4onnx/onnx_network_extraction.py index e63d739..f571347 100644 --- a/sne4onnx/onnx_network_extraction.py +++ b/sne4onnx/onnx_network_extraction.py @@ -119,6 +119,10 @@ def extraction( if node.domain not in ONNX_STANDARD_DOMAINS ] + # domain, ir_version + domain: str = onnx_graph.domain + ir_version: int = onnx_graph.ir_version + graph = gs.import_onnx(onnx_graph) graph.cleanup().toposort() @@ -172,9 +176,9 @@ def extraction( # Shape Estimation extracted_graph = None try: - extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) + extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})) except Exception as e: - extracted_graph = gs.export_onnx(graph) + extracted_graph = gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}) if not non_verbose: print( f'{Color.YELLOW}WARNING:{Color.RESET} '+