Skip to content

Commit

Permalink
Merge pull request #802 from vloncar/conv_bn_fix
Browse files Browse the repository at this point in the history
Fix parsing of QConv2DBatchnorm weights
  • Loading branch information
jmitrevs authored Jun 6, 2023
2 parents 2e71ff4 + 2d5c42d commit dc35658
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hls4ml/converters/keras/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader):

(layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(input_shapes[0], layer['data_format'])

if layer['class_name'] in ['Conv2D', 'QConv2D']:
if layer['class_name'] in ['Conv2D', 'QConv2D', 'QConv2DBatchnorm']:
layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'kernel')
elif layer['class_name'] in ['SeparableConv2D', 'QSeparableConv2D']:
layer['depthwise_data'], layer['pointwise_data'] = get_weights_data(
Expand Down
42 changes: 42 additions & 0 deletions test/pytest/test_qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
from qkeras.qconv2d_batchnorm import QConv2DBatchnorm
from qkeras.qlayers import QActivation, QDense
from qkeras.quantizers import binary, quantized_bits, quantized_relu, quantized_sigmoid, quantized_tanh, ternary
from qkeras.utils import _add_supported_quantized_objects
Expand Down Expand Up @@ -348,3 +349,44 @@ def test_qactivation_kwarg(randX_100_10, activation_quantizer, weight_quantizer)
y_hls4ml = np.where(y_hls4ml == 0, -1, 1)
wrong = (y_hls4ml != y_qkeras).ravel()
assert sum(wrong) / len(wrong) <= 0.005


@pytest.fixture(scope='module')
def randX_100_8_8_1():
return np.random.rand(100, 8, 8, 1)


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_qconv2dbn(randX_100_8_8_1, backend, io_type):
'''
Test proper handling of QConv2DBatchnorm.
'''
X = randX_100_8_8_1
X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6>
model = Sequential()
model.add(
QConv2DBatchnorm(
4,
kernel_size=(3, 3),
input_shape=(8, 8, 1),
kernel_quantizer='quantized_bits(8, 0, alpha=1)',
kernel_initializer='ones',
bias_quantizer='quantized_bits(8, 0, alpha=1)',
bias_initializer='zeros',
activation='quantized_relu(8, 0)',
)
)
model.compile()

config = hls4ml.utils.config_from_keras_model(model, granularity='name')
output_dir = str(test_root_path / f'hls4mlprj_qkeras_qconv2dbn_{backend}_{io_type}')
hls_model = hls4ml.converters.convert_from_keras_model(
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
)
hls_model.compile()

y_qkeras = model.predict(X)
y_hls4ml = hls_model.predict(X)

np.testing.assert_array_equal(y_qkeras, y_hls4ml.reshape(y_qkeras.shape))

0 comments on commit dc35658

Please sign in to comment.