Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax LUT Optimization #570

Merged
merged 2 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,17 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
enum class softmax_implementation {latency=0, legacy=1, stable=2};

template<class data_T, typename CONFIG_T>
inline unsigned softmax_idx_from_real_val(const data_T x){
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
// Number of address bits for table
static constexpr int N = ceillog2(CONFIG_T::table_size);

// Slice the top N bits of the input
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
return y.to_uint();
}

template<class data_T, typename CONFIG_T>
inline unsigned softmax_latency_idx_from_real_val(const data_T x){
// Number of address bits for table
static constexpr int N = ceillog2(CONFIG_T::table_size);

Expand All @@ -148,27 +158,20 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
// Find maximum
Op_max<data_T> op_max;
hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);

// Calculate differences from the maximum, forcing rounding and saturation for better accuracy
hls_register ac_fixed<data_T::width, data_T::i_width, true, AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
d_xi_xmax[i] = data[i] - x_max;
}

// Calculate all the e^x's
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
exp_res[i] = exp_table[softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i])];
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data[i] - x_max)];
}

// Explicitly sum previously calculated exponentials with an adder tree
Op_add<typename CONFIG_T::exp_table_t> op_add;
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

// Multiply previously calculated exponetials with the reciprocal of the sum
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
res[i] = exp_res[i] * inv_exp_sum;
Expand All @@ -178,31 +181,22 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
// TODO - Improve accuracy
template <class data_T, class res_T, typename CONFIG_T>
void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
/*
* Note: The latency tables are equivalent to stable tables
* However, the compiler cannot include the same table twice
* Therefore, an out-of-scope exception is thrown in one of the functions
* Temporary solution - Create the same table twice in quartus_writer.py
* Long-term solution - Only create tables needed by the network;
* Currently, quartus-writer.py generates LUTs for all activations,
* Regardless if they are present in the network or not
*/
#include "activation_tables/exp_table_latency.tb"
#include "activation_tables/invert_table_latency.tb"

// Calculate all the e^x's
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
exp_res[i] = exp_table_latency[softmax_idx_from_real_val<data_T, CONFIG_T>(data[i])];
exp_res[i] = exp_table_latency[softmax_latency_idx_from_real_val<data_T, CONFIG_T>(data[i])];
}

// Explicitly sum the results with an adder tree.
Op_add<typename CONFIG_T::exp_table_t> op_add;
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

// Multiply previously calculated exponetials with the reciprocal of the sum
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
#pragma unroll
for(unsigned i = 0; i < CONFIG_T::n_in; i++){
res[i] = exp_res[i] * inv_exp_sum;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,15 @@ void softmax_stable(stream<data_T> &data, stream<res_T> &res) {
hls_register typename CONFIG_T::exp_table_t exp_res[data_T::size];
#pragma unroll
for(unsigned j = 0; j < data_T::size; j++) {
exp_res[j] = exp_table[softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j])];
exp_res[j] = exp_table[softmax_stable_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j])];
}

// Explicitly sum the results with an adder tree.
// Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing
Op_add<typename CONFIG_T::exp_table_t> op_add;
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, data_T::size, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
res_T out_pack;

SoftmaxInvPackLoop:
Expand Down Expand Up @@ -327,7 +327,7 @@ void softmax_latency(stream<data_T> &data, stream<res_T> &res){
SoftmaxExpPackLoop:
#pragma unroll
for(unsigned j = 0; j < data_T::size; j++) {
exp_res[j] = exp_table_latency[softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j])];
exp_res[j] = exp_table_latency[softmax_latency_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j])];
}

// Explicitly sum the results with an adder tree.
Expand All @@ -336,7 +336,7 @@ void softmax_latency(stream<data_T> &data, stream<res_T> &res){
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);

// Multiply previously calculated exponetials with the reciprocal of the sum
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];

res_T out_pack;
SoftmaxInvPackLoop:
Expand Down
6 changes: 3 additions & 3 deletions hls4ml/utils/fixed_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set_msb_bits(self, bits):
for i in range(0, len(bits)):
if i < self.I:
self.integer_bits[i] = bits[i]
elif i >= self.I and i<self.F:
elif i >= self.I and i<self.N:
self.decimal_bits[i-self.I] = bits[i]

'''
Expand All @@ -77,7 +77,7 @@ def set_msb_bits(self, bits):
Notice:
- If e^x overflow, maximum value of float is used
'''
def exp_float(self, sig_figs=6):
def exp_float(self, sig_figs=12):
try:
return round(math.exp(self.to_float()), sig_figs)
except OverflowError:
Expand All @@ -89,7 +89,7 @@ def exp_float(self, sig_figs=6):
Returns:
- Float : 1/x, rounded some number of decimal points
'''
def inv_float(self, sig_figs=10):
def inv_float(self, sig_figs=12):
if self.to_float()!=0:
return round(1.0/self.to_float(), sig_figs)
else:
Expand Down
17 changes: 14 additions & 3 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,12 +918,19 @@ def __write_exp_table(self, model, path):
except:
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
pass
if fp_signed is False:
raise Exception('Softmax types need to be signed')

sep = ''
N = ceil_log2(table_size)
for i in range(table_size):
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
f.set_msb_bits(uint_to_binary(i, N))
b = uint_to_binary(i, N)
if i == 0:
b.insert(0, 0)
else:
b.insert(0, 1)
f.set_msb_bits(b)
real_val = f.exp_float()
h_file.write(sep + str(real_val))
sep = ", "
Expand Down Expand Up @@ -957,19 +964,23 @@ def __write_invert_table(self, model, path):
except:
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
pass
if fp_signed is False:
raise Exception('Softmax types need to be signed')

sep = ''
N = ceil_log2(table_size)
for i in range(table_size):
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
f.set_msb_bits(uint_to_binary(i, N))
b = uint_to_binary(i, N)
b.insert(0, 0)
f.set_msb_bits(b)
real_val = f.inv_float()
h_file.write(sep + str(real_val))
sep = ", "

h_file.write('};\n')
h_file.close()

def __write_exp_table_latency(self, model, path):
table_name = 'exp_table_latency'
table_size = self.__get_table_size(model, 'softmax')
Expand Down