diff --git a/.gitignore b/.gitignore index 22c8ff685b..8fb87927ce 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ docs/autodoc/* hls4mlprj_* *~ *.ipynb_checkpoints/ +*.bak diff --git a/example-models b/example-models index c6bb3c0686..e7a9dee394 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit c6bb3c0686d52439d8c53d7407903bf78e852562 +Subproject commit e7a9dee394b6c1f6e0eb23178d34e55f077297fe diff --git a/hls4ml/converters/onnx/core.py b/hls4ml/converters/onnx/core.py index 8ad851426d..efb5f20c4a 100644 --- a/hls4ml/converters/onnx/core.py +++ b/hls4ml/converters/onnx/core.py @@ -120,3 +120,14 @@ def parse_quant_layer(node, input_names, input_shapes, graph): layer['signed'] = bool(get_onnx_attribute(node, 'signed')) return layer + + +@onnx_handler('BipolarQuant') +def parse_bipolar_quant_layer(node, input_names, input_shapes, graph): + layer = {} + + layer['class_name'] = 'BipolarQuant' + layer['name'] = node.name + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + return layer diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index e9f596c7e5..5f7f1ba40a 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -421,6 +421,21 @@ def initialize(self): self.add_output_variable(shape, dims) +class BipolarQuant(Layer): # The QONNX quantization layer + """ + This is a QONNX quantization layer. Optimizations should convert it + before HLS is produced. + """ + + _expected_attributes = [] + + def initialize(self): + inp = self.get_input_variable(self.inputs[0]) + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + + class Reshape(Layer): _expected_attributes = [ Attribute('target_shape', value_type=typing.Sequence), @@ -1724,6 +1739,7 @@ def initialize(self): 'GarNet': GarNet, 'GarNetStack': GarNetStack, 'Quant': Quant, + 'BipolarQuant': BipolarQuant, 'ApplyAlpha': ApplyAlpha, 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index c474970448..8b25805fc7 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -40,6 +40,9 @@ 'fuse_quant_with_constant', 'const_quant_to_const_alpha', 'quant_to_alpha_activation_alpha', + 'bipolar_quant_constant_parameters', + 'bipolar_quant_to_activation', + 'fuse_bipolar_quant_with_constant', 'batch_norm_onnx_constant_parameters', 'constant_batch_norm_fusion', 'merge_two_constants', diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py new file mode 100644 index 0000000000..5fece76d50 --- /dev/null +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -0,0 +1,133 @@ +""" +This file includes optimizations related to BipolarQuant nodes. + +""" + +import numpy as np + +from hls4ml.model.layers import Activation, BipolarQuant, Constant +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.quantizers import BinaryQuantizer +from hls4ml.model.types import XnorPrecisionType + + +class BipolarQuantConstantParameters(OptimizerPass): + """Remove Constant from the BipolarQaunt node parameters (but not input[0])""" + + def match(self, node): + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 2 + and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) + ) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from the BipolarQuant node parameters (but not input[0]) + """ + if node.get_input_node(node.inputs[1]): + scale_node = node.get_input_node(node.inputs[1]) + if isinstance(scale_node, Constant): + node.set_attr('scale', scale_node.get_attr('value')) + node.inputs[1] = '' + model.remove_node(scale_node) + + node.inputs = [inp for inp in node.inputs if inp] + if len(node.inputs) != 1: + raise RuntimeError("hls4ml only supports constant scale") + + return True + + +class BipolarQuantToActivation(OptimizerPass): + """ + This is for the case when scale is po2. + It is a a 1:1 transformation of a BipolarQuant to an Activation. + As an optimization, this is not called when the input is constant. + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and not isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is po2 + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + scale_unit_or_po2 = (scale == 1.0).all() + # This optimization only works if all scales are the same + if np.all(scale[0] == scale): + mantissa, _ = np.frexp(scale[0]) + scale_unit_or_po2 = mantissa == 0.5 + is_match = scale_unit_or_po2 + + return is_match + + def transform(self, model, node): + """ + Change BipolarQuant node to Activation + """ + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + attributes = {'activation': 'linear', 'quantizer': quantizer} + + # update the configuration + config = model.config.get_layer_config(node) + prec_config = config.setdefault('Precision', {}) + prec_config['result'] = str(precision) + new_name = f'{node.name}_act' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + print(f"Node {new_name} inputs: {[node.inputs[0]]}, outputs: {list(node.outputs)}.") + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], list(node.outputs)) + model.replace_node(node, new_node) + return True + + +class FuseBipolarQuantWithConstant(OptimizerPass): + """ + This is for the case when scale is po2. + """ + + def match(self, node): + + # only matches after the other inputs are already folded + # and scale is unit + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is po2 + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + scale_unit_or_po2 = (scale == 1.0).all() + # This optimization only works if all scales are the same + if np.all(scale[0] == scale): + mantissa, _ = np.frexp(scale[0]) + scale_unit_or_po2 = mantissa == 0.5 + is_match = scale_unit_or_po2 + + return is_match + + def transform(self, model, node): + """ + Fuse BipolarQuant with Constant. + """ + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + const_node = node.get_input_node(node.inputs[0]) + const_node.set_attr('quantizer', quantizer) + const_node.get_output_variable().type.precision = precision + + # remove the Quant node + model.remove_node(node) + return True diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index bfa6e0a49c..0a386c482d 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -1,8 +1,10 @@ +import copy import os import urllib from pathlib import Path import numpy as np +import onnx import pytest import qonnx.core.onnx_exec as oxe import qonnx.util.cleanup @@ -12,6 +14,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean from qonnx.transformation.gemm_to_matmul import GemmToMatMul +from qonnx.util.cleanup import cleanup_model import hls4ml @@ -231,6 +234,69 @@ def conv2d_small_mp_keras_model(): return model +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model(): + """ + Load a small binarized model of a single fully connected layer. + """ + dl_file = str(example_model_path / "onnx/bnn_model_fc_1layer.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = cleanup_model(model) + model = model.transform(GemmToMatMul()) # ishape = (1, 3) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model_scale_nonunit(bnn_fc_small_qonnx_model): + """ + Use scale factors of 0.5 to see if that works. + This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors. + """ + + model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? + new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [0.5]) + new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [0.5]) + old_iscale = old_wscale = None + for init in model.graph.initializer: + if init.name == "BipolarQuant_0_param0": + old_iscale = init + elif init.name == "BipolarQuant_1_param1": + old_wscale = init + model.graph.initializer.remove(old_iscale) + model.graph.initializer.remove(old_wscale) + model.graph.initializer.append(new_iscale) + model.graph.initializer.append(new_wscale) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model_scale_nonunit2(bnn_fc_small_qonnx_model): + """ + Use po2 scale factors to see if that works. + This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors. + """ + + model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? + new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [2]) + new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [4]) + old_iscale = old_wscale = None + for init in model.graph.initializer: + if init.name == "BipolarQuant_0_param0": + old_iscale = init + elif init.name == "BipolarQuant_1_param1": + old_wscale = init + model.graph.initializer.remove(old_iscale) + model.graph.initializer.remove(old_wscale) + model.graph.initializer.append(new_iscale) + model.graph.initializer.append(new_wscale) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + # The actual tests @@ -428,3 +494,45 @@ def test_simple_model(model_name, io_type, backend, request): y_hls4ml = hls_model.predict(X) np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +@pytest.mark.parametrize( + 'model_name', + ['bnn_fc_small_qonnx_model', 'bnn_fc_small_qonnx_model_scale_nonunit', 'bnn_fc_small_qonnx_model_scale_nonunit2'], +) +@pytest.mark.parametrize('backend', ['Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_bnn(model_name, io_type, backend, request): + "Checks if a basic binarized model works correctly." + qonnx_model = request.getfixturevalue(model_name) + + config = hls4ml.utils.config.config_from_onnx_model( + qonnx_model, granularity='name', backend=backend, default_precision='fixed<16,6>' + ) + hls_model = hls4ml.converters.convert_from_onnx_model( + qonnx_model, + output_dir=str(test_root_path / f'hls4mlprj_onnx_{model_name}_{io_type}_{backend}'), + io_type=io_type, + backend=backend, + hls_config=config, + ) + hls_model.compile() + + data_x = np.array( + [ + [[+1, +1, +1]], + [[+1, +1, -1]], + [[+1, -1, +1]], + [[-1, -1, -1]], + [[-1, +1, +1]], + [[-1, +1, -1]], + [[-1, -1, +1]], + [[-1, -1, -1]], + ], + dtype=np.float32, + ) + for x in data_x: + idict = {qonnx_model.graph.input[0].name: x} + y_qonnx = oxe.execute_onnx(qonnx_model, idict)[qonnx_model.graph.output[0].name] + y_hls4ml = hls_model.predict(x) + np.array_equal(y_qonnx.ravel(), y_hls4ml.ravel())