This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MXNET-533] MXNet-ONNX export (#11213)
* Resolve conflicts * Export module Test Framework * refactoring export to work with pretrained models * comments added * 1. Refactored export module. 2. Refactored test framework to support ONNX backened tests. 2. Added Operator support: - Convolution2D - BatchNorm - Add * Added Arithmetic operators: - Add, Sub, Mul, Div, Sum * Added operator support: - sigmoid, relu, pad( constant, edge, reflect), tanh - enabled corresponding ONNX backend tests. * Enabled ONNX tests: test_conv, test_basic_conv Added Operators : Ceil, Floor * Added support for: MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul * adding more operators * Added Operator support: ArgMax, ArgMin, maximum, minimum * Enabled more BASIC_MODEL tests * Added power operator tests * Added support for reshape. ONNX only supports 0, -1 special values. Added only for these. Fixed logic error with convert_string_to_list() * some tests enabled * enabling squeezenet * LRN Op support * mul_scalar modified to take scalar input * cleaning some code * Resolving conlicts on rebase * Resolving rebase conflicts * id mapping updated for all operators * save onnx models added, some code cleanup * enabled more tests * conv pad calc fixed * reshape op fix * Added support for elu, leakyRelu, prelu * Cleanup - Removed run_node, not needed anymore. - Used correct get_metadata api * valueinfoproto fix, googlenet test added * Removed redundant code. - run_node - Using correct get_metadata_api * dilation added * Lint fixes * lint fixes * some fixes to make export work with onx1.2.1 * enabled more tests * mxnet_export_test file added * duplicate file deleted * reduce ops added * some small fixes * some lint fixes * Add tests for inception_v1 and inception_v2 * Add CI runs for export module * docstring added * lint fixes, pooling attr fix * fix * fix global_pool * CI run fix * code cleanup * lint fix * some code cleanup * pad in pooling added * slicechannel notimplementederror raised * Added required license comments * Lint fixes * lint fix * lint fix * lint fix * lint fix * Correct license statement * Adding onnx a runtime dependency * Fix import module error for string_types * Making ONNX runtime dependency * fixing some comments * addressing some comments * params rename * lint fixes * fixes * spatial disabled, path fixed * fixing some comments * Added support for remaining act_type(softsign, sigmoid, softrelu) in Activation operator * changing import * adding some comments * Add squeeze op * Refactored logic to handle extra node(output label node) for saved mxnet model Added comments * minor fix for squeeze operator. Also, added error handling * identity operator added * scalar ops added * Renamed onnx support folders to mark it public folders Changed underline files public or private as per usage Resolved conflicts with the latest * Added support L2Normalization op Added some error checking * added comments and warning * added comments and warning * doc API ref added
- Loading branch information
1 parent
9b27262
commit 7d91602
Showing
25 changed files
with
3,028 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
# Based on | ||
# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/# | ||
# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
"""ONNX Export module""" | ||
from __future__ import absolute_import | ||
|
||
from . import export_model | ||
from . import export_onnx | ||
from . import _op_translations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""export helper functions""" | ||
# coding: utf-8 | ||
import os | ||
import logging | ||
import mxnet as mx | ||
|
||
|
||
def load_module(sym_filepath, params_filepath): | ||
"""Loads the MXNet model file and | ||
returns MXNet symbol and params (weights). | ||
Parameters | ||
---------- | ||
json_path : str | ||
Path to the json file | ||
params_path : str | ||
Path to the params file | ||
Returns | ||
------- | ||
sym : MXNet symbol | ||
Model symbol object | ||
params : params object | ||
Model weights including both arg and aux params. | ||
""" | ||
if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)): | ||
raise ValueError("Symbol and params files provided are invalid") | ||
else: | ||
try: | ||
# reads symbol.json file from given path and | ||
# retrieves model prefix and number of epochs | ||
model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0] | ||
params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1) | ||
# Setting num_epochs to 0 if not present in filename | ||
num_epochs = 0 if len(params_file_list) == 1 else int(params_file_list[1]) | ||
except IndexError: | ||
logging.info("Model and params name should be in format: " | ||
"prefix-symbol.json, prefix-epoch.params") | ||
raise | ||
|
||
sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) | ||
|
||
# Merging arg and aux parameters | ||
params = {} | ||
params.update(arg_params) | ||
params.update(aux_params) | ||
|
||
return sym, params |
Oops, something went wrong.