Skip to content

Commit

Permalink
feat: add brevitas channel-wise support
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Jul 24, 2024
1 parent e4c21fa commit 7848d8e
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 54 deletions.
24 changes: 14 additions & 10 deletions deps_licenses/licenses_mac_silicon_user.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
Name, Version, License
Jinja2, 3.1.4, BSD License
MarkupSafe, 2.1.5, BSD License
PyYAML, 6.0.1, MIT License
brevitas, 0.8.0, UNKNOWN
certifi, 2024.6.2, Mozilla Public License 2.0 (MPL 2.0)
brevitas, 0.10.2, UNKNOWN
certifi, 2024.7.4, Mozilla Public License 2.0 (MPL 2.0)
charset-normalizer, 3.3.2, MIT License
coloredlogs, 15.0.1, MIT License
concrete-python, 2.7.0, BSD-3-Clause
dependencies, 2.0.1, BSD License
dill, 0.3.8, BSD License
filelock, 3.15.3, The Unlicense (Unlicense)
filelock, 3.15.4, The Unlicense (Unlicense)
flatbuffers, 24.3.25, Apache Software License
fsspec, 2024.6.0, BSD License
fsspec, 2024.6.1, BSD License
huggingface-hub, 0.23.4, Apache Software License
humanfriendly, 10.0, MIT License
hummingbird-ml, 0.4.8, MIT License
hummingbird-ml, 0.4.11, MIT License
idna, 3.7, BSD License
importlib_resources, 6.4.0, Apache Software License
joblib, 1.4.2, BSD License
Expand All @@ -22,7 +24,7 @@ networkx, 3.1, BSD License
numpy, 1.23.5, BSD License
onnx, 1.16.1, Apache License v2.0
onnxconverter-common, 1.13.0, MIT License
onnxmltools, 1.11.0, Apache Software License
onnxmltools, 1.12.0, Apache Software License
onnxoptimizer, 0.3.13, Apache License v2.0
onnxruntime, 1.18.0, MIT License
packaging, 24.1, Apache Software License; BSD License
Expand All @@ -35,16 +37,18 @@ requests, 2.32.3, Apache Software License
scikit-learn, 1.1.3, BSD License
scipy, 1.10.1, BSD License
six, 1.16.0, MIT License
skl2onnx, 1.12, Apache Software License
skl2onnx, 1.17.0, Apache Software License
skops, 0.5.0, MIT
skorch, 0.11.0, new BSD 3-Clause
sympy, 1.12.1, BSD License
sympy, 1.13.0, BSD License
tabulate, 0.8.10, MIT License
threadpoolctl, 3.5.0, BSD License
torch, 1.13.1, BSD License
torch, 2.3.1, BSD License
tqdm, 4.66.4, MIT License; Mozilla Public License 2.0 (MPL 2.0)
typing_extensions, 4.5.0, Python Software Foundation License
typing_extensions, 4.12.2, Python Software Foundation License
tzdata, 2024.1, Apache Software License
unfoldNd, 0.2.2, MIT License
urllib3, 2.2.2, MIT License
xgboost, 1.6.2, Apache Software License
z3-solver, 4.13.0.0, MIT License
zipp, 3.19.2, MIT License
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_silicon_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
adb925c3b7be3e651975febcf49b6543
6d367701c3ef5eff8763f4e994e03681
13 changes: 11 additions & 2 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[
matmul_node.input[0],
matmul_node.input[1],
bias_other_input_node_name,
], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
Expand Down Expand Up @@ -142,9 +146,14 @@ def get_equivalent_numpy_forward_from_torch(

arguments = list(inspect.signature(torch_module.forward).parameters)

if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to("cpu")
else:
dummy_input = tuple(elt.to("cpu") for elt in dummy_input)

# Export to ONNX
torch.onnx.export(
torch_module,
torch_module.to("cpu"),
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
Expand Down
200 changes: 191 additions & 9 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import brevitas.nn as qnn
import numpy
import torch
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias
from brevitas.core.restrict_val import FloatRestrictValue, RestrictValueType
from brevitas.quant import (
Int8AccumulatorAwareWeightQuant,
Int8AccumulatorAwareZeroCenterWeightQuant,
Int8ActPerTensorFloat,
Int8WeightPerTensorFloat,
IntBias,
Uint8ActPerTensorFloat,
)
from torch import nn
from torch.nn.utils import prune

Expand Down Expand Up @@ -38,7 +46,7 @@ def forward(self, x, y):
return x + y + self.value, (x - y) ** 2


class SimpleNet(torch.nn.Module):
class SimpleNet(nn.Module):
"""Fake torch model used to generate some onnx."""

def __init__(self) -> None:
Expand Down Expand Up @@ -292,7 +300,7 @@ def forward(self, x):
return x


class NetWithLoops(torch.nn.Module):
class NetWithLoops(nn.Module):
"""Torch model, where we reuse some elements in a loop.
Torch model, where we reuse some elements in a loop in the forward and don't expect the
Expand Down Expand Up @@ -538,7 +546,7 @@ def step(x, bias):
return x


class NetWithConcatUnsqueeze(torch.nn.Module):
class NetWithConcatUnsqueeze(nn.Module):
"""Torch model to test the concat and unsqueeze operators."""

def __init__(self, activation_function, input_output, n_fc_layers):
Expand Down Expand Up @@ -1004,6 +1012,7 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits):
layer_obj = self.mixing_layer

layer_obj.weight.data = torch.from_numpy(np_weights).float()
assert layer_obj.bias is not None
layer_obj.bias.data = torch.rand(size=(1,))

def forward(self, x):
Expand Down Expand Up @@ -1216,12 +1225,12 @@ def forward(self, x):
# for example a 4d tensor NCHW, padded with [1, 2, 2, 3] is padded
# along the last 2 dimensions, with 1 cell to the left and 2 to the right (dimension 4: W)
# and 2 cells at the top and 3 at the bottom (dimension 3: H)
x = torch.nn.functional.pad(x, (3, 2))
x = torch.nn.functional.pad(x, (1, 2, 3, 4))
x = nn.functional.pad(x, (3, 2))
x = nn.functional.pad(x, (1, 2, 3, 4))

# Concrete ML only supports padding on the last two dimensions as this is the
# most common setting
x = torch.nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0))
x = nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0))
return x


Expand Down Expand Up @@ -1393,7 +1402,7 @@ def forward(self, x):
return x


class PartialQATModel(torch.nn.Module):
class PartialQATModel(nn.Module):
"""A model with a QAT Module."""

def __init__(self, input_shape: int, output_shape: int, n_bits: int):
Expand Down Expand Up @@ -1442,7 +1451,7 @@ def forward(self, input1):
return output


class ManualLogisticRegressionTraining(torch.nn.Module):
class ManualLogisticRegressionTraining(nn.Module):
"""PyTorch module for performing SGD training."""

def __init__(self, learning_rate=0.1):
Expand Down Expand Up @@ -1600,3 +1609,176 @@ def forward(self, x): # pylint: disable-next=no-self-use
Tuple[torch.Tensor. torch.Tensor]: Outputs of the network.
"""
return x, x.unsqueeze(0)


# pylint: disable-next=too-many-ancestors
class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
"""CommonIntWeightPerChannelQuant."""

scaling_per_output_channel = True


# pylint: disable-next=too-many-ancestors
class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
"""CommonIntAccumulatorAwareWeightQuant."""

restrict_scaling_impl = FloatRestrictValue # backwards compatibility
bit_width = None


# pylint: disable-next=too-many-ancestors
class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
"""CommonIntAccumulatorAwareZeroCenterWeightQuant."""

bit_width = None


# pylint: disable-next=too-many-ancestors
class CommonUintActQuant(Uint8ActPerTensorFloat):
"""CommonUintActQuant."""

bit_width = None
restrict_scaling_type = RestrictValueType.LOG_FP


def weight_init(layer: nn.Module):
"""Initialize layer weights.
Arguments:
layer (nn.Module): a conv2d layer
"""

if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain("relu"))
if layer.bias is not None:
layer.bias.data.zero_()


# pylint: disable-next=too-many-instance-attributes
class FloatLeNet(nn.Module):
"""Floating point LeNet."""

def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = nn.ReLU(inplace=True)

self.fc1 = nn.Linear(400, 120, bias=True)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84, bias=True)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10, bias=True)

