diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py new file mode 100644 index 000000000..598c1c59d --- /dev/null +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -0,0 +1,171 @@ + +from hls4ml.backends.backend import get_backend +from hls4ml.model.layers import LSTM, GRU +from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate + +# recurrent multiplication template + +recr_mult_config_template = """struct config{index} : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned n_nonzeros = {nonzeros}; + static const bool store_weights_in_bram = false; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef ap_{index_t} index_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +#activation templates + +activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef ap_{table_t} table_t; +}};\n""" + +recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef ap_{table_t} table_t; +}};\n""" + +# LSTM + GRU templates + +recr_config_template = """struct config{index} : nnet::{recr_type}_config {{ + typedef {accum_t.name} accum_t; + typedef {weight_t.name} weight_t; // Matrix + typedef {bias_t.name} bias_t; // Vector + typedef {config_mult_t1} mult_config1; + typedef {config_mult_t2} mult_config2; + typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE}; + template + using activation_recr = nnet::activation::{recurrent_activation}; + typedef {act_t} ACT_CONFIG_T; + template + using activation = nnet::activation::{activation}; + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned n_state = {n_state}; + static const unsigned n_sequence = {n_sequence}; + static const unsigned n_sequence_out = {n_sequence_out}; + static const unsigned io_type = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; + static const bool use_static = {static}; +}};\n""" + +recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});' + +recr_include_list = ['nnet_utils/nnet_recurrent.h'] + +class RecurrentConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((LSTM, GRU)) + self.template = recr_config_template + self.act_template = activ_config_template + self.recr_act_template = recr_activ_config_template + self.mult1_template = recr_mult_config_template + self.mult2_template = recr_mult_config_template + + def format(self, node): + + params = self._default_config_params(node) + + params['n_in'] = node.get_input_variable().dim_names[1] + params['n_sequence'] = node.get_input_variable().dim_names[0] + if node.get_attr('return_sequences'): + params['n_sequence_out'] = node.get_output_variable().dim_names[0] + params['n_state'] = node.get_output_variable().dim_names[1] + params['n_out'] = node.get_output_variable().dim_names[1] + else: + params['n_sequence_out'] = 1 + params['n_state'] = node.get_output_variable().dim_names[0] + params['n_out'] = node.get_output_variable().dim_names[0] + params['config_mult_t1'] = 'config{}_1'.format(node.index) + params['config_mult_t2'] = 'config{}_2'.format(node.index) + params['recr_act_t'] = '{}_config{}_recr'.format(node.get_attr('recurrent_activation'), node.index) + params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index) + params['strategy'] = node.get_attr('strategy') + params['static'] = 'true' if node.attributes['static'] else 'false' + params['recr_type'] = node.class_name.lower() + params['RECR_TYPE'] = node.class_name + + if node.class_name=='LSTM': + n_recr_mult = 4 + else: #GRU + n_recr_mult = 3 + + recr_config = self.template.format(**params) + + act_params = self._default_config_params(node) + recr_act_params = self._default_config_params(node) + + act_params['type'] = node.get_attr('activation') + recr_act_params['type'] = node.get_attr('recurrent_activation') + if node.get_attr('return_sequences'): + act_params['n_in'] = node.get_output_variable().dim_names[1] + recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * %i'%(n_recr_mult-1) + else: + act_params['n_in'] = node.get_output_variable().dim_names[0] + recr_act_params['n_in'] = node.get_output_variable().dim_names[0] + ' * %i'%(n_recr_mult-1) + + act_config = self.act_template.format(**act_params) + recr_act_config = self.recr_act_template.format(**recr_act_params) + + mult_params1 = self._default_config_params(node) + mult_params2 = self._default_config_params(node) + + mult_params1['n_in'] = node.get_input_variable().dim_names[1] + if node.get_attr('return_sequences'): + mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * %i'%n_recr_mult + else: + mult_params1['n_out'] = node.get_output_variable().dim_names[0] + ' * %i'%n_recr_mult + mult_params1['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_params1['reuse'] = params['reuse'] + mult_params1['index'] = str(node.index) + '_1' + mult_params1['nzeros'] = node.get_weights('weight').nzeros + mult_params1['nonzeros'] = node.get_weights('weight').nonzeros + if node.get_attr('return_sequences'): + mult_params2['n_in'] = node.get_output_variable().dim_names[1] + mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i'%n_recr_mult + else: + mult_params2['n_in'] = node.get_output_variable().dim_names[0] + mult_params2['n_out'] = node.get_output_variable().dim_names[0] + ' * %i'%n_recr_mult + mult_params2['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision) + mult_params2['reuse'] = node.attributes['recurrent_reuse_factor'] + mult_params2['index'] = str(node.index) + '_2' + mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros + mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros + + mult_config1 = self.mult1_template.format(**mult_params1) + mult_config2 = self.mult2_template.format(**mult_params2) + + return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config + +class RecurrentFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((LSTM, GRU), include_header=recr_include_list) + self.template = recr_function_template + + def format(self, node): + params = self._default_function_params(node) + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + params['wr'] = node.get_weights('recurrent_weight').name + params['br'] = node.get_weights('recurrent_bias').name + params['activation'] = node.get_attr('activation') + params['recurrent_activation'] = node.get_attr('recurrent_activation') + params['recr_type'] = node.class_name.lower() + + return self.template.format(**params) + diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index f423eeb49..9e41456f5 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -1,13 +1,13 @@ import numpy as np from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU class ApplyResourceStrategy(OptimizerPass): ''' Transposes the weights to use the dense_resource matrix multiply routine ''' def match(self, node): - node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D)) + node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' already_transformed = node.get_attr('_weights_transposed', False) == True @@ -26,9 +26,12 @@ def transform(self, model, node): elif isinstance(node, SeparableConv2D): node.weights['depthwise'].data = np.transpose(node.weights['depthwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C) node.weights['pointwise'].data = np.transpose(node.weights['pointwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C) + elif isinstance(node, (LSTM, GRU)): + node.weights['weight'].data = np.transpose(node.weights['weight'].data) + node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) else: raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name)) node.set_attr('_weights_transposed', True) - return False \ No newline at end of file + return False diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 0140e55c1..6fb1cc247 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -23,9 +23,9 @@ def __init__(self): def _register_layer_attributes(self): extended_attrs = { - SimpleRNN: [Attribute('recurrent_reuse_factor', default=1)], - LSTM: [Attribute('recurrent_reuse_factor', default=1)], - GRU: [Attribute('recurrent_reuse_factor', default=1)], + SimpleRNN: [Attribute('recurrent_reuse_factor', default=1), Attribute('static', value_type=bool, default=True)], + LSTM: [Attribute('recurrent_reuse_factor', default=1), Attribute('static', value_type=bool, default=True)], + GRU: [Attribute('recurrent_reuse_factor', default=1), Attribute('static', value_type=bool, default=True)], } self.attribute_map.update(extended_attrs) diff --git a/hls4ml/model/optimizer/passes/multi_dense.py b/hls4ml/model/optimizer/passes/multi_dense.py index 9ecc97f84..d2eb80ef7 100644 --- a/hls4ml/model/optimizer/passes/multi_dense.py +++ b/hls4ml/model/optimizer/passes/multi_dense.py @@ -5,7 +5,10 @@ class ReplaceMultidimensionalDenseWithConv(OptimizerPass): def match(self, node): return isinstance(node, Dense) and \ - len(node.get_input_variable().shape) > 1 + len(node.get_input_variable().shape) - sum(d==1 for d in node.get_input_variable().shape) > 1 + # The above sum checks for the number of dimensions in the Dense with size 1 + # The subtraction allows the check to only count the number of dimensions with non-1 size + # For example, this prevents matching for a Dense layer with shape (1,N) def transform(self, model, node): dim = len(node.get_input_variable().shape) - 1 diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recr_activations.h b/hls4ml/templates/vivado/nnet_utils/nnet_recr_activations.h new file mode 100644 index 000000000..487dbebd3 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recr_activations.h @@ -0,0 +1,63 @@ +#ifndef NNET_RECR_ACTIVATION_H_ +#define NNET_RECR_ACTIVATION_H_ + +#include "nnet_common.h" +#include "nnet_helpers.h" +#include "nnet_activation.h" +#include "hls_stream.h" +#include + +namespace nnet { + +namespace activation{ + +template +class Activation{ + public: + // ************************************************* + // Blank Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {} // Nothing to do here +}; + +template +class relu : public Activation{ + public: + // ************************************************* + // Relu Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) + { + nnet::relu(data, res); + } +}; + +template +class sigmoid : public Activation{ + public: + // ************************************************* + // Sigmoid Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) + { + nnet::sigmoid(data, res); + } +}; + +template +class tanh : public Activation{ + public: + // ************************************************* + // TanH Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) + { + nnet::tanh(data, res); + } +}; + +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h new file mode 100644 index 000000000..a7096dde1 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h @@ -0,0 +1,593 @@ + +// +// rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks +// +// Copyright (C) 2017 EJ Kreinar +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +#ifndef NNET_RECURSIVE_H_ +#define NNET_RECURSIVE_H_ + +#include "nnet_common.h" +#include "nnet_activation.h" +#include "nnet_recr_activations.h" +#include "nnet_dense.h" +#include "hls_stream.h" + + +namespace nnet { + +struct lstm_config +{ + // Internal data type definitions + typedef float weight_t; + typedef float bias_t; + + // Layer Sizes + static const unsigned n_in = 2; + static const unsigned n_parts = 20; + static const unsigned n_out = 2; + static const unsigned n_state = 2; + static const unsigned n_4state = 8; + static const unsigned table_size = 1024; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const unsigned n_zeros = 0; + static const bool store_weights_in_bram = false; + static const bool use_static = true; + + template + using activation_recr = nnet::activation::relu; + template + using activation = nnet::activation::relu; +}; +// Long Short term Memory NN (LSTM) +// Resources: +// https://github.com/nicodjimenez/lstm/blob/master/lstm.py +// https://github.com/llSourcell/LSTM_Networks/blob/master/LSTM%20Demo.ipynb +// https://en.wikipedia.org/wiki/Long_short-term_memory +// Notes: +// - LSTM naming conventions adopted from the above links +// - s_newstate = activation(U*input + W*state) +// - h_output = activation(U*input + W*state)*activation(s_newstate) +// - If softmax is needed on output, perform *outside* this operations +// Originall had a version allows for the state in each layer to be saved, moved this to above (this requires are LARGE dense network at the end) +template + void lstm(bool reset_state, + data_T data [CONFIG_T::n_in], + res_T h_newstate[CONFIG_T::n_state], + res_T s_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param [CONFIG_T::n_state*4*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state*4*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state*4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state*4] + ) { + // Initialize the state variable -- will maintain state between function calls + + typename CONFIG_T::accum_t tmpres [CONFIG_T::n_state*4]; + typename CONFIG_T::accum_t tmpres_state[CONFIG_T::n_state*4]; + typename CONFIG_T::accum_t tmpres_ifo [CONFIG_T::n_state*3]; //activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_c [CONFIG_T::n_state]; //activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_ifo[CONFIG_T::n_state*3]; //i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_c [CONFIG_T::n_state]; //c-matrix (keras notation) + typename CONFIG_T::accum_t s_actstate[CONFIG_T::n_state]; + + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=s_newstate complete + #pragma HLS ARRAY_PARTITION variable=tmpres complete + #pragma HLS ARRAY_PARTITION variable=tmpres_state complete + #pragma HLS ARRAY_PARTITION variable=tmpres_ifo complete + #pragma HLS ARRAY_PARTITION variable=tmpres_c complete + #pragma HLS ARRAY_PARTITION variable=inputacc_ifo complete + #pragma HLS ARRAY_PARTITION variable=inputacc_c complete + #pragma HLS ARRAY_PARTITION variable=s_actstate complete + + nnet::dense(data ,tmpres , param,param_b); + nnet::dense(h_newstate,tmpres_state, param_r, param_br); + + for(int iacc = 0; iacc < (3*CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc; + if(iacc > 2*CONFIG_T::n_state-1) index = iacc + CONFIG_T::n_state; + inputacc_ifo[iacc] = tmpres[index] + tmpres_state[index]; + } + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state*2; + inputacc_c[iacc] = tmpres[index] + tmpres_state[index]; + } + + CONFIG_T::template activation_recr::activation(inputacc_ifo, tmpres_ifo); + + //Now for the confusion matrix + CONFIG_T::template activation::activation(inputacc_c, tmpres_c); + + // Operation: s=g*i+sold*f (update state with buffer to avoid timing issues) + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { +#pragma HLS UNROLL + s_newstate[iacc] = tmpres_c[iacc]*tmpres_ifo[iacc] + s_newstate[iacc]*tmpres_ifo[iacc+(CONFIG_T::n_state)]; + } + // Operation: h=act(s)*o + CONFIG_T::template activation::activation(s_newstate, s_actstate); + + for(int iacc = 0; iacc < CONFIG_T::n_state; iacc++) { +#pragma HLS UNROLL + h_newstate[iacc] = tmpres_ifo[iacc+2*(CONFIG_T::n_state)]*s_actstate[iacc]; + } +} + +template + void lstm_static(bool reset_state, + data_T data [CONFIG_T::n_in], + res_T h_newstate[CONFIG_T::n_state], + res_T s_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param [CONFIG_T::n_state*4*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state*4*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state*4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state*4] + ) { + static res_T h_state[CONFIG_T::n_state]; + static res_T s_state[CONFIG_T::n_state]; + // Initialize the state variable -- will maintain state between function calls + typename CONFIG_T::accum_t tmpres [CONFIG_T::n_state*4]; + typename CONFIG_T::accum_t tmpres_state[CONFIG_T::n_state*4]; + typename CONFIG_T::accum_t tmpres_ifo [CONFIG_T::n_state*3]; //activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_c [CONFIG_T::n_state]; //activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_ifo[CONFIG_T::n_state*3]; //i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_c [CONFIG_T::n_state]; //c-matrix (keras notation) + typename CONFIG_T::accum_t s_actstate[CONFIG_T::n_state]; + + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=s_newstate complete + #pragma HLS ARRAY_PARTITION variable=h_state complete + #pragma HLS ARRAY_PARTITION variable=s_state complete + #pragma HLS ARRAY_PARTITION variable=tmpres complete + #pragma HLS ARRAY_PARTITION variable=tmpres_state complete + #pragma HLS ARRAY_PARTITION variable=tmpres_ifo complete + #pragma HLS ARRAY_PARTITION variable=tmpres_c complete + #pragma HLS ARRAY_PARTITION variable=inputacc_ifo complete + #pragma HLS ARRAY_PARTITION variable=inputacc_c complete + #pragma HLS ARRAY_PARTITION variable=s_actstate complete + + if(reset_state){ + for(int i_state = 0; i_state < (CONFIG_T::n_state); i_state++) { + #pragma HLS UNROLL + s_state[i_state] = 0; + h_state[i_state] = 0; + } + } + + nnet::dense(data ,tmpres , param,param_b); + nnet::dense(h_state,tmpres_state, param_r, param_br); + + for(int iacc = 0; iacc < (3*CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc; + if(iacc > 2*CONFIG_T::n_state-1) index = iacc + CONFIG_T::n_state; + inputacc_ifo[iacc] = tmpres[index] + tmpres_state[index]; + } + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state*2; + inputacc_c[iacc] = tmpres[index] + tmpres_state[index]; + } + + CONFIG_T::template activation_recr::activation(inputacc_ifo, tmpres_ifo); + + //Now for the confusion matrix + CONFIG_T::template activation::activation(inputacc_c, tmpres_c); + + // Operation: s=g*i+sold*f (update state with buffer to avoid timing issues) + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + s_state[iacc] = tmpres_c[iacc]*tmpres_ifo[iacc] + s_state[iacc]*tmpres_ifo[iacc+(CONFIG_T::n_state)]; + s_newstate[iacc] = s_state[iacc]; + } + // Operation: h=act(s)*o + CONFIG_T::template activation::activation(s_state, s_actstate); + + for(int iacc = 0; iacc < CONFIG_T::n_state; iacc++) { +#pragma HLS UNROLL + h_state[iacc] = tmpres_ifo[iacc+2*(CONFIG_T::n_state)]*s_actstate[iacc]; + h_newstate[iacc] = h_state[iacc]; + } +} + +template + void lstm_stack( + data_T data [CONFIG_T::n_sequence*CONFIG_T::n_in], + res_T res [CONFIG_T::n_sequence_out*CONFIG_T::n_state], + typename CONFIG_T::weight_t param [CONFIG_T::n_state*4*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state*4*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state*4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state*4] + ) { + + res_T h_newstate[CONFIG_T::n_state]; + res_T s_newstate[CONFIG_T::n_state]; + data_T data_in[CONFIG_T::n_in]; + bool reset_state = true; + + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=s_newstate complete + + for(int ii = 0; ii < CONFIG_T::n_state; ii++) { + #pragma HLS UNROLL + h_newstate[ii] = 0; + s_newstate[ii] = 0; + } + for(int iloop = 0; iloop < CONFIG_T::n_sequence; iloop++) { + for(int j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + data_in[j] = data[j + iloop*CONFIG_T::n_in]; + } + if (CONFIG_T::use_static) + nnet::lstm_static(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br); + else + nnet::lstm(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br); + if (CONFIG_T::n_sequence_out > 1) + for(int i=CONFIG_T::n_state*iloop, j=0; i<(CONFIG_T::n_state*(iloop+1)); i++,j++){ + #pragma HLS UNROLL + res[i] = h_newstate[j]; + } + reset_state = false; + } + if (CONFIG_T::n_sequence_out == 1) + for(int i=0; i<(CONFIG_T::n_state); i++){ + #pragma HLS UNROLL + res[i] = h_newstate[i]; + } +} + +template + void lstm_stack( + hls::stream &data_stream, + hls::stream &res_stream, + typename CONFIG_T::weight_t param [CONFIG_T::n_state*4*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state*4*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state*4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state*4] + ) { + + typename res_T::value_type h_newstate[CONFIG_T::n_state]; + typename res_T::value_type s_newstate[CONFIG_T::n_state]; + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=s_newstate complete + + for(int ii = 0; ii < CONFIG_T::n_state; ii++) { + #pragma HLS UNROLL + h_newstate[ii] = 0; + s_newstate[ii] = 0; + } + + typename data_T::value_type data_in[CONFIG_T::n_in]; + bool reset_state = true; + + DataPropagation: for(int i_in = 0; i_in < CONFIG_T::n_sequence*CONFIG_T::n_in / data_T::size; i_in++) { + if (CONFIG_T::n_sequence*CONFIG_T::n_in / data_T::size > 1) { + // #pragma HLS PIPELINE + } + data_T data_pack = data_stream.read(); + DataPack: for (int i_pack = 0; i_pack < data_T::size; i_pack++) { + #pragma HLS UNROLL + data_in[i_pack] = data_pack[i_pack]; + } + if (CONFIG_T::use_static) + nnet::lstm_static(reset_state,data_in,h_newstate, param,param_r,param_b, param_br); + else + nnet::lstm(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br); + if (CONFIG_T::n_sequence_out > 1){ + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + ResPack_sequences: for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + #pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } + reset_state = false; + } + + if (CONFIG_T::n_sequence_out == 1){ + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + ResPack: for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + #pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } + +} + +// Struct for the GRU template + +struct gru_config +{ + // Internal data type definitions + typedef float weight_t; + typedef float bias_t; + typedef float accum_t; + + // Layer Sizes + static const unsigned n_in = 2; + static const unsigned n_out = 2; + static const unsigned n_state = 2; + static const unsigned n_sequence = 2; + static const unsigned n_4state = 8; + static const unsigned table_size = 1024; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const bool use_static = true; + static const unsigned n_zeros = 0; + + template + using activation_recr = nnet::activation::relu; + template + using activation = nnet::activation::relu; +}; + +template + void gru(bool reset_state, + data_T data [CONFIG_T::n_in], + res_T h_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param [CONFIG_T::n_state*3*CONFIG_T::n_in], // TODO - Check the layout of the param weights - refer page in copy!! + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state*3*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b [CONFIG_T::n_state*3], + typename CONFIG_T::bias_t param_br [CONFIG_T::n_state*3] + ) { + // Initialize the state variable -- will maintain state between function calls + typename CONFIG_T::accum_t tmpres [CONFIG_T::n_state*3]; + typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state*3]; + typename CONFIG_T::accum_t tmpres_state_h [CONFIG_T::n_state]; + typename CONFIG_T::accum_t tmpres_zr [CONFIG_T::n_state*2]; //activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_h [CONFIG_T::n_state]; //activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_zr [CONFIG_T::n_state*2]; //i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_h [CONFIG_T::n_state]; //c-matrix (keras notation) + + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=tmpres complete + #pragma HLS ARRAY_PARTITION variable=tmpres_state_zr complete + #pragma HLS ARRAY_PARTITION variable=tmpres_state_h complete + #pragma HLS ARRAY_PARTITION variable=tmpres_zr complete + #pragma HLS ARRAY_PARTITION variable=tmpres_h complete + #pragma HLS ARRAY_PARTITION variable=inputacc_zr complete + #pragma HLS ARRAY_PARTITION variable=inputacc_h complete + + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_newstate, tmpres_state_zr, param_zr, param_br); + + // Adding the individual vectors from the multiplication of tmpres = Wx*x(t); tmpres_state_zr = Wh*h(t-1); tmpres initialized with biases -- DONE + for(int iacc = 0; iacc < (2*CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc; + inputacc_zr[iacc] = tmpres[index] + tmpres_state_zr[index]; + } + + // Activation function Sub layer -- START + CONFIG_T::template activation_recr::activation(inputacc_zr, tmpres_zr); + + // Activation function Sub layer -- END + + // Hadamrd product of r(t) = inputacc_zr[2*n_state:n_state] and h(t-1) = h_newstate + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + tmpres_state_h[iacc] = tmpres_zr[iacc+(CONFIG_T::n_state)]*tmpres_state_zr[iacc + (2*CONFIG_T::n_state)]; + } + + //Assuming reset_after is false + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state*2; + inputacc_h[iacc] = tmpres[index] + tmpres_state_h[iacc]; + } + + //Now run the activation on this guy + CONFIG_T::template activation::activation(inputacc_h, tmpres_h); + + //Mix the stat with the previous state + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + h_newstate[iacc] = (res_T)(tmpres_h[iacc]*(1-tmpres_zr[iacc]) + h_newstate[iacc]*tmpres_zr[iacc]); + } +} + +template + void gru_static(bool reset_state, + data_T data [CONFIG_T::n_in], + res_T h_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param [CONFIG_T::n_state*3*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state*3*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b [CONFIG_T::n_state*3], + typename CONFIG_T::bias_t param_br [CONFIG_T::n_state*3] + ) { + // Initialize the state variable -- will maintain state between function calls + + static res_T h_state[CONFIG_T::n_state]; + typename CONFIG_T::accum_t tmpres [CONFIG_T::n_state*3]; + typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state*3]; + typename CONFIG_T::accum_t tmpres_state_h [CONFIG_T::n_state]; + typename CONFIG_T::accum_t tmpres_zr [CONFIG_T::n_state*2]; //activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_h [CONFIG_T::n_state]; //activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_zr [CONFIG_T::n_state*2]; //i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_h [CONFIG_T::n_state]; //c-matrix (keras notation) + + #pragma HLS ARRAY_PARTITION variable=h_state complete + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=tmpres complete + #pragma HLS ARRAY_PARTITION variable=tmpres_state_zr complete + #pragma HLS ARRAY_PARTITION variable=tmpres_state_h complete + #pragma HLS ARRAY_PARTITION variable=tmpres_zr complete + #pragma HLS ARRAY_PARTITION variable=tmpres_h complete + #pragma HLS ARRAY_PARTITION variable=inputacc_zr complete + #pragma HLS ARRAY_PARTITION variable=inputacc_h complete + + if(reset_state){ + for(int i_h_state = 0; i_h_state < (CONFIG_T::n_state); i_h_state++) { + #pragma HLS UNROLL + h_state[i_h_state] = 0; + } + } + + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_state, tmpres_state_zr, param_zr, param_br); + + // Adding the individual vectors from the multiplication of tmpres = Wx*x(t); tmpres_state_zr = Wh*h(t-1); tmpres initialized with biases -- DONE + for(int iacc = 0; iacc < (2*CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc; + inputacc_zr[iacc] = tmpres[index] + tmpres_state_zr[index]; + } + + // Activation function Sub layer -- START + CONFIG_T::template activation_recr::activation(inputacc_zr, tmpres_zr); + + // Activation function Sub layer -- END + + // Hadamrd product of r(t) = inputacc_zr[2*n_state:n_state] and h(t-1) = h_newstate + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + tmpres_state_h[iacc] = tmpres_zr[iacc+(CONFIG_T::n_state)]*tmpres_state_zr[iacc + (2*CONFIG_T::n_state)]; + } + + //Assuming reset_after is false + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state*2; + inputacc_h[iacc] = tmpres[index] + tmpres_state_h[iacc]; + } + + //Now run the activation on this guy + CONFIG_T::template activation::activation(inputacc_h, tmpres_h); + + //Mix the stat with the previous state + for(int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + #pragma HLS UNROLL + h_state[iacc] = (res_T)(tmpres_h[iacc]*(1-tmpres_zr[iacc]) + h_state[iacc]*tmpres_zr[iacc]); + h_newstate[iacc] = h_state[iacc]; + } +} + +template + void gru_stack( + data_T data [CONFIG_T::n_sequence*CONFIG_T::n_in], + res_T res[CONFIG_T::n_sequence_out*CONFIG_T::n_state], + typename CONFIG_T::weight_t param [CONFIG_T::n_state*3*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state*3*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b [CONFIG_T::n_state*3], + typename CONFIG_T::bias_t param_br [CONFIG_T::n_state*3] + ) { + + res_T h_state[CONFIG_T::n_state]; + data_T data_in[CONFIG_T::n_in]; + bool reset_state = true; + + #pragma HLS ARRAY_PARTITION variable=h_state complete + #pragma HLS ARRAY_PARTITION variable=data_in complete + + for(int ii = 0; ii < CONFIG_T::n_state; ii++) { + #pragma HLS UNROLL + h_state[ii] = 0; + } + for(int iloop = 0; iloop < CONFIG_T::n_sequence; iloop++) { + for(int j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + data_in[j] = data[j + iloop*CONFIG_T::n_in]; + } + if (CONFIG_T::use_static) + nnet::gru_static(reset_state,data_in,h_state,param,param_zr,param_b, param_br); + else + nnet::gru(reset_state,data_in,h_state,param,param_zr,param_b, param_br); + if (CONFIG_T::n_sequence_out > 1) + for(int i=CONFIG_T::n_state*iloop, j=0; i<(CONFIG_T::n_state*(iloop+1)); i++,j++){ + #pragma HLS UNROLL + res[i] = h_state[j]; + } + reset_state = false; + } + if (CONFIG_T::n_sequence_out == 1) + for(int i=0; i<(CONFIG_T::n_state); i++){ + #pragma HLS UNROLL + res[i] = h_state[i]; + } + } + +template + void gru_stack( + hls::stream &data_stream, + hls::stream &res_stream, + typename CONFIG_T::weight_t param [CONFIG_T::n_state*3*CONFIG_T::n_in], + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state*3*CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b [CONFIG_T::n_state*3], + typename CONFIG_T::bias_t param_br [CONFIG_T::n_state*3] + ) { + + typename res_T::value_type h_newstate[CONFIG_T::n_state]; + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + for(int ii = 0; ii < CONFIG_T::n_state; ii++) { + #pragma HLS UNROLL + h_newstate[ii] = 0; + } + + typename data_T::value_type data_in[CONFIG_T::n_in]; + bool reset_state = true; + + DataPropagation: for(int i_in = 0; i_in < CONFIG_T::n_sequence*CONFIG_T::n_in / data_T::size; i_in++) { + if (CONFIG_T::n_sequence*CONFIG_T::n_in / data_T::size > 1) { + // #pragma HLS PIPELINE + } + data_T data_pack = data_stream.read(); + DataPack: for (int i_pack = 0; i_pack < data_T::size; i_pack++) { + #pragma HLS UNROLL + data_in[i_pack] = data_pack[i_pack]; + } + if (CONFIG_T::use_static) + nnet::gru_static(reset_state,data_in,h_newstate,param,param_zr,param_b, param_br); + else + nnet::gru(reset_state,data_in,h_newstate,param,param_zr,param_b, param_br); + if (CONFIG_T::n_sequence_out > 1){ + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + ResPack_sequences: for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + #pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } + reset_state = false; + } + + if (CONFIG_T::n_sequence_out == 1){ + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + ResPack: for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + #pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } + +} + + +}//end namespace + +#endif