diff --git a/hls4ml/backends/quartus/passes/recurrent_templates.py b/hls4ml/backends/quartus/passes/recurrent_templates.py new file mode 100644 index 000000000..b37fc1f03 --- /dev/null +++ b/hls4ml/backends/quartus/passes/recurrent_templates.py @@ -0,0 +1,133 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.model.layers import GRU +from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate + +recurrent_include_list = ['nnet_utils/nnet_recurrent.h', 'nnet_utils/nnet_recurrent_stream.h'] + +# Shared Matrix Multiplication Template (Dense) +recr_mult_config_template = '''struct config{index}_mult : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + + static const unsigned rf_pad = {rfpad}; + static const unsigned bf_pad = {bfpad}; + static const unsigned reuse_factor = {reuse}; + static const unsigned reuse_factor_rounded = reuse_factor + rf_pad; + static const unsigned block_factor = DIV_ROUNDUP(n_in*n_out, reuse_factor); + static const unsigned block_factor_rounded = block_factor + bf_pad; + static const unsigned multiplier_factor = MIN(n_in, reuse_factor); + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in*n_out, multiplier_factor); + static const unsigned multiplier_scale = multiplier_limit/n_out; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + + template + using product = nnet::product::{product_type}; +}};\n''' + +# Activation Template +activ_config_template = '''struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; +}};\n''' + +# GRU Template +gru_config_template = '''struct config{index} : nnet::gru_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned n_units = {n_units}; + static const unsigned n_timesteps = {n_timesteps}; + static const unsigned n_outputs = {n_outputs}; + static const bool return_sequences = {return_sequences}; + + typedef {accum_t.name} accum_t; + typedef {weight_t.name} weight_t; + typedef {bias_t.name} bias_t; + + typedef {config_mult_x} mult_config_x; + typedef {config_mult_h} mult_config_h; + + typedef {act_t} ACT_CONFIG_T; + template + using activation = nnet::activation::{activation}; + + typedef {act_recurrent_t} ACT_CONFIG_RECURRENT_T; + template + using activation_recr = nnet::activation::{recurrent_activation}; + + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; +}};\n''' + +gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});' + +class GRUConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(GRU) + self.gru_template = gru_config_template + self.act_template = activ_config_template + self.recr_act_template = activ_config_template + self.mult_x_template = recr_mult_config_template + self.mult_h_template = recr_mult_config_template + + def format(self, node): + # Input has shape (n_timesteps, inp_dimensionality) + # Output / hidden units has shape (1 if !return_sequences else n_timesteps , n_units) + params = self._default_config_params(node) + params['n_units'] = node.get_attr('n_out') + params['n_outputs'] = node.get_attr('n_timesteps') if node.get_attr('return_sequences', False) else '1' + params['return_sequences'] ='true' if node.get_attr('return_sequences', False) else 'false' + params['config_mult_x'] = 'config{}_x_mult'.format(node.index) + params['config_mult_h'] = 'config{}_h_mult'.format(node.index) + params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act') + params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act') + gru_config = self.gru_template.format(**params) + + # Activation is on candidate hidden state, dimensionality (1, n_units) + act_params = self._default_config_params(node) + act_params['type'] = node.get_attr('activation') + act_params['n_in'] = node.get_attr('n_out') + act_params['index'] = str(node.index) + '_act' + act_config = self.act_template.format(**act_params) + + # Recurrent activation is on reset and update gates (therefore x2), dimensionality (1, n_units) + recr_act_params = self._default_config_params(node) + recr_act_params['type'] = node.get_attr('recurrent_activation') + recr_act_params['n_in'] = str(node.get_attr('n_out')) + ' * 2' + recr_act_params['index'] = str(node.index) + '_rec_act' + recr_act_config = self.recr_act_template.format(**recr_act_params) + + # Multiplication config for matrix multiplications of type Wx (reset, update and candidate states) + mult_params_x = self._default_config_params(node) + mult_params_x['n_in'] = node.get_attr('n_in') + mult_params_x['n_out'] = str(node.get_attr('n_out')) + ' * 3' + mult_params_x['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_params_x['index'] = str(node.index) + '_x' + mult_config_x = self.mult_x_template.format(**mult_params_x) + + # Multiplication config for matrix multiplications of type Wh (reset, update and candidate states) + mult_params_h = self._default_config_params(node) + mult_params_h['n_in'] = node.get_attr('n_out') + mult_params_h['n_out'] = str(node.get_attr('n_out')) + ' * 3' + mult_params_h['reuse_factor'] = params['recurrent_reuse_factor'] + mult_params_h['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision) + mult_params_h['index'] = str(node.index) + '_h' + mult_config_h = self.mult_h_template.format(**mult_params_h) + + return mult_config_x + '\n' + mult_config_h + '\n' + recr_act_config + '\n' + act_config + '\n' + gru_config + +class GRUFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(GRU, include_header=recurrent_include_list) + self.template = gru_function_template + + def format(self, node): + params = self._default_function_params(node) + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + params['wr'] = node.get_weights('recurrent_weight').name + params['br'] = node.get_weights('recurrent_bias').name + return self.template.format(**params) diff --git a/hls4ml/backends/quartus/passes/resource_strategy.py b/hls4ml/backends/quartus/passes/resource_strategy.py new file mode 100644 index 000000000..797eb3352 --- /dev/null +++ b/hls4ml/backends/quartus/passes/resource_strategy.py @@ -0,0 +1,46 @@ +import numpy as np +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.layers import Dense, GRU + +class ApplyResourceStrategy(OptimizerPass): + ''' Transposes the weights to use the dense_resource matrix multiply routine ''' + def match(self, node): + node_matches = isinstance(node, (Dense, GRU)) + is_resource_strategy = True # node.get_attr('strategy', '').lower() == 'resource' ... Quartus only supports resource strategy + already_transformed = node.get_attr('_weights_transposed', False) == True + return node_matches and is_resource_strategy and not already_transformed + + def transform(self, model, node): + if isinstance(node, Dense) and not node.model.config.get_compression(node): + rf = node.get_attr('reuse_factor') + bf = int((node.attributes['n_in']*node.attributes['n_out'])/rf) + bf_rounded = int(pow(2, np.ceil(np.log2(bf)))) + rf_rounded = int(pow(2, np.ceil(np.log2(rf)))) + + node.weights['weight'].data = np.transpose(node.weights['weight'].data).flatten() + + if(node.attributes['n_in']*node.attributes['n_out'] > 2048 and rf_rounded != rf): + node.set_attr('rfpad', rf_rounded-rf) + node.set_attr('bfpad', bf_rounded-bf) + + temp = np.empty([bf_rounded, rf_rounded]) + for i in range(rf_rounded): + for j in range (bf_rounded): + if (i < rf and j < bf): + w_index = i + rf * j + temp[j][i] = node.weights['weight'].data[w_index] + else: + temp[j][i] = 0 + node.weights['weight'].data = temp.flatten() + node.weights['weight'].data_length = node.weights['weight'].data.size + + elif isinstance(node, GRU): + node.weights['weight'].data = np.transpose(node.weights['weight'].data) + node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) + + else: + raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name)) + + node.set_attr('_weights_transposed', True) + return False + diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index ab79bd820..a02be6fb6 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -1,12 +1,14 @@ -import numpy as np import os +from hls4ml.model.attributes import Attribute +import numpy as np from contextlib import contextmanager + +from hls4ml.backends import FPGABackend from hls4ml.model.types import NamedType, IntegerPrecisionType, FixedPrecisionType -from hls4ml.model.layers import Layer, Dense, Activation, Softmax, Embedding -from hls4ml.model.optimizer import get_backend_passes, layer_optimizer +from hls4ml.model.layers import Embedding, Layer, Dense, Activation, Softmax, GRU from hls4ml.model.flow import register_flow -from hls4ml.backends import FPGABackend from hls4ml.report import parse_quartus_report +from hls4ml.model.optimizer import get_backend_passes, layer_optimizer @contextmanager def chdir(newdir): @@ -20,8 +22,16 @@ def chdir(newdir): class QuartusBackend(FPGABackend): def __init__(self): super(QuartusBackend, self).__init__('Quartus') + self._register_layer_attributes() self._register_flows() + def _register_layer_attributes(self): + extended_attrs = { + GRU: [Attribute('recurrent_reuse_factor', default=1)], + } + self.attribute_map.update(extended_attrs) + + def _register_flows(self): initializers = self._get_layer_initializers() init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) @@ -33,6 +43,7 @@ def _register_flows(self): quartus_types = [ 'quartus:transform_types', + 'quartus:apply_resource_strategy' ] quartus_types_flow = register_flow('specific_types', quartus_types, requires=[init_flow], backend=self.name) @@ -86,31 +97,6 @@ def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_para return config - def gen_quartus_weight_array(self, layer): - rf = layer.get_attr('reuse_factor') - block_factor = int((layer.attributes['n_in']*layer.attributes['n_out'])/rf) - bf_rounded = int(pow(2, np.ceil(np.log2(block_factor)))) - rf_rounded = int(pow(2, np.ceil(np.log2(rf)))) - - layer.weights['weight'].data = np.transpose(layer.weights['weight'].data).flatten() - - if(layer.attributes['n_in']*layer.attributes['n_out'] > 2048 and rf_rounded != rf): - layer.set_attr('rfpad', rf_rounded-rf) - layer.set_attr('bfpad', bf_rounded-block_factor) - - temp = np.empty([bf_rounded, rf_rounded]) - for i in range(rf_rounded): - for j in range (bf_rounded): - if (i < rf and j < block_factor): - w_index = i + rf * j - temp[j][i] = layer.weights['weight'].data[w_index] - else: - temp[j][i] = 0 - layer.weights['weight'].data = temp.flatten() - - layer.weights['weight'].data_length = layer.weights['weight'].data.size - return - def build(self, model, synth=True, fpgasynth=False): """ Builds the project using Intel HLS compiler. @@ -163,7 +149,6 @@ def init_dense(self, layer): else: n_in, n_out = self.get_layer_mult_size(layer) self.set_closest_reuse_factor(layer, n_in, n_out) - self.gen_quartus_weight_array(layer) layer.set_attr('strategy', 'resource') if layer.model.config.is_resource_strategy(layer): @@ -196,4 +181,27 @@ def init_softmax(self, layer): @layer_optimizer(Embedding) def init_embed(self, layer): if layer.attributes['n_in'] is None: - raise Exception('Input length of Embedding layer must be specified.') \ No newline at end of file + raise Exception('Input length of Embedding layer must be specified.') + + @layer_optimizer(GRU) + def init_gru(self, layer): + reuse_factor = layer.model.config.get_reuse_factor(layer) + layer.set_attr('recurrent_reuse_factor', reuse_factor) + + # Dense multiplication properties + layer.set_attr('rfpad', 0) + layer.set_attr('bfpad', 0) + + index_t = IntegerPrecisionType(width=1, signed=False) + + if 'table_t' not in layer.attributes: + layer.set_attr('table_t', FixedPrecisionType(width=18, integer=8)) + if 'table_size' not in layer.attributes: + layer.set_attr('table_size', 1024) + if True: # layer.model.config.is_resource_strategy(layer): ... Quartus only supports Dense resource multiplication + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') + layer.set_attr('strategy', 'resource') + + layer.set_attr('index_t', index_t) \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h new file mode 100644 index 000000000..5ccaf4161 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h @@ -0,0 +1,141 @@ +#ifndef NNET_RECURRENT_H_ +#define NNET_RECURRENT_H_ + +#include "nnet_common.h" +#include "nnet_dense.h" +#include "nnet_recurrent_activation.h" + +namespace nnet { + +struct gru_config { + // Internal data type definitions + typedef float weight_t; + typedef float bias_t; + typedef float accum_t; + + // Layer Sizes + static const unsigned n_in = 1; + static const unsigned n_out = 1; + static const unsigned n_units = 1; + static const unsigned n_timesteps = 1; + static const unsigned n_outputs = 1; + static const bool return_sequences = false; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + + template + using activation_recr = nnet::activation::relu; + + template + using activation = nnet::activation::relu; +}; + +template +void gru_cell( + data_T x[CONFIG_T::n_in], + res_T h[CONFIG_T::n_units], + const typename CONFIG_T::weight_t weights[3 * CONFIG_T::n_units * CONFIG_T::n_in], + const typename CONFIG_T::weight_t recurrent_weights[3 * CONFIG_T::n_units * CONFIG_T::n_units], + const typename CONFIG_T::bias_t bias[3 * CONFIG_T::n_units], + const typename CONFIG_T::bias_t recurrent_bias[3 * CONFIG_T::n_units] +) { + static constexpr int recurrent_unroll_factor = CONFIG_T::n_units / CONFIG_T::reuse_factor; + // A matrix containing the values of matrix product between input (x) and weights (weights), for update, reset and candidate state gates, for each of the units + hls_register typename CONFIG_T::accum_t mat_mul_x_w[3 * CONFIG_T::n_units]; + nnet::dense_resource(x, mat_mul_x_w, weights, bias); + + // A matrix containing the values of matrix product between previou state (h) and recurrent weights (recurrent_weights), for update, reset and candidate state gates, for each of the units + hls_register typename CONFIG_T::accum_t mat_mul_h_wr[3 * CONFIG_T::n_units]; + nnet::dense_resource(h, mat_mul_h_wr, recurrent_weights, recurrent_bias); + + // A vector containing both the values of z(t) and r(t) for every state + hls_register typename CONFIG_T::accum_t z_r [2 * CONFIG_T::n_units]; + + // Add the individual vectors from the multiplication of mat_mul_x_w = Wx*x(t) and mat_mul_h_wr = Wh*h(t-1) + // Unrolled fully, no DSPs used + #pragma unroll + for(int i = 0; i < (2 * CONFIG_T::n_units); i++) { + z_r[i] = mat_mul_x_w[i] + mat_mul_h_wr[i]; + } + + // Activation on z(t) and r(t) + hls_register typename CONFIG_T::accum_t z_r_act [2*CONFIG_T::n_units]; + CONFIG_T::template activation_recr::activation(z_r, z_r_act); + + // A matrix containing the values of Hadamard product between r(t) = z_r_act[n_units:2*n_units] and h(t-1) = h + hls_register typename CONFIG_T::accum_t hadamard_r_h[CONFIG_T::n_units]; + #pragma unroll recurrent_unroll_factor + for(int i = 0; i < (CONFIG_T::n_units); i++) { + hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units]; + } + + // The candidate state; X * W_{hx} + hadmard(r(t), h_(t-1)) * W_{hh} + b_{h} + typename CONFIG_T::accum_t h_cand[CONFIG_T::n_units]; + // Addition - can unroll fully; no DSPs used here + #pragma unroll + for(int i = 0; i < (CONFIG_T::n_units); i++) { + h_cand[i] = mat_mul_x_w[i + 2 * CONFIG_T::n_units] + hadamard_r_h[i]; + } + + // Activation on candidate state + hls_register typename CONFIG_T::accum_t h_cand_act[CONFIG_T::n_units]; + CONFIG_T::template activation::activation(h_cand, h_cand_act); + + // Update state + #pragma unroll recurrent_unroll_factor + for(int i = 0; i < (CONFIG_T::n_units); i++) { + h[i] = static_cast(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]); + } +} + +template +void gru( + data_T data[CONFIG_T::n_in], + res_T res[CONFIG_T::n_outputs * CONFIG_T::n_units], + const typename CONFIG_T::weight_t weights[3 * CONFIG_T::n_units * CONFIG_T::n_in], + const typename CONFIG_T::weight_t recurrent_weights[3 * CONFIG_T::n_units * CONFIG_T::n_units], + const typename CONFIG_T::bias_t bias[3 * CONFIG_T::n_units], + const typename CONFIG_T::bias_t recurrent_bias[3 * CONFIG_T::n_units] +) { + + hls_register data_T x[CONFIG_T::n_in]; + hls_register res_T h[CONFIG_T::n_units]; + + #pragma unroll + for(int i = 0; i < CONFIG_T::n_units; i++) { + h[i] = 0; + } + + // Loop depedency - cannot pipeline + #pragma disable_loop_pipelining + for(int t = 0; t < CONFIG_T::n_timesteps; t++) { + // Get data at current time step + #pragma unroll + for(int j = 0; j < CONFIG_T::n_in; j++) { + x[j] = data[j + t * CONFIG_T::n_in]; + } + + nnet::gru_cell(x, h, weights, recurrent_weights, bias, recurrent_bias); + + if (CONFIG_T::return_sequences) { + #pragma unroll + for(int i = 0 ; i < CONFIG_T::n_units ; i++) { + res[CONFIG_T::n_units * t + i] = h[i]; + } + } + } + + if (!CONFIG_T::return_sequences) { + #pragma unroll + for(int i = 0; i < (CONFIG_T::n_units); i++) { + res[i] = h[i]; + } + } +} + +} + +#endif diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent_activation.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent_activation.h new file mode 100644 index 000000000..d2e4b8da7 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent_activation.h @@ -0,0 +1,57 @@ +#ifndef NNET_RECR_ACTIVATION_H_ +#define NNET_RECR_ACTIVATION_H_ + +#include "nnet_common.h" +#include "nnet_activation.h" + +namespace nnet { + +namespace activation { + +template +class Activation { + public: + // ************************************************* + // Blank Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {} +}; + +template +class relu : public Activation { + public: + // ************************************************* + // Relu Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + nnet::relu(data, res); + } +}; + +template +class sigmoid : public Activation{ + public: + // ************************************************* + // Sigmoid Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + nnet::sigmoid(data, res); + } +}; + +template +class tanh : public Activation{ + public: + // ************************************************* + // TanH Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + nnet::dense_tanh(data, res); + } +}; + +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent_stream.h new file mode 100644 index 000000000..bf414b9a2 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent_stream.h @@ -0,0 +1,10 @@ +/* +* PLACEHOLDER - TODO - Implement once PR #557 is merged +*/ + +#ifndef NNET_RECURRENT_STREAM_H_ +#define NNET_RECURRENT_STREAM_H_ + +namespace nnet {} + +#endif diff --git a/test/pytest/test_rnn.py b/test/pytest/test_rnn.py index bc7ecc7aa..444b26329 100644 --- a/test/pytest/test_rnn.py +++ b/test/pytest/test_rnn.py @@ -2,8 +2,6 @@ import hls4ml import numpy as np from pathlib import Path -import math -from tensorflow.keras import backend as K from tensorflow.keras.models import Model, Sequential from tensorflow.keras.layers import Input, SimpleRNN, LSTM, GRU @@ -65,10 +63,14 @@ def test_rnn_parsing(rnn_layer, return_sequences): else: np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2]) -@pytest.mark.parametrize('rnn_layer', [LSTM, GRU]) +@pytest.mark.parametrize('rnn_layer,backend, io_type', [ + (LSTM, 'Vivado', 'io_parallel'), + (LSTM, 'Vivado', 'io_stream'), + (GRU, 'Vivado', 'io_parallel'), + (GRU, 'Vivado', 'io_stream'), + (GRU, 'Quartus', 'io_parallel'), + ]) @pytest.mark.parametrize('return_sequences', [True, False]) -@pytest.mark.parametrize('backend', ['Vivado']) -@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('static', [True, False]) def test_rnn_accuracy(rnn_layer, return_sequences, backend, io_type, static): # Subtract 0.5 to include negative values