Skip to content

Commit

Permalink
Fix GlobalPooling1D Layers (#399)
Browse files Browse the repository at this point in the history
* add global_pooling1d_cl

* io_parallel global_pooling1d_cl

* add globalpooling1d testing; add missing include in nnet_conv_stream.h

* update test

* Update test_globalpooling1d.py

* Change project directory name

Co-authored-by: Jovan Mitrevski <[email protected]>
Co-authored-by: Sioni Summers <[email protected]>
  • Loading branch information
3 people committed Oct 7, 2021
1 parent fd8a618 commit f5acaca
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 16 deletions.
4 changes: 2 additions & 2 deletions hls4ml/converters/keras/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def parse_pooling_layer(keras_layer, input_names, input_shapes, data_reader, con

return layer, output_shape

pooling_layers = ['GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'GlobalAveragePooling1D', 'GlobalAveragePooling2D']
@keras_handler(*pooling_layers)
global_pooling_layers = ['GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'GlobalAveragePooling1D', 'GlobalAveragePooling2D']
@keras_handler(*global_pooling_layers)
def parse_global_pooling_layer(keras_layer, input_names, input_shapes, data_reader, config):
assert('Pooling' in keras_layer['class_name'])

Expand Down
15 changes: 10 additions & 5 deletions hls4ml/model/hls_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,7 @@ def config_cpp(self):
params['n_filt'] = self.get_output_variable().dim_names[1]
else:
params['n_in'] = self.get_input_variable().dim_names[1]
params['n_out'] = self.get_input_variable().dim_names[1]
params['n_out'] = self.get_output_variable().dim_names[1]
params['n_filt'] = self.get_output_variable().dim_names[0]

return self._config_template.format(**params)
Expand Down Expand Up @@ -1208,19 +1208,24 @@ def config_cpp(self):

class GlobalPooling1D(Layer):
def initialize(self):
shape = [self.attributes['n_out'], self.attributes['n_filt']]
dims = ['N_OUTPUTS_{}'.format(self.index), 'N_FILT_{}'.format(self.index)]
shape = [self.attributes['n_filt']]
dims = ['N_FILT_{}'.format(self.index)]
self.add_output_variable(shape, dims)
self.set_attr('pool_op', self.get_attr('class_name').split('Pooling')[0].replace('Global', ''))

def function_cpp(self):
params = self._default_function_params()

params['data_format'] = 'cf' if self.get_attr('data_format') == 'channels_first' else 'cl'
return [self._function_template.format(**params)]

def config_cpp(self):
params = self._default_config_params()
params['n_in'] = self.get_input_variable().size_cpp()
if self.get_attr('data_format') == 'channels_last':
params['n_in'] = self.get_input_variable().dim_names[0]
params['n_filt'] = self.get_input_variable().dim_names[1]
else:
params['n_in'] = self.get_input_variable().dim_names[1]
params['n_filt'] = self.get_input_variable().dim_names[0]

return self._config_template.format(**params)

Expand Down
1 change: 1 addition & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ap_shift_reg.h"
#include "nnet_common.h"
#include "hls_stream.h"
#include "nnet_dense.h"

namespace nnet {

Expand Down
22 changes: 21 additions & 1 deletion hls4ml/templates/vivado/nnet_utils/nnet_pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ struct pooling1d_config{
// IO size
static const unsigned n_in = 10;
static const unsigned pool_width = 2;
static const unsigned n_out = n_in / pool_width;
static const unsigned stride_width = 2;
static const unsigned n_out = (n_in - pool_width) / stride_width + 1;
static const unsigned pad_left = 0;
static const unsigned pad_right = 0;
// Pooling function
Expand Down Expand Up @@ -141,6 +142,25 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
}
}

template<class data_T, class res_T, typename CONFIG_T>
void global_pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONFIG_T::n_filt]) {
assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0);
assert(CONFIG_T::pool_width == CONFIG_T::stride_width);

// TODO partition the arrays according to the reuse factor
const int limit = pool_op_limit_1d<CONFIG_T>();
#pragma HLS ALLOCATION instances=pool_op limit=limit function

for(int ff = 0; ff < CONFIG_T::n_filt; ff++) {
data_T pool[CONFIG_T::n_in];
for(int jj = 0; jj < CONFIG_T::n_in; jj++) {
pool[jj] = data[jj * CONFIG_T::n_filt + ff];
}
// do the pooling
res[ff] = pool_op<data_T, CONFIG_T::n_in, CONFIG_T::pool_op>(pool);
}
}

struct pooling2d_config{
// IO size
static const unsigned in_height = 10;
Expand Down
58 changes: 55 additions & 3 deletions hls4ml/templates/vivado/nnet_utils/nnet_pooling_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,6 @@ T reduce_global_pool(T x, T y[N]) {

template<class data_T, class res_T, typename CONFIG_T>
void compute_global_pool(
const unsigned h_idx,
const unsigned w_idx,
const data_T& in_elem,
typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt]
) {
Expand Down Expand Up @@ -516,7 +514,7 @@ void global_pooling2d_cl(
ReadInputHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) {
ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_filt); i_iw++) {
#pragma HLS LOOP_FLATTEN
compute_global_pool<data_T, res_T, CONFIG_T>(i_ih, i_iw, data.read(), data_window);
compute_global_pool<data_T, res_T, CONFIG_T>(data.read(), data_window);
}
}

Expand Down Expand Up @@ -548,6 +546,60 @@ void global_pooling2d_cl(

}

template<class data_T, class res_T, typename CONFIG_T>
void global_pooling1d_cl(
hls::stream<data_T> &data,
hls::stream<res_T> &res
) {
assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0);
assert(CONFIG_T::pool_width == CONFIG_T::stride_width);

typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt];
#pragma HLS ARRAY_PARTITION variable=data_window complete

