Skip to content

Commit

Permalink
fix: compilable module for api v2
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Jul 23, 2024
1 parent 20ae12e commit 8ba942b
Showing 1 changed file with 75 additions and 63 deletions.
138 changes: 75 additions & 63 deletions src/concrete/ml/pandas/_development.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, Tuple, Union

from concrete.fhe import Configuration
from concrete.fhe.compilation.module import FheModule
from concrete.fhe.tracing import Tracer
import numpy

Expand Down Expand Up @@ -37,7 +38,7 @@
from ..torch.compile import build_quantized_module
from ..common.utils import generate_proxy_function

class DFApiV2Helper:
class DFApiV2StaticHelper:
_N_DIMS_TRAINING = 1

_training_input_set = make_training_inputset(
Expand All @@ -47,47 +48,52 @@ class DFApiV2Helper:
2**N_BITS_PANDAS-1,
8, True
)
# Build the quantized module
_training_module = build_quantized_module(
model=LogisticRegressionTraining(
learning_rate=1,
iterations=1,
fit_bias=False,
),
torch_inputset=_training_input_set,
import_qat=False,
n_bits=N_BITS_PANDAS,
rounding_threshold_bits=6,
)

_forward_proxy, _orig_args_to_proxy_func_args = generate_proxy_function(
_training_module._clear_forward, _training_module.ordered_module_input_names
)
def create_api_v2():
class DFApiV2Helper:
# Build the quantized module
_training_module = build_quantized_module(
model=LogisticRegressionTraining(
learning_rate=1,
iterations=1,
fit_bias=False,
),
torch_inputset=DFApiV2StaticHelper._training_input_set,
import_qat=False,
n_bits=N_BITS_PANDAS,
rounding_threshold_bits={"n_bits": 6, "method": fhe.Exactness.EXACT },
)

_forward_proxy, _orig_args_to_proxy_func_args = generate_proxy_function(
_training_module._clear_forward, _training_module.ordered_module_input_names
)

@fhe.module()
class DFApiV2:
@fhe.function(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
)
def train_log_reg(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
):
return DFApiV2Helper._forward_proxy(val_1, val_2, left_key, right_key)

@fhe.function(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
)
def left_right_join_to_compile(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
) -> Union[Tracer, int]:
return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key)
@fhe.module()
class DFApiV2:
@fhe.function(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
)
def train_log_reg(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
):
return DFApiV2Helper._forward_proxy(val_1, val_2, left_key, right_key)

@fhe.function(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
)
def left_right_join_to_compile(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
) -> Union[Tracer, int]:
return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key)

return DFApiV2

def identity_pbs(value: Union[Tracer, int]) -> Union[Tracer, int]:
"""Define an identity TLU.
Expand All @@ -100,27 +106,29 @@ def identity_pbs(value: Union[Tracer, int]) -> Union[Tracer, int]:
"""
return fhe.univariate(lambda x: x)(value)

def create_api_v1():
@fhe.compiler(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
)
def left_right_join_to_compile(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
) -> Union[Tracer, int]:
"""Runs the atomic left/right join in FHE.
Args:
val_1 (Union[Tracer, int]): The value used for accumulating the sum.
val_2 (Union[Tracer, int]): The value to add if the keys match.
left_key (Union[Tracer, int]): The left data-frame's encrypted key to consider.
right_key (Union[Tracer, int]): The right data-frame's encrypted key to consider.
Returns:
Union[Tracer, int]): The new accumulated sum.
"""
return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key)

@fhe.compiler(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
)
def left_right_join_to_compile(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
) -> Union[Tracer, int]:
"""Runs the atomic left/right join in FHE.
Args:
val_1 (Union[Tracer, int]): The value used for accumulating the sum.
val_2 (Union[Tracer, int]): The value to add if the keys match.
left_key (Union[Tracer, int]): The left data-frame's encrypted key to consider.
right_key (Union[Tracer, int]): The right data-frame's encrypted key to consider.
Returns:
Union[Tracer, int]): The new accumulated sum.
"""
return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key)
return left_right_join_to_compile

def _left_right_join_to_compile_internal(
val_1: Union[Tracer, int],
Expand Down Expand Up @@ -195,21 +203,21 @@ def get_left_right_join_inputset(n_bits: int) -> List:
return inputset

def get_training_inputset():
return _get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2Helper._training_input_set)))
return _get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2StaticHelper._training_input_set)))

# Store the configuration functions and parameters to their associated operator
PANDAS_OPS_TO_CIRCUIT_CONFIG = {
1: {
"get_inputset": partial(get_left_right_join_inputset, n_bits=N_BITS_PANDAS),
"to_compile": left_right_join_to_compile,
"to_compile": create_api_v1,
"encrypt_config": {
"n": 4,
"pos": 1,
},
},
2: {
"get_inputset": { "left_right_join_to_compile": partial(get_left_right_join_inputset, n_bits=N_BITS_PANDAS), "train_log_reg": get_training_inputset},
"to_compile": DFApiV2,
"to_compile": create_api_v2,
"encrypt_config": {
"n": 4,
"pos": 1,
Expand Down Expand Up @@ -275,8 +283,12 @@ def save_client_server(client_path: Path = CLIENT_PATH, server_path: Path = SERV
)

# Save the client and server files using the MLIR
merge_circuit.client.save(client_path)
merge_circuit.server.save(server_path, via_mlir=True)
if isinstance(merge_circuit, FheModule):
merge_circuit.runtime.server.save(server_path, via_mlir=True)
merge_circuit.runtime.client.save(client_path)
else:
merge_circuit.server.save(server_path, via_mlir=True)
merge_circuit.client.save(client_path)


def load_server() -> fhe.Server:
Expand Down

0 comments on commit 8ba942b

Please sign in to comment.