Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add swin #165

Merged
merged 5 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .utils import is_numpy, ensure_tf_type
from .tfops_funcs import tf_tensor_scatter_nd_update, tf_maximum, tf_minimum, tf_cast, tf_expand_dims, tf_repeat,\
tf_equal, tf_where, tf_round, tf_sign, tf_abs, tf_math_mod, tf_bitwise_left_shift, tf_bitwise_right_shift,\
tf_logical_not
tf_logical_not, tf_add
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor

Expand Down Expand Up @@ -81,11 +81,13 @@ def convert_elementwise_add(node, params, layers, lambda_func, node_name, keras_
try:
if not input_0_is_non_keras and not input_1_is_non_keras:
to_add = input_1
if input_0.shape != input_1.shape and input_0.shape[:-1] == input_1.shape:
to_add = tf_repeat(tf_expand_dims(input_1, axis=-1, tf_name=f"{params['cleaned_name']}_expand"),
input_0.shape[-1], axis=-1, tf_name=f"{params['cleaned_name']}_repeat")

layers[node_name] = keras.layers.Add(name=f"{params['cleaned_name']}_add")([input_0, to_add])
# We probably need to seperate two possibilities here. Currently we only deal with second option
# [Batch] + [Batch,1] -> [Batch,1]
# [Not-Batch] + [Not,Batch,1] -> [Not-batch, Not-batch]
if len(input_0.shape) != len(input_1.shape):
layers[node_name] = tf_add(input_0, to_add, tf_name=f"{params['cleaned_name']}_add")
else:
layers[node_name] = keras.layers.Add(name=f"{params['cleaned_name']}_add")([input_0, to_add])
else:
raise ValueError('Operands are different.')
except (IndexError, ValueError):
Expand Down
5 changes: 3 additions & 2 deletions onnx2kerastl/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
convert_concat, convert_reshape, convert_flatten, convert_slice, convert_squeeze, convert_expand, convert_resize, \
convert_tile, convert_gather_elements
from .constant_layers import convert_constant, convert_constant_of_shape, convert_one_hot
from .normalization_layers import convert_batchnorm, convert_instancenorm, convert_dropout, convert_lrn
from .normalization_layers import convert_batchnorm, convert_instancenorm, convert_dropout, convert_lrn, convert_layernorm
from .pooling_layers import convert_avgpool, convert_maxpool, convert_global_avg_pool, convert_topk, convert_roi_align
from .padding_layers import convert_padding
from .upsampling_layers import convert_upsample
Expand Down Expand Up @@ -152,5 +152,6 @@
'Unique': convert_unique,
'If': convert_if,
'RoiAlign': convert_roi_align,
'Einsum': convert_einsum
'Einsum': convert_einsum,
'LayerNormalization': convert_layernorm
}
38 changes: 37 additions & 1 deletion onnx2kerastl/normalization_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from .utils import ensure_tf_type
from .tfops_funcs import tf_math_reduce_mean, tf_math_reduce_variance, tf_sqrt
from .tfops_funcs import tf_math_reduce_mean, tf_math_reduce_variance, tf_sqrt, tf_rank, tf_concat, tf_reshape