self.apply(weight_init)

def forward(self, x: torch.Tensor):
"""Forward function.
Arguments:
x (torch.Tensor): input image
Returns:
Neural network prediction
"""
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = torch.flatten(x, 1)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x


# pylint: disable-next=too-many-instance-attributes
class QuantLeNet(FloatLeNet):
"""Quantized LeNet with per-channel quantization."""

def __init__(
self,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=32,
weight_quant=CommonIntAccumulatorAwareWeightQuant,
):
super().__init__()

self.conv1 = qnn.QuantConv2d(
bias=False,
in_channels=1,
out_channels=6,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu1 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.conv2 = qnn.QuantConv2d(
bias=False,
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu2 = qnn.QuantReLU(
inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width
)

self.fc1 = qnn.QuantLinear(
400,
120,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc2 = qnn.QuantLinear(
120,
84,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)
self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width)
self.fc3 = qnn.QuantLinear(
84,
10,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_restrict_scaling_type=RestrictValueType.LOG_FP,
weight_quant=weight_quant,
)

self.apply(weight_init)
9 changes: 7 additions & 2 deletions src/concrete/ml/quantization/base_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union, cast

import numpy
import numpy.typing as npt

from concrete import fhe

Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
input_quant_opts: Optional[QuantizationOptions] = None,
**attrs,
) -> None:

self.n_bits = n_bits_output

if input_quant_opts is not None:
Expand Down Expand Up @@ -913,7 +915,7 @@ def can_fuse(self) -> bool:
def make_output_quant_parameters(
self,
q_values: Union[numpy.ndarray, Any],
scale: numpy.float64,
scale: npt.NDArray[numpy.float64],
zero_point: Union[int, float, numpy.ndarray],
) -> QuantizedArray:
"""Build a quantized array from quantized integer results of the op and quantization params.
Expand Down Expand Up @@ -1013,6 +1015,9 @@ def cnp_round(
# Rounding to low bit-width with approximate can cause issues with overflow protection
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345
x = fhe.round_bit_pattern(
x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False
x,
lsbs_to_remove=lsbs_value,
exactness=exactness,
overflow_protection=False,
)
return x
Loading

0 comments on commit 7848d8e

Please sign in to comment.