Skip to content

Commit

Permalink
feat: add batch creation for log reg DF traininig
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Aug 1, 2024
1 parent 37e8941 commit 26530e5
Show file tree
Hide file tree
Showing 7 changed files with 438 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Utility functions for FHE training."""

from typing import Tuple
import itertools
from typing import Tuple

import numpy
import torch
from torch.nn.functional import binary_cross_entropy_with_logits
Expand All @@ -19,6 +20,7 @@ def binary_cross_entropy(y_true: numpy.ndarray, logits: numpy.ndarray):
"""
return binary_cross_entropy_with_logits(torch.Tensor(logits), torch.Tensor(y_true)).item()


def make_training_inputset(x_min, x_max, param_min, param_max, batch_size, fit_intercept):
"""Get the quantized module for FHE training.
Expand All @@ -31,7 +33,7 @@ def make_training_inputset(x_min, x_max, param_min, param_max, batch_size, fit_i
Returns:
(QuantizedModule): The quantized module containing the FHE circuit for training.
"""

combinations = list(
itertools.product(
[1.0, 0.0], # Labels
Expand Down
214 changes: 178 additions & 36 deletions src/concrete/ml/pandas/_development.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,31 @@
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy
from concrete.fhe import Configuration
from concrete.fhe.compilation.module import FheModule
from concrete.fhe.tracing import Tracer
import numpy

from concrete import fhe

from ..quantization.quantized_module import _get_inputset_generator

script_dir = Path(__file__).parent


CURRENT_API_VERSION = 2

API_VERSION_SPECS = {
1: {"configuration": Configuration()},
1: {"configuration": Configuration(), "join_function": "main"},
2: {
"configuration": Configuration(
compress_input_ciphertexts=True, compress_evaluation_keys=True
)
compress_evaluation_keys=True #compress_input_ciphertexts=True,
),
"join_function": "left_right_join_to_compile",
"batch_1d_function": "build_batch_1d",
"batch_2d_function": "build_batch_2d",
"train_log_reg_function": "train_log_reg",
"create_batch_2d": "create_batch_2d",
"create_batch_1d": "create_batch_1d"
},
}

Expand All @@ -32,24 +38,27 @@
CLIENT_PATH = CLIENT_SERVER_DIR / "client.zip"
SERVER_PATH = CLIENT_SERVER_DIR / "server.zip"

N_BITS_PANDAS = 2
N_BITS_PANDAS = 4

from ..sklearn._fhe_training_utils import LogisticRegressionTraining, make_training_inputset
from ..torch.compile import build_quantized_module
from ..common.utils import generate_proxy_function
from ..common._fhe_training_utils import LogisticRegressionTraining, make_training_inputset
from ..torch.compile import build_quantized_module
from concrete.fhe import Wired, Wire, Output, Input, AllInputs, AllOutputs

class DFApiV2StaticHelper:
_N_DIMS_TRAINING = 1
_BATCH_SIZE = 1
N_DIMS_TRAINING = 16
BATCH_SIZE = 8

_training_input_set = make_training_inputset(
numpy.zeros((_N_DIMS_TRAINING, ), dtype=numpy.int64),
numpy.ones((_N_DIMS_TRAINING, ) , dtype=numpy.int64) * 2**N_BITS_PANDAS - 1,
0,
2**N_BITS_PANDAS-1,
_BATCH_SIZE, True
numpy.ones((N_DIMS_TRAINING,), dtype=numpy.float64) * -1.0,
numpy.ones((N_DIMS_TRAINING,), dtype=numpy.float64) * 1.0,
0,
2**N_BITS_PANDAS - 1,
BATCH_SIZE,
True,
)


def create_api_v2():
class DFApiV2Helper:
# Build the quantized module
Expand All @@ -62,29 +71,38 @@ class DFApiV2Helper:
torch_inputset=DFApiV2StaticHelper._training_input_set,
import_qat=False,
n_bits=N_BITS_PANDAS,
rounding_threshold_bits={"n_bits": 6, "method": fhe.Exactness.EXACT },
rounding_threshold_bits={"n_bits": 6, "method": fhe.Exactness.EXACT},
)

_forward_proxy, _orig_args_to_proxy_func_args = generate_proxy_function(
_training_module._clear_forward, _training_module.ordered_module_input_names
)


