Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add brevitas channel-wise support #807

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

fd0r
Copy link
Collaborator

@fd0r fd0r commented Jul 24, 2024

No description provided.

@cla-bot cla-bot bot added the cla-signed label Jul 24, 2024
@fd0r fd0r force-pushed the channelwise_quantization_support branch 3 times, most recently from 7848d8e to 30a1791 Compare July 25, 2024 08:33
Comment on lines +149 to 166
if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the module and the inputs to be on CPU for the exporter to work properly

Comment on lines +1615 to +1691
class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
"""CommonIntWeightPerChannelQuant."""

scaling_per_output_channel = True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The per-channel quantizer from Brevitas

Comment on lines +1658 to +1857
Returns:
Neural network prediction
"""
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = torch.flatten(x, 1)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x


# pylint: disable-next=too-many-instance-attributes
class QuantLeNet(FloatLeNet):
"""Quantized LeNet with per-channel quantization."""

def __init__(
self,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=32,
weight_quant=CommonIntAccumulatorAwareWeightQuant,
):
super().__init__()

self.conv1 = qnn.QuantConv2d(
bias=False,
in_channels=1,
out_channels=6,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.conv2 = qnn.QuantConv2d(
bias=False,
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.fc1 = qnn.QuantLinear(
400,
120,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc2 = qnn.QuantLinear(
120,
84,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc3 = qnn.QuantLinear(
84,
10,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)

self.apply(weight_init)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A LeNet provided by a user

@fd0r fd0r marked this pull request as ready for review July 25, 2024 12:07
@fd0r fd0r requested a review from a team as a code owner July 25, 2024 12:07
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for tackling this.

My main comment is: How do we make sure per channel quantization is indeed activated and used in concrete-ml?

Currently we only test that a per channel brevitas quantization can be compiled in CML but are we sure we use all the scales properly?

I am thinking, maybe have a simple model with a single conv and make sure that the number of scales == number of channel?

if q_input2.quantizer.scale.shape == tuple():
m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale
else:
# TODO: add assert on shapes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO to remove or convert in fixme

m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale
else:
# TODO: add assert on shapes
weight_quant_scale = numpy.transpose(q_input2.quantizer.scale, axes=(1, 0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these axes 1, 0 always going to be true?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, might be worth a comment if that's the case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes always true, the assert will catch any issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which assert ? 🤔

if q_weights.quantizer.scale.shape == tuple():
m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale
else:
# TODO: add assert on shapes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO to remove or convert in fixme + issue

# TODO: add assert on shapes
weight_quant_scale = numpy.transpose(
q_weights.quantizer.scale,
axes=(1, 0, 2, 3),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same answer, the assert above (to be added) will catch any issue but it should always have this shape for conv2d.

I'll add some errors in the case of conv1d or conv3d

out_qm = quantized_module(images.detach().numpy())
mse = ((out - out_qm) ** 2).mean()
# Arbitrary threshold to check that the predictions are relatively similar
assert mse < 1e-4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better to use our tools like check_float_array_equal, which can provide rtol or atol, or check_r2_score or similar

if not, then probably better to make this mse check a fixture so that we can re-use it elsewhere and not forget about this arbitrary value here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah true check_float_array_equal would probably be better indeed. I'll change that.

jfrery
jfrery previously approved these changes Jul 26, 2024
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing my request as I will be off for some time. Good luck with the PR!

@fd0r fd0r force-pushed the channelwise_quantization_support branch 5 times, most recently from a9ac260 to fad224d Compare July 29, 2024 21:51
@fd0r fd0r force-pushed the channelwise_quantization_support branch from fad224d to d645e82 Compare July 31, 2024 11:50
Copy link

Coverage passed ✅

Coverage details

---------- coverage: platform linux, python 3.8.18-final-0 -----------
Name    Stmts   Miss  Cover   Missing
-------------------------------------
TOTAL    8179      0   100%

60 files skipped due to complete coverage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants