Skip to content

Commit

Permalink
minor tweaks for improved use of jax API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623305044
  • Loading branch information
j2kun authored and copybara-github committed Apr 10, 2024
1 parent 9a189bc commit a5bfadc
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 20 deletions.
3 changes: 1 addition & 2 deletions BUILD
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# An FHE cryptosystem built in JAX

load("@rules_python//python:defs.bzl", "py_library")
load("@rules_python//python:defs.bzl", "py_test")
load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "multichip_tpu_test")
load("@rules_license//rules:license.bzl", "license")
load("@rules_python//python:defs.bzl", "py_library", "py_test")

package(
default_applicable_licenses = ["@jaxite//:license"],
Expand Down
2 changes: 1 addition & 1 deletion jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""The API for bootstrapping in TFHE."""
"""The API for bootstrapping in CGGI."""
import dataclasses
import functools
from typing import Any, Callable, Optional
Expand Down
2 changes: 1 addition & 1 deletion jaxite/jaxite_lib/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def signed_decomposition(

# reversed because the digits are computed from least-significant bit to
# highest-significant bit..
return result[::-1][:num_levels]
return jnp.flip(result, axis=-1)[:num_levels]


# Applies the signed decomposition to each coefficient of a polynomial
Expand Down
24 changes: 8 additions & 16 deletions jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def _toeplitz(inp_ref, out_ref):
return pl.pallas_call(
_toeplitz,
out_shape=jax.ShapeDtypeStruct((n, n), x.dtype),
interpret=(jax.default_backend() == "cpu"),
)(x)


Expand Down Expand Up @@ -258,27 +257,20 @@ def monomial_mul(
polynomial
"""
n = poly.shape[0]

# After a multiplication by X^{2N}, the polynomial is unchanged.
degree = degree % (2 * n)
flip = degree // n
shift = degree % n

# equivalent to "if degree >= n: poly = -poly" because 0 <= degree < 2N
poly = jnp.uint32((-1) ** flip) * poly
rolled = jnp.roll(poly, shift)

# trick: generate an array like [1, 1, ..., 1, -1, -1, ..., -1]
# and rolling that gives the right argument to use in an element-wise product
# with the tail truncated.
ones = jnp.ones(n, dtype=jnp.uint32)
sign = jnp.roll(jnp.concatenate([ones, -ones]), shift)
output = rolled * sign[:n]
flip = (degree // n) % 2 == 1
indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=poly.shape, dimension=0
)
rolled = jnp.roll(poly, degree)
rolled = jnp.where(flip, -rolled, rolled)
output = jnp.where(indices < shift, -rolled, rolled)

if 0 < log_modulus < 32:
output = jnp.mod(output, jnp.uint32(2) ** log_modulus)

return output
return output.astype(poly.dtype)


monomial_mul_list = jax.vmap(monomial_mul, in_axes=(0, None, None), out_axes=0)
Expand Down

0 comments on commit a5bfadc

Please sign in to comment.