diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index 4683239d8..5532d07dd 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -717,19 +717,27 @@ template void selu(data_T data[CO initialized = true; } - #pragma HLS PIPELINE + typedef ap_ufixed<16,2> selu_const_t; + static const selu_const_t lambda = 1.0507009873554805; - data_T datareg; - // Index into the lookup table based on data - int index; + #pragma HLS PIPELINE for (int ii = 0; ii < CONFIG_T::n_in; ii++) { - datareg = data[ii]; + data_T datareg = data[ii]; + if (datareg >= 0) { - res[ii] = res_T(1.0507009873554804934193349852946) * datareg; + // Positive branch y = λ · x + res[ii] = lambda * datareg; } else { - index = datareg * CONFIG_T::table_size / -8; - if (index > CONFIG_T::table_size - 1) + // Negative branch y = table(x) + int index = datareg * CONFIG_T::table_size / -8; + + // clamp index to [0, table_size-1] + if (index < 0) + index = 0; + else if (index > CONFIG_T::table_size - 1) { index = CONFIG_T::table_size - 1; + } + res[ii] = selu_table[index]; } }