Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vivado Backend GRU/LSTM support #560

Merged
merged 11 commits into from
Jun 24, 2022
249 changes: 249 additions & 0 deletions hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@

from hls4ml.backends.backend import get_backend
from hls4ml.model.layers import LSTM, GRU
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate

# recurrent multiplication template

recr_mult_config_template = """struct config{index} : nnet::dense_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned strategy = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const unsigned n_nonzeros = {nonzeros};
static const bool store_weights_in_bram = false;
typedef {accum_t.name} accum_t;
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef ap_{index_t} index_t;
template<class x_T, class y_T, class res_T>
using product = nnet::product::{product_type}<x_T, y_T, res_T>;
}};\n"""

#activation templates

activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
typedef ap_{table_t} table_t;
}};\n"""

recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
typedef ap_{table_t} table_t;
}};\n"""

# LSTM + GRU templates

lstm_config_template = """struct config{index} : nnet::lstm_config {{
typedef {accum_t.name} accum_t;
typedef {weight_t.name} weight_t; // Matrix
typedef {bias_t.name} bias_t; // Vector
typedef {config_mult_t1} mult_config1;
typedef {config_mult_t2} mult_config2;
typedef {lstm_act_t} ACT_CONFIG_LSTM;
template<class x_T, class y_T, class config_T>
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
typedef {act_t} ACT_CONFIG_T;
template<class x_T, class y_T, class config_T>
using activation = nnet::activation::{activation}<x_T, y_T, config_T>;
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned n_state = {n_state};
static const unsigned n_sequence = {n_sequence};
static const unsigned n_sequence_out = {n_sequence_out};
static const unsigned io_type = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
static const bool use_static = {static};
}};\n"""

lstm_function_template = 'nnet::lstm_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'

gru_config_template = """struct config{index} : nnet::gru_config {{
typedef {accum_t.name} accum_t;
typedef {weight_t.name} weight_t; // Matrix
typedef {bias_t.name} bias_t; // Vector
typedef {config_mult_t1} mult_config1;
typedef {config_mult_t2} mult_config2;
typedef {gru_act_t} ACT_CONFIG_GRU;
template<class x_T, class y_T, class config_T>
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
typedef {act_t} ACT_CONFIG_T;
template<class x_T, class y_T, class config_T>
using activation = nnet::activation::{activation}<x_T, y_T, config_T>;
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned n_state = {n_state};
static const unsigned n_sequence = {n_sequence};
static const unsigned n_sequence_out = {n_sequence_out};
static const unsigned io_type = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
static const bool use_static = {static};
}};\n"""

gru_function_template = 'nnet::gru_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'

recr_include_list = ['nnet_utils/nnet_recurrent.h']

class LSTMConfigTemplate(LayerConfigTemplate):
dsrankin marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super().__init__(LSTM)
self.template = lstm_config_template
self.act_template = activ_config_template
self.recr_act_template = recr_activ_config_template
self.mult1_template = recr_mult_config_template
self.mult2_template = recr_mult_config_template

def format(self, node):

params = self._default_config_params(node)

params['n_in'] = node.get_input_variable().dim_names[1]
params['n_sequence'] = node.get_input_variable().dim_names[0]
params['n_sequence_out'] = node.get_output_variable().dim_names[0]
params['n_state'] = node.get_output_variable().dim_names[1]
params['n_out'] = node.get_output_variable().dim_names[1]
params['config_mult_t1'] = 'config{}_1'.format(node.index)
params['config_mult_t2'] = 'config{}_2'.format(node.index)
params['lstm_act_t'] = '{}_config{}_recr'.format(node.get_attr('recurrent_activation'), node.index)
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
params['strategy'] = node.get_attr('strategy')
params['static'] = 'true' if node.attributes['static'] else 'false'

recr_config = self.template.format(**params)

act_params = self._default_config_params(node)
recr_act_params = self._default_config_params(node)

act_params['type'] = node.get_attr('activation')
act_params['n_in'] = node.get_output_variable().dim_names[1]
recr_act_params['type'] = node.get_attr('recurrent_activation')
recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * 3'

