-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
433 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from hls4ml.backends.backend import get_backend | ||
from hls4ml.model.layers import GRU | ||
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate | ||
|
||
recurrent_include_list = ['nnet_utils/nnet_recurrent.h', 'nnet_utils/nnet_recurrent_stream.h'] | ||
|
||
# Shared Matrix Multiplication Template (Dense) | ||
recr_mult_config_template = '''struct config{index}_mult : nnet::dense_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned n_out = {n_out}; | ||
static const unsigned rf_pad = {rfpad}; | ||
static const unsigned bf_pad = {bfpad}; | ||
static const unsigned reuse_factor = {reuse}; | ||
static const unsigned reuse_factor_rounded = reuse_factor + rf_pad; | ||
static const unsigned block_factor = DIV_ROUNDUP(n_in*n_out, reuse_factor); | ||
static const unsigned block_factor_rounded = block_factor + bf_pad; | ||
static const unsigned multiplier_factor = MIN(n_in, reuse_factor); | ||
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in*n_out, multiplier_factor); | ||
static const unsigned multiplier_scale = multiplier_limit/n_out; | ||
typedef {accum_t.name} accum_t; | ||
typedef {bias_t.name} bias_t; | ||
typedef {weight_t.name} weight_t; | ||
template<class x_T, class y_T> | ||
using product = nnet::product::{product_type}<x_T, y_T>; | ||
}};\n''' | ||
|
||
# Activation Template | ||
activ_config_template = '''struct {type}_config{index} : nnet::activ_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned table_size = {table_size}; | ||
static const unsigned io_type = nnet::{iotype}; | ||
static const unsigned reuse_factor = {reuse}; | ||
}};\n''' | ||
|
||
# GRU Template | ||
gru_config_template = '''struct config{index} : nnet::gru_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned n_out = {n_out}; | ||
static const unsigned n_units = {n_units}; | ||
static const unsigned n_timesteps = {n_timesteps}; | ||
static const unsigned n_outputs = {n_outputs}; | ||
static const bool return_sequences = {return_sequences}; | ||
typedef {accum_t.name} accum_t; | ||
typedef {weight_t.name} weight_t; | ||
typedef {bias_t.name} bias_t; | ||
typedef {config_mult_x} mult_config_x; | ||
typedef {config_mult_h} mult_config_h; | ||
typedef {act_t} ACT_CONFIG_T; | ||
template<class x_T, class y_T, class config_T> | ||
using activation = nnet::activation::{activation}<x_T, y_T, config_T>; | ||
typedef {act_recurrent_t} ACT_CONFIG_RECURRENT_T; | ||
template<class x_T, class y_T, class config_T> | ||
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>; | ||
static const unsigned reuse_factor = {reuse}; | ||
static const bool store_weights_in_bram = false; | ||
}};\n''' | ||
|
||
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});' | ||
|
||
class GRUConfigTemplate(LayerConfigTemplate): | ||
def __init__(self): | ||
super().__init__(GRU) | ||
self.gru_template = gru_config_template | ||
self.act_template = activ_config_template | ||
self.recr_act_template = activ_config_template | ||
self.mult_x_template = recr_mult_config_template | ||
self.mult_h_template = recr_mult_config_template | ||
|
||
def format(self, node): | ||
# Input has shape (n_timesteps, inp_dimensionality) | ||
# Output / hidden units has shape (1 if !return_sequences else n_timesteps , n_units) | ||
params = self._default_config_params(node) | ||
params['n_units'] = node.get_attr('n_out') | ||
params['n_outputs'] = node.get_attr('n_timesteps') if node.get_attr('return_sequences', False) else '1' | ||
params['return_sequences'] ='true' if node.get_attr('return_sequences', False) else 'false' | ||
params['config_mult_x'] = 'config{}_x_mult'.format(node.index) | ||
params['config_mult_h'] = 'config{}_h_mult'.format(node.index) | ||
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act') | ||
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act') | ||
gru_config = self.gru_template.format(**params) | ||
|
||
# Activation is on candidate hidden state, dimensionality (1, n_units) | ||
act_params = self._default_config_params(node) | ||
act_params['type'] = node.get_attr('activation') | ||
act_params['n_in'] = node.get_attr('n_out') | ||
act_params['index'] = str(node.index) + '_act' | ||
act_config = self.act_template.format(**act_params) | ||
|
||
# Recurrent activation is on reset and update gates (therefore x2), dimensionality (1, n_units) | ||
recr_act_params = self._default_config_params(node) | ||
recr_act_params['type'] = node.get_attr('recurrent_activation') | ||
recr_act_params['n_in'] = str(node.get_attr('n_out')) + ' * 2' | ||
recr_act_params['index'] = str(node.index) + '_rec_act' | ||
recr_act_config = self.recr_act_template.format(**recr_act_params) | ||
|
||
# Multiplication config for matrix multiplications of type Wx (reset, update and candidate states) | ||
mult_params_x = self._default_config_params(node) | ||
mult_params_x['n_in'] = node.get_attr('n_in') | ||
mult_params_x['n_out'] = str(node.get_attr('n_out')) + ' * 3' | ||
mult_params_x['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) | ||
mult_params_x['index'] = str(node.index) + '_x' | ||
mult_config_x = self.mult_x_template.format(**mult_params_x) | ||
|
||
# Multiplication config for matrix multiplications of type Wh (reset, update and candidate states) | ||
mult_params_h = self._default_config_params(node) | ||
mult_params_h['n_in'] = node.get_attr('n_out') | ||
mult_params_h['n_out'] = str(node.get_attr('n_out')) + ' * 3' | ||
mult_params_h['reuse_factor'] = params['recurrent_reuse_factor'] | ||
mult_params_h['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision) | ||
mult_params_h['index'] = str(node.index) + '_h' | ||
mult_config_h = self.mult_h_template.format(**mult_params_h) | ||
|
||
return mult_config_x + '\n' + mult_config_h + '\n' + recr_act_config + '\n' + act_config + '\n' + gru_config | ||
|
||
class GRUFunctionTemplate(FunctionCallTemplate): | ||
def __init__(self): | ||
super().__init__(GRU, include_header=recurrent_include_list) | ||
self.template = gru_function_template | ||
|
||
def format(self, node): | ||
params = self._default_function_params(node) | ||
params['w'] = node.get_weights('weight').name | ||
params['b'] = node.get_weights('bias').name | ||
params['wr'] = node.get_weights('recurrent_weight').name | ||
params['br'] = node.get_weights('recurrent_bias').name | ||
return self.template.format(**params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import numpy as np | ||
from hls4ml.model.optimizer import OptimizerPass | ||
from hls4ml.model.layers import Dense, GRU | ||
|
||
class ApplyResourceStrategy(OptimizerPass): | ||
''' Transposes the weights to use the dense_resource matrix multiply routine ''' | ||
def match(self, node): | ||
node_matches = isinstance(node, (Dense, GRU)) | ||
is_resource_strategy = True # node.get_attr('strategy', '').lower() == 'resource' ... Quartus only supports resource strategy | ||
already_transformed = node.get_attr('_weights_transposed', False) == True | ||
return node_matches and is_resource_strategy and not already_transformed | ||
|
||
def transform(self, model, node): | ||
if isinstance(node, Dense) and not node.model.config.get_compression(node): | ||
rf = node.get_attr('reuse_factor') | ||
bf = int((node.attributes['n_in']*node.attributes['n_out'])/rf) | ||
bf_rounded = int(pow(2, np.ceil(np.log2(bf)))) | ||
rf_rounded = int(pow(2, np.ceil(np.log2(rf)))) | ||
|
||
node.weights['weight'].data = np.transpose(node.weights['weight'].data).flatten() | ||
|
||
if(node.attributes['n_in']*node.attributes['n_out'] > 2048 and rf_rounded != rf): | ||
node.set_attr('rfpad', rf_rounded-rf) | ||
node.set_attr('bfpad', bf_rounded-bf) | ||
|
||
temp = np.empty([bf_rounded, rf_rounded]) | ||
for i in range(rf_rounded): | ||
for j in range (bf_rounded): | ||
if (i < rf and j < bf): | ||
w_index = i + rf * j | ||
temp[j][i] = node.weights['weight'].data[w_index] | ||
else: | ||
temp[j][i] = 0 | ||
node.weights['weight'].data = temp.flatten() | ||
node.weights['weight'].data_length = node.weights['weight'].data.size | ||
|
||
elif isinstance(node, GRU): | ||
node.weights['weight'].data = np.transpose(node.weights['weight'].data) | ||
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) | ||
|
||
else: | ||
raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name)) | ||
|
||
node.set_attr('_weights_transposed', True) | ||
return False | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.