@fhe.module()
class DFApiV2:
@fhe.function(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
{
"features": "encrypted",
"targets": "encrypted",
"weights": "encrypted",
"bias": "encrypted",
}
)
def train_log_reg(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
left_key: Union[Tracer, int],
right_key: Union[Tracer, int],
features: Union[Tracer, int],
targets: Union[Tracer, int],
weights: Union[Tracer, int],
bias: Union[Tracer, int],
):
return DFApiV2Helper._forward_proxy(val_1, val_2, left_key, right_key)
return DFApiV2Helper._forward_proxy(features, targets, weights, bias)

@fhe.function(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
{
"val_1": "encrypted",
"val_2": "encrypted",
"left_key": "encrypted",
"right_key": "encrypted",
}
)
def left_right_join_to_compile(
val_1: Union[Tracer, int],
Expand All @@ -94,8 +112,79 @@ def left_right_join_to_compile(
) -> Union[Tracer, int]:
return _left_right_join_to_compile_internal(val_1, val_2, left_key, right_key)

@fhe.function({"value": "encrypted"})
def create_batch_2d(value):
batch = fhe.zeros((DFApiV2StaticHelper.BATCH_SIZE, DFApiV2StaticHelper.N_DIMS_TRAINING))
batch[0,0] = value
return batch

@fhe.function({"value": "encrypted"})
def create_batch_1d(value):
batch = fhe.zeros((DFApiV2StaticHelper.BATCH_SIZE, ))
batch[0] = value
return batch

@fhe.function(
{
"batch": "encrypted",
"value": "encrypted",
"index1": "clear",
"index2": "clear",
}
)
def build_batch_2d(
batch: Union[Tracer, int],
value: Union[Tracer, int],
index1: Union[Tracer, int],
index2: Union[Tracer, int],
):
batch[index1, index2] = value
return batch

@fhe.function(
{
"batch": "encrypted",
"value": "encrypted",
"index1": "clear",
}
)
def build_batch_1d(
batch: Union[Tracer, int],
value: Union[Tracer, int],
index1: Union[Tracer, int],
):
batch[index1] = value
return batch

composition = Wired(
[
# Compose every input -> output of the join function
Wire(AllOutputs(left_right_join_to_compile), AllInputs(left_right_join_to_compile)),

# The output of the join function is used to build the training batch or the labels batch
Wire(Output(left_right_join_to_compile, 0), Input(build_batch_2d, 1)),
Wire(Output(left_right_join_to_compile, 0), Input(build_batch_1d, 1)),

# Batch creation
Wire(Output(create_batch_2d, 0), Input(build_batch_2d, 0)),
Wire(Output(create_batch_1d, 0), Input(build_batch_1d, 0)),

# Batch building is composable
Wire(Output(build_batch_2d, 0), Input(build_batch_2d, 0)),
Wire(Output(build_batch_1d, 0), Input(build_batch_1d, 0)),

# Batches of training data and labels are inputs to log reg training
Wire(Output(build_batch_2d, 0), Input(train_log_reg, 0)),
Wire(Output(build_batch_1d, 0), Input(train_log_reg, 1)),

Wire(Output(train_log_reg, 0), Input(train_log_reg, 2)),
Wire(Output(train_log_reg, 1), Input(train_log_reg, 3)),
]
)

return DFApiV2


def identity_pbs(value: Union[Tracer, int]) -> Union[Tracer, int]:
"""Define an identity TLU.
Expand All @@ -107,9 +196,15 @@ def identity_pbs(value: Union[Tracer, int]) -> Union[Tracer, int]:
"""
return fhe.univariate(lambda x: x)(value)


def create_api_v1():
@fhe.compiler(
{"val_1": "encrypted", "val_2": "encrypted", "left_key": "encrypted", "right_key": "encrypted"}
{
"val_1": "encrypted",
"val_2": "encrypted",
"left_key": "encrypted",
"right_key": "encrypted",
}
)
def left_right_join_to_compile(
val_1: Union[Tracer, int],
Expand All @@ -131,6 +226,7 @@ def left_right_join_to_compile(

return left_right_join_to_compile


def _left_right_join_to_compile_internal(
val_1: Union[Tracer, int],
val_2: Union[Tracer, int],
Expand Down Expand Up @@ -200,13 +296,52 @@ def get_left_right_join_inputset(n_bits: int) -> List:
# the input-set needs to consider 0 although pre-processing requires data-frame to provide
# integers values greater or equal to 1
inputset = list(itertools.product([0, high], [0, high], [0, high], [0, high]))

inputset = [numpy.asarray(v).reshape(1, 1, -1) for v in inputset]
inputset = [numpy.repeat(v, DFApiV2StaticHelper._BATCH_SIZE, axis=1) for v in inputset]

return inputset


def get_training_inputset():
return list(_get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2StaticHelper._training_input_set))))
return list(
_get_inputset_generator(
tuple(map(lambda x: x.astype(numpy.int64), DFApiV2StaticHelper._training_input_set))
)
)

def get_batch_build_dataset_2d():
batch_min = numpy.zeros((DFApiV2StaticHelper.BATCH_SIZE, DFApiV2StaticHelper.N_DIMS_TRAINING), dtype=numpy.uint64)
batch_max = numpy.ones((DFApiV2StaticHelper.BATCH_SIZE, DFApiV2StaticHelper.N_DIMS_TRAINING), dtype=numpy.uint64) * (2 ** N_BITS_PANDAS - 1)
value_min = 0
value_max = (2 ** N_BITS_PANDAS - 1)
index_min = 0
index_max = DFApiV2StaticHelper.BATCH_SIZE - 1
return [
(batch_min, value_max, index_min, index_max),
(batch_max, value_min, index_max, index_min),
(batch_min, value_min, index_min, index_max),
(batch_max, value_max, index_max, index_min)
]

def get_batch_build_dataset_1d():
batch_min = numpy.zeros((DFApiV2StaticHelper.BATCH_SIZE, ), dtype=numpy.uint64)
batch_max = numpy.ones((DFApiV2StaticHelper.BATCH_SIZE, ), dtype=numpy.uint64) * (2 ** N_BITS_PANDAS - 1)
value_min = 0
value_max = (2 ** N_BITS_PANDAS - 1)
index_min = 0
index_max = DFApiV2StaticHelper.BATCH_SIZE - 1
return [
(batch_min, value_max, index_min),
(batch_max, value_min, index_max),
(batch_min, value_min, index_min),
(batch_max, value_max, index_max)
]

def get_batch_create_dataset():
value_min = 0
value_max = (2 ** N_BITS_PANDAS - 1)
return [
(value_max,),
(value_min,),
]

# Store the configuration functions and parameters to their associated operator
PANDAS_OPS_TO_CIRCUIT_CONFIG = {
Expand All @@ -219,15 +354,25 @@ def get_training_inputset():
},
},
2: {
"get_inputset": { "left_right_join_to_compile": partial(get_left_right_join_inputset, n_bits=N_BITS_PANDAS), "train_log_reg": get_training_inputset},
"get_inputset": {
"left_right_join_to_compile": partial(
get_left_right_join_inputset, n_bits=N_BITS_PANDAS
),
"train_log_reg": get_training_inputset,
"build_batch_2d": get_batch_build_dataset_2d,
"build_batch_1d": get_batch_build_dataset_1d,
"create_batch_1d": get_batch_create_dataset,
"create_batch_2d": get_batch_create_dataset,
},
"to_compile": create_api_v2,
"encrypt_config": {
"n": 4,
"pos": 1,
},
}
},
}


