Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformation pass to introduce quant nodes #94

Merged
merged 7 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def get_non_finn_nodes(self):
return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node))

def get_node_index(self, node):
"""Returns current index of given node."""
"""Returns current index of given node, or None if not found."""
n_ind = 0
try:
for n in self.graph.node:
Expand All @@ -541,6 +541,17 @@ def get_node_index(self, node):
n_ind += 1
except ValueError:
return None
return None

def get_node_from_name(self, node_name):
"""Returns the node with the specified name, or None if not found."""
try:
for node in self.graph.node:
if node.name == node_name:
return node
except ValueError:
return None
return None

def get_tensor_layout(self, tensor_name):
"""Returns the data layout annotation of tensor with given name.
Expand Down
238 changes: 238 additions & 0 deletions src/qonnx/transformation/quantize_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright (c) 2024 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of qonnx nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import numpy as np
import onnx
from onnx import TensorProto

from qonnx.transformation.base import Transformation
from qonnx.transformation.general import SortGraph
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.cleanup import cleanup_model


def create_quantnode(
model,
quantnode_input,
quantnode_output_shape,
scale_value,
zeropoint_value,
bitwidth_value,
narrow,
signed,
rounding_mode,
):
quant_tensor = onnx.helper.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape
)
model.graph.value_info.append(quant_tensor)

scale_tensor = np.array(scale_value).astype(np.float32)
s_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape)
model.graph.value_info.append(s_value)
model.set_initializer(s_value.name, scale_tensor)

zeropt_tensor = np.array(zeropoint_value).astype(np.float32)
z_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape)
model.graph.value_info.append(z_value)
model.set_initializer(z_value.name, zeropt_tensor)

bitwidth_tensor = np.array(bitwidth_value).astype(np.float32)
b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1])
model.graph.value_info.append(b_value)
model.set_initializer(b_value.name, bitwidth_tensor)

quantnode = onnx.helper.make_node(
"Quant",
inputs=[quantnode_input, s_value.name, z_value.name, b_value.name],
outputs=[quant_tensor.name],
name="Quant_" + quantnode_input,
narrow=narrow,
signed=signed,
rounding_mode=rounding_mode,
)

return quantnode, quant_tensor


def adjust_graph(model, input_positions, node_name, quantized_nodes):
for pos in input_positions:
node_details = (node_name, pos[0])
if node_details not in quantized_nodes: # not quantizing for same node_inp/out index.
node_in_focus = model.get_node_from_name(node_name)

if pos[0][0] == "input":
quantnode_input = node_in_focus.input[pos[0][1]]
consumer_node = node_in_focus
producer_node = model.find_producer(quantnode_input)
if producer_node is None or producer_node.op_type != "Quant":
quantization_to_perform = True
else:
quantization_to_perform = False
else:
quantnode_input = node_in_focus.output[pos[0][1]]
consumer_node = model.find_consumer(quantnode_input)
producer_node = model.find_producer(quantnode_input)
if consumer_node is None or consumer_node.op_type != "Quant":
quantization_to_perform = True
else:
quantization_to_perform = False
if quantization_to_perform is True:
quantnode_output_shape = model.get_tensor_shape(quantnode_input) # Step: 3
quantnode, quant_tensor = create_quantnode(
model,
quantnode_input,
quantnode_output_shape,
scale_value=pos[1][0],
zeropoint_value=pos[1][1],
bitwidth_value=pos[1][2],
narrow=pos[1][3],
signed=pos[1][4],
rounding_mode=pos[1][5],
)

if consumer_node is not None:
node_pos = model.get_node_index(consumer_node)
model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name
model.graph.node.append(quantnode)
else:
model.graph.value_info.remove(quant_tensor)
model.graph.node.append(quantnode)
model.graph.output.insert(0, quant_tensor)
model.graph.output.pop(1)

model = model.transform(SortGraph())
quantized_nodes.append(node_details)
else:
print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.")
else:
print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.")
continue

return model


class QuantizeGraph(Transformation):
"""This transformation can be used to introduce a Quant node for a specific type of node in the graph.
Users would be able to specify the location of the quant node by providing the input and output index
as the parameters.

