Skip to content

Commit

Permalink
Swap in u8 matmul for vanilla matmul
Browse files Browse the repository at this point in the history
This requires making a small change to how the bootstrapping key is generated
for GPUs, because the polymul change results in a larger tensor generated
during encryption, which in turn exhausts the GPU's RAM. Replaced a naive vmap
with a looped one.

PiperOrigin-RevId: 578970147
  • Loading branch information
j2kun authored and copybara-github committed Nov 2, 2023
1 parent 95ca578 commit e2ece1e
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
95 changes: 95 additions & 0 deletions jaxite/jaxite_lib/jax_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""A module containing JAX helper code."""

import functools
from typing import Any, Callable, Sequence, TypeVar
import jax
import jax.numpy as jnp


tree_flatten = jax.tree_util.tree_flatten
tree_unflatten = jax.tree_util.tree_unflatten
tree_map = jax.tree_util.tree_map


def _tree_map_multi_output(f, *args):
"""Like tree_map, but for functions that return tuples."""
leaves, treedefs = zip(*map(tree_flatten, args))
if any(treedef != treedefs[0] for treedef in treedefs):
raise ValueError(f'argument treedefs do not match {treedefs=}')
outputs = zip(*map(f, *leaves))
return tuple(tree_unflatten(treedefs[0], out) for out in outputs)


def _lax_map(f, *xs):
"""Like lax.map, but supports multiple arguments like the built-in map."""
g = lambda _, x: ((), f(*x))
_, ys = jax.lax.scan(g, (), xs)
return ys


F = TypeVar('F', bound=Callable)


def batch_vmap(
f: F,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
*,
batch_size: int,
) -> F:
"""jax.vmap, but looping when the batch dimension exceeds batch_size."""

def preprocess(x, in_axis):
batch_count = x.shape[in_axis] // batch_size
x = jnp.moveaxis(x, in_axis, 0)
loop_elements = batch_count * batch_size
x_loop = x[:loop_elements].reshape((batch_count, batch_size) + x.shape[1:])
x_tail = x[loop_elements:]
return x_loop, x_tail

def postprocess(x_loop, x_tail, out_axis):
shape = x_loop.shape
x_loop = x_loop.reshape((shape[0] * shape[1],) + shape[2:])
x = jnp.concatenate([x_loop, x_tail], axis=0)
return jnp.moveaxis(x, 0, out_axis)

def g(*args):
if isinstance(in_axes, int) or in_axes is None:
in_axes_tuple = (in_axes,) * len(args)
else:
in_axes_tuple = tuple(in_axes)

unbatched = []
loop_args = []
tail_args = []
for i, (arg, in_axis) in enumerate(zip(args, in_axes_tuple)):
if in_axis is None:
unbatched.append((i, arg))
elif isinstance(in_axis, int):
loop_arg, tail_arg = _tree_map_multi_output(
functools.partial(preprocess, in_axis=in_axis), arg
)
loop_args.append(loop_arg)
tail_args.append(tail_arg)
else:
loop_arg, tail_arg = _tree_map_multi_output(preprocess, arg, in_axis)
loop_args.append(loop_arg)
tail_args.append(tail_arg)

def f2(*args):
args2 = list(args)
for i, arg in unbatched:
args2.insert(i, arg)
return f(*args2)

loop_out = _lax_map(jax.vmap(f2), *loop_args)
tail_out = jax.vmap(f2)(*tail_args)
if isinstance(out_axes, int):
out = tree_map(
functools.partial(postprocess, out_axis=out_axes), loop_out, tail_out
)
else:
out = tree_map(postprocess, loop_out, tail_out, out_axes)
return out

return g
2 changes: 1 addition & 1 deletion jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def toeplitz_poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
"""Computes a poly multiplication mod (X^N + 1) where N = len(a)."""
multiplier = _generate_sign_matrix(len(a))
left_matrix = multiplier * toeplitz(a).transpose()
return jnp.matmul(b, left_matrix.transpose())
return i32_as_u8_matmul(b, left_matrix.transpose())


@jax.named_call
Expand Down
8 changes: 7 additions & 1 deletion jaxite/jaxite_lib/rgsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.numpy as jnp
from jaxite.jaxite_lib import decomposition
from jaxite.jaxite_lib import encoding
from jaxite.jaxite_lib import jax_helpers
from jaxite.jaxite_lib import matrix_utils
from jaxite.jaxite_lib import parameters
from jaxite.jaxite_lib import random_source
Expand Down Expand Up @@ -249,10 +250,15 @@ def encrypt_block(ai_samples, error_samples, block, plaintext_message):
out_axes=0,
)(ai_samples, error_samples, levels_range, block, plaintext_message)

ciphertext = jax.vmap(
# We use batch_vmap here because on GPU, the additional use of
# i32_as_u8_matmul during encryption results in too much memory usage.
# However, note that because key generation will typically not happen on a TPU
# or GPU, this is mostly a mechanism to ensure we can run tests fast in CI.
ciphertext = jax_helpers.batch_vmap(
encrypt_block,
in_axes=(0, 0, 0, None),
out_axes=0,
batch_size=1,
)(ai_samples, error_samples, block_range, plaintext)

return ciphertext.reshape(
Expand Down

0 comments on commit e2ece1e

Please sign in to comment.