def get_encrypt_config() -> Dict:
"""Get the configuration parameters to use when encrypting the input values.
Expand Down Expand Up @@ -274,16 +419,13 @@ def save_client_server(client_path: Path = CLIENT_PATH, server_path: Path = SERV
else:
inputset = config["get_inputset"]()

cp_func = config["to_compile"]
cp_func = config["to_compile"]()

# Configuration used for this API version
cfg = API_VERSION_SPECS[CURRENT_API_VERSION]["configuration"]
cfg.parameter_selection_strategy = "v0"

# Compile the circuit and allow it to be composable with itself
merge_circuit = cp_func.compile(
inputset, composable=True, configuration=cfg
)
merge_circuit = cp_func.compile(inputset, composable=True, configuration=cfg)

# Save the client and server files using the MLIR
if isinstance(merge_circuit, FheModule):
Expand Down
4 changes: 3 additions & 1 deletion src/concrete/ml/pandas/_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas
from concrete.fhe import Server
from pandas.core.reshape.merge import _MergeOperation
from ._development import API_VERSION_SPECS, CURRENT_API_VERSION

# List of Pandas parameters per operator that are not currently supported
UNSUPPORTED_PANDAS_PARAMETERS = {
Expand Down Expand Up @@ -207,7 +208,8 @@ def encrypted_left_right_join(
# In practice, keys only match once throughout this very loop as keys are assumed to
# be unique on both data-frames.
right_value_to_join = server.run(
*merge_inputs, evaluation_keys=left_encrypted.evaluation_keys
*merge_inputs, evaluation_keys=left_encrypted.evaluation_keys,
function_name=API_VERSION_SPECS[CURRENT_API_VERSION]["join_function"]
)

right_row_to_join.append(right_value_to_join)
Expand Down
Loading

0 comments on commit 26530e5

Please sign in to comment.