def convert_batchnorm(node, params, layers, lambda_func, node_name, keras_name):
Expand Down Expand Up @@ -155,3 +155,39 @@ def target_layer(x, depth_radius=params['size'], bias=params['bias'], alpha=para
lambda_layer = keras.layers.Lambda(target_layer, name=f"{params['cleaned_name']}_lrn")
layers[node_name] = lambda_layer(input_0)
lambda_func[keras_name] = target_layer


def convert_layernorm(node, params, layers, lambda_func, node_name, keras_name):
axis = params.get('axis', -1)
epsilon = params.get('epsilon', 1e-05)
stash_type = params.get('stash_type')
if stash_type is not None:
raise Exception("LayerNorm stash_type attribute is not implemented")
input_x = layers[node.input[0]]
weight = layers[node.input[1]]
if len(node.input) > 2:
bias = layers[node.input[2]]
else:
bias = None
center = True if bias is not None else False
layer_norm = tf.keras.layers.LayerNormalization(
axis=axis,
epsilon=epsilon,
center=center,
name=f"{params['cleaned_name']}_LayerNorm"
)
input_shape = input_x.shape.as_list()
if input_shape[axis] is None:
# reshape input such that the axis dim would be non-None (set by weights)
tf_input_shape = tf.shape(input_x)
if axis < 0:
axis = tf_rank(input_x, tf_name=f"{params['cleaned_name']}_LayerNorm_rank")._inferred_value[0] + axis
tf_new_shape = tf_concat([tf_input_shape[:axis], [weight.shape[0]], tf_input_shape[axis+1:]], axis=-1,
tf_name=f"{params['cleaned_name']}_LayerNorm_new_shape")
input_x = tf_reshape(input_x, tf_new_shape, tf_name=f"{params['cleaned_name']}_LayerNorm_reshape_none_axis")
layer_norm.build(input_x.shape)
if center:
layer_norm.set_weights([weight, bias])
else:
layer_norm.set_weights([weight])
layers[node_name] = layer_norm(input_x)
9 changes: 8 additions & 1 deletion onnx2kerastl/padding_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def convert_padding(node, params, layers, lambda_func, node_name, keras_name):
tf_name=f"{params['cleaned_name']}_pad")
# Magic ordering
else:
if pads.shape[0] == 8:
if isinstance(pads, keras.engine.keras_tensor.KerasTensor) and pads.shape[0] == 8:
padding_layer = lambda x: tf_pad(x,
[[pads[0], pads[4]],
[pads[1], pads[5]],
[pads[2], pads[6]],
[pads[3], pads[7]]],
tf_name=f"{params['cleaned_name']}_pad_3")
elif pads.shape[0] == 8:
padding_layer = keras.layers.ZeroPadding2D(
padding=((pads[2], pads[6]), (pads[3], pads[7])),
name=f"{params['cleaned_name']}_pad_0"
Expand Down
6 changes: 4 additions & 2 deletions onnx2kerastl/reshape_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .utils import is_numpy, ensure_tf_type, unsqueeze_tensors_of_rank_one
from .tfops_funcs import tf_reshape, tf_shape, tf_cast, tf_stack, tf_image_resize, tf_strided_slice,\
tf_squeeze, tf_transpose, tf_where, tf_gather, tf_range, tf_reduce_sum, tf_abs, tf_expand_dims, tf_concat, \
tf_shape, tf_tile, tf_fill, tf_gather_nd, tf_reduce_sum, tf_zeros_like, tf_multiply, tf_tensor_scatter_nd_update
tf_shape, tf_tile, tf_fill, tf_gather_nd, tf_reduce_sum, tf_zeros_like, tf_multiply, tf_tensor_scatter_nd_update,\
tf_ones


def convert_transpose(node, params, layers, lambda_func, node_name, keras_name):
Expand Down Expand Up @@ -671,7 +672,8 @@ def convert_expand(node, params, layers, lambda_func, node_name, keras_name):
input_1 = layers[node.input[1]]
if input_0.dtype.is_bool:
input_0 = tf_cast(input_0, dtype='int32', tf_name=f"{params['cleaned_name']}_bool_to_int")
multiply_res = input_0 * tf.ones(shape=input_1, dtype=input_0.dtype)
multiply_res = input_0 * tf_ones(shape=input_1, dtype=input_0.dtype,
tf_name=f"{params['cleaned_name']}_expand_use_ones")
# input_0.dtype == np.int32 since we can't serialize constants as int64, need to cast to true type
if layers[node.input[0]].dtype == np.int64:
multiply_res = tf_cast(multiply_res, tf.int64, tf_name=f"{params['cleaned_name']}_to_int64")
Expand Down
25 changes: 25 additions & 0 deletions test/models/private_tests/test_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import onnxruntime as ort
import numpy as np
import onnx
from onnx2kerastl import onnx_to_keras
from keras_data_format_converter import convert_channels_first_to_last
import tensorflow as tf
from test.models.private_tests.aws_utils import aws_s3_download
import pytest


@pytest.mark.parametrize('aws_s3_download', [["swin/", "swin/", False]], indirect=True)
def test_swin(aws_s3_download):
model_path = f'{aws_s3_download}/swin_v2_t.onnx'
inpt = np.load(f'{aws_s3_download}/input.npy')
result = np.load(f'{aws_s3_download}/output.npy')
onnx_model = onnx.load(model_path)
keras_model = onnx_to_keras(onnx_model, ['input'], name_policy='attach_weights_name',
allow_partial_compilation=False).converted_model
final_model = convert_channels_first_to_last(keras_model, should_transform_inputs_and_outputs=True)
res = final_model(inpt)
mean_error = (res-result).numpy().__abs__().mean()
max_error = (res-result).numpy().__abs__().max()
eps = 5e-6
assert mean_error < eps
assert max_error < eps
Loading