act_config = self.act_template.format(**act_params)
recr_act_config = self.recr_act_template.format(**recr_act_params)

mult_params1 = self._default_config_params(node)
mult_params2 = self._default_config_params(node)

mult_params1['n_in'] = node.get_input_variable().dim_names[1]
mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * 4'
mult_params1['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
mult_params1['reuse'] = params['reuse']
mult_params1['index'] = str(node.index) + '_1'
mult_params1['nzeros'] = node.get_weights('weight').nzeros
mult_params1['nonzeros'] = node.get_weights('weight').nonzeros
mult_params2['n_in'] = node.get_output_variable().dim_names[1]
mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * 4'
mult_params2['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision)
mult_params2['reuse'] = node.attributes['recurrent_reuse_factor']
mult_params2['index'] = str(node.index) + '_2'
mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros
mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros

mult_config1 = self.mult1_template.format(**mult_params1)
mult_config2 = self.mult2_template.format(**mult_params2)

return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config

class LSTMFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(LSTM, include_header=recr_include_list)
self.template = lstm_function_template

def format(self, node):
params = self._default_function_params(node)
params['w'] = node.get_weights('weight').name
params['b'] = node.get_weights('bias').name
params['wr'] = node.get_weights('recurrent_weight').name
params['br'] = node.get_weights('recurrent_bias').name
params['activation'] = node.get_attr('activation')
params['recurrent_activation'] = node.get_attr('recurrent_activation')

return self.template.format(**params)

class GRUConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(GRU)
self.template = gru_config_template
self.act_template = activ_config_template
self.recr_act_template = recr_activ_config_template
self.mult1_template = recr_mult_config_template
self.mult2_template = recr_mult_config_template

def format(self, node):

params = self._default_config_params(node)

params['n_in'] = node.get_input_variable().dim_names[1]
params['n_sequence'] = node.get_input_variable().dim_names[0]
params['n_sequence_out'] = node.get_output_variable().dim_names[0]
params['n_state'] = node.get_output_variable().dim_names[1]
params['n_out'] = node.get_output_variable().dim_names[1]
params['config_mult_t1'] = 'config{}_1'.format(node.index)
params['config_mult_t2'] = 'config{}_2'.format(node.index)
params['gru_act_t'] = '{}_config{}_recr'.format(node.get_attr('recurrent_activation'), node.index)
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
params['strategy'] = node.get_attr('strategy')
params['static'] = 'true' if node.attributes['static'] else 'false'

recr_config = self.template.format(**params)

act_params = self._default_config_params(node)
recr_act_params = self._default_config_params(node)

act_params['type'] = node.get_attr('activation')
act_params['n_in'] = node.get_output_variable().dim_names[1]
recr_act_params['type'] = node.get_attr('recurrent_activation')
recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * 2'

act_config = self.act_template.format(**act_params)
recr_act_config = self.recr_act_template.format(**recr_act_params)

mult_params1 = self._default_config_params(node)
mult_params2 = self._default_config_params(node)

mult_params1['n_in'] = node.get_input_variable().dim_names[1]
mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * 3'
mult_params1['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
mult_params1['reuse'] = params['reuse']
mult_params1['index'] = str(node.index) + '_1'
mult_params1['nzeros'] = node.get_weights('weight').nzeros
mult_params1['nonzeros'] = node.get_weights('weight').nonzeros
mult_params2['n_in'] = node.get_output_variable().dim_names[1]
mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * 3'
mult_params2['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision)
mult_params2['reuse'] = node.attributes['recurrent_reuse_factor']
mult_params2['index'] = str(node.index) + '_2'
mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros
mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros

mult_config1 = self.mult1_template.format(**mult_params1)
mult_config2 = self.mult2_template.format(**mult_params2)

return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config

class GRUFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(GRU, include_header=recr_include_list)
self.template = gru_function_template

def format(self, node):
params = self._default_function_params(node)
params['w'] = node.get_weights('weight').name
params['b'] = node.get_weights('bias').name
params['wr'] = node.get_weights('recurrent_weight').name
params['br'] = node.get_weights('recurrent_bias').name
params['activation'] = node.get_attr('activation')
params['recurrent_activation'] = node.get_attr('recurrent_activation')

return self.template.format(**params)

4 changes: 2 additions & 2 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def init_lstm(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[0 if layer.model.config.is_resource_strategy(layer) else 1])
dsrankin marked this conversation as resolved.
Show resolved Hide resolved
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)
Expand All @@ -267,7 +267,7 @@ def init_gru(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[0 if layer.model.config.is_resource_strategy(layer) else 1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)
Expand Down
32 changes: 21 additions & 11 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ class SimpleRNN(Layer):
Attribute('return_sequences', value_type=bool, default=False),
Attribute('return_state', value_type=bool, default=False),
ChoiceAttribute('direction', ['forward', 'backward'], default='forward'),
Attribute('static', value_type=bool, default=True),
dsrankin marked this conversation as resolved.
Show resolved Hide resolved

WeightAttribute('weight'),
WeightAttribute('bias'),
Expand All @@ -863,10 +864,9 @@ class SimpleRNN(Layer):
def initialize(self):
if self.attributes['return_sequences']:
shape = [self.attributes['n_timesteps'], self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]
else:
shape = [self.attributes['n_out']]
dims = ['N_OUT_{}'.format(self.index)]
shape = [1, self.attributes['n_out']]
dsrankin marked this conversation as resolved.
Show resolved Hide resolved
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]

self.add_output_variable(shape, dims)

Expand All @@ -891,6 +891,7 @@ class LSTM(Layer):
Attribute('return_state', value_type=bool, default=False),
ChoiceAttribute('direction', ['forward', 'backward'], default='forward'),
Attribute('time_major', value_type=bool, default=False),
Attribute('static', value_type=bool, default=True),

WeightAttribute('weight'),
WeightAttribute('bias'),
Expand All @@ -904,10 +905,9 @@ class LSTM(Layer):
def initialize(self):
if self.attributes['return_sequences']:
shape = [self.attributes['n_timesteps'], self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]
else:
shape = [self.attributes['n_out']]
dims = ['N_OUT_{}'.format(self.index)]
shape = [1, self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]

self.add_output_variable(shape, dims)

Expand All @@ -917,10 +917,15 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

self.add_weights()
data = self.model.get_weights_data(self.name, 'kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
dsrankin marked this conversation as resolved.
Show resolved Hide resolved
data = np.transpose(data)
self.add_weights_variable(name='weight', var_name='w{index}', data=data)
self.add_bias()

recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
recurrent_weight = np.transpose(recurrent_weight)
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

class GRU(Layer):
Expand All @@ -933,6 +938,7 @@ class GRU(Layer):
ChoiceAttribute('direction', ['forward', 'backward'], default='forward'),
Attribute('time_major', value_type=bool, default=False),
ChoiceAttribute('apply_reset_gate', ['before', 'after'], default='after'),
Attribute('static', value_type=bool, default=True),

WeightAttribute('weight'),
WeightAttribute('bias'),
Expand All @@ -946,10 +952,9 @@ class GRU(Layer):
def initialize(self):
if self.attributes['return_sequences']:
shape = [self.attributes['n_timesteps'], self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]
else:
shape = [self.attributes['n_out']]
dims = ['N_OUT_{}'.format(self.index)]
shape = [1, self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]

self.add_output_variable(shape, dims)

Expand All @@ -959,10 +964,15 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

self.add_weights()
data = self.model.get_weights_data(self.name, 'kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
data = np.transpose(data)
self.add_weights_variable(name='weight', var_name='w{index}', data=data)
self.add_bias()

recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
recurrent_weight = np.transpose(recurrent_weight)
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

class GarNet(Layer):
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/model/optimizer/passes/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class ReplaceMultidimensionalDenseWithConv(OptimizerPass):
def match(self, node):
return isinstance(node, Dense) and \
len(node.get_input_variable().shape) > 1
len(node.get_input_variable().shape) - sum(d==1 for d in node.get_input_variable().shape) > 1
dsrankin marked this conversation as resolved.
Show resolved Hide resolved

def transform(self, model, node):
dim = len(node.get_input_variable().shape) - 1
Expand Down
Loading