Skip to content

Commit

Permalink
Merge pull request #827 from joshlerner/main
Browse files Browse the repository at this point in the history
Fix loading weights in GarNetStacked and GarNet internal array precisions
  • Loading branch information
jmitrevs authored Jul 7, 2023
2 parents 9799d57 + f07c112 commit 3b458f9
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ docs/_build
docs/autodoc/*
hls4mlprj_*
*~
*.ipynb_checkpoints/
31 changes: 16 additions & 15 deletions contrib/garnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,35 +313,36 @@ def _setup_transforms(self, n_aggregators, n_filters, n_propagate):
name=('Fout%d' % it),
)

# Check for correctness. This commented out because pre-commit showed it was unused.

# if self._output_activation is None or self._output_activation == "linear":
# output_activation_transform = (QActivation("quantized_bits(%i, %i)"
# % (self._total_bits, self._int_bits)))
# else:
# output_activation_transform = QActivation(
# "quantized_%s(%i, %i)" % (self._output_activation, self._total_bits, self._int_bits)
# )
if self._output_activation is None or self._output_activation == "linear":
output_activation_transform = QActivation("quantized_bits(%i, %i)" % (self._total_bits, self._int_bits))
else:
output_activation_transform = QActivation(
"quantized_%s(%i, %i)" % (self._output_activation, self._total_bits, self._int_bits)
)
else:
input_feature_transform = NamedDense(p, name=('FLR%d' % it))
output_feature_transform = NamedDense(f, name=('Fout%d' % it))
# output_activation_transform = keras.layers.Activation(self._output_activation)
output_activation_transform = keras.layers.Activation(self._output_activation)

aggregator_distance = NamedDense(a, name=('S%d' % it))

self._transform_layers.append((input_feature_transform, aggregator_distance, output_feature_transform))
self._transform_layers.append(
(input_feature_transform, aggregator_distance, output_feature_transform, output_activation_transform)
)

self._sublayers = sum((list(layers) for layers in self._transform_layers), [])

def _build_transforms(self, data_shape):
for in_transform, d_compute, out_transform in self._transform_layers:
for in_transform, d_compute, out_transform, act_transform in self._transform_layers:
in_transform.build(data_shape)
d_compute.build(data_shape)
if self._simplified:
out_transform.build(data_shape[:2] + (d_compute.units * in_transform.units,))
act_transform.build(out_transform.build(data_shape[:2] + (d_compute.units * in_transform.units,)))
else:
out_transform.build(
data_shape[:2] + (data_shape[2] + d_compute.units * in_transform.units + d_compute.units,)
act_transform.build(
out_transform.build(
data_shape[:2] + (data_shape[2] + d_compute.units * in_transform.units + d_compute.units,)
)
)

data_shape = data_shape[:2] + (out_transform.units,)
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/backends/vivado/passes/garnet_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def format(self, node):
params[f'{vname}_t'], type_name = node.model.config.get_precision(node, var=vname)
if type_name.endswith('default_t'):
params[f'{vname}_t'] = precision_converter.convert(default_precision).definition_cpp()

else:
params[f'{vname}_t'] = precision_converter.convert(params[f'{vname}_t']).definition_cpp()
params['output_t'] = node.get_output_variable().type.name

if node.attributes['collapse'] in ['mean', 'max']:
Expand Down
8 changes: 6 additions & 2 deletions hls4ml/converters/keras/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ def parse_garnet_layer(keras_layer, input_names, input_shapes, data_reader):
layer['n_sublayers'] = keras_layer['config']['n_sublayers']
layer['n_in_features'] = [input_shapes[0][2]]

for il in range(1, layer['n_sublayers']):
layer['n_in_features'].append(layer['n_out_features'][il - 1])
for il in range(layer['n_sublayers']):
if il > 0:
layer['n_in_features'].append(layer['n_out_features'][il - 1])

weights_source = [
f'FLR{il}_kernel',
f'FLR{il}_bias',
f'S{il}_kernel',
f'S{il}_bias',
f'Fout{il}_kernel',
f'Fout{il}_bias',
]
for weight in weights_source:
Expand Down
1 change: 0 additions & 1 deletion hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,6 @@ def _initialize_transforms(self):

def _make_input_transform_weights(self, n_propagate, n_aggregators, n_out_features, quantize=False, sublayer=''):
# Due to linearity of the input transform, input weights and biases can be contracted away at conversion time

output_transform_kernel = self.get_attr(
f'Fout{sublayer}_kernel_data'
) # [(n_aggregators, n_propagate), n_out_features]
Expand Down
47 changes: 46 additions & 1 deletion test/pytest/test_garnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tensorflow.keras.models import Model

import hls4ml
from contrib.garnet import GarNet
from contrib.garnet import GarNet, GarNetStack

test_root_path = Path(__file__).parent

Expand Down Expand Up @@ -49,6 +49,40 @@ def garnet_models():
return model, hls_model


@pytest.fixture(scope='module')
def garnet_stack_models():
x = Input(shape=(vmax, feat))
n = Input(shape=(1,), dtype='uint16')
inputs = [x, n]
outputs = GarNetStack(
([4, 4, 8]),
([4, 4, 8]),
([8, 8, 16]),
simplified=True,
collapse='mean',
input_format='xn',
output_activation=None, # added output_activation_transform back in contrib.garnet.py
name='gar_1',
quantize_transforms=None, # this should be false, not None...fix in contrib.garnet.py
)(inputs)
model = Model(inputs=inputs, outputs=outputs)
model.summary()

config = hls4ml.utils.config_from_keras_model(model, granularity='name')
config['Model'] = {}
config['Model']['ReuseFactor'] = 1
config['Model']['Strategy'] = 'Latency'
config['Model']['Precision'] = 'ap_fixed<32,6>'
# config should now have precisions specified for ['LayerName']['gar_1']['Precision']['norm', 'aggr', etc.]
cfg = hls4ml.converters.create_config(output_dir=str(test_root_path / 'hls4mlprj_garnet'), part='xc7z020clg400-1')
cfg['HLSConfig'] = config
cfg['KerasModel'] = model

hls_model = hls4ml.converters.keras_to_hls(cfg)
hls_model.compile()
return model, hls_model


@pytest.mark.parametrize('batch', [1, 3])
def test_accuracy(garnet_models, batch):
model, hls_model = garnet_models
Expand All @@ -58,3 +92,14 @@ def test_accuracy(garnet_models, batch):
y_hls = hls_model.predict(x_hls).reshape(y.shape)

np.testing.assert_allclose(y_hls, y, rtol=0, atol=0.1)


@pytest.mark.parametrize('batch', [1, 3])
def test_accuracy_stack(garnet_stack_models, batch):
model, hls_model = garnet_stack_models
x = [np.random.rand(batch, vmax, feat), np.random.randint(0, vmax, size=(batch, 1))]
y = model.predict(x)
x_hls = [x[0], x[1].astype(np.float64)]
y_hls = hls_model.predict(x_hls).reshape(y.shape)

np.testing.assert_allclose(y_hls, y, rtol=0, atol=0.1)

0 comments on commit 3b458f9

Please sign in to comment.