diff --git a/BUILD b/BUILD index 8da6d15..ea185a0 100644 --- a/BUILD +++ b/BUILD @@ -1,9 +1,8 @@ # An FHE cryptosystem built in JAX -load("@rules_python//python:defs.bzl", "py_library") -load("@rules_python//python:defs.bzl", "py_test") load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "multichip_tpu_test") load("@rules_license//rules:license.bzl", "license") +load("@rules_python//python:defs.bzl", "py_library", "py_test") package( default_applicable_licenses = ["@jaxite//:license"], diff --git a/jaxite/jaxite_lib/bootstrap.py b/jaxite/jaxite_lib/bootstrap.py index 60c7c56..a51ecd9 100644 --- a/jaxite/jaxite_lib/bootstrap.py +++ b/jaxite/jaxite_lib/bootstrap.py @@ -1,4 +1,4 @@ -"""The API for bootstrapping in TFHE.""" +"""The API for bootstrapping in CGGI.""" import dataclasses import functools from typing import Any, Callable, Optional diff --git a/jaxite/jaxite_lib/decomposition.py b/jaxite/jaxite_lib/decomposition.py index 568402a..6695858 100644 --- a/jaxite/jaxite_lib/decomposition.py +++ b/jaxite/jaxite_lib/decomposition.py @@ -210,7 +210,7 @@ def signed_decomposition( # reversed because the digits are computed from least-significant bit to # highest-significant bit.. - return result[::-1][:num_levels] + return jnp.flip(result, axis=-1)[:num_levels] # Applies the signed decomposition to each coefficient of a polynomial diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index 36b019a..0984de6 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -190,7 +190,6 @@ def _toeplitz(inp_ref, out_ref): return pl.pallas_call( _toeplitz, out_shape=jax.ShapeDtypeStruct((n, n), x.dtype), - interpret=(jax.default_backend() == "cpu"), )(x) @@ -258,27 +257,20 @@ def monomial_mul( polynomial """ n = poly.shape[0] - - # After a multiplication by X^{2N}, the polynomial is unchanged. degree = degree % (2 * n) - flip = degree // n shift = degree % n - - # equivalent to "if degree >= n: poly = -poly" because 0 <= degree < 2N - poly = jnp.uint32((-1) ** flip) * poly - rolled = jnp.roll(poly, shift) - - # trick: generate an array like [1, 1, ..., 1, -1, -1, ..., -1] - # and rolling that gives the right argument to use in an element-wise product - # with the tail truncated. - ones = jnp.ones(n, dtype=jnp.uint32) - sign = jnp.roll(jnp.concatenate([ones, -ones]), shift) - output = rolled * sign[:n] + flip = (degree // n) % 2 == 1 + indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=poly.shape, dimension=0 + ) + rolled = jnp.roll(poly, degree) + rolled = jnp.where(flip, -rolled, rolled) + output = jnp.where(indices < shift, -rolled, rolled) if 0 < log_modulus < 32: output = jnp.mod(output, jnp.uint32(2) ** log_modulus) - return output + return output.astype(poly.dtype) monomial_mul_list = jax.vmap(monomial_mul, in_axes=(0, None, None), out_axes=0)