diff --git a/onnx2kerastl/customonnxlayer/onnxeinsum.py b/onnx2kerastl/customonnxlayer/onnxeinsum.py new file mode 100644 index 00000000..ad48dea0 --- /dev/null +++ b/onnx2kerastl/customonnxlayer/onnxeinsum.py @@ -0,0 +1,40 @@ +from keras.layers import Layer, TFOpLambda +import tensorflow as tf +import numpy as np + + +# this custom layer needed because of a tensorflow bug on einsum serielization +class OnnxEinsumLayer(Layer): + """ + Layer wrapping a single tf.einsum operation. + + Usage: + x = EinsumLayer("bmhwf,bmoh->bmowf")((x1, x2)) + """ + + def __init__(self, equation: str, constant_input, constant_place): + super().__init__() + self.equation = equation + if constant_input is not None: + if hasattr(constant_input, 'numpy'): + constant_input = constant_input.numpy() + if not isinstance(constant_input, np.ndarray): + constant_input = np.array(constant_input) + self.constant_input = constant_input + self.constant_place = constant_place + + def call(self, inputs, *args, **kwargs): + if self.constant_input is not None: + if self.constant_place == 1: + inputs = [inputs, self.constant_input] + else: + inputs = [self.constant_input, inputs] + + return tf.einsum(self.equation, *inputs) + + def get_config(self): + return { + "equation": self.equation, + "constant_input": self.constant_input, + "constant_place": self.constant_place + } diff --git a/onnx2kerastl/operation_layers.py b/onnx2kerastl/operation_layers.py index f9068bdc..1d814abe 100644 --- a/onnx2kerastl/operation_layers.py +++ b/onnx2kerastl/operation_layers.py @@ -2,9 +2,10 @@ import keras import numpy as np -import tensorflow as tf from keras import backend as K +import tensorflow as tf +from .customonnxlayer.onnxeinsum import OnnxEinsumLayer from .exceptions import UnsupportedLayer from .utils import is_numpy, ensure_tf_type, ensure_float from .tfops_funcs import tf_math_abs, tf_clip_by_value, tf_math_negative, K_mean, tf_math_reduce_prod,\ @@ -889,10 +890,17 @@ def get_empty_array(x, dtype=new_dtype, keras_name=keras_name): def convert_einsum(node, params, layers, lambda_func, node_name, keras_name): - # input_0 = layers[node.input[0]] - # input_1 = layers[node.input[1]] - # equation = params['equation'].decode('utf-8') - # layers[node_name] = tf.einsum(equation, *[input_0, input_1], name=keras_name) - input_0 = tf_expand_dims(layers[node.input[0]], axis=2, tf_name=f"{params['cleaned_name']}_einsum_0") - input_1 = tf_expand_dims(layers[node.input[1]], axis=0, tf_name=f"{params['cleaned_name']}_einsum_1") - layers[node_name] = tf_multiply(input_0, input_1, tf_name=f"{params['cleaned_name']}_einsum_mult") + input_0 = layers[node.input[0]] + input_1 = layers[node.input[1]] + equation = params['equation'].decode('utf-8') + + if isinstance(input_0, tf.Tensor) and isinstance(input_1, tf.Tensor): + layers[node_name] = tf.einsum(equation, *[input_0, input_1], name=keras_name) + elif isinstance(input_0, tf.Tensor): + layers[node_name] = OnnxEinsumLayer(equation, input_0, 0)(input_1, name=keras_name) + elif isinstance(input_1 , tf.Tensor): + layers[node_name] = OnnxEinsumLayer(equation, input_1, 1)(input_0, name=keras_name) + else: + layers[node_name] = OnnxEinsumLayer(equation, None, None)([input_0, input_1], name=keras_name) + + diff --git a/pyproject.toml b/pyproject.toml index fd7facc2..1f7344ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "onnx2kerastl" -version = "0.0.142" +version = "0.0.145" description = "" authors = ["dorhar "] license = "MIT"