typename CONFIG_T::accum_t init = 0;
if (CONFIG_T::pool_op == Max) {
init = hls::numeric_limits<typename CONFIG_T::accum_t>::min();
}

PoolInitLoop: for (unsigned i_init = 0; i_init < CONFIG_T::n_filt; i_init++) {
#pragma HLS UNROLL
data_window[i_init] = init;
}

ReadInput: for (unsigned i_iw = 0; i_iw < CONFIG_T::n_in / (data_T::size / CONFIG_T::n_filt); i_iw++) {
#pragma HLS LOOP_FLATTEN
compute_global_pool<data_T, res_T, CONFIG_T>(data.read(), data_window);
}

if (CONFIG_T::pool_op == Max) {
MaxPoolRes: for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) {
#pragma HLS PIPELINE

res_T res_pack;
#pragma HLS DATA_PACK variable=res_pack
MaxPoolPack: for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
#pragma HLS UNROLL
res_pack[i_pack] = data_window[i_pack];
}
res.write(res_pack);
}
} else {
AvgPoolRes: for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) {
#pragma HLS PIPELINE

res_T res_pack;
#pragma HLS DATA_PACK variable=res_pack
AvgPoolPack: for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
#pragma HLS UNROLL
res_pack[i_pack] = data_window[i_pack] / CONFIG_T::n_in;
}
res.write(res_pack);
}
}

}

}

#endif
8 changes: 3 additions & 5 deletions hls4ml/templates/vivado_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,9 @@

global_pooling1d_config_template = """struct config{index} : nnet::pooling1d_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const unsigned stride = {stride};
static const unsigned n_filt = {n_filt};
static const nnet::Pool_Op pool_op = nnet::{pool_op};
static const unsigned reuse = {reuse};
typedef {accum_t} accum_t;
}};\n"""

Expand Down Expand Up @@ -358,7 +356,7 @@
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
pooling1d_function_template = 'nnet::pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
pooling2d_function_template = 'nnet::pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
global_pooling1d_function_template = 'nnet::global_pooling1d<{input_t}, {config}>({input}, {output});'
global_pooling1d_function_template = 'nnet::global_pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
global_pooling2d_function_template = 'nnet::global_pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
zeropad1d_function_template = 'nnet::zeropad1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
zeropad2d_function_template = 'nnet::zeropad2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
Expand Down
57 changes: 57 additions & 0 deletions test/pytest/test_globalpooling1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D
import numpy as np
import hls4ml


in_shape = 8
in_feat = 4
atol = 5e-3

@pytest.fixture(scope='module')
def data():
X = np.random.rand(100, in_shape, in_feat)
return X


@pytest.fixture(scope='module')
def keras_model_max():
model = Sequential()
model.add(GlobalMaxPooling1D(input_shape=(in_shape, in_feat)))
model.compile()
return model

@pytest.fixture(scope='module')
def keras_model_ave():
model = Sequential()
model.add(GlobalAveragePooling1D(input_shape=(in_shape, in_feat)))
model.compile()
return model


@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
@pytest.mark.parametrize('model_type', ['max', 'ave'])
def test_global_pool1d(keras_model_max, keras_model_ave, data, model_type, io_type):
if model_type == 'ave':
model = keras_model_ave
else:
model = keras_model_max
config = hls4ml.utils.config_from_keras_model(model,
default_precision='ap_fixed<32,1>',
granularity='name')
if model_type == 'ave':
config['LayerName']['global_average_pooling1d']['accum_t'] = 'ap_fixed<32,6>'

hls_model = hls4ml.converters.convert_from_keras_model(model,
hls_config=config,
io_type=io_type,
output_dir=f'hls4mlprj_globalplool1d_{model_type}_{io_type}',
part='xcvu9p-flgb2104-2-i')
hls_model.compile()


# Predict
y_keras = np.squeeze(model.predict(data))
y_hls = hls_model.predict(data)
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)

0 comments on commit f5acaca

Please sign in to comment.