Skip to content

Commit

Permalink
Add an initial vector-matrix polymul megakernel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626056090
  • Loading branch information
j2kun authored and copybara-github committed May 28, 2024
1 parent 4d5e5d4 commit b6b2b32
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 15 deletions.
20 changes: 19 additions & 1 deletion BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# An FHE cryptosystem built in JAX

load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "multichip_tpu_test")
load("@jaxite//bazel:test_oss.bzl", "cpu_gpu_tpu_test", "gpu_tpu_test", "multichip_tpu_test", "tpu_test")
load("@rules_license//rules:license.bzl", "license")
load("@rules_python//python:defs.bzl", "py_library", "py_test")

Expand Down Expand Up @@ -75,6 +75,24 @@ cpu_gpu_tpu_test(
],
)

tpu_test(
name = "polymul_kernel_test",
size = "large",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/polymul_kernel_test.py"],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "decomposition_test",
size = "small",
Expand Down
4 changes: 3 additions & 1 deletion jaxite/jaxite_bool/jaxite_bool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def test_ksk_decomposition_params(self, decomp_log_base: int, l: int) -> None:

@parameterized.named_parameters(
dict(testcase_name='_b=4_L=8', decomp_log_base=4, l=8),
dict(testcase_name='_b=4_L=7', decomp_log_base=4, l=7),
# TODO(b/335701655): odd L results in tensor shapes that conflict with
# the TPU kernel's requirements in polymul_kernel.py.
# dict(testcase_name='_b=4_L=7', decomp_log_base=4, l=7),
dict(testcase_name='_b=4_L=6', decomp_log_base=4, l=6),
)
def test_bsk_decomposition_params(self, decomp_log_base: int, l: int) -> None:
Expand Down
14 changes: 5 additions & 9 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The API for bootstrapping in CGGI."""

import dataclasses
import functools
from typing import Any, Callable, Optional
Expand All @@ -10,6 +11,7 @@
from jaxite.jaxite_lib import lwe
from jaxite.jaxite_lib import matrix_utils
from jaxite.jaxite_lib import parameters
from jaxite.jaxite_lib import polymul_kernel
from jaxite.jaxite_lib import random_source
from jaxite.jaxite_lib import rgsw
from jaxite.jaxite_lib import rlwe
Expand Down Expand Up @@ -326,14 +328,6 @@ def external_product(
)


# in_axes = (None, 1) means that the first argument is repeated across all
# calls, while the second argument is mapped across its second index
# (column-wise)
vector_matrix_polymul = jax.jit(
jax.vmap(matrix_utils.poly_dot_product, in_axes=(None, 1), out_axes=0)
)


@functools.partial(jax.jit, static_argnames="decomposition_params")
def jit_external_product(
rgsw_ct: jnp.ndarray,
Expand All @@ -344,7 +338,9 @@ def jit_external_product(
decomposed_rlwe = decomposition.decompose_rlwe_ciphertext(
rlwe_ct, decomposition_params
)
return vector_matrix_polymul(decomposed_rlwe, rgsw_ct)
return polymul_kernel.negacyclic_vector_matrix_polymul(
decomposed_rlwe, rgsw_ct
)


def cmux(
Expand Down
10 changes: 6 additions & 4 deletions jaxite/jaxite_lib/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,16 @@ def test_3_bit_bootstrap_larger_lwe_dimension(
message_bits = 3
padding_bits = 1
lwe_dimension = 100
mod_degree = 1024
mod_degree = 512
# TODO(b/339715397): make the kernel work for degree 1024
# mod_degree = 1024

rng = random_source.PseudorandomSource(
uniform_bounds=(0, 2**log_ai_bound),
normal_std=1,
uniform_bounds=(0, 2**28),
normal_std=0,
seed=seed,
)
injected_noise = 2 ** (32 - padding_bits - message_bits - 2) - 1
injected_noise = 2 ** (32 - padding_bits - message_bits - 3) - 1

self.run_bootstrap_test(
injected_noise=injected_noise,
Expand Down
182 changes: 182 additions & 0 deletions jaxite/jaxite_lib/polymul_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Kernel for negacyclic vector-matrix polymul."""

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
from jaxite.jaxite_lib import jax_helpers
from jaxite.jaxite_lib import matrix_utils


# This fallback serves as a reference implementation, but does not lower well on
# TPU due to the semantics of the vmap.
#
# in_axes = (None, 1) means that the first argument is repeated across all
# calls, while the second argument is mapped across its second index
# (column-wise)
fallback_vector_matrix_polymul = jax.jit(
jax.vmap(matrix_utils.poly_dot_product, in_axes=(None, 1), out_axes=0)
)

# i32_as_u8_matmul is a (m,) x (m, k) -> (k,) matmul, but _i32_matmul_unreduced
# is an (m, k) x (k, n) -> (m, n) matmul. To compare, we can vmap
# i32_as_u8_matmul over the first axis.
#
# in_axes = (0, None) means that the second argument is repeated across all
# calls, while the first argument is mapped across its first axis.
fallback_i32_matmul = jax.vmap(
matrix_utils.i32_as_u8_matmul, in_axes=(0, None), out_axes=0
)


def _i32_matmul_unreduced(lhs, rhs):
lax = jax.lax
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[1]
lhs_i8 = jnp.broadcast_to(lhs, (4, *lhs.shape))
lhs_shift = lax.broadcasted_iota(jnp.int32, lhs_i8.shape, dimension=0) * 8
lhs_i8 = lax.shift_right_logical(lhs_i8, lhs_shift)
lhs_i8 = lax.bitwise_and(lhs_i8, jnp.broadcast_to(0xFF, lhs_i8.shape))
lhs_i8 = lhs_i8.reshape((4 * m, k))

acc = jnp.zeros((4 * m, n), dtype=jnp.int32)
out_shift_base = lax.mul(
lax.div(lax.broadcasted_iota(jnp.int32, (4 * m, n), dimension=0), m), 8
)
for rhs_shift in range(0, 32, 8):
# TODO(b/201562458): Don't multiply lhs rows with large shift.
rhs_i8 = lax.shift_right_logical(
rhs, jnp.broadcast_to(rhs_shift, rhs.shape)
)
rhs_i8 = lax.bitwise_and(rhs_i8, jnp.broadcast_to(0xFF, rhs_i8.shape))
# TODO(b/201562458): Use int8 matmuls once properly supported
raw_out = lax.dot(
lhs_i8.astype(jnp.float32),
rhs_i8.astype(jnp.float32),
preferred_element_type=jnp.float32,
).astype(jnp.int32)
acc += jnp.left_shift(raw_out, out_shift_base + rhs_shift)
return acc


def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray):
b, n = poly_vec1.shape
b2, m, n2 = poly_mat2.shape
assert b == b2 and n == n2
real_m = m
m = 8
poly_mat2 = jnp.pad(
poly_mat2,
((0, 0), (0, m - real_m), (0, 0)),
mode="constant",
constant_values=(0,),
)

