From 9fec426acdf5935d67f8fbdf6d7d05f7e7f2b114 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 21 Jun 2024 16:40:31 +0200 Subject: [PATCH 01/11] add type check for input name of resolve_datatype() --- src/qonnx/core/datatype.py | 4 ++++ tests/core/test_datatypes.py | 39 ++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index f37d4eea..4bdffdf9 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -376,6 +376,10 @@ def get_canonical_name(self): def resolve_datatype(name): + + if not isinstance(name, str): + raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'") + _special_types = { "BINARY": IntType(1, False), "BIPOLAR": BipolarType(), diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index 1bd0fece..efd590b3 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -29,6 +29,7 @@ import numpy as np from qonnx.core.datatype import DataType +from qonnx.core.datatype import resolve_datatype def test_datatypes(): @@ -97,3 +98,41 @@ def test_smallest_possible(): assert DataType.get_smallest_possible(-1) == DataType["BIPOLAR"] assert DataType.get_smallest_possible(-3) == DataType["INT3"] assert DataType.get_smallest_possible(-3.2) == DataType["FLOAT32"] + + +def test_resolve_datatype(): + assert resolve_datatype("BIPOLAR") + assert resolve_datatype("BINARY") + assert resolve_datatype("TERNARY") + assert resolve_datatype("UINT2") + assert resolve_datatype("UINT3") + assert resolve_datatype("UINT4") + assert resolve_datatype("UINT8") + assert resolve_datatype("UINT16") + assert resolve_datatype("UINT32") + assert resolve_datatype("INT2") + assert resolve_datatype("INT3") + assert resolve_datatype("INT4") + assert resolve_datatype("INT8") + assert resolve_datatype("INT16") + assert resolve_datatype("INT32") + assert resolve_datatype("BINARY") + assert resolve_datatype("FLOAT32") + + +def test_input_type_error(): + # test with invalid input to check if the TypeError works + try: + resolve_datatype(123) # This should raise a TypeError + except TypeError as e: + pass + else: + print("Test with invalid input failed: No TypeError was raised.") + + # test with invalid input to check if the TypeError works + try: + resolve_datatype(1.23) # This should raise a TypeError + except TypeError as e: + pass + else: + print("Test with invalid input failed: No TypeError was raised.") From af8249cb529401e74dbf9d5124a10742390b6ded Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 21 Jun 2024 16:52:37 +0200 Subject: [PATCH 02/11] improve test_input_type_error --- tests/core/test_datatypes.py | 42 +++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index efd590b3..e98a885f 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -116,23 +116,35 @@ def test_resolve_datatype(): assert resolve_datatype("INT8") assert resolve_datatype("INT16") assert resolve_datatype("INT32") - assert resolve_datatype("BINARY") assert resolve_datatype("FLOAT32") def test_input_type_error(): - # test with invalid input to check if the TypeError works - try: - resolve_datatype(123) # This should raise a TypeError - except TypeError as e: - pass - else: - print("Test with invalid input failed: No TypeError was raised.") - # test with invalid input to check if the TypeError works - try: - resolve_datatype(1.23) # This should raise a TypeError - except TypeError as e: - pass - else: - print("Test with invalid input failed: No TypeError was raised.") + def test_resolve_datatype(input): + # test with invalid input to check if the TypeError works + try: + resolve_datatype(input) # This should raise a TypeError + except TypeError as e: + pass + else: + print("Test with invalid input failed: No TypeError was raised.") + + test_resolve_datatype(123) + test_resolve_datatype(1.23) + test_resolve_datatype(DataType["BIPOLAR"]) + test_resolve_datatype(DataType["BINARY"]) + test_resolve_datatype(DataType["TERNARY"]) + test_resolve_datatype(DataType["UINT2"]) + test_resolve_datatype(DataType["UINT3"]) + test_resolve_datatype(DataType["UINT4"]) + test_resolve_datatype(DataType["UINT8"]) + test_resolve_datatype(DataType["UINT16"]) + test_resolve_datatype(DataType["UINT32"]) + test_resolve_datatype(DataType["INT2"]) + test_resolve_datatype(DataType["INT3"]) + test_resolve_datatype(DataType["INT4"]) + test_resolve_datatype(DataType["INT8"]) + test_resolve_datatype(DataType["INT16"]) + test_resolve_datatype(DataType["INT32"]) + test_resolve_datatype(DataType["FLOAT32"]) From ccfd48e4b28358e846884b280424d246a578d7c0 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Wed, 26 Jun 2024 23:31:51 +0200 Subject: [PATCH 03/11] start with adding HGQ support for QONNX --- src/qonnx/converters/keras.py | 91 ++++++++++--- src/qonnx/converters/qkeras/qlayers.py | 2 +- src/qonnx/custom_op/general/FixedPoint.py | 150 ++++++++++++++++++++++ src/qonnx/custom_op/general/__init__.py | 3 + src/qonnx/util/cleanup.py | 2 +- 5 files changed, 230 insertions(+), 18 deletions(-) create mode 100644 src/qonnx/custom_op/general/FixedPoint.py diff --git a/src/qonnx/converters/keras.py b/src/qonnx/converters/keras.py index 5b9e7e09..7e434428 100644 --- a/src/qonnx/converters/keras.py +++ b/src/qonnx/converters/keras.py @@ -8,13 +8,18 @@ from qonnx.util.cleanup import cleanup_model from .qkeras.onnx import get_qkeras_onnx_handlers -from .qkeras.qlayers import extract_quantizers_from_layer +from .qkeras.qlayers import extract_quantizers_from_qkeras_layer +from .HGQ.onnx import get_hgq_onnx_handlers +from .HGQ.hgqlayers import HGQ_LAYERS, extract_quantizers_from_hgq_layer + +# NOTE: thats a list for qkeras & HGQ layers _unsupported_layers = [ # These require some extra work "QBatchNormalization", "QConv2DBatchnorm", "QDepthwiseConv2DBatchnorm", + # TODO: add HGQ layers ] # Skip remove_identity optimizer @@ -102,6 +107,30 @@ def iterate_model(model): return iterate_model(model) +def _is_hgq_model(model): + """Check if the model has any HGQ layers, so we can handle the HGQ layers separately + + Args: + model: the model we want to convert + + Returns: + True if the model contains any HGQ layer + """ + + def iterate_model(model): + for layer in model.layers: + if isinstance(layer, tf.keras.Model): + found_qkeras = iterate_model(layer) + if found_qkeras: + return True + elif layer.__class__.__name__ in HGQ_LAYERS: + return True + + return False + + return iterate_model(model) + + def _check_supported_layers(model): """Check if all the layers in the model are supported for conversion @@ -117,7 +146,7 @@ def iterate_model(model): if isinstance(layer, tf.keras.Model): iterate_model(layer) elif layer.__class__.__name__ in _unsupported_layers: - raise Exception("Currently unsupported layer found in QKeras model: {}".format(layer.__class__.__name__)) + raise Exception("Currently unsupported layer found in model: {}".format(layer.__class__.__name__)) iterate_model(model) @@ -134,7 +163,7 @@ def _strip_qkeras_model(model): quantizers = OrderedDict() def extract_quantizers(layer): - keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_layer(layer) + keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_qkeras_layer(layer) if layer_quantizers: layer_quantizers = { k: None if v == "None" else v for k, v in layer_quantizers.items() @@ -150,18 +179,43 @@ def extract_quantizers(layer): stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers) stripped_model.set_weights(model.get_weights()) + return stripped_model, quantizers -# tests run without this function -def _convert_quantizers_to_nodes(onnx_model, quantizers_dict): - for node_name, quantizers in quantizers_dict.items(): - print(node_name, quantizers) +def _strip_hgq_model(model): + """Strip a HGQ model to obtain the keras model and obtain the quant nodes. - for n in onnx_model.graph.node: - print(n) + Args: + model: the proxy tf.keras model from HGQ - return onnx_model.model + Returns: + The stripped model, and the quantizers in a dictionary format + """ + quantizers = OrderedDict() + + def extract_quantizers(layer): + keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_hgq_layer(layer, model) + if layer_quantizers: + layer_quantizers["input"] = layer.input.name + quantizers[layer_quantizers["name"]] = layer_quantizers + + layer_class = tf.keras.layers.__dict__.get(keras_cls_name, None) + if layer_class is None: + raise Exception("Cannot create Keras layer from QKeras class {}".format(keras_cls_name)) + + return layer_class.from_config(layer_cfg) + + stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers) + + for layer in model.layers: + if layer.__class__.__name__ in HGQ_LAYERS: + # NOTE: the FixedPointQuantizer does not have weights it only has + # self.keep_negative, self.bits and self.integers which we extract later + # from the quantizer + continue + stripped_model.get_layer(layer.name).set_weights(layer.get_weights()) + return stripped_model, quantizers def from_keras( @@ -203,23 +257,28 @@ def from_keras( assert not large_model # TODO for now, let's focus only on models that don't store tensors externally + _check_supported_layers(model) + keras_op_handlers = {} if _is_qkeras_model(model): - _check_supported_layers(model) keras_model, quantizers = _strip_qkeras_model(model) + keras_op_handlers.update(get_qkeras_onnx_handlers(quantizers)) + elif _is_hgq_model(model): + keras_model = model + keras_model, quantizers = _strip_hgq_model(model) + keras_op_handlers.update(get_hgq_onnx_handlers(quantizers)) else: keras_model, quantizers = model, {} - qkeras_op_handlers = get_qkeras_onnx_handlers(quantizers) - + keras_model.summary() if custom_op_handlers is not None: - qkeras_op_handlers.update(custom_op_handlers) + keras_op_handlers.update(custom_op_handlers) model_proto, external_storage = tf2onnx.convert.from_keras( keras_model, input_signature=input_signature, opset=opset, custom_ops=custom_ops, - custom_op_handlers=qkeras_op_handlers, + custom_op_handlers=keras_op_handlers, custom_rewriter=custom_rewriter, inputs_as_nchw=inputs_as_nchw, extra_opset=extra_opset, @@ -242,7 +301,7 @@ def from_keras( onnx_model.set_tensor_shape(onnx_model.graph.output[0].name, out_shape) # Set all Quant output tensors to float32 datatype, otherwise they are undefined and crash ONNX execution - qonnx_domain_ops = ["Quant", "Trunc", "BipolarQuant"] + qonnx_domain_ops = ["FixedPoint", "Quant", "Trunc", "BipolarQuant"] for q_op_type in qonnx_domain_ops: quant_nodes = onnx_model.get_nodes_by_op_type(q_op_type) q_node_outputs = [qn.output[0] for qn in quant_nodes] diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py index fdbca71b..932471dc 100644 --- a/src/qonnx/converters/qkeras/qlayers.py +++ b/src/qonnx/converters/qkeras/qlayers.py @@ -5,7 +5,7 @@ from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS -def extract_quantizers_from_layer(layer): +def extract_quantizers_from_qkeras_layer(layer): """ """ layer_class = layer.__class__.__name__ if layer_class in QKERAS_LAYERS: diff --git a/src/qonnx/custom_op/general/FixedPoint.py b/src/qonnx/custom_op/general/FixedPoint.py new file mode 100644 index 00000000..cf1e84ba --- /dev/null +++ b/src/qonnx/custom_op/general/FixedPoint.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 Xilinx, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +from onnx import TensorProto, helper + +from HGQ.proxy.fixed_point_quantizer import gfixed_quantizer + +from qonnx.core.datatype import DataType +from qonnx.custom_op.base import CustomOp + + +class FixedPoint(CustomOp): + """Generic quantization operation for HGQ FixedPoint layer to QONNX. + + Takes four inputs: + + - input tensor to quantize + - the integer_bits + - the keep_negative + - the bits + + The output is a tensor of the same shape as the input tensor, with quantized + values. + """ + + def get_nodeattr_types(self): + return { + # The rounding mode, which is used for the quant function + # (e.g. "TRN": Truncate towards negative infinity. Fast. Preferred when possible.) + "RND": ("s", True, "TRN"), + # Saturate between highest and lowest representable values. + # (e.g. "WRAP" Wrap around.) + "SAT": ("s", True, "WRAP"), + } + + def make_shape_compatible_op(self, model): + """Returns a standard ONNX op which is compatible with this CustomOp + for performing shape inference.""" + return helper.make_node( + "Cast", + inputs=[self.onnx_node.input[0]], + outputs=[self.onnx_node.output[0]], + to=int(TensorProto.FLOAT), + ) + + def get_integer_datatype(self, model): + signed = self.get_nodeattr("signed") + bit_width = model.get_initializer(self.onnx_node.input[3]) + bit_width = int(bit_width) + if bit_width == 1: + if signed: + finn_dt = DataType["BIPOLAR"] + else: + finn_dt = DataType["BINARY"] + else: + if signed: + finn_dt = DataType["INT" + str(bit_width)] + else: + finn_dt = DataType["UINT" + str(bit_width)] + return finn_dt + + def get_scaled_integer_datatype(self, model): + bit_width = model.get_initializer(self.onnx_node.input[3]) + bit_width = int(bit_width) + finn_dt = DataType["SCALEDINT<%d>" % (bit_width)] + return finn_dt + + def get_output_dtype(self, model): + node = self.onnx_node + # scale, zero-point and bitwidth must be read from initializers + integer_bits = model.get_initializer(node.input[1]) + keep_negative = model.get_initializer(node.input[2]) + bits = model.get_initializer(node.input[3]) + assert integer_bits is not None, "Found unspecified scale for Quant node: " + str(node) + assert keep_negative is not None, "Found unspecified zero point for Quant node: " + str(node) + assert bits is not None, "Found unspecified bitwidth for Quant node: " + str(node) + # extract the bitwidth (assume scalar) + assert bitwidth.ndim == 0, "Bitwidth must be scalar for Quant node: " + str(node) + bitwidth = bitwidth.item() + assert int(bitwidth) == bitwidth, "Bitwidth must be integer for Quant node: " + str(node) + bitwidth = int(bitwidth) + # determine the FINN DataType + unit_scale = np.all(scale == 1.0) + zero_zeropt = np.all(zeropt == 0.0) + assert zero_zeropt, "Only zero_point=0 Quant nodes supported for now" + if unit_scale and zero_zeropt: + finn_dt = self.get_integer_datatype(model) + else: + finn_dt = self.get_scaled_integer_datatype(model) + return finn_dt + + def infer_node_datatype(self, model): + try: + finn_dt = self.get_output_dtype(model) + except AssertionError: + finn_dt = DataType["FLOAT32"] + node = self.onnx_node + model.set_tensor_datatype(node.output[0], finn_dt) + + def execute_node(self, context, graph): + node = self.onnx_node + # save inputs + inp_tensor = context[node.input[0]] + integer_bits = context[node.input[1]] + keep_negative = context[node.input[2]] + bits = context[node.input[3]] + # save attributes + RND = self.get_nodeattr("RND") + SAT = self.get_nodeattr("SAT") + # calculate output + ret = gfixed_quantizer( + inp_tensor, keep_negative, bits, integer_bits, RND=RND, SAT=SAT) + # ensure output is ndarray (even if 0d) + # since numpy silently flattens 0d arrays to scalars + # more: https://github.com/numpy/numpy/issues/13105 + if not isinstance(ret, np.ndarray): + ret = np.asarray(ret, dtype=np.float32) + if not ret.dtype == np.float32: + ret = ret.astype(np.float32) + # set context according to output name + context[node.output[0]] = ret + + def verify_node(self): + pass diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index a656d4a5..390ce76c 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -33,6 +33,7 @@ from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quant import Quant +from qonnx.custom_op.general.FixedPoint import FixedPoint from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul @@ -49,3 +50,5 @@ custom_op["Quant"] = Quant custom_op["Trunc"] = Trunc custom_op["BipolarQuant"] = BipolarQuant +custom_op["FixedPoint"] = FixedPoint + diff --git a/src/qonnx/util/cleanup.py b/src/qonnx/util/cleanup.py index 933f729d..12974fdb 100644 --- a/src/qonnx/util/cleanup.py +++ b/src/qonnx/util/cleanup.py @@ -52,7 +52,7 @@ def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_c """ # temporary fix for QONNX op domains - qonnx_domain_ops = ["Quant", "Trunc", "BipolarQuant"] + qonnx_domain_ops = ["FixedPoint", "Quant", "Trunc", "BipolarQuant"] for q_op_type in qonnx_domain_ops: qnt_nodes = model.get_nodes_by_op_type(q_op_type) for qnt_node in qnt_nodes: From a3fb522be994437861615ff5a7dbcbec2c9ca2f6 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Wed, 26 Jun 2024 23:58:33 +0200 Subject: [PATCH 04/11] add HGQ folder --- src/qonnx/converters/HGQ/__init__.py | 0 src/qonnx/converters/HGQ/hgqlayers.py | 44 ++++++++++++++++ src/qonnx/converters/HGQ/onnx.py | 70 ++++++++++++++++++++++++++ src/qonnx/converters/HGQ/quantizers.py | 27 ++++++++++ 4 files changed, 141 insertions(+) create mode 100644 src/qonnx/converters/HGQ/__init__.py create mode 100644 src/qonnx/converters/HGQ/hgqlayers.py create mode 100644 src/qonnx/converters/HGQ/onnx.py create mode 100644 src/qonnx/converters/HGQ/quantizers.py diff --git a/src/qonnx/converters/HGQ/__init__.py b/src/qonnx/converters/HGQ/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qonnx/converters/HGQ/hgqlayers.py b/src/qonnx/converters/HGQ/hgqlayers.py new file mode 100644 index 00000000..f481ef6c --- /dev/null +++ b/src/qonnx/converters/HGQ/hgqlayers.py @@ -0,0 +1,44 @@ +import HGQ +import qkeras +from qkeras.quantizers import BaseQuantizer +import tensorflow as tf + + +# TODO: this should be implemented in HGQ so we can import it here +# from HGQ.utils import REGISTERED_LAYERS as HGQ_LAYERS +HGQ_LAYERS = ["FixedPointQuantizer"] + +def extract_quantizers_from_hgq_layer(layer, model): + """ """ + layer_class = layer.__class__.__name__ + if layer_class in HGQ_LAYERS: + handler = handler_map.get(layer_class, None) + if handler: + return handler_map[layer_class](layer, model) + else: + return layer_class, layer.get_config(), None + else: + return layer_class, layer.get_config(), None + + +def extract_FixedPointQuantizer(layer, model): + + quantizers = layer.get_config() + + if "overrides" not in quantizers: + # TODO: add support for FixedPointQuantizer which dont override + raise ValueError(f"Not supported: FixedPointQuantizer has no layers to override") + + quantizers["inputs"] = { + "keep_negative": layer.keep_negative.numpy(), + "bits": layer.bits.numpy(), + "integer_bits": layer.integers.numpy(), + } + keras_config = {'name': quantizers["name"], 'dtype': 'float32'} + + return "Identity", keras_config, quantizers + + +handler_map = { + "FixedPointQuantizer": extract_FixedPointQuantizer +} diff --git a/src/qonnx/converters/HGQ/onnx.py b/src/qonnx/converters/HGQ/onnx.py new file mode 100644 index 00000000..c0137de7 --- /dev/null +++ b/src/qonnx/converters/HGQ/onnx.py @@ -0,0 +1,70 @@ +import numpy as np + +from .quantizers import get_quant_params + + +def get_hgq_onnx_handlers(all_quantizers): + """Returns the handlers for each kind of layer + + Args: + all_quantizers: All the quantizers of the model in dictionary format *check + + Returns: + Dictionary containing the handler information for every type of layer + """ + return { + # NOTE: we replace the StatefulPartitionedCall layers with Identity layers + # after them we are adding now FixedPoint layers for the quantitzation + "Identity": ( + FixedPoint, ["FixedPoint", all_quantizers] + ), + } + + +def _extract_node_name(onnx_node, keras_quantizers): + """ + + Args: + onnx_node: The onnx node to get the information from + keras_quantizers: The dictionary of all the keras quantizers + + """ + onnx_name = onnx_node.name + print(onnx_node) + keras_names = keras_quantizers.keys() + print(keras_names, onnx_name) + for keras_name in keras_names: + match = "/" + keras_name + "/" + if match in onnx_name: + return keras_name + + return None + + +def FixedPoint(ctx, node, name, args): + all_quantizers = args[0] + keras_name = _extract_node_name(node, all_quantizers) + if not keras_name: + return # Not found in quantizers, nothing to do + quantizers = all_quantizers[keras_name] + # if we have overrides we are converting a FixedPointQuantizer from HGQ + if quantizers.get("overrides"): + quant_params = get_quant_params(None, quantizers) + attr = quant_params["attributes"] + input_nodes = [node.output[0]] + print(node.input[0]) + for key in quantizers["inputs"].keys(): + name = f"{node.name}_FixedPointQuantizer_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + quant_fixed_node = ctx.make_node( + "FixedPoint", + input_nodes, + dtypes=None, # TODO: we have to get the type here + name=node.name + "_FixedPoint_quantizer", + attr=attr, + domain="qonnx", + ) + ctx.insert_node_on_output(quant_fixed_node, node.output[0]) + ctx.set_shape(quant_fixed_node.output[0], ctx.get_shape(node.output[0])) diff --git a/src/qonnx/converters/HGQ/quantizers.py b/src/qonnx/converters/HGQ/quantizers.py new file mode 100644 index 00000000..21e12b58 --- /dev/null +++ b/src/qonnx/converters/HGQ/quantizers.py @@ -0,0 +1,27 @@ + + +def get_quant_params(tensor, hgq_quantizer): + + return handler_map[hgq_quantizer["keras_layer"]](tensor, hgq_quantizer) + + +def convert_quantized_bits(tensor, quantizer): + + settings = { + "attributes": { + "RND": quantizer["RND"], + "SAT": quantizer["SAT"], + }, + "inputs": { + "integer_bits": quantizer["inputs"]["integers"], + "keep_negative": quantizer["inputs"]["keep_negative"], + "bits": quantizer["inputs"]["bits"], + }, + } + + return settings + + +handler_map = { + "Identity": convert_quantized_bits, +} From 5917b3c50a89474f5a276c6470401ed2dc918222 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Thu, 27 Jun 2024 15:20:36 +0200 Subject: [PATCH 05/11] first prototype for HGQ converter --- notebooks/4_hgq_to_qonnx.ipynb | 414 +++++++++++++++++++++++++++++++++ tests/HGQ/test_hgq.py | 46 ++++ 2 files changed, 460 insertions(+) create mode 100644 notebooks/4_hgq_to_qonnx.ipynb create mode 100644 tests/HGQ/test_hgq.py diff --git a/notebooks/4_hgq_to_qonnx.ipynb b/notebooks/4_hgq_to_qonnx.ipynb new file mode 100644 index 00000000..022b9207 --- /dev/null +++ b/notebooks/4_hgq_to_qonnx.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import keras\n", + "import numpy as np\n", + "from HGQ.layers import HDense, HConv2D, PMaxPooling2D, PFlatten, PReshape, HQuantize\n", + "from HGQ import ResetMinMax, FreeBOPs\n", + "import HGQ\n", + "from HGQ import trace_minmax, to_proxy_model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", + "x_train = x_train.astype('float32') / 255\n", + "x_test = x_test.astype('float32') / 255" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model = keras.models.Sequential([\n", + " HQuantize(beta=3e-5),\n", + " PReshape((28, 28, 1)),\n", + " PMaxPooling2D((2, 2)),\n", + " HConv2D(1, (3, 3), activation='relu', beta=3e-5, parallel_factor=144),\n", + " PMaxPooling2D((2, 2)),\n", + " HConv2D(1, (3, 3), activation='relu', beta=3e-5, parallel_factor=16),\n", + " PMaxPooling2D((2, 2)),\n", + " PFlatten(),\n", + " HDense(10, beta=3e-5)\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.\n", + "2024-06-27 15:06:03.457498: I external/local_xla/xla/service/service.cc:168] XLA service 0x177343d00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n", + "2024-06-27 15:06:03.457520: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1719493563.465704 1 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n", + "2024-06-27 15:06:03.465881: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-06-27 15:06:05.925208: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", + "2024-06-27 15:06:06.118721: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/4 [======>.......................] - ETA: 8s - loss: 2.6829 - accuracy: 0.1250" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-27 15:06:06.284652: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-06-27 15:06:06.284716: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-06-27 15:06:06.285390: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4/4 [==============================] - 3s 106ms/step - loss: 2.6864 - accuracy: 0.1300 - bops: 13247.0000\n", + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " h_quantize (HQuantize) (None, 28, 28) 786 \n", + " \n", + " p_reshape (PReshape) (None, 28, 28, 1) 0 \n", + " \n", + " p_max_pooling2d (PMaxPooli (None, 14, 14, 1) 0 \n", + " ng2D) \n", + " \n", + " h_conv2d (HConv2D) (None, 12, 12, 1) 165 \n", + " \n", + " p_max_pooling2d_1 (PMaxPoo (None, 6, 6, 1) 0 \n", + " ling2D) \n", + " \n", + " h_conv2d_1 (HConv2D) (None, 4, 4, 1) 37 \n", + " \n", + " p_max_pooling2d_2 (PMaxPoo (None, 2, 2, 1) 0 \n", + " ling2D) \n", + " \n", + " p_flatten (PFlatten) (None, 4) 0 \n", + " \n", + " h_dense (HDense) (None, 10) 102 \n", + " \n", + "=================================================================\n", + "Total params: 1090 (4.26 KB)\n", + "Trainable params: 1082 (4.23 KB)\n", + "Non-trainable params: 8 (32.00 Byte)\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "opt = keras.optimizers.Adam(learning_rate=0.001)\n", + "loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", + "model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])\n", + "callbacks = [ResetMinMax(), FreeBOPs()]\n", + "\n", + "model.fit(x_train[:100], y_train[:100], epochs=1, batch_size=32, callbacks=callbacks)\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "h_quantize: 0.0\n", + "h_conv2d: 5920.0\n", + "h_conv2d_1: 894.0\n", + "h_dense: 219.0\n", + "Model: \"model_3\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_1 (InputLayer) [(None, 28, 28)] 0 \n", + " \n", + " h_quantize (FixedPointQuan (None, 28, 28) 2352 \n", + " tizer) \n", + " \n", + " p_reshape (Reshape) (None, 28, 28, 1) 0 \n", + " \n", + " p_max_pooling2d (MaxPoolin (None, 14, 14, 1) 0 \n", + " g2D) \n", + " \n", + " h_conv2d (Conv2D) (None, 12, 12, 1) 10 \n", + " \n", + " h_conv2d_quantizer (FixedP (None, 12, 12, 1) 432 \n", + " ointQuantizer) \n", + " \n", + " p_max_pooling2d_1 (MaxPool (None, 6, 6, 1) 0 \n", + " ing2D) \n", + " \n", + " h_conv2d_1 (Conv2D) (None, 4, 4, 1) 10 \n", + " \n", + " h_conv2d_1_quantizer (Fixe (None, 4, 4, 1) 48 \n", + " dPointQuantizer) \n", + " \n", + " p_max_pooling2d_2 (MaxPool (None, 2, 2, 1) 0 \n", + " ing2D) \n", + " \n", + " p_flatten (Flatten) (None, 4) 0 \n", + " \n", + " h_dense (Dense) (None, 10) 50 \n", + " \n", + " h_dense_quantizer (FixedPo (None, 10) 30 \n", + " intQuantizer) \n", + " \n", + "=================================================================\n", + "Total params: 2932 (3.07 KB)\n", + "Trainable params: 70 (280.00 Byte)\n", + "Non-trainable params: 2862 (2.79 KB)\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "trace_minmax(model, x_train, cover_factor=1.0)\n", + "proxy = to_proxy_model(model, aggressive=True)\n", + "proxy.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_3\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_1 (InputLayer) [(None, 28, 28)] 0 \n", + " \n", + " h_quantize (FixedPointQuan (None, 28, 28) 2352 \n", + " tizer) \n", + " \n", + " p_reshape (Reshape) (None, 28, 28, 1) 0 \n", + " \n", + " p_max_pooling2d (MaxPoolin (None, 14, 14, 1) 0 \n", + " g2D) \n", + " \n", + " h_conv2d (Conv2D) (None, 12, 12, 1) 10 \n", + " \n", + " h_conv2d_quantizer (FixedP (None, 12, 12, 1) 432 \n", + " ointQuantizer) \n", + " \n", + " p_max_pooling2d_1 (MaxPool (None, 6, 6, 1) 0 \n", + " ing2D) \n", + " \n", + " h_conv2d_1 (Conv2D) (None, 4, 4, 1) 10 \n", + " \n", + " h_conv2d_1_quantizer (Fixe (None, 4, 4, 1) 48 \n", + " dPointQuantizer) \n", + " \n", + " p_max_pooling2d_2 (MaxPool (None, 2, 2, 1) 0 \n", + " ing2D) \n", + " \n", + " p_flatten (Flatten) (None, 4) 0 \n", + " \n", + " h_dense (Dense) (None, 10) 50 \n", + " \n", + " h_dense_quantizer (FixedPo (None, 10) 30 \n", + " intQuantizer) \n", + " \n", + "=================================================================\n", + "Total params: 2932 (3.07 KB)\n", + "Trainable params: 70 (280.00 Byte)\n", + "Non-trainable params: 2862 (2.79 KB)\n", + "_________________________________________________________________\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/anaconda3/lib/python3.11/site-packages/keras/src/initializers/__init__.py:144: UserWarning: The `keras.initializers.serialize()` API should only be used for objects of type `keras.initializers.Initializer`. Found an instance of type , which may lead to improper serialization.\n", + " warnings.warn(\n", + "/opt/anaconda3/lib/python3.11/site-packages/keras/src/initializers/__init__.py:144: UserWarning: The `keras.initializers.serialize()` API should only be used for objects of type `keras.initializers.Initializer`. Found an instance of type , which may lead to improper serialization.\n", + " warnings.warn(\n", + "2024-06-27 15:06:14.062161: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session\n", + "2024-06-27 15:06:14.171857: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session\n" + ] + } + ], + "source": [ + "from qonnx.converters.keras import from_keras\n", + "import onnx\n", + "onnx_model, external_storage = from_keras(proxy, \"test_qkeras_conversion\", opset=9)\n", + "onnx.save(onnx_model, '/tmp/hgq.onnx')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import netron\n", + "import os\n", + "from IPython.display import IFrame\n", + "\n", + "def showInNetron(model_filename: str, localhost_url: str = None, port: int = None):\n", + " \"\"\"Shows a ONNX model file in the Jupyter Notebook using Netron.\n", + "\n", + " :param model_filename: The path to the ONNX model file.\n", + " :type model_filename: str\n", + "\n", + " :param localhost_url: The IP address used by the Jupyter IFrame to show the model.\n", + " Defaults to localhost.\n", + " :type localhost_url: str, optional\n", + "\n", + " :param port: The port number used by Netron and the Jupyter IFrame to show\n", + " the ONNX model. Defaults to 8088.\n", + " :type port: int, optional\n", + "\n", + " :return: The IFrame displaying the ONNX model.\n", + " :rtype: IPython.lib.display.IFrame\n", + " \"\"\"\n", + " try:\n", + " port = port or int(os.getenv(\"NETRON_PORT\", default=\"8088\"))\n", + " except ValueError:\n", + " port = 8088\n", + " localhost_url = localhost_url or os.getenv(\"LOCALHOST_URL\", default=\"localhost\")\n", + " netron.start(model_filename, address=(\"0.0.0.0\", port), browse=False)\n", + " return IFrame(src=f\"http://{localhost_url}:{port}/\", width=\"100%\", height=400)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Serving '/tmp/hgq.onnx' at http://0.0.0.0:8088\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "showInNetron('/tmp/hgq.onnx')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Batch [100/100]: running: 100%|██████████| 100/100 [00:01<00:00, 73.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4/4 [==============================] - 0s 23ms/step\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from qonnx.util.exec_qonnx import exec_qonnx\n", + "np.save(\"/tmp/x_test.npy\", x_test[:100])\n", + "qonnx_out = exec_qonnx('/tmp/hgq.onnx', \"/tmp/x_test.npy\")\n", + "hgq_out = proxy.predict(x_test[:100])\n", + "np.isclose(qonnx_out, hgq_out).all()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "vscode": { + "interpreter": { + "hash": "0502196cae5450340bff0232db0d443756691d3235661ea220d02e5c6b2dd97d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/HGQ/test_hgq.py b/tests/HGQ/test_hgq.py new file mode 100644 index 00000000..9edae14f --- /dev/null +++ b/tests/HGQ/test_hgq.py @@ -0,0 +1,46 @@ +import keras, onnx +import numpy as np +from HGQ.layers import HDense, HConv2D, PMaxPooling2D, PFlatten, PReshape, HQuantize +from HGQ import ResetMinMax, FreeBOPs +from HGQ import trace_minmax, to_proxy_model +from qonnx.converters.keras import from_keras +from qonnx.util.exec_qonnx import exec_qonnx + + +def test_convert_HGQ_two_conv2d_to_QONNX(): + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + x_train = x_train.astype('float32') / 255 + x_test = x_test.astype('float32') / 255 + # NOTE: we just test a few samples + np.save("/tmp/x_test.npy", x_test[:100]) + + model = keras.models.Sequential([ + HQuantize(beta=3e-5), + PReshape((28, 28, 1)), + PMaxPooling2D((2, 2)), + HConv2D(1, (3, 3), activation='relu', beta=3e-5, parallel_factor=144), + PMaxPooling2D((2, 2)), + HConv2D(1, (3, 3), activation='relu', beta=3e-5, parallel_factor=16), + PMaxPooling2D((2, 2)), + PFlatten(), + HDense(10, beta=3e-5) + ]) + + opt = keras.optimizers.Adam(learning_rate=0.001) + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + model.compile(optimizer=opt, loss=loss, metrics=['accuracy']) + callbacks = [ResetMinMax(), FreeBOPs()] + + model.fit(x_train, y_train, epochs=1, batch_size=32, callbacks=callbacks) + + trace_minmax(model, x_train, cover_factor=1.0) + proxy = to_proxy_model(model, aggressive=True) + + onnx_model, external_storage = from_keras(proxy, "test_qkeras_conversion", opset=9) + onnx.save(onnx_model, '/tmp/hgq.onnx') + + qonnx_out = exec_qonnx('/tmp/hgq.onnx', "/tmp/x_test.npy") + hgq_out = proxy.predict(x_test[:100]) + assert np.isclose( + qonnx_out, hgq_out + ).all(), "Output of HGQ proxy model and converted QONNX model should match." From 0ce8cf8eba15025478ed1ba459cb883019d629b5 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 23 Aug 2024 17:32:51 +0200 Subject: [PATCH 06/11] fix linting, add HGQ layer --- src/qonnx/converters/HGQ/hgqlayers.py | 24 ++++++++------ src/qonnx/converters/HGQ/onnx.py | 11 ++----- src/qonnx/converters/HGQ/quantizers.py | 8 ++--- src/qonnx/converters/keras.py | 16 +++++----- src/qonnx/custom_op/general/FixedPoint.py | 38 +++++------------------ src/qonnx/util/cleanup.py | 27 ++++++++++++++-- src/qonnx/util/exec_qonnx.py | 3 ++ 7 files changed, 64 insertions(+), 63 deletions(-) diff --git a/src/qonnx/converters/HGQ/hgqlayers.py b/src/qonnx/converters/HGQ/hgqlayers.py index f481ef6c..089fe0d8 100644 --- a/src/qonnx/converters/HGQ/hgqlayers.py +++ b/src/qonnx/converters/HGQ/hgqlayers.py @@ -1,13 +1,19 @@ -import HGQ -import qkeras -from qkeras.quantizers import BaseQuantizer import tensorflow as tf +class HGQIdentity(tf.keras.layers.Layer): + def __init__(self, name, dtype): + super(HGQIdentity, self).__init__(name=name, dtype=dtype) + + def call(self, inputs): + return inputs + + # TODO: this should be implemented in HGQ so we can import it here # from HGQ.utils import REGISTERED_LAYERS as HGQ_LAYERS HGQ_LAYERS = ["FixedPointQuantizer"] + def extract_quantizers_from_hgq_layer(layer, model): """ """ layer_class = layer.__class__.__name__ @@ -22,23 +28,21 @@ def extract_quantizers_from_hgq_layer(layer, model): def extract_FixedPointQuantizer(layer, model): - quantizers = layer.get_config() if "overrides" not in quantizers: # TODO: add support for FixedPointQuantizer which dont override - raise ValueError(f"Not supported: FixedPointQuantizer has no layers to override") + raise ValueError("Not supported: FixedPointQuantizer has no layers to override") quantizers["inputs"] = { "keep_negative": layer.keep_negative.numpy(), "bits": layer.bits.numpy(), "integer_bits": layer.integers.numpy(), } - keras_config = {'name': quantizers["name"], 'dtype': 'float32'} + quantizers["keras_layer"] = "FixedPointQuantizer" + keras_config = {"name": quantizers["name"], "dtype": "float32"} - return "Identity", keras_config, quantizers + return "HGQIdentity", keras_config, quantizers -handler_map = { - "FixedPointQuantizer": extract_FixedPointQuantizer -} +handler_map = {"FixedPointQuantizer": extract_FixedPointQuantizer} diff --git a/src/qonnx/converters/HGQ/onnx.py b/src/qonnx/converters/HGQ/onnx.py index c0137de7..3b89412e 100644 --- a/src/qonnx/converters/HGQ/onnx.py +++ b/src/qonnx/converters/HGQ/onnx.py @@ -13,11 +13,9 @@ def get_hgq_onnx_handlers(all_quantizers): Dictionary containing the handler information for every type of layer """ return { - # NOTE: we replace the StatefulPartitionedCall layers with Identity layers + # NOTE: we replace the StatefulPartitionedCall layers with HGQIdentity layers # after them we are adding now FixedPoint layers for the quantitzation - "Identity": ( - FixedPoint, ["FixedPoint", all_quantizers] - ), + "StatefulPartitionedCall": (FixedPoint, ["FixedPoint", all_quantizers]), } @@ -30,9 +28,7 @@ def _extract_node_name(onnx_node, keras_quantizers): """ onnx_name = onnx_node.name - print(onnx_node) keras_names = keras_quantizers.keys() - print(keras_names, onnx_name) for keras_name in keras_names: match = "/" + keras_name + "/" if match in onnx_name: @@ -52,7 +48,6 @@ def FixedPoint(ctx, node, name, args): quant_params = get_quant_params(None, quantizers) attr = quant_params["attributes"] input_nodes = [node.output[0]] - print(node.input[0]) for key in quantizers["inputs"].keys(): name = f"{node.name}_FixedPointQuantizer_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) @@ -61,7 +56,7 @@ def FixedPoint(ctx, node, name, args): quant_fixed_node = ctx.make_node( "FixedPoint", input_nodes, - dtypes=None, # TODO: we have to get the type here + dtypes=None, # TODO: we have to get the type here name=node.name + "_FixedPoint_quantizer", attr=attr, domain="qonnx", diff --git a/src/qonnx/converters/HGQ/quantizers.py b/src/qonnx/converters/HGQ/quantizers.py index 21e12b58..d48e5fdc 100644 --- a/src/qonnx/converters/HGQ/quantizers.py +++ b/src/qonnx/converters/HGQ/quantizers.py @@ -1,19 +1,15 @@ - - def get_quant_params(tensor, hgq_quantizer): - return handler_map[hgq_quantizer["keras_layer"]](tensor, hgq_quantizer) def convert_quantized_bits(tensor, quantizer): - settings = { "attributes": { "RND": quantizer["RND"], "SAT": quantizer["SAT"], }, "inputs": { - "integer_bits": quantizer["inputs"]["integers"], + "integer_bits": quantizer["inputs"]["integer_bits"], "keep_negative": quantizer["inputs"]["keep_negative"], "bits": quantizer["inputs"]["bits"], }, @@ -23,5 +19,5 @@ def convert_quantized_bits(tensor, quantizer): handler_map = { - "Identity": convert_quantized_bits, + "FixedPointQuantizer": convert_quantized_bits, } diff --git a/src/qonnx/converters/keras.py b/src/qonnx/converters/keras.py index 7e434428..2317844e 100644 --- a/src/qonnx/converters/keras.py +++ b/src/qonnx/converters/keras.py @@ -7,12 +7,11 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.util.cleanup import cleanup_model +from .HGQ.hgqlayers import HGQ_LAYERS, HGQIdentity, extract_quantizers_from_hgq_layer +from .HGQ.onnx import get_hgq_onnx_handlers from .qkeras.onnx import get_qkeras_onnx_handlers from .qkeras.qlayers import extract_quantizers_from_qkeras_layer -from .HGQ.onnx import get_hgq_onnx_handlers -from .HGQ.hgqlayers import HGQ_LAYERS, extract_quantizers_from_hgq_layer - # NOTE: thats a list for qkeras & HGQ layers _unsupported_layers = [ # These require some extra work @@ -193,6 +192,8 @@ def _strip_hgq_model(model): The stripped model, and the quantizers in a dictionary format """ quantizers = OrderedDict() + layer_dict = tf.keras.layers.__dict__ + layer_dict["HGQIdentity"] = HGQIdentity def extract_quantizers(layer): keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_hgq_layer(layer, model) @@ -200,14 +201,13 @@ def extract_quantizers(layer): layer_quantizers["input"] = layer.input.name quantizers[layer_quantizers["name"]] = layer_quantizers - layer_class = tf.keras.layers.__dict__.get(keras_cls_name, None) + layer_class = layer_dict.get(keras_cls_name, None) if layer_class is None: raise Exception("Cannot create Keras layer from QKeras class {}".format(keras_cls_name)) return layer_class.from_config(layer_cfg) stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers) - for layer in model.layers: if layer.__class__.__name__ in HGQ_LAYERS: # NOTE: the FixedPointQuantizer does not have weights it only has @@ -259,13 +259,15 @@ def from_keras( _check_supported_layers(model) keras_op_handlers = {} + is_HGQ = False if _is_qkeras_model(model): keras_model, quantizers = _strip_qkeras_model(model) keras_op_handlers.update(get_qkeras_onnx_handlers(quantizers)) elif _is_hgq_model(model): - keras_model = model keras_model, quantizers = _strip_hgq_model(model) keras_op_handlers.update(get_hgq_onnx_handlers(quantizers)) + keras_model = model + is_HGQ = True else: keras_model, quantizers = model, {} @@ -309,7 +311,7 @@ def from_keras( if tensor.name in q_node_outputs: tensor.type.tensor_type.elem_type = 1 - onnx_model = cleanup_model(onnx_model) + onnx_model = cleanup_model(onnx_model, is_HGQ=is_HGQ) onnx_model.model = add_value_info_for_constants(onnx_model.model) if output_path is not None: diff --git a/src/qonnx/custom_op/general/FixedPoint.py b/src/qonnx/custom_op/general/FixedPoint.py index cf1e84ba..b1c64bce 100644 --- a/src/qonnx/custom_op/general/FixedPoint.py +++ b/src/qonnx/custom_op/general/FixedPoint.py @@ -27,9 +27,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np -from onnx import TensorProto, helper - from HGQ.proxy.fixed_point_quantizer import gfixed_quantizer +from onnx import TensorProto, helper from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp @@ -37,7 +36,7 @@ class FixedPoint(CustomOp): """Generic quantization operation for HGQ FixedPoint layer to QONNX. - + Takes four inputs: - input tensor to quantize @@ -92,28 +91,7 @@ def get_scaled_integer_datatype(self, model): return finn_dt def get_output_dtype(self, model): - node = self.onnx_node - # scale, zero-point and bitwidth must be read from initializers - integer_bits = model.get_initializer(node.input[1]) - keep_negative = model.get_initializer(node.input[2]) - bits = model.get_initializer(node.input[3]) - assert integer_bits is not None, "Found unspecified scale for Quant node: " + str(node) - assert keep_negative is not None, "Found unspecified zero point for Quant node: " + str(node) - assert bits is not None, "Found unspecified bitwidth for Quant node: " + str(node) - # extract the bitwidth (assume scalar) - assert bitwidth.ndim == 0, "Bitwidth must be scalar for Quant node: " + str(node) - bitwidth = bitwidth.item() - assert int(bitwidth) == bitwidth, "Bitwidth must be integer for Quant node: " + str(node) - bitwidth = int(bitwidth) - # determine the FINN DataType - unit_scale = np.all(scale == 1.0) - zero_zeropt = np.all(zeropt == 0.0) - assert zero_zeropt, "Only zero_point=0 Quant nodes supported for now" - if unit_scale and zero_zeropt: - finn_dt = self.get_integer_datatype(model) - else: - finn_dt = self.get_scaled_integer_datatype(model) - return finn_dt + raise NotImplementedError("get_output_dtype for FixedPoint is not implemented") def infer_node_datatype(self, model): try: @@ -127,15 +105,15 @@ def execute_node(self, context, graph): node = self.onnx_node # save inputs inp_tensor = context[node.input[0]] - integer_bits = context[node.input[1]] - keep_negative = context[node.input[2]] - bits = context[node.input[3]] + # TODO: we assume here an order that the name of the input[1] = keep_negative etc. + keep_negative = context[node.input[1]] + bits = context[node.input[2]] + integer_bits = context[node.input[3]] # save attributes RND = self.get_nodeattr("RND") SAT = self.get_nodeattr("SAT") # calculate output - ret = gfixed_quantizer( - inp_tensor, keep_negative, bits, integer_bits, RND=RND, SAT=SAT) + ret = gfixed_quantizer(inp_tensor, keep_negative, bits, integer_bits, RND=RND, SAT=SAT) # ensure output is ndarray (even if 0d) # since numpy silently flattens 0d arrays to scalars # more: https://github.com/numpy/numpy/issues/13105 diff --git a/src/qonnx/util/cleanup.py b/src/qonnx/util/cleanup.py index 12974fdb..0144c066 100644 --- a/src/qonnx/util/cleanup.py +++ b/src/qonnx/util/cleanup.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import clize +import onnx from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.change_batchsize import ChangeBatchSize @@ -43,7 +44,7 @@ from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False): +def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False, is_HGQ=False): """Execute the transformations for the cleanup function on a model level. This allows the reuse of the cleanup transformations, without needing to read/write the model from/to disk. @@ -58,10 +59,32 @@ def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_c for qnt_node in qnt_nodes: qnt_node.domain = "qonnx.custom_op.general" if preserve_qnt_ops: - preserve_qnt_optypes = ["Quant", "BipolarQuant", "QuantizeLinear", "DequantizeLinear"] + preserve_qnt_optypes = ["FixedPoint", "Quant", "BipolarQuant", "QuantizeLinear", "DequantizeLinear"] else: preserve_qnt_optypes = [] + if is_HGQ: + nodes_to_keep = [] + for node in model.graph.node: + if "StatefulPartitionedCall" in node.name and "FixedPoint" not in node.name: + new_input_FixedPoint = node.input + continue + elif "StatefulPartitionedCall" in node.name and "FixedPoint" in node.name: + node.input[:] = new_input_FixedPoint + nodes_to_keep.append(node) + + # Create a new graph with the nodes to keep + new_graph = onnx.helper.make_graph( + nodes_to_keep, model.graph.name, model.graph.input, model.graph.output, model.graph.initializer + ) + model = onnx.helper.make_model(new_graph) + + # check the model + # onnx.checker.check_model(model) + + # Create a new model with the new graph + model = ModelWrapper(model) + if override_inpsize is not None: if type(override_inpsize) is str: override_inpsize = eval(override_inpsize) diff --git a/src/qonnx/util/exec_qonnx.py b/src/qonnx/util/exec_qonnx.py index 5c059281..6a5ca0cf 100644 --- a/src/qonnx/util/exec_qonnx.py +++ b/src/qonnx/util/exec_qonnx.py @@ -147,6 +147,7 @@ def exec_qonnx( pbar = tqdm(range(n_dset_iters)) + prediction = [] for iter in pbar: iter_suffix = "_batch%d" % iter idict = {} @@ -163,6 +164,7 @@ def exec_qonnx( if n_custom_nodes > 0: # run node-by-node in qonnx odict = execute_onnx(model, idict) + prediction.append(odict[model.graph.output[0].name].flatten()) else: # run using onnxruntime sess = rt.InferenceSession(model.model.SerializeToString()) @@ -190,6 +192,7 @@ def exec_qonnx( "Batch [%d/%d]: ok %d nok %d accuracy %f (overall ok %d nok %d accuracy %f)" % (iter + 1, n_dset_iters, ok_batch, nok_batch, accuracy_batch, ok, nok, accuracy_overall) ) + return np.array(prediction) def main(): From 41f394849d3a38baa268482943bfe128541292e9 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 23 Aug 2024 17:42:01 +0200 Subject: [PATCH 07/11] fix linting --- tests/core/test_datatypes.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index e98a885f..722a0ac4 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -28,8 +28,7 @@ import numpy as np -from qonnx.core.datatype import DataType -from qonnx.core.datatype import resolve_datatype +from qonnx.core.datatype import DataType, resolve_datatype def test_datatypes(): @@ -120,12 +119,11 @@ def test_resolve_datatype(): def test_input_type_error(): - def test_resolve_datatype(input): # test with invalid input to check if the TypeError works try: resolve_datatype(input) # This should raise a TypeError - except TypeError as e: + except TypeError: pass else: print("Test with invalid input failed: No TypeError was raised.") From 288070926f919c2d6f09a3c489b0def3806f9380 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 23 Aug 2024 18:05:29 +0200 Subject: [PATCH 08/11] rename notebook, more linting fixes --- ...gq_to_qonnx.ipynb => 5_hgq_to_qonnx.ipynb} | 1 - src/qonnx/custom_op/general/__init__.py | 3 +- tests/HGQ/test_hgq.py | 47 ++++++++++--------- 3 files changed, 25 insertions(+), 26 deletions(-) rename notebooks/{4_hgq_to_qonnx.ipynb => 5_hgq_to_qonnx.ipynb} (99%) diff --git a/notebooks/4_hgq_to_qonnx.ipynb b/notebooks/5_hgq_to_qonnx.ipynb similarity index 99% rename from notebooks/4_hgq_to_qonnx.ipynb rename to notebooks/5_hgq_to_qonnx.ipynb index 022b9207..762a13d5 100644 --- a/notebooks/4_hgq_to_qonnx.ipynb +++ b/notebooks/5_hgq_to_qonnx.ipynb @@ -10,7 +10,6 @@ "import numpy as np\n", "from HGQ.layers import HDense, HConv2D, PMaxPooling2D, PFlatten, PReshape, HQuantize\n", "from HGQ import ResetMinMax, FreeBOPs\n", - "import HGQ\n", "from HGQ import trace_minmax, to_proxy_model" ] }, diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 390ce76c..5eecb6cd 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -28,12 +28,12 @@ from qonnx.custom_op.general.bipolar_quant import BipolarQuant from qonnx.custom_op.general.debugmarker import DebugMarker +from qonnx.custom_op.general.FixedPoint import FixedPoint from qonnx.custom_op.general.genericpartition import GenericPartition from qonnx.custom_op.general.im2col import Im2Col from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quant import Quant -from qonnx.custom_op.general.FixedPoint import FixedPoint from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul @@ -51,4 +51,3 @@ custom_op["Trunc"] = Trunc custom_op["BipolarQuant"] = BipolarQuant custom_op["FixedPoint"] = FixedPoint - diff --git a/tests/HGQ/test_hgq.py b/tests/HGQ/test_hgq.py index 9edae14f..01283dff 100644 --- a/tests/HGQ/test_hgq.py +++ b/tests/HGQ/test_hgq.py @@ -1,34 +1,37 @@ -import keras, onnx +import keras import numpy as np -from HGQ.layers import HDense, HConv2D, PMaxPooling2D, PFlatten, PReshape, HQuantize -from HGQ import ResetMinMax, FreeBOPs -from HGQ import trace_minmax, to_proxy_model +import onnx +from HGQ import FreeBOPs, ResetMinMax, to_proxy_model, trace_minmax +from HGQ.layers import HConv2D, HDense, HQuantize, PFlatten, PMaxPooling2D, PReshape + from qonnx.converters.keras import from_keras from qonnx.util.exec_qonnx import exec_qonnx def test_convert_HGQ_two_conv2d_to_QONNX(): (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() - x_train = x_train.astype('float32') / 255 - x_test = x_test.astype('float32') / 255 + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 # NOTE: we just test a few samples np.save("/tmp/x_test.npy", x_test[:100]) - model = keras.models.Sequential([ - HQuantize(beta=3e-5), - PReshape((28, 28, 1)), - PMaxPooling2D((2, 2)), - HConv2D(1, (3, 3), activation='relu', beta=3e-5, parallel_factor=144), - PMaxPooling2D((2, 2)), - HConv2D(1, (3, 3), activation='relu', beta=3e-5, parallel_factor=16), - PMaxPooling2D((2, 2)), - PFlatten(), - HDense(10, beta=3e-5) - ]) + model = keras.models.Sequential( + [ + HQuantize(beta=3e-5), + PReshape((28, 28, 1)), + PMaxPooling2D((2, 2)), + HConv2D(1, (3, 3), activation="relu", beta=3e-5, parallel_factor=144), + PMaxPooling2D((2, 2)), + HConv2D(1, (3, 3), activation="relu", beta=3e-5, parallel_factor=16), + PMaxPooling2D((2, 2)), + PFlatten(), + HDense(10, beta=3e-5), + ] + ) opt = keras.optimizers.Adam(learning_rate=0.001) loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) - model.compile(optimizer=opt, loss=loss, metrics=['accuracy']) + model.compile(optimizer=opt, loss=loss, metrics=["accuracy"]) callbacks = [ResetMinMax(), FreeBOPs()] model.fit(x_train, y_train, epochs=1, batch_size=32, callbacks=callbacks) @@ -37,10 +40,8 @@ def test_convert_HGQ_two_conv2d_to_QONNX(): proxy = to_proxy_model(model, aggressive=True) onnx_model, external_storage = from_keras(proxy, "test_qkeras_conversion", opset=9) - onnx.save(onnx_model, '/tmp/hgq.onnx') + onnx.save(onnx_model, "/tmp/hgq.onnx") - qonnx_out = exec_qonnx('/tmp/hgq.onnx', "/tmp/x_test.npy") + qonnx_out = exec_qonnx("/tmp/hgq.onnx", "/tmp/x_test.npy") hgq_out = proxy.predict(x_test[:100]) - assert np.isclose( - qonnx_out, hgq_out - ).all(), "Output of HGQ proxy model and converted QONNX model should match." + assert np.isclose(qonnx_out, hgq_out).all(), "Output of HGQ proxy model and converted QONNX model should match." From 40e41c7f1bc3a0d5e663e4f83388a8228ba9d7a3 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 23 Aug 2024 18:08:33 +0200 Subject: [PATCH 09/11] linting fix --- src/qonnx/core/datatype.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index 4bdffdf9..343416ba 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -376,7 +376,6 @@ def get_canonical_name(self): def resolve_datatype(name): - if not isinstance(name, str): raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'") From be5e979cb8de692b13873ff2fe32468a37d3eadf Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 23 Aug 2024 18:36:32 +0200 Subject: [PATCH 10/11] update README for github actions, add HGQ as a requirement --- README.md | 9 +++++++++ setup.cfg | 1 + 2 files changed, 10 insertions(+) diff --git a/README.md b/README.md index a89baa86..4b83ee5f 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,15 @@ pre-commit install Every time you commit some code, the pre-commit hooks will first run, performing various checks and fixes. In some cases pre-commit won’t be able to fix the issues and you may have to fix it manually, then run git commit once again. The checks are configured in .pre-commit-config.yaml under the repo root. +### Github Actions + +By using [act](https://github.com/nektos/act) you can also run all github actions locally using docker. After installation you can run: + +``` +cd qonnx +act +``` + ## Why QONNX? The QONNX representation has several advantages compared to other alternatives, as summarized in the table below. diff --git a/setup.cfg b/setup.cfg index 602d6ada..a98454bf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ install_requires = onnxruntime>=1.16.1 sigtools>=4.0.1 toposort>=1.7.0 + hgq>=0.2.3 [options.packages.find] From ecbe2aef785d07f48e1e6687e1986d615ce870cb Mon Sep 17 00:00:00 2001 From: makoeppel Date: Fri, 23 Aug 2024 18:39:05 +0200 Subject: [PATCH 11/11] update HQG version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a98454bf..37509771 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ install_requires = onnxruntime>=1.16.1 sigtools>=4.0.1 toposort>=1.7.0 - hgq>=0.2.3 + hgq [options.packages.find]