Skip to content

Commit

Permalink
Simplify application of bmmp precompute routine
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605468230
  • Loading branch information
j2kun authored and copybara-github committed Feb 9, 2024
1 parent 39e6346 commit 45c3228
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 19 deletions.
19 changes: 19 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@ cpu_gpu_tpu_test(
],
)

cpu_gpu_tpu_test(
name = "bmmp_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/bmmp_test.py"],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "decomposition_test",
size = "small",
Expand Down
93 changes: 93 additions & 0 deletions jaxite/jaxite_lib/bmmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Routines related to the BMMP17 bootstrapping trick.
Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier
(BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
bootstrapping key to reduce the number of external products required by 1/2.
Rather than encrypt the secret key bits of the LWE key separately, we
encrypt:
BSK_{3i} = s_{2i} * s_{2i+1},
BSK_{3i+1} = s_{2i} * (1 − s_{2i+1}),
BSK_{3i+2} = (1 − s_{2i}) * s_{2i+1}
which enables a bootstrap operation that involves 1/2 as many external
products, though this causes the bootstrapping key to be 50% larger.
"""

import functools

import jax
import jax.numpy as jnp
from jaxite.jaxite_lib import types


@jax.jit
def scale_by_x_power_n_minus_1_vanilla_jax(
power: jnp.int32, matrix: jnp.ndarray
) -> jnp.ndarray:
indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=matrix.shape, dimension=2
)
n = matrix.shape[2]

power = power % (2 * n)
shift = power % n
flip = (power // n) % 2 == 1
rolled = jnp.roll(matrix, shift, axis=2)
rolled = jnp.where(flip, -rolled, rolled)
x_power_n_part = jnp.where(indices < shift, -rolled, rolled)
return x_power_n_part - matrix


@functools.partial(jax.jit, static_argnames="log_modulus")
def scale_by_x_power_n_minus_1(
power: jnp.int32, matrix: jnp.ndarray, log_modulus: int
) -> jnp.ndarray:
"""An optimized poly mul for scaling a matrix of polynomials by x^n - 1.
Args:
power: The exponent n of x^n - 1 to scale each matrix entry by
matrix: The matrix to be scaled.
log_modulus: the base-2 logarithm of the polynomial coefficient modulus.
Returns:
An `jnp.ndarray` of the same shape as `matrix`, containing the
entries of `matrix` each scaled by x^n - 1.
"""
output = scale_by_x_power_n_minus_1_vanilla_jax(power, matrix)

if 0 < log_modulus < 32:
output = jnp.mod(output, jnp.uint32(2) ** log_modulus)

return output


@jax.named_call
@functools.partial(jax.jit, static_argnums=(2))
def compute_bmmp_factors(
coefficient_index: types.LweCiphertext,
bsk: jnp.ndarray,
log_coefficient_modulus: int,
):
"""Pre-process the bootstrapping key in preparation for blind rotate."""
num_loop_terms = (coefficient_index.shape[0] - 1) // 2

def one_bmmp_factor(j):
power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1]
power2 = coefficient_index[2 * j]
power3 = coefficient_index[2 * j + 1]
return (
scale_by_x_power_n_minus_1(
power1, bsk[3 * j], log_modulus=log_coefficient_modulus
)
+ scale_by_x_power_n_minus_1(
power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus
)
+ scale_by_x_power_n_minus_1(
power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus
)
).astype(jnp.uint32)

return jax.vmap(one_bmmp_factor, in_axes=(0,), out_axes=0)(
jnp.arange(num_loop_terms)
)
119 changes: 119 additions & 0 deletions jaxite/jaxite_lib/bmmp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Tests for bmmp."""

import hypothesis
from hypothesis import strategies
import jax.numpy as jnp
from jaxite.jaxite_lib import bmmp
from jaxite.jaxite_lib import matrix_utils
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized


@hypothesis.strategies.composite
def vectors(draw, size, min_value=-(2**31), max_value=2**31 - 1):
# Note hypothesis.extras.numpy has no build rule in google3
return np.array(
draw(
strategies.lists(
strategies.integers(min_value=min_value, max_value=max_value),
min_size=size,
max_size=size,
),
),
dtype=np.int32,
)


class BmmpTest(parameterized.TestCase):

def test_kernel_equivalence(self):
N = 64
# create a 8x8x8 matrix, with each polynomial the same
poly = jnp.arange(N).astype(jnp.int32)
matrix = jnp.tile(poly, reps=jnp.array([8, 8, 1]))
power = 4
transformed_poly = jnp.array([
-60,
-62,
-64,
-66,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
-4,
])
expected = jnp.tile(transformed_poly, reps=jnp.array([8, 8, 1]))
actual = bmmp.scale_by_x_power_n_minus_1(power, matrix, log_modulus=32)
np.testing.assert_array_equal(expected, actual)

@hypothesis.given(strategies.integers(min_value=0, max_value=10), vectors(16))
@hypothesis.settings(deadline=None)
def test_scale_by_x_power_n_minus_1(self, power, poly):
matrix = jnp.tile(jnp.array(list(poly)), reps=jnp.array([8, 8, 1]))
poly_term = matrix_utils.x_power_n_minus_1(power, poly_mod_deg=16)
expected = matrix_utils.poly_mul_const_matrix(poly_term, matrix)
actual = matrix_utils.scale_by_x_power_n_minus_1(
power, matrix, log_modulus=32
)
np.testing.assert_array_equal(expected, actual)


if __name__ == '__main__':
absltest.main()
22 changes: 3 additions & 19 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jax
import jax.numpy as jnp
from jaxite.jaxite_lib import bmmp
from jaxite.jaxite_lib import decomposition
from jaxite.jaxite_lib import key_switch
from jaxite.jaxite_lib import lwe
Expand Down Expand Up @@ -472,25 +473,8 @@ def jit_blind_rotate(
# (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
# bootstrapping key to reduce the number of external products required by 1/2.
num_loop_terms = (coefficient_index.shape[0] - 1) // 2

def one_bmmp_factor(j):
power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1]
power2 = coefficient_index[2 * j]
power3 = coefficient_index[2 * j + 1]
return (
matrix_utils.scale_by_x_power_n_minus_1(
power1, bsk[3 * j], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus
)
).astype(jnp.uint32)

bmmp_factors = jax.vmap(one_bmmp_factor, in_axes=(0,), out_axes=0)(
jnp.arange(num_loop_terms, dtype=jnp.uint32)
bmmp_factors = bmmp.compute_bmmp_factors(
coefficient_index, bsk, log_coefficient_modulus
)

def one_external_product(j, c_prime_accum):
Expand Down

0 comments on commit 45c3228

Please sign in to comment.