Skip to content

Qonnx binary quant #1292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ docs/autodoc/*
hls4mlprj_*
*~
*.ipynb_checkpoints/
*.bak
2 changes: 1 addition & 1 deletion example-models
11 changes: 11 additions & 0 deletions hls4ml/converters/onnx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,14 @@ def parse_quant_layer(node, input_names, input_shapes, graph):
layer['signed'] = bool(get_onnx_attribute(node, 'signed'))

return layer


@onnx_handler('BipolarQuant')
def parse_bipolar_quant_layer(node, input_names, input_shapes, graph):
layer = {}

layer['class_name'] = 'BipolarQuant'
layer['name'] = node.name
layer['inputs'] = input_names
layer['outputs'] = list(node.output)
return layer
16 changes: 16 additions & 0 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,21 @@ def initialize(self):
self.add_output_variable(shape, dims)


class BipolarQuant(Layer): # The QONNX quantization layer
"""
This is a QONNX quantization layer. Optimizations should convert it
before HLS is produced.
"""

_expected_attributes = []

def initialize(self):
inp = self.get_input_variable(self.inputs[0])
shape = inp.shape
dims = inp.dim_names
self.add_output_variable(shape, dims)


class Reshape(Layer):
_expected_attributes = [
Attribute('target_shape', value_type=typing.Sequence),
Expand Down Expand Up @@ -1724,6 +1739,7 @@ def initialize(self):
'GarNet': GarNet,
'GarNetStack': GarNetStack,
'Quant': Quant,
'BipolarQuant': BipolarQuant,
'ApplyAlpha': ApplyAlpha,
'BatchNormOnnx': BatchNormOnnx,
'LayerGroup': LayerGroup,
Expand Down
3 changes: 3 additions & 0 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
'fuse_quant_with_constant',
'const_quant_to_const_alpha',
'quant_to_alpha_activation_alpha',
'bipolar_quant_constant_parameters',
'bipolar_quant_to_activation',
'fuse_bipolar_quant_with_constant',
'batch_norm_onnx_constant_parameters',
'constant_batch_norm_fusion',
'merge_two_constants',
Expand Down
133 changes: 133 additions & 0 deletions hls4ml/model/optimizer/passes/bipolar_quant_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
This file includes optimizations related to BipolarQuant nodes.

"""

import numpy as np

from hls4ml.model.layers import Activation, BipolarQuant, Constant
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.quantizers import BinaryQuantizer
from hls4ml.model.types import XnorPrecisionType


class BipolarQuantConstantParameters(OptimizerPass):
"""Remove Constant from the BipolarQaunt node parameters (but not input[0])"""

def match(self, node):
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 2
and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant))
)

return is_match

def transform(self, model, node):
"""
Remove Constant from the BipolarQuant node parameters (but not input[0])
"""
if node.get_input_node(node.inputs[1]):
scale_node = node.get_input_node(node.inputs[1])
if isinstance(scale_node, Constant):
node.set_attr('scale', scale_node.get_attr('value'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't seem to handle to case when scale != 1. Ideally we should be able to extract ApplyAlpha scales in such a case that we propagate up and down. I think basic support can be fairly straightforwadly added, in the style of the Quant support. (If we don't support scale != 1, we should catch those cases and exit gracefully, with an error message.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the the BinaryQuantizer code and it does not define a scaling factor. Meaning that this can only work for scale factors 1. Further more this whole optimizer pass becomes irrelevant. So I will delete it.
What does ApplyAlpha do? I am not familiar with this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a quant layer with a scale and/or zero offset really means scale/shift, then quantize, then unscale/unshift. The ApplyAlpha are scale and shift layers in the hls4ml IR. When a quant node is applied to a weight, the initial scaling/shifting can actually be done to the weights (assuming they are constant and not update able). Otherwise, the hope is that the scaling and unscaling can be moved around the graph to where the implementation is easiest. There are optimizers that already exist for that.

node.inputs[1] = ''
model.remove_node(scale_node)

node.inputs = [inp for inp in node.inputs if inp]
if len(node.inputs) != 1:
raise RuntimeError("hls4ml only supports constant scale")

return True


class BipolarQuantToActivation(OptimizerPass):
"""
This is for the case when scale is po2.
It is a a 1:1 transformation of a BipolarQuant to an Activation.
As an optimization, this is not called when the input is constant.
"""

def match(self, node):
# only matches after the other inputs are already folded
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 1
and not isinstance(node.get_input_node(node.inputs[0]), Constant)
)

# Only match if the scale is po2
if is_match: # to make sure this is a quant node with inputs
scale = node.get_attr('scale')
scale_unit_or_po2 = (scale == 1.0).all()
# This optimization only works if all scales are the same
if np.all(scale[0] == scale):
mantissa, _ = np.frexp(scale[0])
scale_unit_or_po2 = mantissa == 0.5
is_match = scale_unit_or_po2

return is_match

def transform(self, model, node):
"""
Change BipolarQuant node to Activation
"""
precision = XnorPrecisionType()
quantizer = BinaryQuantizer(bits=1)

attributes = {'activation': 'linear', 'quantizer': quantizer}

# update the configuration
config = model.config.get_layer_config(node)
prec_config = config.setdefault('Precision', {})
prec_config['result'] = str(precision)
new_name = f'{node.name}_act'
model.config.set_name_config(new_name, config)
model.config.parse_name_config(new_name, config)
print(f"Node {new_name} inputs: {[node.inputs[0]]}, outputs: {list(node.outputs)}.")
new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], list(node.outputs))
model.replace_node(node, new_node)
return True


class FuseBipolarQuantWithConstant(OptimizerPass):
"""
This is for the case when scale is po2.
"""

def match(self, node):

# only matches after the other inputs are already folded
# and scale is unit
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 1
and isinstance(node.get_input_node(node.inputs[0]), Constant)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, does this really work if scale != 1? If it doesn't, the matching criteria should change.

)

# Only match if the scale is po2
if is_match: # to make sure this is a quant node with inputs
scale = node.get_attr('scale')
scale_unit_or_po2 = (scale == 1.0).all()
# This optimization only works if all scales are the same
if np.all(scale[0] == scale):
mantissa, _ = np.frexp(scale[0])
scale_unit_or_po2 = mantissa == 0.5
is_match = scale_unit_or_po2

return is_match

def transform(self, model, node):
"""
Fuse BipolarQuant with Constant.
"""
precision = XnorPrecisionType()
quantizer = BinaryQuantizer(bits=1)

const_node = node.get_input_node(node.inputs[0])
const_node.set_attr('quantizer', quantizer)
const_node.get_output_variable().type.precision = precision

# remove the Quant node
model.remove_node(node)
return True
108 changes: 108 additions & 0 deletions test/pytest/test_qonnx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
import os
import urllib
from pathlib import Path

import numpy as np
import onnx
import pytest
import qonnx.core.onnx_exec as oxe
import qonnx.util.cleanup
Expand All @@ -12,6 +14,7 @@
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
from qonnx.util.cleanup import cleanup_model

import hls4ml

Expand Down Expand Up @@ -231,6 +234,69 @@ def conv2d_small_mp_keras_model():
return model


@pytest.fixture(scope='module')
def bnn_fc_small_qonnx_model():
"""
Load a small binarized model of a single fully connected layer.
"""
dl_file = str(example_model_path / "onnx/bnn_model_fc_1layer.onnx")
assert os.path.isfile(dl_file)

model = ModelWrapper(dl_file)
model = cleanup_model(model)
model = model.transform(GemmToMatMul()) # ishape = (1, 3)
model = qonnx.util.cleanup.cleanup_model(model)
return model


@pytest.fixture(scope='module')
def bnn_fc_small_qonnx_model_scale_nonunit(bnn_fc_small_qonnx_model):
"""
Use scale factors of 0.5 to see if that works.
This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors.
"""

model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary?
new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [0.5])
new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [0.5])
old_iscale = old_wscale = None
for init in model.graph.initializer:
if init.name == "BipolarQuant_0_param0":
old_iscale = init
elif init.name == "BipolarQuant_1_param1":
old_wscale = init
model.graph.initializer.remove(old_iscale)
model.graph.initializer.remove(old_wscale)
model.graph.initializer.append(new_iscale)
model.graph.initializer.append(new_wscale)
model = qonnx.util.cleanup.cleanup_model(model)
return model


@pytest.fixture(scope='module')
def bnn_fc_small_qonnx_model_scale_nonunit2(bnn_fc_small_qonnx_model):
"""
Use po2 scale factors to see if that works.
This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors.
"""

model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary?
new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [2])
new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [4])
old_iscale = old_wscale = None
for init in model.graph.initializer:
if init.name == "BipolarQuant_0_param0":
old_iscale = init
elif init.name == "BipolarQuant_1_param1":
old_wscale = init
model.graph.initializer.remove(old_iscale)
model.graph.initializer.remove(old_wscale)
model.graph.initializer.append(new_iscale)
model.graph.initializer.append(new_wscale)
model = qonnx.util.cleanup.cleanup_model(model)
return model


# The actual tests


Expand Down Expand Up @@ -428,3 +494,45 @@ def test_simple_model(model_name, io_type, backend, request):
y_hls4ml = hls_model.predict(X)

np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1)


@pytest.mark.parametrize(
'model_name',
['bnn_fc_small_qonnx_model', 'bnn_fc_small_qonnx_model_scale_nonunit', 'bnn_fc_small_qonnx_model_scale_nonunit2'],
)
@pytest.mark.parametrize('backend', ['Vitis'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_bnn(model_name, io_type, backend, request):
"Checks if a basic binarized model works correctly."
qonnx_model = request.getfixturevalue(model_name)

config = hls4ml.utils.config.config_from_onnx_model(
qonnx_model, granularity='name', backend=backend, default_precision='fixed<16,6>'
)
hls_model = hls4ml.converters.convert_from_onnx_model(
qonnx_model,
output_dir=str(test_root_path / f'hls4mlprj_onnx_{model_name}_{io_type}_{backend}'),
io_type=io_type,
backend=backend,
hls_config=config,
)
hls_model.compile()

data_x = np.array(
[
[[+1, +1, +1]],
[[+1, +1, -1]],
[[+1, -1, +1]],
[[-1, -1, -1]],
[[-1, +1, +1]],
[[-1, +1, -1]],
[[-1, -1, +1]],
[[-1, -1, -1]],
],
dtype=np.float32,
)
for x in data_x:
idict = {qonnx_model.graph.input[0].name: x}
y_qonnx = oxe.execute_onnx(qonnx_model, idict)[qonnx_model.graph.output[0].name]
y_hls4ml = hls_model.predict(x)
np.array_equal(y_qonnx.ravel(), y_hls4ml.ravel())
Loading