Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hgq/support for qonnx #123

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ pre-commit install
Every time you commit some code, the pre-commit hooks will first run, performing various checks and fixes. In some cases pre-commit won’t be able to
fix the issues and you may have to fix it manually, then run git commit once again. The checks are configured in .pre-commit-config.yaml under the repo root.

### Github Actions

By using [act](https://github.com/nektos/act) you can also run all github actions locally using docker. After installation you can run:

```
cd qonnx
act
```

## Why QONNX?

The QONNX representation has several advantages compared to other alternatives, as summarized in the table below.
Expand Down
413 changes: 413 additions & 0 deletions notebooks/5_hgq_to_qonnx.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ install_requires =
onnxruntime>=1.16.1
sigtools>=4.0.1
toposort>=1.7.0
hgq


[options.packages.find]
Expand Down
Empty file.
48 changes: 48 additions & 0 deletions src/qonnx/converters/HGQ/hgqlayers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import tensorflow as tf


class HGQIdentity(tf.keras.layers.Layer):
def __init__(self, name, dtype):
super(HGQIdentity, self).__init__(name=name, dtype=dtype)

def call(self, inputs):
return inputs


# TODO: this should be implemented in HGQ so we can import it here
# from HGQ.utils import REGISTERED_LAYERS as HGQ_LAYERS
HGQ_LAYERS = ["FixedPointQuantizer"]


def extract_quantizers_from_hgq_layer(layer, model):
""" """
layer_class = layer.__class__.__name__
if layer_class in HGQ_LAYERS:
handler = handler_map.get(layer_class, None)
if handler:
return handler_map[layer_class](layer, model)
else:
return layer_class, layer.get_config(), None
else:
return layer_class, layer.get_config(), None


def extract_FixedPointQuantizer(layer, model):
quantizers = layer.get_config()

if "overrides" not in quantizers:
# TODO: add support for FixedPointQuantizer which dont override
raise ValueError("Not supported: FixedPointQuantizer has no layers to override")

quantizers["inputs"] = {
"keep_negative": layer.keep_negative.numpy(),
"bits": layer.bits.numpy(),
"integer_bits": layer.integers.numpy(),
}
quantizers["keras_layer"] = "FixedPointQuantizer"
keras_config = {"name": quantizers["name"], "dtype": "float32"}

return "HGQIdentity", keras_config, quantizers


handler_map = {"FixedPointQuantizer": extract_FixedPointQuantizer}
65 changes: 65 additions & 0 deletions src/qonnx/converters/HGQ/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from .quantizers import get_quant_params


def get_hgq_onnx_handlers(all_quantizers):
"""Returns the handlers for each kind of layer

Args:
all_quantizers: All the quantizers of the model in dictionary format *check

Returns:
Dictionary containing the handler information for every type of layer
"""
return {
# NOTE: we replace the StatefulPartitionedCall layers with HGQIdentity layers
# after them we are adding now FixedPoint layers for the quantitzation
"StatefulPartitionedCall": (FixedPoint, ["FixedPoint", all_quantizers]),
}


def _extract_node_name(onnx_node, keras_quantizers):
"""

Args:
onnx_node: The onnx node to get the information from
keras_quantizers: The dictionary of all the keras quantizers

"""
onnx_name = onnx_node.name
keras_names = keras_quantizers.keys()
for keras_name in keras_names:
match = "/" + keras_name + "/"
if match in onnx_name:
return keras_name

return None


def FixedPoint(ctx, node, name, args):
all_quantizers = args[0]
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
# if we have overrides we are converting a FixedPointQuantizer from HGQ
if quantizers.get("overrides"):
quant_params = get_quant_params(None, quantizers)
attr = quant_params["attributes"]
input_nodes = [node.output[0]]
for key in quantizers["inputs"].keys():
name = f"{node.name}_FixedPointQuantizer_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
quant_fixed_node = ctx.make_node(
"FixedPoint",
input_nodes,
dtypes=None, # TODO: we have to get the type here
name=node.name + "_FixedPoint_quantizer",
attr=attr,
domain="qonnx",
)
ctx.insert_node_on_output(quant_fixed_node, node.output[0])
ctx.set_shape(quant_fixed_node.output[0], ctx.get_shape(node.output[0]))
23 changes: 23 additions & 0 deletions src/qonnx/converters/HGQ/quantizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
def get_quant_params(tensor, hgq_quantizer):
return handler_map[hgq_quantizer["keras_layer"]](tensor, hgq_quantizer)


def convert_quantized_bits(tensor, quantizer):
settings = {
"attributes": {
"RND": quantizer["RND"],
"SAT": quantizer["SAT"],
},
"inputs": {
"integer_bits": quantizer["inputs"]["integer_bits"],
"keep_negative": quantizer["inputs"]["keep_negative"],
"bits": quantizer["inputs"]["bits"],
},
}

return settings


handler_map = {
"FixedPointQuantizer": convert_quantized_bits,
}
95 changes: 78 additions & 17 deletions src/qonnx/converters/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.cleanup import cleanup_model

from .HGQ.hgqlayers import HGQ_LAYERS, HGQIdentity, extract_quantizers_from_hgq_layer
from .HGQ.onnx import get_hgq_onnx_handlers
from .qkeras.onnx import get_qkeras_onnx_handlers
from .qkeras.qlayers import extract_quantizers_from_layer
from .qkeras.qlayers import extract_quantizers_from_qkeras_layer

# NOTE: thats a list for qkeras & HGQ layers
_unsupported_layers = [
# These require some extra work
"QBatchNormalization",
"QConv2DBatchnorm",
"QDepthwiseConv2DBatchnorm",
# TODO: add HGQ layers
]

# Skip remove_identity optimizer
Expand Down Expand Up @@ -102,6 +106,30 @@ def iterate_model(model):
return iterate_model(model)


def _is_hgq_model(model):
"""Check if the model has any HGQ layers, so we can handle the HGQ layers separately

Args:
model: the model we want to convert

Returns:
True if the model contains any HGQ layer
"""

def iterate_model(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model):
found_qkeras = iterate_model(layer)
if found_qkeras:
return True
elif layer.__class__.__name__ in HGQ_LAYERS:
return True

return False

return iterate_model(model)


def _check_supported_layers(model):
"""Check if all the layers in the model are supported for conversion

Expand All @@ -117,7 +145,7 @@ def iterate_model(model):
if isinstance(layer, tf.keras.Model):
iterate_model(layer)
elif layer.__class__.__name__ in _unsupported_layers:
raise Exception("Currently unsupported layer found in QKeras model: {}".format(layer.__class__.__name__))
raise Exception("Currently unsupported layer found in model: {}".format(layer.__class__.__name__))

iterate_model(model)

Expand All @@ -134,7 +162,7 @@ def _strip_qkeras_model(model):
quantizers = OrderedDict()

def extract_quantizers(layer):
keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_layer(layer)
keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_qkeras_layer(layer)
if layer_quantizers:
layer_quantizers = {
k: None if v == "None" else v for k, v in layer_quantizers.items()
Expand All @@ -150,18 +178,44 @@ def extract_quantizers(layer):

stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers)
stripped_model.set_weights(model.get_weights())

return stripped_model, quantizers


# tests run without this function
def _convert_quantizers_to_nodes(onnx_model, quantizers_dict):
for node_name, quantizers in quantizers_dict.items():
print(node_name, quantizers)
def _strip_hgq_model(model):
"""Strip a HGQ model to obtain the keras model and obtain the quant nodes.

for n in onnx_model.graph.node:
print(n)
Args:
model: the proxy tf.keras model from HGQ

return onnx_model.model
Returns:
The stripped model, and the quantizers in a dictionary format
"""
quantizers = OrderedDict()
layer_dict = tf.keras.layers.__dict__
layer_dict["HGQIdentity"] = HGQIdentity

def extract_quantizers(layer):
keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_hgq_layer(layer, model)
if layer_quantizers:
layer_quantizers["input"] = layer.input.name
quantizers[layer_quantizers["name"]] = layer_quantizers

layer_class = layer_dict.get(keras_cls_name, None)
if layer_class is None:
raise Exception("Cannot create Keras layer from QKeras class {}".format(keras_cls_name))

return layer_class.from_config(layer_cfg)

stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers)
for layer in model.layers:
if layer.__class__.__name__ in HGQ_LAYERS:
# NOTE: the FixedPointQuantizer does not have weights it only has
# self.keep_negative, self.bits and self.integers which we extract later
# from the quantizer
continue
stripped_model.get_layer(layer.name).set_weights(layer.get_weights())
return stripped_model, quantizers


def from_keras(
Expand Down Expand Up @@ -203,23 +257,30 @@ def from_keras(

assert not large_model # TODO for now, let's focus only on models that don't store tensors externally

_check_supported_layers(model)
keras_op_handlers = {}
is_HGQ = False
if _is_qkeras_model(model):
_check_supported_layers(model)
keras_model, quantizers = _strip_qkeras_model(model)
keras_op_handlers.update(get_qkeras_onnx_handlers(quantizers))
elif _is_hgq_model(model):
keras_model, quantizers = _strip_hgq_model(model)
keras_op_handlers.update(get_hgq_onnx_handlers(quantizers))
keras_model = model
is_HGQ = True
else:
keras_model, quantizers = model, {}

qkeras_op_handlers = get_qkeras_onnx_handlers(quantizers)

keras_model.summary()
if custom_op_handlers is not None:
qkeras_op_handlers.update(custom_op_handlers)
keras_op_handlers.update(custom_op_handlers)

model_proto, external_storage = tf2onnx.convert.from_keras(
keras_model,
input_signature=input_signature,
opset=opset,
custom_ops=custom_ops,
custom_op_handlers=qkeras_op_handlers,
custom_op_handlers=keras_op_handlers,
custom_rewriter=custom_rewriter,
inputs_as_nchw=inputs_as_nchw,
extra_opset=extra_opset,
Expand All @@ -242,15 +303,15 @@ def from_keras(
onnx_model.set_tensor_shape(onnx_model.graph.output[0].name, out_shape)

# Set all Quant output tensors to float32 datatype, otherwise they are undefined and crash ONNX execution
qonnx_domain_ops = ["Quant", "Trunc", "BipolarQuant"]
qonnx_domain_ops = ["FixedPoint", "Quant", "Trunc", "BipolarQuant"]
for q_op_type in qonnx_domain_ops:
quant_nodes = onnx_model.get_nodes_by_op_type(q_op_type)
q_node_outputs = [qn.output[0] for qn in quant_nodes]
for tensor in onnx_model.graph.value_info:
if tensor.name in q_node_outputs:
tensor.type.tensor_type.elem_type = 1

onnx_model = cleanup_model(onnx_model)
onnx_model = cleanup_model(onnx_model, is_HGQ=is_HGQ)
onnx_model.model = add_value_info_for_constants(onnx_model.model)

if output_path is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/qonnx/converters/qkeras/qlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS


def extract_quantizers_from_layer(layer):
def extract_quantizers_from_qkeras_layer(layer):
""" """
layer_class = layer.__class__.__name__
if layer_class in QKERAS_LAYERS:
Expand Down
3 changes: 3 additions & 0 deletions src/qonnx/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ def get_canonical_name(self):


def resolve_datatype(name):
if not isinstance(name, str):
raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'")

_special_types = {
"BINARY": IntType(1, False),
"BIPOLAR": BipolarType(),
Expand Down
Loading
Loading