if n % 128 != 0:
raise ValueError(f"Input size {n} is not a multiple of 128")
dtype = poly_vec1.dtype
# TODO: dtype checks

def vec_mat_polymul_kernel_single_batch(vec_ref, mat_ref, out_ref):
chunk = jnp.broadcast_to(vec_ref[...], (128, n))
chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0)
chunk_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=0
)
chunk_col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=1
)
toeplitz_chunks = []
for _ in range(0, n, 128):
toeplitz_chunks.append(
jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk)
)
# Because the vector registers are aligned to size 128, this roll
# operation lowers to telling the TPU to refer to a different register,
# rather than actually applying any rolling operation. Hence, the op
# produces no hardware instructions.
chunk = pltpu.roll(chunk, 128, 1)
chunk_row_indices = chunk_row_indices + 128
vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0)

assert vec_toeplitz.shape == (n, n)
result = _i32_matmul_unreduced(mat_ref[...], vec_toeplitz)
assert result.shape == (4 * m, n), result.shape
out_ref[...] = result

def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref):
for b in range(vec_ref.shape[0]):
vec_mat_polymul_kernel_single_batch(
vec_ref.at[b], mat_ref.at[b], out_ref.at[b]
)

block_b = 2
steps_b, rem_b = divmod(b, block_b)
if rem_b:
raise ValueError(f"b={b} is not a multiple of block_b={block_b}")

