Skip to content

Commit

Permalink
conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
ranhomri committed Aug 6, 2024
1 parent 86a6fe7 commit 74321ec
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions onnx2kerastl/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf
from keras_data_format_converter import convert_channels_first_to_last

def convert_onnx_to_keras(onnx_model_path):
def convert_onnx_to_keras(onnx_model_path, transform_io:bool = True):
# Load ONNX model
save_model_path = onnx_model_path.replace('.onnx', '.h5')
onnx_model = onnx.load(onnx_model_path)
Expand All @@ -18,7 +18,7 @@ def convert_onnx_to_keras(onnx_model_path):
name_policy='attach_weights_name', allow_partial_compilation=False).converted_model

# Convert from channels-first to channels-last format
final_model = convert_channels_first_to_last(keras_model, should_transform_inputs_and_outputs=False,
final_model = convert_channels_first_to_last(keras_model, should_transform_inputs_and_outputs=transform_io,
verbose=True)

# Save the final Keras model
Expand All @@ -29,8 +29,9 @@ def convert_onnx_to_keras(onnx_model_path):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert ONNX model to Keras')
parser.add_argument('onnx_path', type=str, help='Path to the input ONNX model')
parser.add_argument('transform_input_output', type=bool, help='Whether to transform input and output data format')
args = parser.parse_args()

# Convert input_shape string to tuple of integers

convert_onnx_to_keras(args.onnx_path)
convert_onnx_to_keras(args.onnx_path, args.transform_input_output)

0 comments on commit 74321ec

Please sign in to comment.