Skip to content

Commit

Permalink
Quartus Softmax optimize LUT to store only used values
Browse files Browse the repository at this point in the history
  • Loading branch information
bo3z committed Aug 4, 2022
1 parent c8e8f75 commit 41fd158
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 28 deletions.
36 changes: 15 additions & 21 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,17 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
enum class softmax_implementation {latency=0, legacy=1, stable=2};

template<class data_T, typename CONFIG_T>
inline unsigned softmax_idx_from_real_val(const data_T x){
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
// Number of address bits for table
static constexpr int N = ceillog2(CONFIG_T::table_size);

// Slice the top N bits of the input
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
return y.to_uint();
}

template<class data_T, typename CONFIG_T>
inline unsigned softmax_latency_idx_from_real_val(const data_T x){
// Number of address bits for table
static constexpr int N = ceillog2(CONFIG_T::table_size);

Expand All @@ -148,27 +158,20 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
// Find maximum
Op_max<data_T> op_max;
hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);

// Calculate differences from the maximum, forcing rounding and saturation for better accuracy
hls_register ac_fixed<data_T::width, data_T::i_width, true, AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
d_xi_xmax[i] = data[i] - x_max;
}

// Calculate all the e^x's
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
exp_res[i] = exp_table[softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i])];
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data[i] - x_max)];
}

// Explicitly sum previously calculated exponentials with an adder tree
Op_add<typename CONFIG_T::exp_table_t> op_add;
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

// Multiply previously calculated exponetials with the reciprocal of the sum
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
res[i] = exp_res[i] * inv_exp_sum;
Expand All @@ -178,31 +181,22 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
// TODO - Improve accuracy
template <class data_T, class res_T, typename CONFIG_T>
void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
/*
* Note: The latency tables are equivalent to stable tables
* However, the compiler cannot include the same table twice
* Therefore, an out-of-scope exception is thrown in one of the functions
* Temporary solution - Create the same table twice in quartus_writer.py
* Long-term solution - Only create tables needed by the network;
* Currently, quartus-writer.py generates LUTs for all activations,
* Regardless if they are present in the network or not
*/
#include "activation_tables/exp_table_latency.tb"
#include "activation_tables/invert_table_latency.tb"

// Calculate all the e^x's
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
exp_res[i] = exp_table_latency[softmax_idx_from_real_val<data_T, CONFIG_T>(data[i])];
exp_res[i] = exp_table_latency[softmax_latency_idx_from_real_val<data_T, CONFIG_T>(data[i])];
}

// Explicitly sum the results with an adder tree.
Op_add<typename CONFIG_T::exp_table_t> op_add;
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

// Multiply previously calculated exponetials with the reciprocal of the sum
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++){
res[i] = exp_res[i] * inv_exp_sum;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,15 @@ void softmax_stable(stream<data_T> &data, stream<res_T> &res) {
hls_register typename CONFIG_T::exp_table_t exp_res[data_T::size];
#pragma unroll
for(unsigned j = 0; j < data_T::size; j++) {
exp_res[j] = exp_table[softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j])];
exp_res[j] = exp_table[softmax_stable_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j])];
}

// Explicitly sum the results with an adder tree.
// Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing
Op_add<typename CONFIG_T::exp_table_t> op_add;
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, data_T::size, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
res_T out_pack;

SoftmaxInvPackLoop:
Expand Down Expand Up @@ -327,7 +327,7 @@ void softmax_latency(stream<data_T> &data, stream<res_T> &res){
SoftmaxExpPackLoop:
#pragma unroll
for(unsigned j = 0; j < data_T::size; j++) {
exp_res[j] = exp_table_latency[softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j])];
exp_res[j] = exp_table_latency[softmax_latency_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j])];
}

// Explicitly sum the results with an adder tree.
Expand All @@ -336,7 +336,7 @@ void softmax_latency(stream<data_T> &data, stream<res_T> &res){
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

// Multiply previously calculated exponetials with the reciprocal of the sum
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];

res_T out_pack;
SoftmaxInvPackLoop:
Expand Down
17 changes: 14 additions & 3 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,12 +918,19 @@ def __write_exp_table(self, model, path):
except:
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
pass
if fp_signed is False:
raise Exception('Softmax types need to be signed')

sep = ''
N = ceil_log2(table_size)
for i in range(table_size):
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
f.set_msb_bits(uint_to_binary(i, N))
b = uint_to_binary(i, N)
if i == 0:
b.insert(0, 0)
else:
b.insert(0, 1)
f.set_msb_bits(b)
real_val = f.exp_float()
h_file.write(sep + str(real_val))
sep = ", "
Expand Down Expand Up @@ -957,19 +964,23 @@ def __write_invert_table(self, model, path):
except:
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
pass
if fp_signed is False:
raise Exception('Softmax types need to be signed')

sep = ''
N = ceil_log2(table_size)
for i in range(table_size):
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
f.set_msb_bits(uint_to_binary(i, N))
b = uint_to_binary(i, N)
b.insert(0, 0)
f.set_msb_bits(b)
real_val = f.inv_float()
h_file.write(sep + str(real_val))
sep = ", "

h_file.write('};\n')
h_file.close()

def __write_exp_table_latency(self, model, path):
table_name = 'exp_table_latency'
table_size = self.__get_table_size(model, 'softmax')
Expand Down

0 comments on commit 41fd158

Please sign in to comment.