Skip to content

Commit

Permalink
correct channels calculation when groups appear
Browse files Browse the repository at this point in the history
  • Loading branch information
tomkoren21 committed Nov 20, 2024
1 parent a3ef47c commit adb6df5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onnx2kerastl/convolution_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def convert_conv(node, params, layers, lambda_func, node_name, keras_name):
"name": f"{params['cleaned_name']}_" + 'conv',
"groups": n_groups}
partial_conv = partial(keras.layers.Conv3D, **conv_args)
layers[node_name] = permute_wrap_conv_if_constant(partial_conv, input_0, is_constant, weights[0].shape[-2], params)
layers[node_name] = permute_wrap_conv_if_constant(partial_conv, input_0, is_constant, weights[0].shape[-2]*n_groups, params)

elif len(W.shape) == 4: # 2D conv
logger.debug('2D convolution')
Expand Down Expand Up @@ -149,7 +149,7 @@ def convert_conv(node, params, layers, lambda_func, node_name, keras_name):
"groups": n_groups}

partial_conv = partial(keras.layers.Conv2D, **conv_args)
layers[node_name] = permute_wrap_conv_if_constant(partial_conv, input_0, is_constant, weights[0].shape[-2], params)
layers[node_name] = permute_wrap_conv_if_constant(partial_conv, input_0, is_constant, weights[0].shape[-2]*n_groups, params)
else:
input_0_nhwc = tf_transpose(input_0, [0, 2, 3, 1],
tf_name=f"{params['cleaned_name']}_" + 'conv_transpose_nhwc')
Expand Down Expand Up @@ -196,7 +196,7 @@ def convert_conv(node, params, layers, lambda_func, node_name, keras_name):
else:
conv_args['padding'] = 'valid'
partial_conv = partial(keras.layers.Conv1D, **conv_args)
res = permute_wrap_conv_if_constant(partial_conv, input_0, is_constant, weights[0].shape[-2], params)
res = permute_wrap_conv_if_constant(partial_conv, input_0, is_constant, weights[0].shape[-2]*n_groups, params)
if has_bias:
res_shape = np.asarray(keras.backend.int_shape(res))
bias_dim = np.argwhere(res_shape == bias.shape)[0][0]
Expand Down

0 comments on commit adb6df5

Please sign in to comment.