-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add brevitas channel-wise support #807
base: main
Are you sure you want to change the base?
Conversation
7848d8e
to
30a1791
Compare
if isinstance(dummy_input, torch.Tensor): | ||
dummy_input = dummy_input.to("cpu") | ||
else: | ||
dummy_input = tuple(elt.to("cpu") for elt in dummy_input) | ||
|
||
# Export to ONNX | ||
torch.onnx.export( | ||
torch_module, | ||
torch_module.to("cpu"), | ||
dummy_input, | ||
str(output_onnx_file_path), | ||
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the module and the inputs to be on CPU for the exporter to work properly
class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): | ||
"""CommonIntWeightPerChannelQuant.""" | ||
|
||
scaling_per_output_channel = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The per-channel quantizer from Brevitas
Returns: | ||
Neural network prediction | ||
""" | ||
x = self.pool1(self.relu1(self.conv1(x))) | ||
x = self.pool2(self.relu2(self.conv2(x))) | ||
x = torch.flatten(x, 1) | ||
x = self.relu3(self.fc1(x)) | ||
x = self.relu4(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
|
||
# pylint: disable-next=too-many-instance-attributes | ||
class QuantLeNet(FloatLeNet): | ||
"""Quantized LeNet with per-channel quantization.""" | ||
|
||
def __init__( | ||
self, | ||
weight_bit_width=4, | ||
act_bit_width=4, | ||
acc_bit_width=32, | ||
weight_quant=CommonIntAccumulatorAwareWeightQuant, | ||
): | ||
super().__init__() | ||
|
||
self.conv1 = qnn.QuantConv2d( | ||
bias=False, | ||
in_channels=1, | ||
out_channels=6, | ||
kernel_size=5, | ||
stride=1, | ||
padding=0, | ||
input_bit_width=act_bit_width, | ||
input_quant=CommonUintActQuant, | ||
weight_accumulator_bit_width=acc_bit_width, | ||
weight_bit_width=weight_bit_width, | ||
weight_restrict_scaling_type=RestrictValueType.LOG_FP, | ||
weight_quant=weight_quant, | ||
) | ||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.relu1 = qnn.QuantReLU( | ||
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width | ||
) | ||
|
||
self.conv2 = qnn.QuantConv2d( | ||
bias=False, | ||
in_channels=6, | ||
out_channels=16, | ||
kernel_size=5, | ||
stride=1, | ||
padding=0, | ||
input_bit_width=act_bit_width, | ||
input_quant=CommonUintActQuant, | ||
weight_accumulator_bit_width=acc_bit_width, | ||
weight_bit_width=weight_bit_width, | ||
weight_restrict_scaling_type=RestrictValueType.LOG_FP, | ||
weight_quant=weight_quant, | ||
) | ||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.relu2 = qnn.QuantReLU( | ||
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width | ||
) | ||
|
||
self.fc1 = qnn.QuantLinear( | ||
400, | ||
120, | ||
bias=True, | ||
input_bit_width=act_bit_width, | ||
input_quant=CommonUintActQuant, | ||
weight_accumulator_bit_width=acc_bit_width, | ||
weight_bit_width=weight_bit_width, | ||
weight_restrict_scaling_type=RestrictValueType.LOG_FP, | ||
weight_quant=weight_quant, | ||
) | ||
self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width) | ||
self.fc2 = qnn.QuantLinear( | ||
120, | ||
84, | ||
bias=True, | ||
input_bit_width=act_bit_width, | ||
input_quant=CommonUintActQuant, | ||
weight_accumulator_bit_width=acc_bit_width, | ||
weight_bit_width=weight_bit_width, | ||
weight_restrict_scaling_type=RestrictValueType.LOG_FP, | ||
weight_quant=weight_quant, | ||
) | ||
self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width) | ||
self.fc3 = qnn.QuantLinear( | ||
84, | ||
10, | ||
bias=True, | ||
input_bit_width=act_bit_width, | ||
input_quant=CommonUintActQuant, | ||
weight_accumulator_bit_width=acc_bit_width, | ||
weight_bit_width=weight_bit_width, | ||
weight_restrict_scaling_type=RestrictValueType.LOG_FP, | ||
weight_quant=weight_quant, | ||
) | ||
|
||
self.apply(weight_init) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A LeNet provided by a user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for tackling this.
My main comment is: How do we make sure per channel quantization is indeed activated and used in concrete-ml?
Currently we only test that a per channel brevitas quantization can be compiled in CML but are we sure we use all the scales properly?
I am thinking, maybe have a simple model with a single conv and make sure that the number of scales == number of channel?
if q_input2.quantizer.scale.shape == tuple(): | ||
m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale | ||
else: | ||
# TODO: add assert on shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO to remove or convert in fixme
m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale | ||
else: | ||
# TODO: add assert on shapes | ||
weight_quant_scale = numpy.transpose(q_input2.quantizer.scale, axes=(1, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these axes 1, 0 always going to be true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, might be worth a comment if that's the case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes always true, the assert will catch any issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which assert ? 🤔
if q_weights.quantizer.scale.shape == tuple(): | ||
m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale | ||
else: | ||
# TODO: add assert on shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO to remove or convert in fixme + issue
# TODO: add assert on shapes | ||
weight_quant_scale = numpy.transpose( | ||
q_weights.quantizer.scale, | ||
axes=(1, 0, 2, 3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same answer, the assert above (to be added) will catch any issue but it should always have this shape for conv2d.
I'll add some errors in the case of conv1d or conv3d
tests/torch/test_brevitas_qat.py
Outdated
out_qm = quantized_module(images.detach().numpy()) | ||
mse = ((out - out_qm) ** 2).mean() | ||
# Arbitrary threshold to check that the predictions are relatively similar | ||
assert mse < 1e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably better to use our tools like check_float_array_equal
, which can provide rtol or atol, or check_r2_score
or similar
if not, then probably better to make this mse check a fixture so that we can re-use it elsewhere and not forget about this arbitrary value here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah true check_float_array_equal
would probably be better indeed. I'll change that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing my request as I will be off for some time. Good luck with the PR!
a9ac260
to
fad224d
Compare
fad224d
to
d645e82
Compare
Coverage passed ✅Coverage details
|
No description provided.