From 538a935089f6c5926d0e73bf0d77cac61bfdf10a Mon Sep 17 00:00:00 2001 From: Harish Date: Sun, 14 Jan 2024 18:50:51 +0000 Subject: [PATCH 1/6] Transformation pass to introduce quantnodes --- .../transformation/introduce_quantnode.py | 263 ++++++++++++++++++ .../test_introduce_quantnode.py | 147 ++++++++++ 2 files changed, 410 insertions(+) create mode 100644 src/qonnx/transformation/introduce_quantnode.py create mode 100644 tests/transformation/test_introduce_quantnode.py diff --git a/src/qonnx/transformation/introduce_quantnode.py b/src/qonnx/transformation/introduce_quantnode.py new file mode 100644 index 00000000..f7e25edf --- /dev/null +++ b/src/qonnx/transformation/introduce_quantnode.py @@ -0,0 +1,263 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import numpy as np +import onnx +from onnx import TensorProto + +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.base import Transformation +from qonnx.transformation.general import SortGraph +from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.basic import qonnx_make_model +from qonnx.util.cleanup import cleanup_model + + +class graph_util: + def get_node_id(self, model): + node_index = {} + node_ind = 0 + for node in model.graph.node: + node_index[node.name] = node_ind + node_ind += 1 + return node_index + + def node_from_name(self, model, node_name): + for node in model.graph.node: + if node.name == node_name: + return node + + def identify_nodes(self, model, node_type): + node_list = [] + for node in model.graph.node: + if node.op_type == node_type: + node_list.append(node) + return node_list + + def create_node( + self, + model, + quantnode_input, + quantnode_output_shape, + node_count, + tensor_count, + scale_value, + zeropoint_value, + bitwidth_value, + narrow, + signed, + rounding_mode, + ): + quantnode_output_dtype = DataType["UINT8"] + quant_tensor = onnx.helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape + ) + model.graph.value_info.append(quant_tensor) + model.set_tensor_datatype(quant_tensor.name, quantnode_output_dtype) + + stationary_input_dtype = DataType["FLOAT32"] + scale_tensor = np.array(scale_value).astype(np.float32) + s_value = onnx.helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape + ) + model.graph.value_info.append(s_value) + model.set_tensor_datatype(s_value.name, stationary_input_dtype) + model.set_initializer(s_value.name, scale_tensor) + + zeropt_tensor = np.array(zeropoint_value).astype(np.float32) + z_value = onnx.helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape + ) + model.graph.value_info.append(z_value) + model.set_tensor_datatype(z_value.name, stationary_input_dtype) + model.set_initializer(z_value.name, zeropt_tensor) + + bitwidth_tensor = np.array(bitwidth_value).astype(np.float32) + b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1]) + model.graph.value_info.append(b_value) + model.set_tensor_datatype(b_value.name, stationary_input_dtype) + model.set_initializer(b_value.name, bitwidth_tensor) + + quant_node = onnx.helper.make_node( + "Quant", + inputs=[quantnode_input, s_value.name, z_value.name, b_value.name], + outputs=[quant_tensor.name], + name="Quant_node_" + str(node_count) + str(tensor_count), + narrow=narrow, + signed=signed, + rounding_mode=rounding_mode, + ) + + return quant_node, quant_tensor + + def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, node_count): + tensor_count = 0 + for pos in input_positions: + node_details = (node_in_focus.name, pos[0]) + if ( + node_details not in quantized_nodes + ): # This is to ensure that we don't quantize the same node for the same input/output index. + if pos[0][0] == "input": + input_to_quantnode = node_in_focus.input[pos[0][1]] + consumer_node = node_in_focus + producer_node = model.find_producer(input_to_quantnode) + if producer_node is None or producer_node.op_type != "Quant": + quantization_to_perform = "yes" + else: + quantization_to_perform = "no" + else: + input_to_quantnode = node_in_focus.output[pos[0][1]] + consumer_node = model.find_consumer(input_to_quantnode) + producer_node = model.find_producer(input_to_quantnode) + if consumer_node is None or consumer_node.op_type != "Quant": + quantization_to_perform = "yes" + else: + quantization_to_perform = "no" + if quantization_to_perform == "yes": + node_indx = self.get_node_id(model) # Getting index of each node in the graph. + quantnode_output_shape = model.get_tensor_shape(input_to_quantnode) # Step: 3 + + quant_node, quant_tensor = self.create_node( + model, + input_to_quantnode, + quantnode_output_shape, + node_count, + tensor_count, + scale_value=pos[1][0], + zeropoint_value=pos[1][1], + bitwidth_value=pos[1][2], + narrow=pos[1][3], + signed=pos[1][4], + rounding_mode=pos[1][5], + ) + + if consumer_node is not None: + node_pos = node_indx[consumer_node.name] + model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name + model.graph.node.append(quant_node) + else: + model.graph.value_info.remove(quant_tensor) + model.graph.node.append(quant_node) + model.graph.output.insert(0, quant_tensor) + model.graph.output.pop(1) + + model = model.transform(SortGraph()) + tensor_count += 1 + quantized_nodes.append(node_details) + else: + print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.") + else: + print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.") + continue + + return model + + +class IntroduceQuantnode(Transformation): + """This transformation can be used to introduce a Quant node for a specific type of node in the graph. + Users would be able to specify the location of the quant node by providing the input and output indexs + as the parameters. + + 1) Expectations: + a) Onnx model in the modelwraper format. + b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model() + c) Batchsize to be set. + + 2) Steps to transform are + Step1: Finding the input for the quant node. + Step2: Finding the consumer of the quant node output. + Step3: Finding the shape for the output tensor of quant node. + Note: The output tensor of the quant node must have the same shape as the + consumer of the input to the quant node. + + 3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type". + + 4) Assert: + a) The input is a dictionary representing the node names as keys and a list of quant positions + as values. + b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. + + 5) Return: + Returns a cleaned version of the model. + + """ + + def __init__(self, quant_node_inputs): + super().__init__() + self.quant_node_inputs = quant_node_inputs + self.graph_util = graph_util() + + def apply(self, model): + model = model.transform(InferShapes()) + if type(self.quant_node_inputs) == dict: + selection_type = self.quant_node_inputs.keys() + if set(selection_type) <= {"name", "op_type"}: + node_count = 0 + quantized_nodes = [] + if "name" in selection_type: + by_name = self.quant_node_inputs[ + "name" + ] # by_name is a dictionary with the unique node names as keys and the list of positions as values. + node_list_by_name = by_name.keys() # name of all the nodes specified by the user for an quant node. + for node_name in node_list_by_name: + node_in_focus = self.graph_util.node_from_name(model, node_name) + input_positions = by_name[ + node_name + ] # input positions specified by the user to introduce quant node. + model = self.graph_util.adjust_graph( + model, input_positions, node_in_focus, quantized_nodes, node_count + ) + node_count += 1 + if "op_type" in selection_type: + by_op_type = self.quant_node_inputs[ + "op_type" + ] # by_name is a dictionary with the unique node names as keys and the list of positions as values. + op_list = by_op_type.keys() + for op in op_list: + node_list = self.graph_util.identify_nodes( + model, op + ) # List of all nodes with the operation type "op". + input_positions = by_op_type[op] + for node_in_focus in node_list: + model = self.graph_util.adjust_graph( + model, input_positions, node_in_focus, quantized_nodes, node_count + ) + node_count += 1 + model = qonnx_make_model(model.graph) + model = ModelWrapper(model) + model = cleanup_model(model) + else: + raise Exception("Unsupported selection type") + else: + raise TypeError("Input must be a dictionary.") + + graph_modified = False + + return (model, graph_modified) diff --git a/tests/transformation/test_introduce_quantnode.py b/tests/transformation/test_introduce_quantnode.py new file mode 100644 index 00000000..cc2e88ef --- /dev/null +++ b/tests/transformation/test_introduce_quantnode.py @@ -0,0 +1,147 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import os +import random +import urllib.request + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.introduce_quantnode import IntroduceQuantnode, graph_util +from qonnx.util.cleanup import cleanup +from qonnx.util.inference_cost import inference_cost + +random.seed(42) + +graph_util = graph_util() + +a = "https://github.com/onnx/models/raw/main/validated/vision/" +b = "classification/resnet/model/resnet18-v1-7.onnx?download=" + +model_details = { + "resnet18-v1-7": { + "description": "Resnet18 Opset version 7.", + "url": (a + b), + "test_input": { + "name": { + "Conv_0": [ + (("input", 0), (1, 0, 8, 0, 1, "ROUND")), + (("input", 1), (1, 0, 8, 0, 1, "ROUND")), + (("output", 0), (1, 0, 8, 0, 1, "ROUND")), + ], + "Conv_1": [(("input", 0), (1, 0, 8, 0, 1, "ROUND"))], + "Conv_2": [(("input", 1), (1, 0, 8, 0, 1, "ROUND")), (("output", 0), (1, 0, 8, 0, 1, "ROUND"))], + }, + "op_type": { + "Gemm": [ + (("input", 0), (1, 0, 8, 0, 1, "ROUND")), + (("input", 1), (1, 0, 8, 0, 1, "ROUND")), + (("input", 2), (1, 0, 8, 0, 1, "ROUND")), + (("output", 0), (1, 0, 8, 0, 1, "ROUND")), + ] + }, + }, + }, +} + + +def download_model(test_model, do_cleanup=False, return_modelwrapper=False): + qonnx_url = model_details[test_model]["url"] + # download test data + dl_dir = "/tmp" + dl_file = dl_dir + f"/{test_model}.onnx" + ret = dl_file + if not os.path.isfile(dl_file): + urllib.request.urlretrieve(qonnx_url, dl_file) + if do_cleanup: + out_file = dl_dir + f"/{test_model}_clean.onnx" + cleanup(dl_file, out_file=out_file, override_batchsize=1) + ret = out_file + if return_modelwrapper: + ret = ModelWrapper(ret) + return ret + + +def to_verify(model, test_details): + by = random.choice(list(test_details.keys())) # by "name" or "op_type" + + if by == "name": + sample_node_name = random.choice(list(test_details["name"].keys())) + sample_node = graph_util.node_from_name(model, sample_node_name) + sample_pos = random.choice(test_details["name"][sample_node_name]) + if by == "op_type": + node_type = random.choice(list(test_details["op_type"].keys())) + sample_node = random.choice(graph_util.identify_nodes(model, node_type)) + sample_pos = random.choice(test_details["op_type"][node_type]) + + if sample_pos[0][0] == "input": + tensor_to_verify = sample_node.input[sample_pos[0][1]] + producer_node = model.find_producer(tensor_to_verify) + if producer_node.op_type == "Quant": + verification = "Success" + else: + verification = "Failure" + if sample_pos[0][0] == "output": + tensor_to_verify = sample_node.output[sample_pos[0][1]] + consumer_node = model.find_consumer(tensor_to_verify) + if consumer_node.op_type == "Quant": + verification = "Success" + else: + verification = "Failure" + + return verification + + +@pytest.mark.parametrize("test_model", model_details.keys()) +def test_introduce_quantnode(test_model): + test_details = model_details[test_model] + model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) + original_model_inf_cost = inference_cost(model, discount_sparsity=False) + nodes_pos = test_details["test_input"] + model = model.transform(IntroduceQuantnode(nodes_pos)) + quantnodes_added = len(model.get_nodes_by_op_type("Quant")) + assert quantnodes_added == 10 # 10 positions are specified. + verification = to_verify(model, nodes_pos) + assert verification == "Success" + inf_cost = inference_cost(model, discount_sparsity=False) + assert ( + inf_cost["total_macs"] == original_model_inf_cost["total_macs"] + ) # "1814073344.0" must be same as the original model. + assert ( + inf_cost["total_mem_w_elems"] == original_model_inf_cost["total_mem_w_elems"] + ) # "11678912.0" must be same as the original model. + assert ( + inf_cost["total_mem_o_bits"] == original_model_inf_cost["total_mem_o_bits"] + ) # "79510784.0" must be same as the original model. + assert ( + inf_cost["total_mem_o_elems"] == original_model_inf_cost["total_mem_o_elems"] + ) # "2484712.0" must be same as the original model. + assert inf_cost["total_bops"] == 1566256136192.0 + assert inf_cost["total_mem_w_bits"] == 360326656.0 + assert inf_cost["op_mac_INT8_INT8"] == 118525952.0 From 2feab8455a240faa421ac2aa6748b7e6950a6348 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Mon, 5 Feb 2024 16:46:34 +0100 Subject: [PATCH 2/6] [Test] override_batchsize -> override_inpsize for cleanup --- tests/transformation/test_introduce_quantnode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transformation/test_introduce_quantnode.py b/tests/transformation/test_introduce_quantnode.py index cc2e88ef..f53dbf63 100644 --- a/tests/transformation/test_introduce_quantnode.py +++ b/tests/transformation/test_introduce_quantnode.py @@ -81,7 +81,7 @@ def download_model(test_model, do_cleanup=False, return_modelwrapper=False): urllib.request.urlretrieve(qonnx_url, dl_file) if do_cleanup: out_file = dl_dir + f"/{test_model}_clean.onnx" - cleanup(dl_file, out_file=out_file, override_batchsize=1) + cleanup(dl_file, out_file=out_file, override_inpsize=1) ret = out_file if return_modelwrapper: ret = ModelWrapper(ret) From 3e132fe2dd892032a75013ff0c6956ab9be172a5 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Mon, 5 Feb 2024 17:01:10 +0100 Subject: [PATCH 3/6] [GraphQnt] some cleanup and renaming --- ...troduce_quantnode.py => quantize_graph.py} | 42 +++++++++---------- ...ce_quantnode.py => test_quantize_graph.py} | 10 ++--- 2 files changed, 25 insertions(+), 27 deletions(-) rename src/qonnx/transformation/{introduce_quantnode.py => quantize_graph.py} (89%) rename tests/transformation/{test_introduce_quantnode.py => test_quantize_graph.py} (95%) diff --git a/src/qonnx/transformation/introduce_quantnode.py b/src/qonnx/transformation/quantize_graph.py similarity index 89% rename from src/qonnx/transformation/introduce_quantnode.py rename to src/qonnx/transformation/quantize_graph.py index f7e25edf..af290730 100644 --- a/src/qonnx/transformation/introduce_quantnode.py +++ b/src/qonnx/transformation/quantize_graph.py @@ -180,32 +180,30 @@ def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, n return model -class IntroduceQuantnode(Transformation): - """This transformation can be used to introduce a Quant node for a specific type of node in the graph. - Users would be able to specify the location of the quant node by providing the input and output indexs - as the parameters. +class QuantizeGraph(Transformation): + """This transformation can be used to introduce a Quant node for particular nodes in the graph, + determined based on either op_type or node name. + For the particular nodes identified, users can specify the location of the Quant nodes by providing + the input and output indices where Quant nodes are to be inserted. + Assumes the input model is cleaned-up with all intermediate shapes specified and nodes given + unique names already. - 1) Expectations: - a) Onnx model in the modelwraper format. - b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model() - c) Batchsize to be set. + 2) Steps to transform are + Step1: Finding the input for the quant node. + Step2: Finding the consumer of the quant node output. + Step3: Finding the shape for the output tensor of quant node. + Note: The output tensor of the quant node must have the same shape as the + consumer of the input to the quant node. - 2) Steps to transform are - Step1: Finding the input for the quant node. - Step2: Finding the consumer of the quant node output. - Step3: Finding the shape for the output tensor of quant node. - Note: The output tensor of the quant node must have the same shape as the - consumer of the input to the quant node. + 3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type". - 3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type". + 4) Assert: + a) The input is a dictionary representing the node names as keys and a list of quant positions + as values. + b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. - 4) Assert: - a) The input is a dictionary representing the node names as keys and a list of quant positions - as values. - b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. - - 5) Return: - Returns a cleaned version of the model. + 5) Return: + Returns a cleaned version of the model. """ diff --git a/tests/transformation/test_introduce_quantnode.py b/tests/transformation/test_quantize_graph.py similarity index 95% rename from tests/transformation/test_introduce_quantnode.py rename to tests/transformation/test_quantize_graph.py index f53dbf63..c0ceb456 100644 --- a/tests/transformation/test_introduce_quantnode.py +++ b/tests/transformation/test_quantize_graph.py @@ -33,7 +33,7 @@ import urllib.request from qonnx.core.modelwrapper import ModelWrapper -from qonnx.transformation.introduce_quantnode import IntroduceQuantnode, graph_util +from qonnx.transformation.quantize_graph import QuantizeGraph, graph_util from qonnx.util.cleanup import cleanup from qonnx.util.inference_cost import inference_cost @@ -41,13 +41,13 @@ graph_util = graph_util() -a = "https://github.com/onnx/models/raw/main/validated/vision/" -b = "classification/resnet/model/resnet18-v1-7.onnx?download=" +download_url = "https://github.com/onnx/models/raw/main/validated/vision/" +download_url += "classification/resnet/model/resnet18-v1-7.onnx?download=" model_details = { "resnet18-v1-7": { "description": "Resnet18 Opset version 7.", - "url": (a + b), + "url": download_url, "test_input": { "name": { "Conv_0": [ @@ -124,7 +124,7 @@ def test_introduce_quantnode(test_model): model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) original_model_inf_cost = inference_cost(model, discount_sparsity=False) nodes_pos = test_details["test_input"] - model = model.transform(IntroduceQuantnode(nodes_pos)) + model = model.transform(QuantizeGraph(nodes_pos)) quantnodes_added = len(model.get_nodes_by_op_type("Quant")) assert quantnodes_added == 10 # 10 positions are specified. verification = to_verify(model, nodes_pos) From 42df7a016584cef8023191ff18e74792f102f3c9 Mon Sep 17 00:00:00 2001 From: Harish Date: Thu, 8 Feb 2024 14:22:24 +0000 Subject: [PATCH 4/6] Revised version for QuantizeGraph --- src/qonnx/core/modelwrapper.py | 9 + src/qonnx/transformation/quantize_graph.py | 347 +++++++++----------- tests/transformation/test_quantize_graph.py | 10 +- 3 files changed, 175 insertions(+), 191 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index f78e1334..f21efdab 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -542,6 +542,15 @@ def get_node_index(self, node): except ValueError: return None + def get_node_from_name(self, node_name): + """Returns the node with the specified name.""" + try: + for node in self.graph.node: + if node.name == node_name: + return node + except ValueError: + return None + def get_tensor_layout(self, tensor_name): """Returns the data layout annotation of tensor with given name. The data layout is expressed as a list of strings with as many diff --git a/src/qonnx/transformation/quantize_graph.py b/src/qonnx/transformation/quantize_graph.py index af290730..20feb049 100644 --- a/src/qonnx/transformation/quantize_graph.py +++ b/src/qonnx/transformation/quantize_graph.py @@ -31,225 +31,202 @@ import onnx from onnx import TensorProto -from qonnx.core.datatype import DataType -from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation from qonnx.transformation.general import SortGraph from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import qonnx_make_model from qonnx.util.cleanup import cleanup_model -class graph_util: - def get_node_id(self, model): - node_index = {} - node_ind = 0 - for node in model.graph.node: - node_index[node.name] = node_ind - node_ind += 1 - return node_index - - def node_from_name(self, model, node_name): - for node in model.graph.node: - if node.name == node_name: - return node - - def identify_nodes(self, model, node_type): - node_list = [] - for node in model.graph.node: - if node.op_type == node_type: - node_list.append(node) - return node_list - - def create_node( - self, - model, - quantnode_input, - quantnode_output_shape, - node_count, - tensor_count, - scale_value, - zeropoint_value, - bitwidth_value, - narrow, - signed, - rounding_mode, - ): - quantnode_output_dtype = DataType["UINT8"] - quant_tensor = onnx.helper.make_tensor_value_info( - model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape - ) - model.graph.value_info.append(quant_tensor) - model.set_tensor_datatype(quant_tensor.name, quantnode_output_dtype) - - stationary_input_dtype = DataType["FLOAT32"] - scale_tensor = np.array(scale_value).astype(np.float32) - s_value = onnx.helper.make_tensor_value_info( - model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape - ) - model.graph.value_info.append(s_value) - model.set_tensor_datatype(s_value.name, stationary_input_dtype) - model.set_initializer(s_value.name, scale_tensor) - - zeropt_tensor = np.array(zeropoint_value).astype(np.float32) - z_value = onnx.helper.make_tensor_value_info( - model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape - ) - model.graph.value_info.append(z_value) - model.set_tensor_datatype(z_value.name, stationary_input_dtype) - model.set_initializer(z_value.name, zeropt_tensor) - - bitwidth_tensor = np.array(bitwidth_value).astype(np.float32) - b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1]) - model.graph.value_info.append(b_value) - model.set_tensor_datatype(b_value.name, stationary_input_dtype) - model.set_initializer(b_value.name, bitwidth_tensor) - - quant_node = onnx.helper.make_node( - "Quant", - inputs=[quantnode_input, s_value.name, z_value.name, b_value.name], - outputs=[quant_tensor.name], - name="Quant_node_" + str(node_count) + str(tensor_count), - narrow=narrow, - signed=signed, - rounding_mode=rounding_mode, - ) - - return quant_node, quant_tensor - - def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, node_count): - tensor_count = 0 - for pos in input_positions: - node_details = (node_in_focus.name, pos[0]) - if ( - node_details not in quantized_nodes - ): # This is to ensure that we don't quantize the same node for the same input/output index. - if pos[0][0] == "input": - input_to_quantnode = node_in_focus.input[pos[0][1]] - consumer_node = node_in_focus - producer_node = model.find_producer(input_to_quantnode) - if producer_node is None or producer_node.op_type != "Quant": - quantization_to_perform = "yes" - else: - quantization_to_perform = "no" +def create_quantnode( + model, + quantnode_input, + quantnode_output_shape, + scale_value, + zeropoint_value, + bitwidth_value, + narrow, + signed, + rounding_mode, +): + quant_tensor = onnx.helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape + ) + model.graph.value_info.append(quant_tensor) + + scale_tensor = np.array(scale_value).astype(np.float32) + s_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape) + model.graph.value_info.append(s_value) + model.set_initializer(s_value.name, scale_tensor) + + zeropt_tensor = np.array(zeropoint_value).astype(np.float32) + z_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape) + model.graph.value_info.append(z_value) + model.set_initializer(z_value.name, zeropt_tensor) + + bitwidth_tensor = np.array(bitwidth_value).astype(np.float32) + b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1]) + model.graph.value_info.append(b_value) + model.set_initializer(b_value.name, bitwidth_tensor) + + quantnode = onnx.helper.make_node( + "Quant", + inputs=[quantnode_input, s_value.name, z_value.name, b_value.name], + outputs=[quant_tensor.name], + name="Quant_" + quantnode_input, + narrow=narrow, + signed=signed, + rounding_mode=rounding_mode, + ) + + return quantnode, quant_tensor + + +def adjust_graph(model, input_positions, node_name, quantized_nodes): + for pos in input_positions: + node_details = (node_name, pos[0]) + if node_details not in quantized_nodes: # not quantizing for same node_inp/out index. + node_in_focus = model.get_node_from_name(node_name) + + if pos[0][0] == "input": + quantnode_input = node_in_focus.input[pos[0][1]] + consumer_node = node_in_focus + producer_node = model.find_producer(quantnode_input) + if producer_node is None or producer_node.op_type != "Quant": + quantization_to_perform = True else: - input_to_quantnode = node_in_focus.output[pos[0][1]] - consumer_node = model.find_consumer(input_to_quantnode) - producer_node = model.find_producer(input_to_quantnode) - if consumer_node is None or consumer_node.op_type != "Quant": - quantization_to_perform = "yes" - else: - quantization_to_perform = "no" - if quantization_to_perform == "yes": - node_indx = self.get_node_id(model) # Getting index of each node in the graph. - quantnode_output_shape = model.get_tensor_shape(input_to_quantnode) # Step: 3 - - quant_node, quant_tensor = self.create_node( - model, - input_to_quantnode, - quantnode_output_shape, - node_count, - tensor_count, - scale_value=pos[1][0], - zeropoint_value=pos[1][1], - bitwidth_value=pos[1][2], - narrow=pos[1][3], - signed=pos[1][4], - rounding_mode=pos[1][5], - ) - - if consumer_node is not None: - node_pos = node_indx[consumer_node.name] - model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name - model.graph.node.append(quant_node) - else: - model.graph.value_info.remove(quant_tensor) - model.graph.node.append(quant_node) - model.graph.output.insert(0, quant_tensor) - model.graph.output.pop(1) - - model = model.transform(SortGraph()) - tensor_count += 1 - quantized_nodes.append(node_details) + quantization_to_perform = False + else: + quantnode_input = node_in_focus.output[pos[0][1]] + consumer_node = model.find_consumer(quantnode_input) + producer_node = model.find_producer(quantnode_input) + if consumer_node is None or consumer_node.op_type != "Quant": + quantization_to_perform = True + else: + quantization_to_perform = False + if quantization_to_perform is True: + quantnode_output_shape = model.get_tensor_shape(quantnode_input) # Step: 3 + quantnode, quant_tensor = create_quantnode( + model, + quantnode_input, + quantnode_output_shape, + scale_value=pos[1][0], + zeropoint_value=pos[1][1], + bitwidth_value=pos[1][2], + narrow=pos[1][3], + signed=pos[1][4], + rounding_mode=pos[1][5], + ) + + if consumer_node is not None: + node_pos = model.get_node_index(consumer_node) + model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name + model.graph.node.append(quantnode) else: - print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.") + model.graph.value_info.remove(quant_tensor) + model.graph.node.append(quantnode) + model.graph.output.insert(0, quant_tensor) + model.graph.output.pop(1) + + model = model.transform(SortGraph()) + quantized_nodes.append(node_details) else: - print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.") - continue + print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.") + else: + print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.") + continue - return model + return model class QuantizeGraph(Transformation): - """This transformation can be used to introduce a Quant node for particular nodes in the graph, - determined based on either op_type or node name. - For the particular nodes identified, users can specify the location of the Quant nodes by providing - the input and output indices where Quant nodes are to be inserted. - Assumes the input model is cleaned-up with all intermediate shapes specified and nodes given - unique names already. - - 2) Steps to transform are + """This transformation can be used to introduce a Quant node for a specific type of node in the graph. + Users would be able to specify the location of the quant node by providing the input and output indexs + as the parameters. + + 1) Expectations: + a) Onnx model in the modelwraper format. + b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model() + c) Batchsize to be set. + + 2) S.teps to transform are: Step1: Finding the input for the quant node. Step2: Finding the consumer of the quant node output. Step3: Finding the shape for the output tensor of quant node. - Note: The output tensor of the quant node must have the same shape as the - consumer of the input to the quant node. - - 3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type". + Note: The output tensor of the quant node must have the same shape as the consumer of the input + to the quant node. + + 3) Input: + A dict "quantnode_map" specifying the criterion, positions, and input parameters like + scale, bitwidth, zeropoint, and others for the particular quantnode. + + Criterion: + a) name: This will allow users to add quant nodes for specific node like "Conv_0" and "Gemm_0". + Note: using this users can have quant nodes with different parameters. Ex: quantizing + "Conv_0" and "Conv_1" with bitwidth of 4 and 6, respectively. + b) op_type: This will allow users to add quant nodes for all nodes of a particular op_type such + as, "Conv", "Gemm", and others. + Note: All quant nodes created using op_type criterion will have the same input + parameters (scale, zeropoint, bitwidth, and others.) + c) name and op_type: In this case, quant nodes will be added with precedence to "Name" + in comparison to "op_type". + + Positions: ("input", index) or ("output", index) + a) "input": specifies that the user want to quantize the input of the selected node. + b) "output": specifies that the user want to quantize the input of the selected node. + c) index: specifies which input/output to quantize (as a node can have multiple inputs and outputs) + + Parameters (to quant node) are provided as (scale, zeropoint, bitwidth, narrow, signed, rounding_mode) + + a) Inputs: scale, zeropoint, bitwidth. + b) Attributes: narrow, signed, rounding_mode. 4) Assert: - a) The input is a dictionary representing the node names as keys and a list of quant positions - as values. - b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. + a) The input is a dictionary representing the node names as keys and a list of quant positions + as values. + b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. 5) Return: - Returns a cleaned version of the model. - + Returns a model with new quant nodes created at the positions specified using the "quantnode_map". + + 6) Example: + quantnode_map = {"name": {"Conv_0": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")), + (("input", 1), (1, 0, 8, 0, 1, "ROUND")), + (("output", 0), (1, 0, 8, 0, 1, "ROUND"))], + "Conv_1": [(("input", 0), (1, 0, 8, 0, 1, "ROUND"))], + "Conv_2": [(("input", 1), (1, 0, 8, 0, 1, "ROUND")), + (("output", 0), (1, 0, 8, 0, 1, "ROUND"))]}, + + "op_type": {"Gemm": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")), + (("input", 1), (1, 0, 8, 0, 1, "ROUND")), + (("input", 2), (1, 0, 8, 0, 1, "ROUND")), + (("output", 0), (1, 0, 8, 0, 1, "ROUND"))]}} """ - def __init__(self, quant_node_inputs): + def __init__(self, quantnode_map): super().__init__() - self.quant_node_inputs = quant_node_inputs - self.graph_util = graph_util() + self.quantnode_map = quantnode_map def apply(self, model): model = model.transform(InferShapes()) - if type(self.quant_node_inputs) == dict: - selection_type = self.quant_node_inputs.keys() + if type(self.quantnode_map) == dict: + selection_type = self.quantnode_map.keys() if set(selection_type) <= {"name", "op_type"}: - node_count = 0 quantized_nodes = [] if "name" in selection_type: - by_name = self.quant_node_inputs[ - "name" - ] # by_name is a dictionary with the unique node names as keys and the list of positions as values. - node_list_by_name = by_name.keys() # name of all the nodes specified by the user for an quant node. + by_name = self.quantnode_map["name"] # dict with unique names and list of positions. + node_list_by_name = by_name.keys() # node names specified by the user for quant nodes. for node_name in node_list_by_name: - node_in_focus = self.graph_util.node_from_name(model, node_name) - input_positions = by_name[ - node_name - ] # input positions specified by the user to introduce quant node. - model = self.graph_util.adjust_graph( - model, input_positions, node_in_focus, quantized_nodes, node_count - ) - node_count += 1 + input_positions = by_name[node_name] # input positions to introduce quant nodes. + model = adjust_graph(model, input_positions, node_name, quantized_nodes) if "op_type" in selection_type: - by_op_type = self.quant_node_inputs[ - "op_type" - ] # by_name is a dictionary with the unique node names as keys and the list of positions as values. + by_op_type = self.quantnode_map["op_type"] # dict with the unique names and list of positions. op_list = by_op_type.keys() for op in op_list: - node_list = self.graph_util.identify_nodes( - model, op - ) # List of all nodes with the operation type "op". + node_list = model.get_nodes_by_op_type(op) # List of all nodes with the operation type "op". input_positions = by_op_type[op] - for node_in_focus in node_list: - model = self.graph_util.adjust_graph( - model, input_positions, node_in_focus, quantized_nodes, node_count - ) - node_count += 1 - model = qonnx_make_model(model.graph) - model = ModelWrapper(model) + for node in node_list: + node_name = node.name + model = adjust_graph(model, input_positions, node_name, quantized_nodes) model = cleanup_model(model) else: raise Exception("Unsupported selection type") diff --git a/tests/transformation/test_quantize_graph.py b/tests/transformation/test_quantize_graph.py index c0ceb456..e613bd17 100644 --- a/tests/transformation/test_quantize_graph.py +++ b/tests/transformation/test_quantize_graph.py @@ -33,14 +33,12 @@ import urllib.request from qonnx.core.modelwrapper import ModelWrapper -from qonnx.transformation.quantize_graph import QuantizeGraph, graph_util +from qonnx.transformation.quantize_graph import QuantizeGraph from qonnx.util.cleanup import cleanup from qonnx.util.inference_cost import inference_cost random.seed(42) -graph_util = graph_util() - download_url = "https://github.com/onnx/models/raw/main/validated/vision/" download_url += "classification/resnet/model/resnet18-v1-7.onnx?download=" @@ -93,11 +91,11 @@ def to_verify(model, test_details): if by == "name": sample_node_name = random.choice(list(test_details["name"].keys())) - sample_node = graph_util.node_from_name(model, sample_node_name) + sample_node = model.node_from_name(model, sample_node_name) sample_pos = random.choice(test_details["name"][sample_node_name]) if by == "op_type": node_type = random.choice(list(test_details["op_type"].keys())) - sample_node = random.choice(graph_util.identify_nodes(model, node_type)) + sample_node = random.choice(model.get_nodes_by_op_type(node_type)) sample_pos = random.choice(test_details["op_type"][node_type]) if sample_pos[0][0] == "input": @@ -119,7 +117,7 @@ def to_verify(model, test_details): @pytest.mark.parametrize("test_model", model_details.keys()) -def test_introduce_quantnode(test_model): +def test_quantize_graph(test_model): test_details = model_details[test_model] model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) original_model_inf_cost = inference_cost(model, discount_sparsity=False) From 04619a397670dd3c76001a30ddf6c82bab5356be Mon Sep 17 00:00:00 2001 From: Harish Date: Thu, 15 Feb 2024 15:04:31 +0000 Subject: [PATCH 5/6] revised version of quantize_graph --- src/qonnx/transformation/quantize_graph.py | 14 +++++++------- tests/transformation/test_quantize_graph.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/qonnx/transformation/quantize_graph.py b/src/qonnx/transformation/quantize_graph.py index 20feb049..230650bd 100644 --- a/src/qonnx/transformation/quantize_graph.py +++ b/src/qonnx/transformation/quantize_graph.py @@ -140,15 +140,15 @@ def adjust_graph(model, input_positions, node_name, quantized_nodes): class QuantizeGraph(Transformation): """This transformation can be used to introduce a Quant node for a specific type of node in the graph. - Users would be able to specify the location of the quant node by providing the input and output indexs + Users would be able to specify the location of the quant node by providing the input and output index as the parameters. 1) Expectations: a) Onnx model in the modelwraper format. - b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model() + b) Model must be cleaned using qonnx.util.cleanup.cleanup_model() c) Batchsize to be set. - 2) S.teps to transform are: + 2) Steps to transform are: Step1: Finding the input for the quant node. Step2: Finding the consumer of the quant node output. Step3: Finding the shape for the output tensor of quant node. @@ -157,7 +157,7 @@ class QuantizeGraph(Transformation): 3) Input: A dict "quantnode_map" specifying the criterion, positions, and input parameters like - scale, bitwidth, zeropoint, and others for the particular quantnode. + scale, bitwidth, zeropoint, and others for a specific quantnode. Criterion: a) name: This will allow users to add quant nodes for specific node like "Conv_0" and "Gemm_0". @@ -171,9 +171,9 @@ class QuantizeGraph(Transformation): in comparison to "op_type". Positions: ("input", index) or ("output", index) - a) "input": specifies that the user want to quantize the input of the selected node. - b) "output": specifies that the user want to quantize the input of the selected node. - c) index: specifies which input/output to quantize (as a node can have multiple inputs and outputs) + a) "input": indicates that the user want to quantize the input of the selected node. + b) "output": indicates that the user want to quantize the output of the selected node. + c) index: refers to the input/output index to quantize (a node can have multiple inputs and outputs) Parameters (to quant node) are provided as (scale, zeropoint, bitwidth, narrow, signed, rounding_mode) diff --git a/tests/transformation/test_quantize_graph.py b/tests/transformation/test_quantize_graph.py index e613bd17..867f9b34 100644 --- a/tests/transformation/test_quantize_graph.py +++ b/tests/transformation/test_quantize_graph.py @@ -91,7 +91,7 @@ def to_verify(model, test_details): if by == "name": sample_node_name = random.choice(list(test_details["name"].keys())) - sample_node = model.node_from_name(model, sample_node_name) + sample_node = model.get_node_from_name(sample_node_name) sample_pos = random.choice(test_details["name"][sample_node_name]) if by == "op_type": node_type = random.choice(list(test_details["op_type"].keys())) From 5e2d0b808a3333157be6b6b397b9983f88ab7ec9 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 23 Feb 2024 14:47:43 +0100 Subject: [PATCH 6/6] [Wrapper] explicitly return None for name/index finder functions --- src/qonnx/core/modelwrapper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index f21efdab..2abf9d9d 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -532,7 +532,7 @@ def get_non_finn_nodes(self): return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node)) def get_node_index(self, node): - """Returns current index of given node.""" + """Returns current index of given node, or None if not found.""" n_ind = 0 try: for n in self.graph.node: @@ -541,15 +541,17 @@ def get_node_index(self, node): n_ind += 1 except ValueError: return None + return None def get_node_from_name(self, node_name): - """Returns the node with the specified name.""" + """Returns the node with the specified name, or None if not found.""" try: for node in self.graph.node: if node.name == node_name: return node except ValueError: return None + return None def get_tensor_layout(self, tensor_name): """Returns the data layout annotation of tensor with given name.