Skip to content

Commit

Permalink
elementwise sub with proper const handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ranhomri committed Nov 3, 2024
1 parent 185db00 commit 11ca70c
Showing 1 changed file with 80 additions and 45 deletions.
125 changes: 80 additions & 45 deletions onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,38 +99,6 @@ def convert_elementwise_add(node, params, layers, lambda_func, node_name, keras_
layers[node_name] = input_0 + input_1


# def convert_elementwise_mul(node, params, layers, lambda_func, node_name, keras_name):
# """
# Convert element-wise mul.
# :param node: current operation node
# :param params: operation attributes
# :param layers: available keras layers
# :param lambda_func: function for keras Lambda layer
# :param node_name: internal converter name
# :param keras_name: resulting layer name
# :return: None
# """
# logger = logging.getLogger('onnx2keras.mul')

# if len(node.input) != 2:
# raise AttributeError('Number of inputs is not equal 2 for element-wise layer')

# input_0 = layers[node.input[0]]
# input_1 = layers[node.input[1]]

# input_0_is_constant = is_numpy(input_0) or isinstance(input_0, EagerTensor)
# input_1_is_constant = is_numpy(input_1) or isinstance(input_1, EagerTensor)
# try:
# if not input_0_is_constant and not input_1_is_constant:
# mul = keras.layers.Multiply(name=f"{params['cleaned_name']}_mul")
# layers[node_name] = mul([input_0, input_1])
# else:
# raise ValueError('Operands are different.')

# except (IndexError, ValueError):
# logger.warning('Failed to use keras.layers.Multiply. Fallback to TF lambda.')
# layers[node_name] = input_0 * input_1

def convert_elementwise_mul(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert element-wise mul.
Expand Down Expand Up @@ -197,16 +165,44 @@ def convert_elementwise_mul(node, params, layers, lambda_func, node_name, keras_
# Both inputs are constants; compute the result now
layers[node_name] = input_0 * input_1

# def convert_elementwise_sub(node, params, layers, lambda_func, node_name, keras_name):
# """
# Convert element-wise sub.
# :param node: current operation node
# :param params: operation attributes
# :param layers: available keras layers
# :param lambda_func: function for keras Lambda layer
# :param node_name: internal converter name
# :param keras_name: resulting layer name
# :return: None
# """
# logger = logging.getLogger('onnx2keras.sub')

# if len(node.input) != 2:
# raise AttributeError('Number of inputs is not equal 2 for element-wise layer')

# input_0 = layers[node.input[0]]
# input_1 = layers[node.input[1]]
# input_0_is_np = is_numpy(input_0) or isinstance(input_0, EagerTensor)
# input_1_is_np = is_numpy(input_1) or isinstance(input_1, EagerTensor)

# try:
# if not input_0_is_np and not input_1_is_np:
# sub = keras.layers.Subtract(name=f"{params['cleaned_name']}_sub")
# layers[node_name] = sub([input_0, input_1])
# else:
# raise ValueError('Operands are different.')

# except (IndexError, ValueError):
# logger.warning('Failed to use keras.layers.Subtract. Fallback to TF lambda.')
# if input_0_is_np and not input_1_is_np: # constant - tensor does not parse well
# layers[node_name] = - (input_1 - input_0)
# else:
# layers[node_name] = input_0 - input_1

def convert_elementwise_sub(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert element-wise sub.
:param node: current operation node
:param params: operation attributes
:param layers: available keras layers
:param lambda_func: function for keras Lambda layer
:param node_name: internal converter name
:param keras_name: resulting layer name
:return: None
"""
logger = logging.getLogger('onnx2keras.sub')

Expand All @@ -215,24 +211,63 @@ def convert_elementwise_sub(node, params, layers, lambda_func, node_name, keras_

input_0 = layers[node.input[0]]
input_1 = layers[node.input[1]]
input_0_is_np = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_np = is_numpy(input_1) or isinstance(input_1, EagerTensor)

input_0_is_constant = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_constant = is_numpy(input_1) or isinstance(input_1, EagerTensor)

try:
if not input_0_is_np and not input_1_is_np:
if not input_0_is_constant and not input_1_is_constant:
sub = keras.layers.Subtract(name=f"{params['cleaned_name']}_sub")
layers[node_name] = sub([input_0, input_1])
else:
raise ValueError('Operands are different.')

except (IndexError, ValueError):
logger.warning('Failed to use keras.layers.Subtract. Fallback to TF lambda.')
if input_0_is_np and not input_1_is_np: # constant - tensor does not parse well
layers[node_name] = - (input_1 - input_0)
logger.warning('Failed to use keras.layers.Subtract. Fallback to Lambda layer.')

if input_0_is_constant and not input_1_is_constant:
# input_0 is constant, input_1 is variable: constant - variable
constant_value = np.asarray(tf.cast(input_0, dtype=input_1.dtype))
variable_input = input_1

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: const_val - x,
name=keras_name
)(variable_input)
else:
# Cannot avoid embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: constant_value - x,
name=keras_name
)(variable_input)

elif not input_0_is_constant and input_1_is_constant:
# input_0 is variable, input_1 is constant: variable - constant
constant_value = np.asarray(tf.cast(input_1, dtype=input_0.dtype))
variable_input = input_0

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: x - const_val,
name=keras_name
)(variable_input)
else:
# Cannot avoid embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: x - constant_value,
name=keras_name
)(variable_input)
else:
# Both inputs are constants; compute the result now
layers[node_name] = input_0 - input_1



def convert_min(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert Min layer
Expand Down

0 comments on commit 11ca70c

Please sign in to comment.