Skip to content

Commit

Permalink
chore: reduce nbits to make it compile
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Jul 23, 2024
1 parent 8ba942b commit 37e8941
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/concrete/ml/pandas/_development.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,22 @@
CLIENT_PATH = CLIENT_SERVER_DIR / "client.zip"
SERVER_PATH = CLIENT_SERVER_DIR / "server.zip"

N_BITS_PANDAS = 4
N_BITS_PANDAS = 2

from ..sklearn._fhe_training_utils import LogisticRegressionTraining, make_training_inputset
from ..torch.compile import build_quantized_module
from ..common.utils import generate_proxy_function

class DFApiV2StaticHelper:
_N_DIMS_TRAINING = 1
_BATCH_SIZE = 1

_training_input_set = make_training_inputset(
numpy.zeros((_N_DIMS_TRAINING, ), dtype=numpy.int64),
numpy.ones((_N_DIMS_TRAINING, ) , dtype=numpy.int64) * 2**N_BITS_PANDAS - 1,
0,
2**N_BITS_PANDAS-1,
8, True
_BATCH_SIZE, True
)

def create_api_v2():
Expand Down Expand Up @@ -200,10 +201,12 @@ def get_left_right_join_inputset(n_bits: int) -> List:
# integers values greater or equal to 1
inputset = list(itertools.product([0, high], [0, high], [0, high], [0, high]))

inputset = [numpy.asarray(v).reshape(1, 1, -1) for v in inputset]
inputset = [numpy.repeat(v, DFApiV2StaticHelper._BATCH_SIZE, axis=1) for v in inputset]
return inputset

def get_training_inputset():
return _get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2StaticHelper._training_input_set)))
return list(_get_inputset_generator(tuple(map(lambda x: x.astype(numpy.int64), DFApiV2StaticHelper._training_input_set))))

# Store the configuration functions and parameters to their associated operator
PANDAS_OPS_TO_CIRCUIT_CONFIG = {
Expand Down
2 changes: 1 addition & 1 deletion tests/pandas/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def generate_pandas_dataframe(
pandas.DataFrame: The generated Pandas data-frame.
"""
if indexes is None:
indexes = 5
indexes = 3

allowed_dtype = ["int", "float", "str", "mixed"]
assert dtype in allowed_dtype, f"Parameter 'dtype' must be in {allowed_dtype}. Got {dtype}."
Expand Down

0 comments on commit 37e8941

Please sign in to comment.