Skip to content

QDenseBatchnorm #718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from hls4ml.backends.backend import get_backend
from hls4ml.model.layers import Activation, BatchNormalization, Dense, Embedding, PReLU, ParametrizedActivation, Softmax
from hls4ml.model.layers import Activation, BatchNormalization, Dense, DenseBatchnorm, Embedding, PReLU, ParametrizedActivation, Softmax
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate

# Dense templates
Expand Down Expand Up @@ -28,7 +28,7 @@

class DenseConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(Dense)
super().__init__((Dense, DenseBatchnorm))
self.template = dense_config_template

def format(self, node):
Expand All @@ -41,7 +41,7 @@ def format(self, node):

class DenseFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Dense, include_header=dense_include_list)
super().__init__((Dense, DenseBatchnorm), include_header=dense_include_list)
self.template = dense_function_template

def format(self, node):
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader):

@keras_handler('BatchNormalization')
def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
assert 'BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name']
assert('BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'] or 'QDenseBatchnorm' in keras_layer['class_name'])

layer = parse_default_keras_layer(keras_layer, input_names)

Expand Down
11 changes: 11 additions & 0 deletions hls4ml/converters/keras/qkeras_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,14 @@ def parse_qconv2dbatchnorm_layer(keras_layer, input_names, input_shapes, data_re
temp_shape = intermediate_shape
batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader)
return {**conv_layer, **batch_layer}, out_shape

@keras_handler('QDenseBatchnorm')
def parse_qdensebatchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
intermediate_shape = list()
dense_layer, shape_qdense = parse_qdense_layer(keras_layer, input_names, input_shapes, data_reader)
intermediate_shape.append(shape_qdense)
temp_shape = intermediate_shape
batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader)
# remove n_in from batch_layer to prevent overwrite n_in from dense_layer
batch_layer.pop('n_in')
return {**dense_layer, **batch_layer}, out_shape
46 changes: 46 additions & 0 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,51 @@ def initialize(self):
self.add_bias(quantizer=self.get_attr('bias_quantizer'))


class DenseBatchnorm(Dense):
def _get_folded_weights(self):
"""
Function to get the batchnorm folded weights.
This function converts the weights by folding batchnorm parameters into
the weight of QDense. The high-level equation:
W_fold = gamma * W / sqrt(variance + epsilon)
bias_fold = gamma * (bias - moving_mean) / sqrt(variance + epsilon) + beta
"""
kernel = self.model.get_weights_data(self.name, 'kernel')
bias = self.model.get_weights_data(self.name, 'bias')
if bias is None:
bias = 0

# get batchnorm weights and moving stats
gamma = self.model.get_weights_data(self.name, 'gamma')
beta = self.model.get_weights_data(self.name, 'beta')
moving_mean = self.model.get_weights_data(self.name, 'moving_mean')
moving_variance = self.model.get_weights_data(self.name, 'moving_variance')
# get the inversion factor so that we replace division by multiplication
inv = np.reciprocal(np.sqrt(moving_variance + self.get_attr('epsilon')))
if gamma is not None:
inv *= gamma

# wrap conv kernel and bias with bn parameters
folded_kernel = inv * kernel
folded_bias = inv * (bias - moving_mean) + beta

return [folded_kernel, folded_bias]

def initialize(self):
super(DenseBatchnorm, self).initialize()
folded_weights, folded_bias = self._get_folded_weights()
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
self.weights['weight'].data_unquantized = np.transpose(folded_weights)
self.weights['weight'].data = self.get_attr('weight_quantizer')(self.weights['weight'].data_unquantized)

else:
self.weights['weight'].data_unquantized = folded_weights
self.weights['weight'].data = self.get_attr('weight_quantizer')(folded_weights)
self.weights['bias'].data_unquantized = folded_bias
bias_q = self.get_attr('bias_quantizer')
if bias_q is not None:
self.weights['bias'].data = bias_q(folded_bias)

class Conv1D(Layer):
_expected_attributes = [
Attribute('in_width'),
Expand Down Expand Up @@ -1269,6 +1314,7 @@ def _initialize_transforms(self):
'BinaryDense': Dense,
'TernaryDense': Dense,
'QDense': Dense,
'QDenseBatchnorm': DenseBatchnorm,
'Conv1D': Conv1D,
'QConv1D': Conv1D,
'Conv2D': Conv2D,
Expand Down