From 8ba942bc4775f138c5c3643d7c4252ecb550b192 Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Fri, 3 May 2024 18:37:04 +0200 Subject: [PATCH] fix: compilable module for api v2 --- src/concrete/ml/pandas/_development.py | 138 ++++++++++++++----------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/src/concrete/ml/pandas/_development.py b/src/concrete/ml/pandas/_development.py index 54549ad74..11f09c1a4 100644 --- a/src/concrete/ml/pandas/_development.py +++ b/src/concrete/ml/pandas/_development.py @@ -6,6 +6,7 @@ from typing import Dict, List, Tuple, Union from concrete.fhe import Configuration +from concrete.fhe.compilation.module import FheModule from concrete.fhe.tracing import Tracer import numpy @@ -37,7 +38,7 @@ from ..torch.compile import build_quantized_module from ..common.utils import generate_proxy_function -class DFApiV2Helper: +class DFApiV2StaticHelper: _N_DIMS_TRAINING = 1 _training_input_set = make_training_inputset( @@ -47,47 +48,52 @@ class DFApiV2Helper: 2**N_BITS_PANDAS-1, 8, True ) - # Build the quantized module - _training_module = build_quantized_module( - model=LogisticRegressionTraining( - learning_rate=1, - iterations=1, - fit_bias=False, - ), - torch_inputset=_training_input_set, - import_qat=False, - n_bits=N_BITS_PANDAS, - rounding_threshold_bits=6, - ) - _forward_proxy, _orig_args_to_proxy_func_args = generate_proxy_function( - _training_module._clear_forward, _training_module.ordered_module_input_names - ) +def create_api_v2(): + class DFApiV2Helper: + # Build the quantized module + _training_module = build_quantized_module( + model=LogisticRegressionTraining( + learning_rate=1, + iterations=1, + fit_bias=False, + ), + torch_inputset=DFApiV2StaticHelper._training_input_set, + import_qat=False, + n_bits=N_BITS_PANDAS, + rounding_threshold_bits={"n_bits": 6, "method": fhe.Exactness.EXACT }, + ) + _forward_proxy, _orig_args_to_proxy_func_args = generate_proxy_function( + _training_module._clear_forward, _training_module.ordered_module_input_names + ) -@fhe.module() -class DFApiV2: - @fhe.function( - {"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"} - ) - def train_log_reg( - val_1: Union[Tracer, int], - val_2: Union[Tracer, int], - left_key: Union[Tracer, int], - right_key: Union[Tracer, int], - ): - return DFApiV2Helper._forward_proxy(val_1, val_2, left_key, right_key) - @fhe.function( - {"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"} - ) - def left_right_join_to_compile( - val_1: Union[Tracer, int], - val_2: Union[Tracer, int], - left_key: Union[Tracer, int], - right_key: Union[Tracer, int], - ) -> Union[Tracer, int]: - return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key) + @fhe.module() + class DFApiV2: + @fhe.function( + {"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"} + ) + def train_log_reg( + val_1: Union[Tracer, int], + val_2: Union[Tracer, int], + left_key: Union[Tracer, int], + right_key: Union[Tracer, int], + ): + return DFApiV2Helper._forward_proxy(val_1, val_2, left_key, right_key) + + @fhe.function( + {"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"} + ) + def left_right_join_to_compile( + val_1: Union[Tracer, int], + val_2: Union[Tracer, int], + left_key: Union[Tracer, int], + right_key: Union[Tracer, int], + ) -> Union[Tracer, int]: + return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key) + + return DFApiV2 def identity_pbs(value: Union[Tracer, int]) -> Union[Tracer, int]: """Define an identity TLU. @@ -100,27 +106,29 @@ def identity_pbs(value: Union[Tracer, int]) -> Union[Tracer, int]: """ return fhe.univariate(lambda x: x)(value) +def create_api_v1(): + @fhe.compiler( + {"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"} + ) + def left_right_join_to_compile( + val_1: Union[Tracer, int], + val_2: Union[Tracer, int], + left_key: Union[Tracer, int], + right_key: Union[Tracer, int], + ) -> Union[Tracer, int]: + """Runs the atomic left/right join in FHE. + Args: + val_1 (Union[Tracer, int]): The value used for accumulating the sum. + val_2 (Union[Tracer, int]): The value to add if the keys match. + left_key (Union[Tracer, int]): The left data-frame's encrypted key to consider. + right_key (Union[Tracer, int]): The right data-frame's encrypted key to consider. + + Returns: + Union[Tracer, int]): The new accumulated sum. + """ + return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key) -@fhe.compiler( - {"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"} -) -def left_right_join_to_compile( - val_1: Union[Tracer, int], - val_2: Union[Tracer, int], - left_key: Union[Tracer, int], - right_key: Union[Tracer, int], -) -> Union[Tracer, int]: - """Runs the atomic left/right join in FHE. - Args: - val_1 (Union[Tracer, int]): The value used for accumulating the sum. - val_2 (Union[Tracer, int]): The value to add if the keys match. - left_key (Union[Tracer, int]): The left data-frame's encrypted key to consider. - right_key (Union[Tracer, int]): The right data-frame's encrypted key to consider. - - Returns: - Union[Tracer, int]): The new accumulated sum. - """ - return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key) + return left_right_join_to_compile def _left_right_join_to_compile_internal( val_1: Union[Tracer, int], @@ -195,13 +203,13 @@ def get_left_right_join_inputset(n_bits: int) -> List: return inputset def get_training_inputset(): - return _get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2Helper._training_input_set))) + return _get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2StaticHelper._training_input_set))) # Store the configuration functions and parameters to their associated operator PANDAS_OPS_TO_CIRCUIT_CONFIG = { 1: { "get_inputset": partial(get_left_right_join_inputset, n_bits=N_BITS_PANDAS), - "to_compile": left_right_join_to_compile, + "to_compile": create_api_v1, "encrypt_config": { "n": 4, "pos": 1, @@ -209,7 +217,7 @@ def get_training_inputset(): }, 2: { "get_inputset": { "left_right_join_to_compile": partial(get_left_right_join_inputset, n_bits=N_BITS_PANDAS), "train_log_reg": get_training_inputset}, - "to_compile": DFApiV2, + "to_compile": create_api_v2, "encrypt_config": { "n": 4, "pos": 1, @@ -275,8 +283,12 @@ def save_client_server(client_path: Path = CLIENT_PATH, server_path: Path = SERV ) # Save the client and server files using the MLIR - merge_circuit.client.save(client_path) - merge_circuit.server.save(server_path, via_mlir=True) + if isinstance(merge_circuit, FheModule): + merge_circuit.runtime.server.save(server_path, via_mlir=True) + merge_circuit.runtime.client.save(client_path) + else: + merge_circuit.server.save(server_path, via_mlir=True) + merge_circuit.client.save(client_path) def load_server() -> fhe.Server: