Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomkoren21 committed Nov 27, 2024
1 parent 7ec48a1 commit 74cca8b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion onnx2kerastl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def onnx_to_keras(onnx_model, input_names, name_policy=None, verbose=True, chang
else:
keras_names = keras_names[0]
node_names.append(keras_names)
pattern = r'[#:]' # Example pattern to match #, /, and :
pattern = r'[#:@]' # Example pattern to match #, /, and :
cleaned_node_name = re.sub(pattern, '_', node.name.rstrip("/").lstrip("/"))
if len(cleaned_node_name) == 0:
cleaned_node_name = re.sub(pattern, '_', node_name.rstrip("/").lstrip("/"))
Expand Down
7 changes: 5 additions & 2 deletions onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,11 @@ def convert_where(node, params, layers, lambda_func, node_name, keras_name):
layers[node.input[1]],
tf_name=f"{params['cleaned_name']}_where_1")
else:
layers[node_name] = tf_where(casted, layers[node.input[1]], layers[node.input[2]],
tf_name=f"{params['cleaned_name']}_where_2")
try:
layers[node_name] = tf_where(casted, layers[node.input[1]], layers[node.input[2]],
tf_name=f"{params['cleaned_name']}_where_2")
except Exception as e:
print(1)


def convert_scatter_nd(node, params, layers, lambda_func, node_name, keras_name):
Expand Down
11 changes: 8 additions & 3 deletions onnx2kerastl/pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,16 @@ def target_layer(composed_input, to_sort=to_sort, axis=axis):
out = tf.transpose(topk_concat, ord_permute)
return out
in_shape = tf_shape(in_tensor, tf_name=f"{params['cleaned_name']}_topk_in_shape")
k_needed_shape = tf_concat(
[(in_shape)[:-1],[1]], axis=-1,tf_name=f"{params['cleaned_name']}_topk_k_needed_shape")._inferred_value
k_needed_shape_possible_keras_tensor = tf_concat(
[(in_shape)[:-1],[1]], axis=-1,tf_name=f"{params['cleaned_name']}_topk_k_needed_shape")

if hasattr(k_needed_shape_possible_keras_tensor, "_inferred_value"): #is keras tensor
k_needed_shape = k_needed_shape_possible_keras_tensor._inferred_value
else:
k_needed_shape = k_needed_shape_possible_keras_tensor
k_unsqueezed = tf_ones(k_needed_shape, tf_name=f"{params['cleaned_name']}_topk_k_shape")*\
tf_cast(k, tf.float32, tf_name=f"{params['cleaned_name']}_topk_k_cast")
k_reshaped = tf_cast(k_unsqueezed, tf.float32, tf_name=f"{params['cleaned_name']}_topk_k_reshaped")
k_reshaped = tf_cast(k_unsqueezed, in_tensor.dtype, tf_name=f"{params['cleaned_name']}_topk_k_reshaped")
composed_input = tf_concat([in_tensor, k_reshaped], axis=-1,
tf_name=f"{params['cleaned_name']}_topk_k_concat")
lambda_layer = keras.layers.Lambda(target_layer, name=f"{params['cleaned_name']}_topk")
Expand Down
4 changes: 4 additions & 0 deletions onnx2kerastl/reshape_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,13 +670,17 @@ def convert_expand(node, params, layers, lambda_func, node_name, keras_name):

input_0 = ensure_tf_type(layers[node.input[0]], name="%s_const" % keras_name)
input_1 = layers[node.input[1]]
was_converted_from_bool = False
if input_0.dtype.is_bool:
input_0 = tf_cast(input_0, dtype='int32', tf_name=f"{params['cleaned_name']}_bool_to_int")
was_converted_from_bool = True
multiply_res = input_0 * tf_ones(shape=input_1, dtype=input_0.dtype,
tf_name=f"{params['cleaned_name']}_expand_use_ones")
# input_0.dtype == np.int32 since we can't serialize constants as int64, need to cast to true type
if layers[node.input[0]].dtype == np.int64:
multiply_res = tf_cast(multiply_res, tf.int64, tf_name=f"{params['cleaned_name']}_to_int64")
if was_converted_from_bool:
multiply_res = tf_cast(multiply_res, tf.bool, tf_name=f"{params['cleaned_name']}_int_to_bool")
layers[node_name] = multiply_res


Expand Down

0 comments on commit 74cca8b

Please sign in to comment.