Skip to content

Commit

Permalink
Fix for QActivations passed as an argument (fastmachinelearning#553)
Browse files Browse the repository at this point in the history
* Fix handling of QKeras activations passed as an argument

* Add a test for QKeras activations passed as an argument
  • Loading branch information
AdrianAlan authored Jun 7, 2022
1 parent 2dafb98 commit 7109f0e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 16 deletions.
34 changes: 25 additions & 9 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,32 @@ def keras_to_hls(config):
layer_list.append( layer )
if 'activation' in layer and layer['class_name'] not in activation_layers + recurrent_layers:# + qkeras_layers:
act_layer = {}
act_layer['name'] = layer['name'] + '_' + layer['activation']
act_layer['activation'] = layer['activation']
if 'activ_param' in layer:
act_layer['activ_param'] = layer['activ_param']
act_layer['class_name'] = layer['activation']
elif layer['activation'] == 'softmax':
act_layer['class_name'] = 'Softmax'
act_layer['axis'] = -1
# Workaround for QKeras activations passed as an argument
if isinstance(layer['activation'], dict):
act_details = layer['activation']
act_layer['class_name'] = 'QActivation'
act_layer['config'] = {
'name': layer['name'] + '_' + act_details['class_name'],
'activation': act_details['class_name']
}
act_layer, output_shape = layer_handlers['QActivation'](
act_layer,
None,
[output_shape],
reader,
config
)
else:
act_layer['class_name'] = 'Activation'
act_layer['name'] = layer['name'] + '_' + layer['activation']
act_layer['activation'] = layer['activation']
if 'activ_param' in layer:
act_layer['activ_param'] = layer['activ_param']
act_layer['class_name'] = layer['activation']
elif layer['activation'] == 'softmax':
act_layer['class_name'] = 'Softmax'
act_layer['axis'] = -1
else:
act_layer['class_name'] = 'Activation'
inputs_map[layer['name']] = act_layer['name']
if output_layers is not None and layer['name'] in output_layers:
output_layers = [act_layer['name'] if name == layer['name'] else name for name in output_layers]
Expand Down
17 changes: 12 additions & 5 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,21 @@ def config_from_keras_model(model, granularity='model', default_precision='ap_fi

print('Layer name: {}, layer type: {}'.format(layer['name'], layer['class_name']))
layer_list.append( layer )
if 'activation' in layer['config'] and layer['class_name'] not in activation_layers + qkeras_layers:
if 'activation' in layer['config'] and layer['class_name'] not in activation_layers:
act_layer = {}
act_layer['name'] = layer['name'] + '_' + layer['config']['activation']
act_layer['class_name'] = 'Activation'
print(' -> Activation ({}), layer name: {}'.format(layer['config']['activation'], layer['name']))
act_details = layer['config']['activation']
if isinstance(act_details, dict):
precision = _get_precision_from_quantizer(act_details)
act_details = act_details['class_name']
act_layer['precision'] = {}
act_layer['precision']['result'] = precision
act_layer['class_name'] = 'QActivation'
else:
act_layer['class_name'] = 'Activation'
act_layer['name'] = layer['name'] + '_' + act_details
print(' -> Activation ({}), layer name: {}'.format(act_details, layer['name']))
layer_list.append(act_layer)


def make_layer_config(layer):
layer_config = {}
if layer['class_name'] in dense_layers + conv_layers + rnn_layers:
Expand Down
84 changes: 82 additions & 2 deletions test/pytest/test_qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tensorflow.keras.models import Sequential, model_from_json
from tensorflow.keras.models import Sequential, Model, model_from_json
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1
from tensorflow.keras.layers import Activation, BatchNormalization
from tensorflow.keras.layers import Activation, BatchNormalization, Input
from qkeras.qlayers import QDense, QActivation
from qkeras.quantizers import quantized_bits, quantized_relu, ternary, binary
from qkeras.utils import _add_supported_quantized_objects; co = {}; _add_supported_quantized_objects(co)
Expand Down Expand Up @@ -228,3 +228,83 @@ def test_quantizer(randX_1000_1, quantizer, backend):
y_hls4ml = hls_model.predict(X)
# Goal is to get it passing with all equal
np.testing.assert_array_equal(y_qkeras, y_hls4ml)


@pytest.mark.parametrize(
'weight_quantizer,activation_quantizer,', [
('binary', 'binary'),
('ternary', 'ternary'),
('quantized_bits(4, 0, alpha=1)', 'quantized_relu(2, 0)'),
('quantized_bits(4, 0, alpha=1)', 'quantized_relu(4, 0)'),
('quantized_bits(4, 0, alpha=1)', 'quantized_relu(8, 0)')
]
)
def test_qactivation_kwarg(randX_100_10,
activation_quantizer,
weight_quantizer):
if activation_quantizer in ['binary', 'ternary']:
name = 'bnbt_qdense_alpha'
else:
name = 'qdense_{}'.format(
eval(activation_quantizer).__class__.__name__)

inputs = Input(shape=(10,))

outputs = QDense(
10,
activation=activation_quantizer,
name='qdense',
kernel_quantizer=weight_quantizer,
bias_quantizer=weight_quantizer,
kernel_initializer='lecun_uniform'
)(inputs)
model = Model(inputs, outputs)

hls4ml.model.optimizer.get_optimizer(
'output_rounding_saturation_mode'
).configure(
layers=[name],
rounding_mode='AP_RND_CONV',
saturation_mode='AP_SAT'
)
config = hls4ml.utils.config_from_keras_model(
model,
granularity='name'
)

out_dir = str(
test_root_path / 'hls4mlprj_qactivation_kwarg_{}'.format(
activation_quantizer
)
)

hls_model = hls4ml.converters.convert_from_keras_model(
model,
hls_config=config,
output_dir=out_dir
)
hls4ml.model.optimizer.get_optimizer(
'output_rounding_saturation_mode'
).configure(layers=[])
hls_model.compile()

# Verify if activation in hls_model
assert name in [layer.name for layer in hls_model.get_layers()]

# Output tests
X = randX_100_10
X = np.round(X * 2**10) * 2**-10
y_qkeras = model.predict(X)
y_hls4ml = hls_model.predict(X)
if hasattr(eval(activation_quantizer), 'bits'):
np.testing.assert_allclose(
y_qkeras.ravel(),
y_hls4ml.ravel(),
atol=2**-eval(activation_quantizer).bits,
rtol=1.0
)
else:
if activation_quantizer == 'binary':
y_hls4ml = np.where(y_hls4ml == 0, -1, 1)
wrong = (y_hls4ml != y_qkeras).ravel()
assert sum(wrong) / len(wrong) <= 0.005

0 comments on commit 7109f0e

Please sign in to comment.