From 74321ec05684045682252ed6a45898fe81c04fa0 Mon Sep 17 00:00:00 2001 From: ranhomri Date: Tue, 6 Aug 2024 12:02:18 +0300 Subject: [PATCH] conversion script --- onnx2kerastl/convert_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnx2kerastl/convert_model.py b/onnx2kerastl/convert_model.py index 00807804..a012cd6d 100644 --- a/onnx2kerastl/convert_model.py +++ b/onnx2kerastl/convert_model.py @@ -5,7 +5,7 @@ import tensorflow as tf from keras_data_format_converter import convert_channels_first_to_last -def convert_onnx_to_keras(onnx_model_path): +def convert_onnx_to_keras(onnx_model_path, transform_io:bool = True): # Load ONNX model save_model_path = onnx_model_path.replace('.onnx', '.h5') onnx_model = onnx.load(onnx_model_path) @@ -18,7 +18,7 @@ def convert_onnx_to_keras(onnx_model_path): name_policy='attach_weights_name', allow_partial_compilation=False).converted_model # Convert from channels-first to channels-last format - final_model = convert_channels_first_to_last(keras_model, should_transform_inputs_and_outputs=False, + final_model = convert_channels_first_to_last(keras_model, should_transform_inputs_and_outputs=transform_io, verbose=True) # Save the final Keras model @@ -29,8 +29,9 @@ def convert_onnx_to_keras(onnx_model_path): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Convert ONNX model to Keras') parser.add_argument('onnx_path', type=str, help='Path to the input ONNX model') + parser.add_argument('transform_input_output', type=bool, help='Whether to transform input and output data format') args = parser.parse_args() # Convert input_shape string to tuple of integers - convert_onnx_to_keras(args.onnx_path) \ No newline at end of file + convert_onnx_to_keras(args.onnx_path, args.transform_input_output) \ No newline at end of file