Skip to content

Commit

Permalink
Merge pull request #4 from PINTO0309/fix_irversion
Browse files Browse the repository at this point in the history
Fix to preserve `domain` and `ir_version`
  • Loading branch information
PINTO0309 authored Apr 30, 2024
2 parents fdb1cae + 051ab50 commit 2c7056f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sne4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sne4onnx.onnx_network_extraction import extraction, main

__version__ = '1.0.11'
__version__ = '1.0.12'
8 changes: 6 additions & 2 deletions sne4onnx/onnx_network_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def extraction(
if node.domain not in ONNX_STANDARD_DOMAINS
]

# domain, ir_version
domain: str = onnx_graph.domain
ir_version: int = onnx_graph.ir_version

graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()

Expand Down Expand Up @@ -172,9 +176,9 @@ def extraction(
# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
except Exception as e:
extracted_graph = gs.export_onnx(graph)
extracted_graph = gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})
if not non_verbose:
print(
f'{Color.YELLOW}WARNING:{Color.RESET} '+
Expand Down

0 comments on commit 2c7056f

Please sign in to comment.