Skip to content

Commit

Permalink
Merge pull request #2 from PINTO0309/support_custom_domain
Browse files Browse the repository at this point in the history
Support for models with custom domains
  • Loading branch information
PINTO0309 authored Jan 2, 2023
2 parents d29118d + f051dde commit 976418b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sam4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sam4onnx.onnx_attr_const_modify import modify, main

__version__ = '1.0.11'
__version__ = '1.0.12'
32 changes: 32 additions & 0 deletions sam4onnx/onnx_attr_const_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class Color:
'complex128': np.complex128,
}

ONNX_STANDARD_DOMAINS = [
'ai.onnx',
'ai.onnx.ml',
'',
]

def __search_op_constant_from_input_constant_name(
graph: onnx_graphsurgeon.Graph,
Expand Down Expand Up @@ -223,8 +228,27 @@ def modify(
# onnx_graph If specified, onnx_graph is processed first
if not onnx_graph:
onnx_graph = onnx.load(input_onnx_file_path)

# Acquisition of Node with custom domain
custom_domain_check_onnx_nodes = []
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
[
node for node in onnx_graph.graph.node \
if node.domain not in ONNX_STANDARD_DOMAINS
]

graph = gs.import_onnx(onnx_graph)

# Check if Graph contains a custom domain (custom module)
contains_custom_domain = len(
[
domain \
for domain in graph.import_domains \
if domain.domain not in ONNX_STANDARD_DOMAINS
]
) > 0

# Search for OPs matching op_name
node_subject_to_change = None
if op_name:
Expand Down Expand Up @@ -355,6 +379,14 @@ def modify(
tracetxt = traceback.format_exc().splitlines()[-1]
print(f'{Color.YELLOW}WARNING:{Color.RESET} {tracetxt}')

## Restore a node's custom domain
if contains_custom_domain:
new_model_nodes = new_model.graph.node
for new_model_node in new_model_nodes:
for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes:
if new_model_node.name == custom_domain_check_onnx_node.name:
new_model_node.domain = custom_domain_check_onnx_node.domain

# Save
if output_onnx_file_path:
onnx.save(new_model, f'{output_onnx_file_path}')
Expand Down

0 comments on commit 976418b

Please sign in to comment.