From 30a17914b70150e541a9e0e514d7bd6c0d9508fe Mon Sep 17 00:00:00 2001 From: Luis Montero Date: Wed, 24 Jul 2024 11:47:32 +0200 Subject: [PATCH] feat: add brevitas channel-wise support --- deps_licenses/licenses_mac_silicon_user.txt | 24 ++- .../licenses_mac_silicon_user.txt.md5 | 2 +- src/concrete/ml/onnx/convert.py | 13 +- src/concrete/ml/pytest/torch_models.py | 200 +++++++++++++++++- .../ml/quantization/base_quantized_op.py | 9 +- src/concrete/ml/quantization/post_training.py | 12 +- src/concrete/ml/quantization/quantized_ops.py | 53 ++++- src/concrete/ml/quantization/quantizers.py | 40 ++-- tests/torch/test_brevitas_qat.py | 24 ++- 9 files changed, 323 insertions(+), 54 deletions(-) diff --git a/deps_licenses/licenses_mac_silicon_user.txt b/deps_licenses/licenses_mac_silicon_user.txt index 66c6556898..a293f1613f 100644 --- a/deps_licenses/licenses_mac_silicon_user.txt +++ b/deps_licenses/licenses_mac_silicon_user.txt @@ -1,18 +1,20 @@ Name, Version, License +Jinja2, 3.1.4, BSD License +MarkupSafe, 2.1.5, BSD License PyYAML, 6.0.1, MIT License -brevitas, 0.8.0, UNKNOWN -certifi, 2024.6.2, Mozilla Public License 2.0 (MPL 2.0) +brevitas, 0.10.2, UNKNOWN +certifi, 2024.7.4, Mozilla Public License 2.0 (MPL 2.0) charset-normalizer, 3.3.2, MIT License coloredlogs, 15.0.1, MIT License concrete-python, 2.7.0, BSD-3-Clause dependencies, 2.0.1, BSD License dill, 0.3.8, BSD License -filelock, 3.15.3, The Unlicense (Unlicense) +filelock, 3.15.4, The Unlicense (Unlicense) flatbuffers, 24.3.25, Apache Software License -fsspec, 2024.6.0, BSD License +fsspec, 2024.6.1, BSD License huggingface-hub, 0.23.4, Apache Software License humanfriendly, 10.0, MIT License -hummingbird-ml, 0.4.8, MIT License +hummingbird-ml, 0.4.11, MIT License idna, 3.7, BSD License importlib_resources, 6.4.0, Apache Software License joblib, 1.4.2, BSD License @@ -22,7 +24,7 @@ networkx, 3.1, BSD License numpy, 1.23.5, BSD License onnx, 1.16.1, Apache License v2.0 onnxconverter-common, 1.13.0, MIT License -onnxmltools, 1.11.0, Apache Software License +onnxmltools, 1.12.0, Apache Software License onnxoptimizer, 0.3.13, Apache License v2.0 onnxruntime, 1.18.0, MIT License packaging, 24.1, Apache Software License; BSD License @@ -35,16 +37,18 @@ requests, 2.32.3, Apache Software License scikit-learn, 1.1.3, BSD License scipy, 1.10.1, BSD License six, 1.16.0, MIT License -skl2onnx, 1.12, Apache Software License +skl2onnx, 1.17.0, Apache Software License skops, 0.5.0, MIT skorch, 0.11.0, new BSD 3-Clause -sympy, 1.12.1, BSD License +sympy, 1.13.0, BSD License tabulate, 0.8.10, MIT License threadpoolctl, 3.5.0, BSD License -torch, 1.13.1, BSD License +torch, 2.3.1, BSD License tqdm, 4.66.4, MIT License; Mozilla Public License 2.0 (MPL 2.0) -typing_extensions, 4.5.0, Python Software Foundation License +typing_extensions, 4.12.2, Python Software Foundation License tzdata, 2024.1, Apache Software License +unfoldNd, 0.2.2, MIT License urllib3, 2.2.2, MIT License xgboost, 1.6.2, Apache Software License z3-solver, 4.13.0.0, MIT License +zipp, 3.19.2, MIT License diff --git a/deps_licenses/licenses_mac_silicon_user.txt.md5 b/deps_licenses/licenses_mac_silicon_user.txt.md5 index 4dfc9a8918..8a918180a0 100644 --- a/deps_licenses/licenses_mac_silicon_user.txt.md5 +++ b/deps_licenses/licenses_mac_silicon_user.txt.md5 @@ -1 +1 @@ -adb925c3b7be3e651975febcf49b6543 +6d367701c3ef5eff8763f4e994e03681 diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index d339703ce7..5a6b37967b 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -78,7 +78,11 @@ def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto): # Create a GEMM node which combines the MatMul and Add operations gemm_node = helper.make_node( "Gemm", # op_type - [matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs + [ + matmul_node.input[0], + matmul_node.input[1], + bias_other_input_node_name, + ], # inputs [add_node.output[0]], # outputs name="Gemm_Node", alpha=1.0, @@ -142,9 +146,14 @@ def get_equivalent_numpy_forward_from_torch( arguments = list(inspect.signature(torch_module.forward).parameters) + if isinstance(dummy_input, torch.Tensor): + dummy_input = dummy_input.to("cpu") + else: + dummy_input = tuple(elt.to("cpu") for elt in dummy_input) + # Export to ONNX torch.onnx.export( - torch_module, + torch_module.to("cpu"), dummy_input, str(output_onnx_file_path), opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index 90e056990c..cc5f971947 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -6,7 +6,15 @@ import brevitas.nn as qnn import numpy import torch -from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat, IntBias +from brevitas.core.restrict_val import FloatRestrictValue, RestrictValueType +from brevitas.quant import ( + Int8AccumulatorAwareWeightQuant, + Int8AccumulatorAwareZeroCenterWeightQuant, + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + IntBias, + Uint8ActPerTensorFloat, +) from torch import nn from torch.nn.utils import prune @@ -38,7 +46,7 @@ def forward(self, x, y): return x + y + self.value, (x - y) ** 2 -class SimpleNet(torch.nn.Module): +class SimpleNet(nn.Module): """Fake torch model used to generate some onnx.""" def __init__(self) -> None: @@ -292,7 +300,7 @@ def forward(self, x): return x -class NetWithLoops(torch.nn.Module): +class NetWithLoops(nn.Module): """Torch model, where we reuse some elements in a loop. Torch model, where we reuse some elements in a loop in the forward and don't expect the @@ -538,7 +546,7 @@ def step(x, bias): return x -class NetWithConcatUnsqueeze(torch.nn.Module): +class NetWithConcatUnsqueeze(nn.Module): """Torch model to test the concat and unsqueeze operators.""" def __init__(self, activation_function, input_output, n_fc_layers): @@ -1004,6 +1012,7 @@ def __init__(self, use_conv, use_qat, inp_size, n_bits): layer_obj = self.mixing_layer layer_obj.weight.data = torch.from_numpy(np_weights).float() + assert layer_obj.bias is not None layer_obj.bias.data = torch.rand(size=(1,)) def forward(self, x): @@ -1216,12 +1225,12 @@ def forward(self, x): # for example a 4d tensor NCHW, padded with [1, 2, 2, 3] is padded # along the last 2 dimensions, with 1 cell to the left and 2 to the right (dimension 4: W) # and 2 cells at the top and 3 at the bottom (dimension 3: H) - x = torch.nn.functional.pad(x, (3, 2)) - x = torch.nn.functional.pad(x, (1, 2, 3, 4)) + x = nn.functional.pad(x, (3, 2)) + x = nn.functional.pad(x, (1, 2, 3, 4)) # Concrete ML only supports padding on the last two dimensions as this is the # most common setting - x = torch.nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0)) + x = nn.functional.pad(x, (1, 1, 2, 2, 0, 0, 0, 0)) return x @@ -1393,7 +1402,7 @@ def forward(self, x): return x -class PartialQATModel(torch.nn.Module): +class PartialQATModel(nn.Module): """A model with a QAT Module.""" def __init__(self, input_shape: int, output_shape: int, n_bits: int): @@ -1442,7 +1451,7 @@ def forward(self, input1): return output -class ManualLogisticRegressionTraining(torch.nn.Module): +class ManualLogisticRegressionTraining(nn.Module): """PyTorch module for performing SGD training.""" def __init__(self, learning_rate=0.1): @@ -1600,3 +1609,176 @@ def forward(self, x): # pylint: disable-next=no-self-use Tuple[torch.Tensor. torch.Tensor]: Outputs of the network. """ return x, x.unsqueeze(0) + + +# pylint: disable-next=too-many-ancestors +class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): + """CommonIntWeightPerChannelQuant.""" + + scaling_per_output_channel = True + + +# pylint: disable-next=too-many-ancestors +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + """CommonIntAccumulatorAwareWeightQuant.""" + + restrict_scaling_impl = FloatRestrictValue # backwards compatibility + bit_width = None + + +# pylint: disable-next=too-many-ancestors +class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant): + """CommonIntAccumulatorAwareZeroCenterWeightQuant.""" + + bit_width = None + + +# pylint: disable-next=too-many-ancestors +class CommonUintActQuant(Uint8ActPerTensorFloat): + """CommonUintActQuant.""" + + bit_width = None + restrict_scaling_type = RestrictValueType.LOG_FP + + +def weight_init(layer: nn.Module): + """Initialize layer weights. + + Arguments: + layer (nn.Module): a conv2d layer + """ + + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain("relu")) + if layer.bias is not None: + layer.bias.data.zero_() + + +# pylint: disable-next=too-many-instance-attributes +class FloatLeNet(nn.Module): + """Floating point LeNet.""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu2 = nn.ReLU(inplace=True) + + self.fc1 = nn.Linear(400, 120, bias=True) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84, bias=True) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10, bias=True) + + self.apply(weight_init) + + def forward(self, x: torch.Tensor): + """Forward function. + + Arguments: + x (torch.Tensor): input image + + Returns: + Neural network prediction + """ + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = torch.flatten(x, 1) + x = self.relu3(self.fc1(x)) + x = self.relu4(self.fc2(x)) + x = self.fc3(x) + return x + + +# pylint: disable-next=too-many-instance-attributes +class QuantLeNet(FloatLeNet): + """Quantized LeNet with per-channel quantization.""" + + def __init__( + self, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=32, + weight_quant=CommonIntAccumulatorAwareWeightQuant, + ): + super().__init__() + + self.conv1 = qnn.QuantConv2d( + bias=False, + in_channels=1, + out_channels=6, + kernel_size=5, + stride=1, + padding=0, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu1 = qnn.QuantReLU( + inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + + self.conv2 = qnn.QuantConv2d( + bias=False, + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu2 = qnn.QuantReLU( + inplace=True, act_quant=CommonUintActQuant, bit_width=act_bit_width + ) + + self.fc1 = qnn.QuantLinear( + 400, + 120, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.relu3 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width) + self.fc2 = qnn.QuantLinear( + 120, + 84, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + self.relu4 = qnn.QuantReLU(act_quant=CommonUintActQuant, bit_width=act_bit_width) + self.fc3 = qnn.QuantLinear( + 84, + 10, + bias=True, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_accumulator_bit_width=acc_bit_width, + weight_bit_width=weight_bit_width, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_quant=weight_quant, + ) + + self.apply(weight_init) diff --git a/src/concrete/ml/quantization/base_quantized_op.py b/src/concrete/ml/quantization/base_quantized_op.py index 903fc4767f..ed0f668c94 100644 --- a/src/concrete/ml/quantization/base_quantized_op.py +++ b/src/concrete/ml/quantization/base_quantized_op.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union, cast import numpy +import numpy.typing as npt from concrete import fhe @@ -122,6 +123,7 @@ def __init__( input_quant_opts: Optional[QuantizationOptions] = None, **attrs, ) -> None: + self.n_bits = n_bits_output if input_quant_opts is not None: @@ -913,7 +915,7 @@ def can_fuse(self) -> bool: def make_output_quant_parameters( self, q_values: Union[numpy.ndarray, Any], - scale: numpy.float64, + scale: npt.NDArray[numpy.float64], zero_point: Union[int, float, numpy.ndarray], ) -> QuantizedArray: """Build a quantized array from quantized integer results of the op and quantization params. @@ -1013,6 +1015,9 @@ def cnp_round( # Rounding to low bit-width with approximate can cause issues with overflow protection # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345 x = fhe.round_bit_pattern( - x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False + x, + lsbs_to_remove=lsbs_value, + exactness=exactness, + overflow_protection=False, ) return x diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 456d839bde..b2734c73e4 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -393,14 +393,14 @@ def _calibrate_layers_activation( assert isinstance(quant_result, QuantizedArray) return ( quant_result.dequant(), - quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None, + (quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None), ) # For QAT, the calibration is performed on raw data, performing # calibration on quantized that would confound inferred QAT and PTQ. return ( raw_result, - quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None, + (quant_result.quantizer if isinstance(quant_result, QuantizedArray) else None), ) @abstractmethod @@ -625,7 +625,9 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray): for input_name in variable_input_names ) output_calibration_data, layer_quantizer = self._process_layer( - quantized_op_instance, *curr_calibration_data, quantizers=layer_quant + quantized_op_instance, + *curr_calibration_data, + quantizers=layer_quant, ) node_results[output_name] = output_calibration_data node_override_quantizer[output_name] = layer_quantizer @@ -719,7 +721,9 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule: return quantized_module def _process_input_quantizers( - self, quantized_module: QuantizedModule, calibration_data: Tuple[numpy.ndarray, ...] + self, + quantized_module: QuantizedModule, + calibration_data: Tuple[numpy.ndarray, ...], ): # pylint: disable=too-many-branches """Determine the quantizers for a quantized module. diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index cb50df6460..724dd6fe37 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Sequence, Set, Union import numpy +import numpy.typing as npt from concrete.fhe import conv as fhe_conv from concrete.fhe import maxpool as fhe_maxpool from concrete.fhe import tag, univariate, zeros @@ -162,7 +163,7 @@ def __init__( f"Got alpha == {alpha} and beta == {beta}.", ) - # pylint: disable-next=too-many-statements,too-many-locals + # pylint: disable-next=too-many-statements,too-many-locals,too-many-branches def q_impl( self, *q_inputs: ONNXOpInputOutputType, @@ -420,7 +421,14 @@ def copy_function(x): # Note that here we do not rescale to the output_scale and we do not add a zero-point # Any following Gemm/MatMul/Conv layers will do the rescaling (during re-quantization) # by calling _prepare_inputs_with_constants(...quantize_real_values=True) - m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale + m_matmul: npt.NDArray[numpy.float64] + if q_input2.quantizer.scale.shape == tuple(): + m_matmul = q_input1.quantizer.scale * q_input2.quantizer.scale + else: + # TODO: add assert on shapes + weight_quant_scale = numpy.transpose(q_input2.quantizer.scale, axes=(1, 0)) + assert isinstance(weight_quant_scale, numpy.ndarray) + m_matmul = q_input1.quantizer.scale * weight_quant_scale # If this operation's result are network outputs, return # directly the integer values and a appropriate quantization parameters that @@ -566,7 +574,7 @@ def q_impl( # If this operator is the last one in the graph, # we rescale using the smallest scale to keep all information if self.produces_graph_output: - common_scale = min(q_input_0.quantizer.scale, q_input_1.quantizer.scale) + common_scale = numpy.minimum(q_input_0.quantizer.scale, q_input_1.quantizer.scale) # Otherwise we use the output op quantization scale else: common_scale = self.output_quant_params.scale @@ -953,7 +961,17 @@ def q_impl( # This is going to be compiled with a PBS (along with the following activation function) # Note that we don't re-quantize the output of the conv, this will be done by # any Gemm/Add/Conv layers that follow - m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + m_matmul: npt.NDArray[numpy.float64] + if q_weights.quantizer.scale.shape == tuple(): + m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + else: + # TODO: add assert on shapes + weight_quant_scale = numpy.transpose( + q_weights.quantizer.scale, + axes=(1, 0, 2, 3), + ) + assert isinstance(weight_quant_scale, numpy.ndarray) + m_matmul = q_input.quantizer.scale * weight_quant_scale bias_shape = (1, -1, 1) if is_conv1d else (1, -1, 1, 1) @@ -1046,7 +1064,10 @@ def __init__( self.auto_pad = attrs.get("auto_pad", "NOTSET") self.kernel_shape = attrs.get("kernel_shape", None) - assert_true(self.kernel_shape is not None, "Setting parameter 'kernel_shape' is required.") + assert_true( + self.kernel_shape is not None, + "Setting parameter 'kernel_shape' is required.", + ) self.count_include_pad = attrs.get("count_include_pad", 1) self.pads = attrs.get("pads", tuple([0] * 2 * (len(self.kernel_shape) - 2))) @@ -1365,7 +1386,10 @@ def q_impl( assert_true(pads.size == 4, "Not currently supporting padding of 3D tensors") pad_value = 0 if prepared_inputs[2] is None else prepared_inputs[2] - assert_true(pad_value == 0, "Concrete ML only supports padding with constant zero values") + assert_true( + pad_value == 0, + "Concrete ML only supports padding with constant zero values", + ) assert q_input.quantizer.zero_point is not None q_input_pad = numpy_onnx_pad(q_input.qvalues, pads, q_input.quantizer.zero_point, True) @@ -2037,7 +2061,7 @@ def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray: n_bits = int(self.constant_inputs[3]) self.output_quant_params = UniformQuantizationParameters( - scale=numpy.float64(self.constant_inputs[1]), + scale=numpy.array(self.constant_inputs[1], dtype=float), zero_point=int(self.constant_inputs[2]), offset=2 ** (n_bits - 1) if self.is_signed else 0, ) @@ -2661,7 +2685,11 @@ def q_impl( # Compute padding with floor and apply it to the input, pad with the input zero-point pool_pads = compute_onnx_pool_padding( - q_input.qvalues.shape, self.kernel_shape, self.pads, self.strides, ceil_mode=0 + q_input.qvalues.shape, + self.kernel_shape, + self.pads, + self.strides, + ceil_mode=0, ) # Can only pad with scalar zero-points, but zero-points can be float in special cases @@ -2677,7 +2705,14 @@ def q_impl( with tag(self.op_instance_name + ".unfold"): sum_result = fhe_conv( - q_input_pad, kernels, None, fake_pads, self.strides, None, None, n_in_channels + q_input_pad, + kernels, + None, + fake_pads, + self.strides, + None, + None, + n_in_channels, ) if self.debug_value_tracker is not None: diff --git a/src/concrete/ml/quantization/quantizers.py b/src/concrete/ml/quantization/quantizers.py index 3ce3eb825e..fc30d1ee83 100644 --- a/src/concrete/ml/quantization/quantizers.py +++ b/src/concrete/ml/quantization/quantizers.py @@ -8,6 +8,7 @@ import numpy from concrete.fhe.tracing.tracer import Tracer +from numpy import typing as npt from ..common.debugging import assert_true from ..common.serialization.dumpers import dump, dumps @@ -103,7 +104,11 @@ class QuantizationOptions: is_precomputed_qat: bool = False def __init__( - self, n_bits: int, is_signed: bool = False, is_symmetric: bool = False, is_qat: bool = False + self, + n_bits: int, + is_signed: bool = False, + is_symmetric: bool = False, + is_qat: bool = False, ): self.n_bits = n_bits self.is_signed = is_signed @@ -381,13 +386,13 @@ class UniformQuantizationParameters: The parameters are computed from quantization options and quantization statistics. """ - scale: Optional[numpy.float64] = None + scale: Optional[npt.NDArray[numpy.float64]] = None zero_point: Optional[Union[int, float, numpy.ndarray]] = None offset: Optional[int] = None def __init__( self, - scale: Optional[numpy.float64] = None, + scale: Optional[npt.NDArray[numpy.float64]] = None, zero_point: Optional[Union[int, float, numpy.ndarray]] = None, offset: Optional[int] = None, ): @@ -508,7 +513,7 @@ def compute_quantization_parameters( if numpy.abs(stats.rmax) < STABILITY_CONST: # If the value is a 0 we cannot do it since the scale would become 0 as well # resulting in division by 0 - self.scale = numpy.float64(1.0) + self.scale = numpy.array(1.0, dtype=float) # Ideally we should get rid of round here but it is risky # regarding the FHE compilation. # Indeed, the zero_point value for the weights has to be an integer @@ -517,7 +522,7 @@ def compute_quantization_parameters( else: # If the value is not a 0 we can tweak the scale factor so that # the value quantizes to 1 - self.scale = numpy.float64(stats.rmax) + self.scale = numpy.array(stats.rmax, dtype=float) self.zero_point = 0 else: if options.is_symmetric: @@ -552,13 +557,16 @@ def compute_quantization_parameters( "This can occur with a badly trained model.", ) unique_scales = numpy.unique(numpy.diff(stats.uvalues)) - self.scale = numpy.float64(unique_scales[0]) + self.scale = numpy.array(unique_scales[0], dtype=float) if self.scale is None: - self.scale = numpy.float64( - (stats.rmax - stats.rmin) / (2**options.n_bits - 1) - if stats.rmax != stats.rmin - else 1.0 + self.scale = numpy.array( + ( + (stats.rmax - stats.rmin) / (2**options.n_bits - 1) + if stats.rmax != stats.rmin + else 1.0 + ), + dtype=float, ) if options.is_qat: @@ -614,7 +622,7 @@ def __init__( # Force scale to be a float64 if self.scale is not None: - self.scale = numpy.float64(self.scale) + self.scale = numpy.array(self.scale, dtype=float) def __eq__(self, other) -> bool: @@ -789,7 +797,7 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer] assert_true( isinstance(self.scale, (numpy.floating, float)) - or (isinstance(self.scale, numpy.ndarray) and self.scale.dtype is numpy.float64), + or (isinstance(self.scale, numpy.ndarray) and self.scale.dtype == numpy.float64), "Scale is a of type " + type(self.scale).__name__ + ((" " + str(self.scale.dtype)) if isinstance(self.scale, numpy.ndarray) else ""), @@ -904,7 +912,7 @@ def _values_setup( elif isinstance(values, Tracer): self.values = values else: - self.values = numpy.array(values) + self.values = numpy.array(values, dtype=float) # If no stats are provided, compute them. # Note that this cannot be done during tracing @@ -940,7 +948,7 @@ def _values_setup( elif isinstance(values, Tracer): self.qvalues = values else: - self.qvalues = numpy.array(values) # pragma: no cover + self.qvalues = numpy.array(values, dtype=float) # pragma: no cover # Populate self.values self.dequant() @@ -1010,7 +1018,7 @@ def update_values(self, values: Union[numpy.ndarray, Tracer]) -> Union[numpy.nda elif isinstance(values, Tracer): # pragma: no cover self.values = values else: # pragma: no cover - self.values = numpy.array(values) + self.values = numpy.array(values, dtype=float) return self.quant() def update_quantized_values( @@ -1029,7 +1037,7 @@ def update_quantized_values( elif isinstance(qvalues, Tracer): # pragma: no cover self.qvalues = qvalues else: # pragma: no cover - self.qvalues = numpy.array(qvalues) + self.qvalues = numpy.array(qvalues, dtype=float) return self.dequant() def quant(self) -> Union[numpy.ndarray, Tracer]: diff --git a/tests/torch/test_brevitas_qat.py b/tests/torch/test_brevitas_qat.py index 32331fc42d..196979408b 100644 --- a/tests/torch/test_brevitas_qat.py +++ b/tests/torch/test_brevitas_qat.py @@ -23,6 +23,7 @@ from concrete.ml.pytest.torch_models import ( NetWithConstantsFoldedBeforeOps, QuantCustomModel, + QuantLeNet, TinyQATCNN, ) from concrete.ml.quantization.base_quantized_op import QuantizedMixingOp @@ -97,7 +98,11 @@ def train_brevitas_network_tinymnist(is_cnn, qat_bits, signed, narrow, pot_scali x_all = numpy.expand_dims(x_all.reshape((-1, 8, 8)), 1) x_train, x_test, y_train, y_test = train_test_split( - x_all, y_all, test_size=0.25, shuffle=True, random_state=numpy.random.randint(0, 2**15) + x_all, + y_all, + test_size=0.25, + shuffle=True, + random_state=numpy.random.randint(0, 2**15), ) def train_one_epoch(net, optimizer, train_loader): @@ -606,3 +611,20 @@ def test_brevitas_power_of_two( check_array_equal(y_pred_sim_round, y_pred_clear_round) check_array_equal(y_pred_clear_round, y_pred_clear_no_round) + + +def test_brevitas_channel_wise(): + """Make sure that we can compile brevitas channel-wise quantization""" + model = QuantLeNet() + model.eval() + + with torch.no_grad(): + batch_size = 3 + image_size = 1, 32, 32 + images = torch.rand((batch_size, *image_size)) + out = model(images).detach().numpy() + quantized_module = compile_brevitas_qat_model(model, images, rounding_threshold_bits=6) + out_qm = quantized_module(images.detach().numpy()) + mse = ((out - out_qm) ** 2).mean() + # Arbitrary threshold to check that the predictions are relatively similar + assert mse < 1e-4