return jnp.sum(
pl.pallas_call(
vec_mat_polymul_kernel,
in_specs=(
pl.BlockSpec(lambda b: (b, 0, 0), (block_b, 1, n)),
pl.BlockSpec(lambda b: (b, 0, 0), (block_b, m, n)),
),
out_specs=pl.BlockSpec(lambda b: (b, 0, 0), (block_b, 4 * m, n)),
out_shape=jax.ShapeDtypeStruct((b, 4 * m, n), jnp.int32),
grid=(steps_b,),
)(
poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32)
).reshape(
b, 4, m, n
),
axis=(0, 1),
).astype(jnp.uint32)[:real_m]


@jax.named_call
@jax.jit
def negacyclic_vector_matrix_polymul(
vec: jnp.ndarray, matrix: jnp.ndarray
) -> jnp.ndarray:
"""Computes a vector-matrix poly multiplication mod (X^N + 1).
Args:
vec: a vector of polynomials
matrix: a matrix of polynomials
Returns:
the vector-matrix product of the polynomials
"""
n_matrix = matrix.shape[-1]
n_vec = vec.shape[-1]
if n_matrix != n_vec:
raise ValueError(
"Expected polynomial degree of the inputs to match, "
f"but found {n_vec} != {n_matrix}"
)

tpu_version = jax_helpers.get_tpu_version()
if n_vec % 128 == 0 and tpu_version >= 5:
return _vector_matrix_polymul(
vec.astype(jnp.int32), matrix.astype(jnp.int32)
)
else:
return fallback_vector_matrix_polymul(vec, matrix)


def i32_matmul_unreduced(lhs, rhs, out):
"""A helper to isolate the matmul part of the kernel to test in isolation."""
out[...] = _i32_matmul_unreduced(lhs[...], rhs[...])


def i32_matmul(lhs, rhs):
"""A helper to isolate the matmul part of the kernel to test in isolation."""
m, k, k2, n = lhs.shape[0], lhs.shape[1], rhs.shape[0], rhs.shape[1]
assert k == k2
return jnp.sum(
pl.pallas_call(
i32_matmul_unreduced,
out_shape=jax.ShapeDtypeStruct((4 * m, n), jnp.int32),
)(lhs, rhs).reshape(4, m, n),
axis=(0),
).astype(jnp.uint32)
48 changes: 48 additions & 0 deletions jaxite/jaxite_lib/polymul_kernel_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax.numpy as jnp
from jaxite.jaxite_lib import polymul_kernel
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized


_SEEDS = list(range(3))


def random(shape, dtype=np.int32):
return jnp.array(
np.random.randint(low=0, high=2**31 - 1, size=shape, dtype=dtype)
)


class PolymulKernelTest(parameterized.TestCase):

@parameterized.product(seed=_SEEDS)
def test_i32_matmul_vs_reference(self, seed: int):
np.random.seed(seed)
lhs = random(shape=(24, 512)) # leading dimension must be a multiple of 8
rhs = random(shape=(512, 512))
expected = polymul_kernel.fallback_i32_matmul(lhs, rhs).astype(jnp.uint32)
actual = polymul_kernel.i32_matmul(lhs, rhs)
np.testing.assert_array_equal(expected, actual)

def test_vector_matrix_vs_reference(self):
vector = random(shape=(18, 512))
matrix = random(shape=(18, 3, 512))
expected = polymul_kernel.fallback_vector_matrix_polymul(vector, matrix)
actual = polymul_kernel.negacyclic_vector_matrix_polymul(vector, matrix)
np.testing.assert_array_equal(expected, actual)

@parameterized.product(
seed=_SEEDS,
)
def test_many_seeds(self, seed: int):
np.random.seed(seed)
vector = random(shape=(18, 512), dtype=jnp.uint32)
matrix = random(shape=(18, 3, 512), dtype=jnp.uint32)
expected = polymul_kernel.fallback_vector_matrix_polymul(vector, matrix)
actual = polymul_kernel.negacyclic_vector_matrix_polymul(vector, matrix)
np.testing.assert_array_equal(expected, actual)


if __name__ == "__main__":
absltest.main()

0 comments on commit b6b2b32

Please sign in to comment.