Skip to content

Commit

Permalink
Add swin (#165)
Browse files Browse the repository at this point in the history
* add tf_name to ones

* add tf_pad for kerasTensor (dynamic) padding

* use tf.add that support correct broadcasting instead of layers.Add

* add LayerNorm

* add swin test
  • Loading branch information
tomkoren21 authored Aug 6, 2024
1 parent 75a232a commit 5942e44
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 12 deletions.
14 changes: 8 additions & 6 deletions onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .utils import is_numpy, ensure_tf_type
from .tfops_funcs import tf_tensor_scatter_nd_update, tf_maximum, tf_minimum, tf_cast, tf_expand_dims, tf_repeat,\
tf_equal, tf_where, tf_round, tf_sign, tf_abs, tf_math_mod, tf_bitwise_left_shift, tf_bitwise_right_shift,\
tf_logical_not
tf_logical_not, tf_add
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor

Expand Down Expand Up @@ -81,11 +81,13 @@ def convert_elementwise_add(node, params, layers, lambda_func, node_name, keras_
try:
if not input_0_is_non_keras and not input_1_is_non_keras:
to_add = input_1
if input_0.shape != input_1.shape and input_0.shape[:-1] == input_1.shape:
to_add = tf_repeat(tf_expand_dims(input_1, axis=-1, tf_name=f"{params['cleaned_name']}_expand"),
input_0.shape[-1], axis=-1, tf_name=f"{params['cleaned_name']}_repeat")

layers[node_name] = keras.layers.Add(name=f"{params['cleaned_name']}_add")([input_0, to_add])
# We probably need to seperate two possibilities here. Currently we only deal with second option
# [Batch] + [Batch,1] -> [Batch,1]
# [Not-Batch] + [Not,Batch,1] -> [Not-batch, Not-batch]
if len(input_0.shape) != len(input_1.shape):
layers[node_name] = tf_add(input_0, to_add, tf_name=f"{params['cleaned_name']}_add")
else:
layers[node_name] = keras.layers.Add(name=f"{params['cleaned_name']}_add")([input_0, to_add])
else:
raise ValueError('Operands are different.')
except (IndexError, ValueError):
Expand Down
5 changes: 3 additions & 2 deletions onnx2kerastl/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
convert_concat, convert_reshape, convert_flatten, convert_slice, convert_squeeze, convert_expand, convert_resize, \
convert_tile, convert_gather_elements
from .constant_layers import convert_constant, convert_constant_of_shape, convert_one_hot
from .normalization_layers import convert_batchnorm, convert_instancenorm, convert_dropout, convert_lrn
from .normalization_layers import convert_batchnorm, convert_instancenorm, convert_dropout, convert_lrn, convert_layernorm
from .pooling_layers import convert_avgpool, convert_maxpool, convert_global_avg_pool, convert_topk, convert_roi_align
from .padding_layers import convert_padding
from .upsampling_layers import convert_upsample
Expand Down Expand Up @@ -152,5 +152,6 @@
'Unique': convert_unique,
'If': convert_if,
'RoiAlign': convert_roi_align,
'Einsum': convert_einsum
'Einsum': convert_einsum,
'LayerNormalization': convert_layernorm
}
38 changes: 37 additions & 1 deletion onnx2kerastl/normalization_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from .utils import ensure_tf_type
from .tfops_funcs import tf_math_reduce_mean, tf_math_reduce_variance, tf_sqrt
from .tfops_funcs import tf_math_reduce_mean, tf_math_reduce_variance, tf_sqrt, tf_rank, tf_concat, tf_reshape


def convert_batchnorm(node, params, layers, lambda_func, node_name, keras_name):
Expand Down Expand Up @@ -155,3 +155,39 @@ def target_layer(x, depth_radius=params['size'], bias=params['bias'], alpha=para
lambda_layer = keras.layers.Lambda(target_layer, name=f"{params['cleaned_name']}_lrn")
layers[node_name] = lambda_layer(input_0)
lambda_func[keras_name] = target_layer


def convert_layernorm(node, params, layers, lambda_func, node_name, keras_name):
axis = params.get('axis', -1)
epsilon = params.get('epsilon', 1e-05)
stash_type = params.get('stash_type')
if stash_type is not None:
raise Exception("LayerNorm stash_type attribute is not implemented")
input_x = layers[node.input[0]]
weight = layers[node.input[1]]
if len(node.input) > 2:
bias = layers[node.input[2]]
else:
bias = None
center = True if bias is not None else False
layer_norm = tf.keras.layers.LayerNormalization(
axis=axis,
epsilon=epsilon,
center=center,
name=f"{params['cleaned_name']}_LayerNorm"
)
input_shape = input_x.shape.as_list()
if input_shape[axis] is None:
# reshape input such that the axis dim would be non-None (set by weights)
tf_input_shape = tf.shape(input_x)
if axis < 0:
axis = tf_rank(input_x, tf_name=f"{params['cleaned_name']}_LayerNorm_rank")._inferred_value[0] + axis
tf_new_shape = tf_concat([tf_input_shape[:axis], [weight.shape[0]], tf_input_shape[axis+1:]], axis=-1,
tf_name=f"{params['cleaned_name']}_LayerNorm_new_shape")
input_x = tf_reshape(input_x, tf_new_shape, tf_name=f"{params['cleaned_name']}_LayerNorm_reshape_none_axis")
layer_norm.build(input_x.shape)
if center:
layer_norm.set_weights([weight, bias])
else:
layer_norm.set_weights([weight])
layers[node_name] = layer_norm(input_x)
9 changes: 8 additions & 1 deletion onnx2kerastl/padding_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def convert_padding(node, params, layers, lambda_func, node_name, keras_name):
tf_name=f"{params['cleaned_name']}_pad")
# Magic ordering
else:
if pads.shape[0] == 8:
if isinstance(pads, keras.engine.keras_tensor.KerasTensor) and pads.shape[0] == 8:
padding_layer = lambda x: tf_pad(x,
[[pads[0], pads[4]],
[pads[1], pads[5]],
[pads[2], pads[6]],
[pads[3], pads[7]]],
tf_name=f"{params['cleaned_name']}_pad_3")
elif pads.shape[0] == 8:
padding_layer = keras.layers.ZeroPadding2D(
padding=((pads[2], pads[6]), (pads[3], pads[7])),
name=f"{params['cleaned_name']}_pad_0"
Expand Down
6 changes: 4 additions & 2 deletions onnx2kerastl/reshape_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .utils import is_numpy, ensure_tf_type, unsqueeze_tensors_of_rank_one
from .tfops_funcs import tf_reshape, tf_shape, tf_cast, tf_stack, tf_image_resize, tf_strided_slice,\
tf_squeeze, tf_transpose, tf_where, tf_gather, tf_range, tf_reduce_sum, tf_abs, tf_expand_dims, tf_concat, \
tf_shape, tf_tile, tf_fill, tf_gather_nd, tf_reduce_sum, tf_zeros_like, tf_multiply, tf_tensor_scatter_nd_update
tf_shape, tf_tile, tf_fill, tf_gather_nd, tf_reduce_sum, tf_zeros_like, tf_multiply, tf_tensor_scatter_nd_update,\
tf_ones


def convert_transpose(node, params, layers, lambda_func, node_name, keras_name):
Expand Down Expand Up @@ -671,7 +672,8 @@ def convert_expand(node, params, layers, lambda_func, node_name, keras_name):
input_1 = layers[node.input[1]]
if input_0.dtype.is_bool:
input_0 = tf_cast(input_0, dtype='int32', tf_name=f"{params['cleaned_name']}_bool_to_int")
multiply_res = input_0 * tf.ones(shape=input_1, dtype=input_0.dtype)
multiply_res = input_0 * tf_ones(shape=input_1, dtype=input_0.dtype,
tf_name=f"{params['cleaned_name']}_expand_use_ones")
# input_0.dtype == np.int32 since we can't serialize constants as int64, need to cast to true type
if layers[node.input[0]].dtype == np.int64:
multiply_res = tf_cast(multiply_res, tf.int64, tf_name=f"{params['cleaned_name']}_to_int64")
Expand Down
25 changes: 25 additions & 0 deletions test/models/private_tests/test_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import onnxruntime as ort
import numpy as np
import onnx
from onnx2kerastl import onnx_to_keras
from keras_data_format_converter import convert_channels_first_to_last
import tensorflow as tf
from test.models.private_tests.aws_utils import aws_s3_download
import pytest


@pytest.mark.parametrize('aws_s3_download', [["swin/", "swin/", False]], indirect=True)
def test_swin(aws_s3_download):
model_path = f'{aws_s3_download}/swin_v2_t.onnx'
inpt = np.load(f'{aws_s3_download}/input.npy')
result = np.load(f'{aws_s3_download}/output.npy')
onnx_model = onnx.load(model_path)
keras_model = onnx_to_keras(onnx_model, ['input'], name_policy='attach_weights_name',
allow_partial_compilation=False).converted_model
final_model = convert_channels_first_to_last(keras_model, should_transform_inputs_and_outputs=True)
res = final_model(inpt)
mean_error = (res-result).numpy().__abs__().mean()
max_error = (res-result).numpy().__abs__().max()
eps = 5e-6
assert mean_error < eps
assert max_error < eps

0 comments on commit 5942e44

Please sign in to comment.