From d09391306e011c3ac4b67dcd251f4c2cfa31b0f8 Mon Sep 17 00:00:00 2001 From: jvreca Date: Mon, 28 Apr 2025 16:16:50 +0200 Subject: [PATCH 01/13] Added support for BipolarQuant. Its converted to BinaryQuant in hls4ml. --- hls4ml/converters/onnx/core.py | 12 ++ hls4ml/model/layers.py | 16 ++ hls4ml/model/optimizer/__init__.py | 3 + .../optimizer/passes/bipolar_quant_opt.py | 155 ++++++++++++++++++ 4 files changed, 186 insertions(+) create mode 100644 hls4ml/model/optimizer/passes/bipolar_quant_opt.py diff --git a/hls4ml/converters/onnx/core.py b/hls4ml/converters/onnx/core.py index 8ad851426d..ead34ee777 100644 --- a/hls4ml/converters/onnx/core.py +++ b/hls4ml/converters/onnx/core.py @@ -120,3 +120,15 @@ def parse_quant_layer(node, input_names, input_shapes, graph): layer['signed'] = bool(get_onnx_attribute(node, 'signed')) return layer + + +@onnx_handler('BipolarQuant') +def parse_bipolar_quant_layer(node, input_names, input_shapes, graph): + layer = {} + + layer['class_name'] = 'BipolarQuant' + layer['name'] = node.name + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + return layer + diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 0efeaafa3d..a63730f139 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -403,6 +403,21 @@ def initialize(self): self.add_output_variable(shape, dims) +class BipolarQuant(Layer): # The QONNX quantization layer + """ + This is a QONNX quantization layer. Optimizations should convert it + before HLS is produced. + """ + + _expected_attributes = [] + + def initialize(self): + inp = self.get_input_variable(self.inputs[0]) + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + + class Reshape(Layer): _expected_attributes = [ Attribute('target_shape', value_type=typing.Sequence), @@ -1706,6 +1721,7 @@ def initialize(self): 'GarNet': GarNet, 'GarNetStack': GarNetStack, 'Quant': Quant, + 'BipolarQuant': BipolarQuant, 'ApplyAlpha': ApplyAlpha, 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index c474970448..8b25805fc7 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -40,6 +40,9 @@ 'fuse_quant_with_constant', 'const_quant_to_const_alpha', 'quant_to_alpha_activation_alpha', + 'bipolar_quant_constant_parameters', + 'bipolar_quant_to_activation', + 'fuse_bipolar_quant_with_constant', 'batch_norm_onnx_constant_parameters', 'constant_batch_norm_fusion', 'merge_two_constants', diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py new file mode 100644 index 0000000000..894f432c2d --- /dev/null +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -0,0 +1,155 @@ +""" +This file includes optimizations related to BipolarQuant nodes. + +As a first step, QuantConstantParameters converts the extra inputs to attributes. + +The next step differs between the case of (1) (positive) power-of-2 scale and zero offset, or (2) other cases. In the first +case no explicit scaling is required, so a Quant node logically becomes a linear activation. (Cases when the scale is a +power of 2 not equal to one are implicitly scaled with fixed precision types.) When the activation is applied to a constant +weight, the activation is immediately merged with the weight, quantizing the weights. In case (2), we need to explicitly +scale and unscale, so the Quant node becomes 3 nodes, an ApplyAlpha node to apply a scale/shift, a Linear node to apply the +quantization, and another ApplyAlpha to unscale/shift. We depend on optimization steps to move the unscaling ApplyAlpha +down as needed so that we can do integer or fixed-point calculations. When the Quant is a applied to a weight, the scaling +and Linear nodes are immediately merged into the Constant. + +""" + +import copy +import math # prefer to use math.ceil for scalar values + +import numpy as np + +from hls4ml.model.layers import Activation, ApplyAlpha, Constant, BipolarQuant +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.quantizers import BinaryQuantizer +from hls4ml.model.types import XnorPrecisionType + +_ALSO_MATCH_PO2 = True + + +class BipolarQuantConstantParameters(OptimizerPass): + """Remove Constant from the Qaunt node parameters (but not input[0])""" + + def match(self, node): + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 2 + and ( + (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) + ) + ) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from the Quant node parameters (but not input[0]) + """ + if node.get_input_node(node.inputs[1]): + scale_node = node.get_input_node(node.inputs[1]) + if isinstance(scale_node, Constant): + node.set_attr('scale', scale_node.get_attr('value')) + node.inputs[1] = '' + model.remove_node(scale_node) + + node.inputs = [inp for inp in node.inputs if inp] + if len(node.inputs) != 1: + raise RuntimeError("hls4ml only supports constant scale") + + return True + + +class BipolarQuantToActivation(OptimizerPass): + """ + This is for the case when scale is a (positive) power of 2 and zeropt is 0. It is a a 1:1 transformation of + a BipolarQuant to an Activation. + + As an optimization, this is not called when the input is constant. + """ + + def match(self, node): + # only matches after the other inputs are already folded + + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and not isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is power of 2 and the zero-point is 0s + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + # check if scale is ones-like or a power of two + scale_unit_or_po2 = (scale == np.ones_like(scale)).all() + is_match = scale_unit_or_po2 + + return is_match + + def transform(self, model, node): + """ + Change quant node to Activation + """ + scale = node.get_attr('scale') + assert np.all(scale == 1.0) # TODO: Is this required? + + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + attributes = {'activation': 'linear', 'quantizer': quantizer} + + # update the configuration + config = model.config.get_layer_config(node) + prec_config = config.setdefault('Precision', {}) + prec_config['result'] = str(precision) + new_name = f'{node.name}_act' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + model.replace_node(node, new_node) + return True + + +class FuseBipolarQuantWithConstant(OptimizerPass): + """ + This is for the case when scale is a positive power of 2 and zeropt is 0. + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, BipolarQuant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is power of 2 and the zero-point is 0s + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + + # check if scale is ones-like or a power of two + scale_unit_or_po2 = (scale == np.ones_like(scale)).all() + is_match = scale_unit_or_po2 + + return is_match + + def transform(self, model, node): + """ + Fuse Quant with Constant. + """ + + scale = node.get_attr('scale') + assert np.all(scale == 1.0) # TODO: Is this required? + + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + const_node = node.get_input_node(node.inputs[0]) + const_node.set_attr('quantizer', quantizer) + const_node.get_output_variable().type.precision = precision + + # Should we update the configuration to reflect the new precision? I don't think it's necessary + + # remove the Quant node + model.remove_node(node) + + return True + From 4b661802852a428352845d623fa2560b6da23b50 Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 7 May 2025 10:09:45 +0200 Subject: [PATCH 02/13] Added binarized qonnx model for testing the binary quant transformation --- test/pytest/bnn_model_fc_1layer.onnx | Bin 0 -> 2321 bytes test/pytest/test_qonnx.py | 40 ++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 test/pytest/bnn_model_fc_1layer.onnx diff --git a/test/pytest/bnn_model_fc_1layer.onnx b/test/pytest/bnn_model_fc_1layer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..72e16fae4e1fbe62024250b285846fd2094ec111 GIT binary patch literal 2321 zcmbtWO>fgM7>=8CL@(`VBPwk@mIG1_Nn2opN!zM0X;P6uAaU6xa%t2gYMsnYS-DO8 zl8Qe97yblSeg!w|toc|~+QuBLB(~q@_4D|B?5L{tAp8;q^po2xca3ACee$~wKM}}t zqXS3g*dq?6rbohAWM9uPiA-l61~jr=Oon5Jny&*gjaa|kv4dH}r?s`0JH%OD7>qG} z&oR@;rdcquRmtO=R0Bfh#(+AXy$Da?4}up7;Em&r+-)2Z2Exo=?UF9Ps7O Date: Wed, 7 May 2025 10:57:43 +0200 Subject: [PATCH 03/13] Pre-commit fixes --- .gitignore | 1 + hls4ml/converters/onnx/core.py | 1 - .../optimizer/passes/bipolar_quant_opt.py | 16 ++++------ test/pytest/test_qonnx.py | 29 +++++++++++-------- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 22c8ff685b..8fb87927ce 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ docs/autodoc/* hls4mlprj_* *~ *.ipynb_checkpoints/ +*.bak diff --git a/hls4ml/converters/onnx/core.py b/hls4ml/converters/onnx/core.py index ead34ee777..efb5f20c4a 100644 --- a/hls4ml/converters/onnx/core.py +++ b/hls4ml/converters/onnx/core.py @@ -131,4 +131,3 @@ def parse_bipolar_quant_layer(node, input_names, input_shapes, graph): layer['inputs'] = input_names layer['outputs'] = list(node.output) return layer - diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index 894f432c2d..f9662bda60 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -14,12 +14,9 @@ """ -import copy -import math # prefer to use math.ceil for scalar values - import numpy as np -from hls4ml.model.layers import Activation, ApplyAlpha, Constant, BipolarQuant +from hls4ml.model.layers import Activation, BipolarQuant, Constant from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.quantizers import BinaryQuantizer from hls4ml.model.types import XnorPrecisionType @@ -34,9 +31,7 @@ def match(self, node): is_match = ( isinstance(node, BipolarQuant) and len(node.inputs) == 2 - and ( - (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) - ) + and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) ) return is_match @@ -118,7 +113,9 @@ class FuseBipolarQuantWithConstant(OptimizerPass): def match(self, node): # only matches after the other inputs are already folded is_match = ( - isinstance(node, BipolarQuant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and isinstance(node.get_input_node(node.inputs[0]), Constant) ) # Only match if the scale is power of 2 and the zero-point is 0s @@ -138,7 +135,7 @@ def transform(self, model, node): scale = node.get_attr('scale') assert np.all(scale == 1.0) # TODO: Is this required? - + precision = XnorPrecisionType() quantizer = BinaryQuantizer(bits=1) @@ -152,4 +149,3 @@ def transform(self, model, node): model.remove_node(node) return True - diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index bed762f053..05edfe2b09 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -13,6 +13,7 @@ from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean from qonnx.transformation.gemm_to_matmul import GemmToMatMul from qonnx.util.cleanup import cleanup_model + import hls4ml test_root_path = Path(__file__).parent @@ -437,7 +438,7 @@ def test_bnn(io_type, backend): test_dir = os.path.dirname(os.path.abspath(__file__)) qonnx_model = ModelWrapper(f'{test_dir}/bnn_model_fc_1layer.onnx') qonnx_model = cleanup_model(qonnx_model) - qonnx_model = qonnx_model.transform(GemmToMatMul()) # ishape = (1, 3) + qonnx_model = qonnx_model.transform(GemmToMatMul()) # ishape = (1, 3) qonnx_model = qonnx.util.cleanup.cleanup_model(qonnx_model) config = hls4ml.utils.config.config_from_onnx_model( qonnx_model, granularity='name', backend=backend, default_precision='fixed<16,6>' @@ -451,18 +452,22 @@ def test_bnn(io_type, backend): hls_config=config, ) hls_model.compile() - - X = np.array([[[+1, +1, +1]], - [[+1, +1, -1]], - [[+1, -1, +1]], - [[-1, -1, -1]], - [[-1, +1, +1]], - [[-1, +1, -1]], - [[-1, -1, +1]], - [[-1, -1, -1]], - ], dtype=np.float32) + + X = np.array( + [ + [[+1, +1, +1]], + [[+1, +1, -1]], + [[+1, -1, +1]], + [[-1, -1, -1]], + [[-1, +1, +1]], + [[-1, +1, -1]], + [[-1, -1, +1]], + [[-1, -1, -1]], + ], + dtype=np.float32, + ) for x in X: - idict = {qonnx_model.graph.input[0].name: X[0]} + idict = {qonnx_model.graph.input[0].name: x} y_qonnx = oxe.execute_onnx(qonnx_model, idict)[qonnx_model.graph.output[0].name] y_hls4ml = hls_model.predict(X) np.array_equal(y_qonnx.ravel(), y_hls4ml.ravel()) From 6a74bfbc76fb48e7a18365f40b3eda3c50fa18df Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 8 May 2025 10:46:54 +0200 Subject: [PATCH 04/13] Removed BipolarQuantConstantParameters, because such an optimization is not applicable for BinaryQuantizer since it does not have a scaling factor --- .../optimizer/passes/bipolar_quant_opt.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index f9662bda60..b9071bb7be 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -24,36 +24,6 @@ _ALSO_MATCH_PO2 = True -class BipolarQuantConstantParameters(OptimizerPass): - """Remove Constant from the Qaunt node parameters (but not input[0])""" - - def match(self, node): - is_match = ( - isinstance(node, BipolarQuant) - and len(node.inputs) == 2 - and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) - ) - - return is_match - - def transform(self, model, node): - """ - Remove Constant from the Quant node parameters (but not input[0]) - """ - if node.get_input_node(node.inputs[1]): - scale_node = node.get_input_node(node.inputs[1]) - if isinstance(scale_node, Constant): - node.set_attr('scale', scale_node.get_attr('value')) - node.inputs[1] = '' - model.remove_node(scale_node) - - node.inputs = [inp for inp in node.inputs if inp] - if len(node.inputs) != 1: - raise RuntimeError("hls4ml only supports constant scale") - - return True - - class BipolarQuantToActivation(OptimizerPass): """ This is for the case when scale is a (positive) power of 2 and zeropt is 0. It is a a 1:1 transformation of From 2dfdb2585c72a15efde2993475016aa046dd035a Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 8 May 2025 11:07:22 +0200 Subject: [PATCH 05/13] Limited FuseBipolarQuantWithConstant to only support scale factors of 1 --- .../optimizer/passes/bipolar_quant_opt.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index b9071bb7be..1fccf05b7e 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -77,35 +77,25 @@ def transform(self, model, node): class FuseBipolarQuantWithConstant(OptimizerPass): """ - This is for the case when scale is a positive power of 2 and zeropt is 0. + This is for the case when scale is 1 and zeropt is 0. """ def match(self, node): + scale = node.get_attr('scale') # only matches after the other inputs are already folded + # and scale is unit is_match = ( isinstance(node, BipolarQuant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) + and (scale == 1.0).all() ) - - # Only match if the scale is power of 2 and the zero-point is 0s - if is_match: # to make sure this is a quant node with inputs - scale = node.get_attr('scale') - - # check if scale is ones-like or a power of two - scale_unit_or_po2 = (scale == np.ones_like(scale)).all() - is_match = scale_unit_or_po2 - return is_match def transform(self, model, node): """ Fuse Quant with Constant. """ - - scale = node.get_attr('scale') - assert np.all(scale == 1.0) # TODO: Is this required? - precision = XnorPrecisionType() quantizer = BinaryQuantizer(bits=1) @@ -113,8 +103,6 @@ def transform(self, model, node): const_node.set_attr('quantizer', quantizer) const_node.get_output_variable().type.precision = precision - # Should we update the configuration to reflect the new precision? I don't think it's necessary - # remove the Quant node model.remove_node(node) From 89e2136f18590518b4aaa669f939e26804895e8f Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 8 May 2025 11:09:52 +0200 Subject: [PATCH 06/13] Removed bipolar_quant_constant_parameters from list of optimizations, since it was already removed. --- hls4ml/model/optimizer/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 8b25805fc7..36cc45b155 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -40,7 +40,6 @@ 'fuse_quant_with_constant', 'const_quant_to_const_alpha', 'quant_to_alpha_activation_alpha', - 'bipolar_quant_constant_parameters', 'bipolar_quant_to_activation', 'fuse_bipolar_quant_with_constant', 'batch_norm_onnx_constant_parameters', From 76968b52bceb8fd8b06b2c356e4681e5dc0d4d3b Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 8 May 2025 11:26:35 +0200 Subject: [PATCH 07/13] Modified the optimizations to only consider transform when scaling factor is unit. --- .../optimizer/passes/bipolar_quant_opt.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index 1fccf05b7e..334a9440b7 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -14,8 +14,6 @@ """ -import numpy as np - from hls4ml.model.layers import Activation, BipolarQuant, Constant from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.quantizers import BinaryQuantizer @@ -26,27 +24,24 @@ class BipolarQuantToActivation(OptimizerPass): """ - This is for the case when scale is a (positive) power of 2 and zeropt is 0. It is a a 1:1 transformation of - a BipolarQuant to an Activation. - + This is for the case when scale is a (positive) 1 and zeropt is 0. + It is a a 1:1 transformation of a BipolarQuant to an Activation. As an optimization, this is not called when the input is constant. """ def match(self, node): # only matches after the other inputs are already folded - is_match = ( isinstance(node, BipolarQuant) and len(node.inputs) == 1 and not isinstance(node.get_input_node(node.inputs[0]), Constant) ) - # Only match if the scale is power of 2 and the zero-point is 0s + # Only match if the scale is 1 and the zero-point is 0s if is_match: # to make sure this is a quant node with inputs scale = node.get_attr('scale') - # check if scale is ones-like or a power of two - scale_unit_or_po2 = (scale == np.ones_like(scale)).all() - is_match = scale_unit_or_po2 + scale_unit = (scale == 1.0).all() + is_match = scale_unit return is_match @@ -54,9 +49,6 @@ def transform(self, model, node): """ Change quant node to Activation """ - scale = node.get_attr('scale') - assert np.all(scale == 1.0) # TODO: Is this required? - precision = XnorPrecisionType() quantizer = BinaryQuantizer(bits=1) @@ -70,7 +62,7 @@ def transform(self, model, node): model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) - new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + new_node = model.make_node(Activation, new_name, attributes, list(node.inputs[0]), list(node.outputs)) model.replace_node(node, new_node) return True @@ -81,15 +73,21 @@ class FuseBipolarQuantWithConstant(OptimizerPass): """ def match(self, node): - scale = node.get_attr('scale') + # only matches after the other inputs are already folded # and scale is unit is_match = ( isinstance(node, BipolarQuant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) - and (scale == 1.0).all() ) + + # Only match if the scale is 1 and the zero-point is 0s + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + scale_unit = (scale == 1.0).all() + is_match = scale_unit + return is_match def transform(self, model, node): @@ -105,5 +103,4 @@ def transform(self, model, node): # remove the Quant node model.remove_node(node) - return True From 8a20361e7b986d514c06904fe788cbfd8ede6b1c Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 8 May 2025 12:49:09 +0200 Subject: [PATCH 08/13] Removed left-over docs from copying. --- hls4ml/model/optimizer/passes/bipolar_quant_opt.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index 334a9440b7..de250567c9 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -1,17 +1,6 @@ """ This file includes optimizations related to BipolarQuant nodes. -As a first step, QuantConstantParameters converts the extra inputs to attributes. - -The next step differs between the case of (1) (positive) power-of-2 scale and zero offset, or (2) other cases. In the first -case no explicit scaling is required, so a Quant node logically becomes a linear activation. (Cases when the scale is a -power of 2 not equal to one are implicitly scaled with fixed precision types.) When the activation is applied to a constant -weight, the activation is immediately merged with the weight, quantizing the weights. In case (2), we need to explicitly -scale and unscale, so the Quant node becomes 3 nodes, an ApplyAlpha node to apply a scale/shift, a Linear node to apply the -quantization, and another ApplyAlpha to unscale/shift. We depend on optimization steps to move the unscaling ApplyAlpha -down as needed so that we can do integer or fixed-point calculations. When the Quant is a applied to a weight, the scaling -and Linear nodes are immediately merged into the Constant. - """ from hls4ml.model.layers import Activation, BipolarQuant, Constant From 7bd4d94e4db724901b592b4c166d6de7df9993e3 Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 9 May 2025 09:15:46 +0200 Subject: [PATCH 09/13] Revert "Removed BipolarQuantConstantParameters" This reverts commit 6a74bfbc76fb48e7a18365f40b3eda3c50fa18df. --- .../optimizer/passes/bipolar_quant_opt.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index de250567c9..42927c6ddf 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -11,6 +11,36 @@ _ALSO_MATCH_PO2 = True +class BipolarQuantConstantParameters(OptimizerPass): + """Remove Constant from the Qaunt node parameters (but not input[0])""" + + def match(self, node): + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 2 + and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) + ) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from the Quant node parameters (but not input[0]) + """ + if node.get_input_node(node.inputs[1]): + scale_node = node.get_input_node(node.inputs[1]) + if isinstance(scale_node, Constant): + node.set_attr('scale', scale_node.get_attr('value')) + node.inputs[1] = '' + model.remove_node(scale_node) + + node.inputs = [inp for inp in node.inputs if inp] + if len(node.inputs) != 1: + raise RuntimeError("hls4ml only supports constant scale") + + return True + + class BipolarQuantToActivation(OptimizerPass): """ This is for the case when scale is a (positive) 1 and zeropt is 0. From 144d4273ba530e891bf9e3309166fa1acd92ec50 Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 9 May 2025 09:17:25 +0200 Subject: [PATCH 10/13] Revert "Removed bipolar_quant_constant_parameters from list of optimizations, since it was already removed." This reverts commit 89e2136f18590518b4aaa669f939e26804895e8f. --- hls4ml/model/optimizer/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 36cc45b155..8b25805fc7 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -40,6 +40,7 @@ 'fuse_quant_with_constant', 'const_quant_to_const_alpha', 'quant_to_alpha_activation_alpha', + 'bipolar_quant_constant_parameters', 'bipolar_quant_to_activation', 'fuse_bipolar_quant_with_constant', 'batch_norm_onnx_constant_parameters', From 8d6aae28cd185888a6c3c4fa36c75ac2446ee841 Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 9 May 2025 09:53:10 +0200 Subject: [PATCH 11/13] Removed onnx model form repo. Using example-models for that instead. removed irrelevant _AUTO_PO2 constant. --- example-models | 2 +- .../optimizer/passes/bipolar_quant_opt.py | 2 -- test/pytest/bnn_model_fc_1layer.onnx | Bin 2321 -> 0 bytes test/pytest/test_qonnx.py | 23 ++++++++++++++---- 4 files changed, 19 insertions(+), 8 deletions(-) delete mode 100644 test/pytest/bnn_model_fc_1layer.onnx diff --git a/example-models b/example-models index c6bb3c0686..e7a9dee394 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit c6bb3c0686d52439d8c53d7407903bf78e852562 +Subproject commit e7a9dee394b6c1f6e0eb23178d34e55f077297fe diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index 42927c6ddf..6c8ac9a37e 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -8,8 +8,6 @@ from hls4ml.model.quantizers import BinaryQuantizer from hls4ml.model.types import XnorPrecisionType -_ALSO_MATCH_PO2 = True - class BipolarQuantConstantParameters(OptimizerPass): """Remove Constant from the Qaunt node parameters (but not input[0])""" diff --git a/test/pytest/bnn_model_fc_1layer.onnx b/test/pytest/bnn_model_fc_1layer.onnx deleted file mode 100644 index 72e16fae4e1fbe62024250b285846fd2094ec111..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2321 zcmbtWO>fgM7>=8CL@(`VBPwk@mIG1_Nn2opN!zM0X;P6uAaU6xa%t2gYMsnYS-DO8 zl8Qe97yblSeg!w|toc|~+QuBLB(~q@_4D|B?5L{tAp8;q^po2xca3ACee$~wKM}}t zqXS3g*dq?6rbohAWM9uPiA-l61~jr=Oon5Jny&*gjaa|kv4dH}r?s`0JH%OD7>qG} z&oR@;rdcquRmtO=R0Bfh#(+AXy$Da?4}up7;Em&r+-)2Z2Exo=?UF9Ps7O Date: Tue, 13 May 2025 15:01:04 +0200 Subject: [PATCH 12/13] Added test for non-unit (po2) scaling factors --- .../optimizer/passes/bipolar_quant_opt.py | 38 +++++--- test/pytest/test_qonnx.py | 90 +++++++++++++++++-- 2 files changed, 107 insertions(+), 21 deletions(-) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index 6c8ac9a37e..8b0f2f5803 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -3,14 +3,16 @@ """ +import numpy as np from hls4ml.model.layers import Activation, BipolarQuant, Constant from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.quantizers import BinaryQuantizer from hls4ml.model.types import XnorPrecisionType + class BipolarQuantConstantParameters(OptimizerPass): - """Remove Constant from the Qaunt node parameters (but not input[0])""" + """Remove Constant from the BipolarQaunt node parameters (but not input[0])""" def match(self, node): is_match = ( @@ -23,7 +25,7 @@ def match(self, node): def transform(self, model, node): """ - Remove Constant from the Quant node parameters (but not input[0]) + Remove Constant from the BipolarQuant node parameters (but not input[0]) """ if node.get_input_node(node.inputs[1]): scale_node = node.get_input_node(node.inputs[1]) @@ -41,7 +43,7 @@ def transform(self, model, node): class BipolarQuantToActivation(OptimizerPass): """ - This is for the case when scale is a (positive) 1 and zeropt is 0. + This is for the case when scale is po2. It is a a 1:1 transformation of a BipolarQuant to an Activation. As an optimization, this is not called when the input is constant. """ @@ -54,17 +56,21 @@ def match(self, node): and not isinstance(node.get_input_node(node.inputs[0]), Constant) ) - # Only match if the scale is 1 and the zero-point is 0s + # Only match if the scale is po2 if is_match: # to make sure this is a quant node with inputs scale = node.get_attr('scale') - scale_unit = (scale == 1.0).all() - is_match = scale_unit + scale_unit_or_po2 = (scale == 1.0).all() + # This optimization only works if all scales are the same + if np.all(scale[0] == scale): + mantissa, _ = np.frexp(scale[0]) + scale_unit_or_po2 = mantissa == 0.5 + is_match = scale_unit_or_po2 return is_match def transform(self, model, node): """ - Change quant node to Activation + Change BipolarQuant node to Activation """ precision = XnorPrecisionType() quantizer = BinaryQuantizer(bits=1) @@ -78,15 +84,15 @@ def transform(self, model, node): new_name = f'{node.name}_act' model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) - - new_node = model.make_node(Activation, new_name, attributes, list(node.inputs[0]), list(node.outputs)) + print(f"Node {new_name} inputs: {[node.inputs[0]]}, outputs: {list(node.outputs)}.") + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], list(node.outputs)) model.replace_node(node, new_node) return True class FuseBipolarQuantWithConstant(OptimizerPass): """ - This is for the case when scale is 1 and zeropt is 0. + This is for the case when scale is po2. """ def match(self, node): @@ -99,17 +105,21 @@ def match(self, node): and isinstance(node.get_input_node(node.inputs[0]), Constant) ) - # Only match if the scale is 1 and the zero-point is 0s + # Only match if the scale is po2 if is_match: # to make sure this is a quant node with inputs scale = node.get_attr('scale') - scale_unit = (scale == 1.0).all() - is_match = scale_unit + scale_unit_or_po2 = (scale == 1.0).all() + # This optimization only works if all scales are the same + if np.all(scale[0] == scale): + mantissa, _ = np.frexp(scale[0]) + scale_unit_or_po2 = mantissa == 0.5 + is_match = scale_unit_or_po2 return is_match def transform(self, model, node): """ - Fuse Quant with Constant. + Fuse BipolarQuant with Constant. """ precision = XnorPrecisionType() quantizer = BinaryQuantizer(bits=1) diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index 6af82a332d..5752de7c80 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -1,9 +1,11 @@ import os import urllib +import copy from pathlib import Path import numpy as np import pytest +import onnx import qonnx.core.onnx_exec as oxe import qonnx.util.cleanup import qonnx.util.to_channels_last @@ -247,6 +249,74 @@ def bnn_fc_small_qonnx_model(): return model +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model_scale_nonunit(bnn_fc_small_qonnx_model): + """ + Use scale factors of 0.5 to see if that works. + This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors. + """ + + model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? + new_iscale = onnx.helper.make_tensor( + "BipolarQuant_0_param0", + 1, + [1], + [0.5] + ) + new_wscale = onnx.helper.make_tensor( + "BipolarQuant_1_param1", + 1, + [1], + [0.5] + ) + old_iscale = old_wscale = None + for init in model.graph.initializer: + if init.name == "BipolarQuant_0_param0": + old_iscale = init + elif init.name == "BipolarQuant_1_param1": + old_wscale = init + model.graph.initializer.remove(old_iscale) + model.graph.initializer.remove(old_wscale) + model.graph.initializer.append(new_iscale) + model.graph.initializer.append(new_wscale) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model_scale_nonunit2(bnn_fc_small_qonnx_model): + """ + Use po2 scale factors to see if that works. + This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors. + """ + + model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? + new_iscale = onnx.helper.make_tensor( + "BipolarQuant_0_param0", + 1, + [1], + [2] + ) + new_wscale = onnx.helper.make_tensor( + "BipolarQuant_1_param1", + 1, + [1], + [4] + ) + old_iscale = old_wscale = None + for init in model.graph.initializer: + if init.name == "BipolarQuant_0_param0": + old_iscale = init + elif init.name == "BipolarQuant_1_param1": + old_wscale = init + model.graph.initializer.remove(old_iscale) + model.graph.initializer.remove(old_wscale) + model.graph.initializer.append(new_iscale) + model.graph.initializer.append(new_wscale) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + # The actual tests @@ -446,17 +516,23 @@ def test_simple_model(model_name, io_type, backend, request): np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) +@pytest.mark.parametrize( + 'model_name', + [ + 'bnn_fc_small_qonnx_model', + 'bnn_fc_small_qonnx_model_scale_nonunit', + 'bnn_fc_small_qonnx_model_scale_nonunit2' + ], +) @pytest.mark.parametrize('backend', ['Vitis']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -def test_bnn(bnn_fc_small_qonnx_model, io_type, backend): +def test_bnn(model_name, io_type, backend, request): "Checks if a basic binarized model works correctly." - test_dir = os.path.dirname(os.path.abspath(__file__)) - qonnx_model = bnn_fc_small_qonnx_model + qonnx_model = request.getfixturevalue(model_name) config = hls4ml.utils.config.config_from_onnx_model( qonnx_model, granularity='name', backend=backend, default_precision='fixed<16,6>' ) - model_name = 'bnn_model_fc_1layer' hls_model = hls4ml.converters.convert_from_onnx_model( qonnx_model, output_dir=str(test_root_path / f'hls4mlprj_onnx_{model_name}_{io_type}_{backend}'), @@ -466,7 +542,7 @@ def test_bnn(bnn_fc_small_qonnx_model, io_type, backend): ) hls_model.compile() - X = np.array( + data_x = np.array( [ [[+1, +1, +1]], [[+1, +1, -1]], @@ -479,8 +555,8 @@ def test_bnn(bnn_fc_small_qonnx_model, io_type, backend): ], dtype=np.float32, ) - for x in X: + for x in data_x: idict = {qonnx_model.graph.input[0].name: x} y_qonnx = oxe.execute_onnx(qonnx_model, idict)[qonnx_model.graph.output[0].name] - y_hls4ml = hls_model.predict(X) + y_hls4ml = hls_model.predict(x) np.array_equal(y_qonnx.ravel(), y_hls4ml.ravel()) From 08dafdfccd0587bb8054f97c76f4c8fadf25aaa1 Mon Sep 17 00:00:00 2001 From: jvreca Date: Tue, 13 May 2025 15:07:24 +0200 Subject: [PATCH 13/13] Pre-commit fixes. --- .../optimizer/passes/bipolar_quant_opt.py | 4 +- test/pytest/test_qonnx.py | 38 ++++--------------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py index 8b0f2f5803..5fece76d50 100644 --- a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -4,13 +4,13 @@ """ import numpy as np + from hls4ml.model.layers import Activation, BipolarQuant, Constant from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.quantizers import BinaryQuantizer from hls4ml.model.types import XnorPrecisionType - class BipolarQuantConstantParameters(OptimizerPass): """Remove Constant from the BipolarQaunt node parameters (but not input[0])""" @@ -56,7 +56,7 @@ def match(self, node): and not isinstance(node.get_input_node(node.inputs[0]), Constant) ) - # Only match if the scale is po2 + # Only match if the scale is po2 if is_match: # to make sure this is a quant node with inputs scale = node.get_attr('scale') scale_unit_or_po2 = (scale == 1.0).all() diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index 5752de7c80..0a386c482d 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -1,11 +1,11 @@ +import copy import os import urllib -import copy from pathlib import Path import numpy as np -import pytest import onnx +import pytest import qonnx.core.onnx_exec as oxe import qonnx.util.cleanup import qonnx.util.to_channels_last @@ -257,18 +257,8 @@ def bnn_fc_small_qonnx_model_scale_nonunit(bnn_fc_small_qonnx_model): """ model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? - new_iscale = onnx.helper.make_tensor( - "BipolarQuant_0_param0", - 1, - [1], - [0.5] - ) - new_wscale = onnx.helper.make_tensor( - "BipolarQuant_1_param1", - 1, - [1], - [0.5] - ) + new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [0.5]) + new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [0.5]) old_iscale = old_wscale = None for init in model.graph.initializer: if init.name == "BipolarQuant_0_param0": @@ -291,18 +281,8 @@ def bnn_fc_small_qonnx_model_scale_nonunit2(bnn_fc_small_qonnx_model): """ model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? - new_iscale = onnx.helper.make_tensor( - "BipolarQuant_0_param0", - 1, - [1], - [2] - ) - new_wscale = onnx.helper.make_tensor( - "BipolarQuant_1_param1", - 1, - [1], - [4] - ) + new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [2]) + new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [4]) old_iscale = old_wscale = None for init in model.graph.initializer: if init.name == "BipolarQuant_0_param0": @@ -518,11 +498,7 @@ def test_simple_model(model_name, io_type, backend, request): @pytest.mark.parametrize( 'model_name', - [ - 'bnn_fc_small_qonnx_model', - 'bnn_fc_small_qonnx_model_scale_nonunit', - 'bnn_fc_small_qonnx_model_scale_nonunit2' - ], + ['bnn_fc_small_qonnx_model', 'bnn_fc_small_qonnx_model_scale_nonunit', 'bnn_fc_small_qonnx_model_scale_nonunit2'], ) @pytest.mark.parametrize('backend', ['Vitis']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])