1) Expectations:
a) Onnx model in the modelwraper format.
b) Model must be cleaned using qonnx.util.cleanup.cleanup_model()
c) Batchsize to be set.

2) Steps to transform are:
Step1: Finding the input for the quant node.
Step2: Finding the consumer of the quant node output.
Step3: Finding the shape for the output tensor of quant node.
Note: The output tensor of the quant node must have the same shape as the consumer of the input
to the quant node.

3) Input:
A dict "quantnode_map" specifying the criterion, positions, and input parameters like
scale, bitwidth, zeropoint, and others for a specific quantnode.

Criterion:
a) name: This will allow users to add quant nodes for specific node like "Conv_0" and "Gemm_0".
Note: using this users can have quant nodes with different parameters. Ex: quantizing
"Conv_0" and "Conv_1" with bitwidth of 4 and 6, respectively.
b) op_type: This will allow users to add quant nodes for all nodes of a particular op_type such
as, "Conv", "Gemm", and others.
Note: All quant nodes created using op_type criterion will have the same input
parameters (scale, zeropoint, bitwidth, and others.)
c) name and op_type: In this case, quant nodes will be added with precedence to "Name"
in comparison to "op_type".

Positions: ("input", index) or ("output", index)
a) "input": indicates that the user want to quantize the input of the selected node.
b) "output": indicates that the user want to quantize the output of the selected node.
c) index: refers to the input/output index to quantize (a node can have multiple inputs and outputs)

Parameters (to quant node) are provided as (scale, zeropoint, bitwidth, narrow, signed, rounding_mode)

a) Inputs: scale, zeropoint, bitwidth.
b) Attributes: narrow, signed, rounding_mode.

4) Assert:
a) The input is a dictionary representing the node names as keys and a list of quant positions
as values.
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.

5) Return:
Returns a model with new quant nodes created at the positions specified using the "quantnode_map".

6) Example:
quantnode_map = {"name": {"Conv_0": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")),
(("input", 1), (1, 0, 8, 0, 1, "ROUND")),
(("output", 0), (1, 0, 8, 0, 1, "ROUND"))],
"Conv_1": [(("input", 0), (1, 0, 8, 0, 1, "ROUND"))],
"Conv_2": [(("input", 1), (1, 0, 8, 0, 1, "ROUND")),
(("output", 0), (1, 0, 8, 0, 1, "ROUND"))]},

"op_type": {"Gemm": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")),
(("input", 1), (1, 0, 8, 0, 1, "ROUND")),
(("input", 2), (1, 0, 8, 0, 1, "ROUND")),
(("output", 0), (1, 0, 8, 0, 1, "ROUND"))]}}
"""

def __init__(self, quantnode_map):
super().__init__()
self.quantnode_map = quantnode_map

def apply(self, model):
model = model.transform(InferShapes())
if type(self.quantnode_map) == dict:
selection_type = self.quantnode_map.keys()
if set(selection_type) <= {"name", "op_type"}:
quantized_nodes = []
if "name" in selection_type:
by_name = self.quantnode_map["name"] # dict with unique names and list of positions.
node_list_by_name = by_name.keys() # node names specified by the user for quant nodes.
for node_name in node_list_by_name:
input_positions = by_name[node_name] # input positions to introduce quant nodes.
model = adjust_graph(model, input_positions, node_name, quantized_nodes)
if "op_type" in selection_type:
by_op_type = self.quantnode_map["op_type"] # dict with the unique names and list of positions.
op_list = by_op_type.keys()
for op in op_list:
node_list = model.get_nodes_by_op_type(op) # List of all nodes with the operation type "op".
input_positions = by_op_type[op]
for node in node_list:
node_name = node.name
model = adjust_graph(model, input_positions, node_name, quantized_nodes)
model = cleanup_model(model)
else:
raise Exception("Unsupported selection type")
else:
raise TypeError("Input must be a dictionary.")

graph_modified = False

return (model, graph_modified)
Loading
Loading