Skip to content

Commit

Permalink
treat bollean in cons of shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ranhomri committed Dec 5, 2024
1 parent 74cca8b commit 614874f
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions onnx2kerastl/constant_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,23 @@ def convert_constant_of_shape(node, params, layers, lambda_func, node_name, kera
value = params.get('value')
if value is None:
raise NotImplementedError("ConstantOfShape should have a value param")

input_0 = layers[node.input[0]]
if not is_numpy(input_0) and not isinstance(input_0, list) and K.is_keras_tensor(
input_0):
layers[node_name] = tf.ones(layers[node.input[0]], dtype=tf.as_dtype(value.dtype)) * params['value']

if not is_numpy(input_0) and not isinstance(input_0, list) and K.is_keras_tensor(input_0):
# Boolean case
if value.dtype == np.bool_:
layers[node_name] = tf.fill(layers[node.input[0]], tf.constant(value.item(), dtype=tf.bool))
else:
# Non-boolean case
layers[node_name] = tf.ones(layers[node.input[0]], dtype=tf.as_dtype(value.dtype)) * value
else:
layers[node_name] = np.ones(layers[node.input[0]], dtype=value.dtype) * params['value']
# Handle numpy inputs or non-Keras tensors
if value.dtype == np.bool_:
layers[node_name] = np.full(layers[node.input[0]], value.item(), dtype=bool)
else:
layers[node_name] = np.ones(layers[node.input[0]], dtype=value.dtype) * value



def convert_one_hot(node, params, layers, lambda_func, node_name, keras_name):
Expand Down

0 comments on commit 614874f

Please sign in to comment.