Skip to content

Commit

Permalink
fix einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
tip3x committed Sep 19, 2024
1 parent 396519b commit dbe6329
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
40 changes: 40 additions & 0 deletions onnx2kerastl/customonnxlayer/onnxeinsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from keras.layers import Layer, TFOpLambda
import tensorflow as tf
import numpy as np


# this custom layer needed because of a tensorflow bug on einsum serielization
class OnnxEinsumLayer(Layer):
"""
Layer wrapping a single tf.einsum operation.
Usage:
x = EinsumLayer("bmhwf,bmoh->bmowf")((x1, x2))
"""

def __init__(self, equation: str, constant_input, constant_place):
super().__init__()
self.equation = equation
if constant_input is not None:
if hasattr(constant_input, 'numpy'):
constant_input = constant_input.numpy()
if not isinstance(constant_input, np.ndarray):
constant_input = np.array(constant_input)
self.constant_input = constant_input
self.constant_place = constant_place

def call(self, inputs, *args, **kwargs):
if self.constant_input is not None:
if self.constant_place == 1:
inputs = [inputs, self.constant_input]
else:
inputs = [self.constant_input, inputs]

return tf.einsum(self.equation, *inputs)

def get_config(self):
return {
"equation": self.equation,
"constant_input": self.constant_input,
"constant_place": self.constant_place
}
26 changes: 18 additions & 8 deletions onnx2kerastl/operation_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import keras
import numpy as np
import tensorflow as tf
from keras import backend as K
import tensorflow as tf

from .customonnxlayer.onnxeinsum import OnnxEinsumLayer
from .exceptions import UnsupportedLayer
from .utils import is_numpy, ensure_tf_type, ensure_float
from .tfops_funcs import tf_math_abs, tf_clip_by_value, tf_math_negative, K_mean, tf_math_reduce_prod,\
Expand Down Expand Up @@ -889,10 +890,19 @@ def get_empty_array(x, dtype=new_dtype, keras_name=keras_name):


def convert_einsum(node, params, layers, lambda_func, node_name, keras_name):
# input_0 = layers[node.input[0]]
# input_1 = layers[node.input[1]]
# equation = params['equation'].decode('utf-8')
# layers[node_name] = tf.einsum(equation, *[input_0, input_1], name=keras_name)
input_0 = tf_expand_dims(layers[node.input[0]], axis=2, tf_name=f"{params['cleaned_name']}_einsum_0")
input_1 = tf_expand_dims(layers[node.input[1]], axis=0, tf_name=f"{params['cleaned_name']}_einsum_1")
layers[node_name] = tf_multiply(input_0, input_1, tf_name=f"{params['cleaned_name']}_einsum_mult")
input_0 = layers[node.input[0]]
input_1 = layers[node.input[1]]
equation = params['equation'].decode('utf-8')

is_input_0_constant = isinstance(input_0, tf.Tensor)
is_input_1_constant = isinstance(input_1, tf.Tensor)
if is_input_0_constant and is_input_1_constant:
layers[node_name] = tf.einsum(equation, *[input_0, input_1], name=keras_name)
elif is_input_0_constant:
layers[node_name] = OnnxEinsumLayer(equation, input_0, 0)(input_1, name=keras_name)
elif is_input_1_constant:
layers[node_name] = OnnxEinsumLayer(equation, input_1, 1)(input_0, name=keras_name)
else:
layers[node_name] = OnnxEinsumLayer(equation, None, None)([input_0, input_1], name=keras_name)


2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "onnx2kerastl"
version = "0.0.142"
version = "0.0.145"
description = ""
authors = ["dorhar <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit dbe6329

Please sign in to comment.