Skip to content

Commit

Permalink
Use toeplitz kernel in polymul for TPU version >=5
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607028597
  • Loading branch information
j2kun authored and copybara-github committed Feb 15, 2024
1 parent 31248ec commit c46af01
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions jaxite/jaxite_lib/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def test_3_bit_bootstrap_larger_lwe_dimension(
rlwe_rng=rng,
)

@absltest.skip("b/325287870")
def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int):
message_bits = 6
padding_bits = 1
Expand Down
14 changes: 11 additions & 3 deletions jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
from jaxite.jaxite_lib import jax_helpers


@jax.jit
Expand Down Expand Up @@ -197,9 +198,16 @@ def _toeplitz(inp_ref, out_ref):
@jax.jit
def toeplitz_poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
"""Computes a poly multiplication mod (X^N + 1) where N = len(a)."""
multiplier = _generate_sign_matrix(len(a))
left_matrix = multiplier * toeplitz(a).transpose()
return i32_as_u8_matmul(b, left_matrix.transpose())
tpu_version = jax_helpers.get_tpu_version()
n = a.shape[-1]
if n % 128 == 0 and tpu_version >= 5:
toeplitzed = toeplitz_kernelized(a.astype(jnp.int32))
return i32_as_u8_matmul(b, toeplitzed)
else:
# This branch is non-optimized, does not lower well on most platforms.
multiplier = _generate_sign_matrix(len(a))
left_matrix = multiplier.transpose() * toeplitz(a)
return i32_as_u8_matmul(b, left_matrix)


@jax.named_call
Expand Down

0 comments on commit c46af01

Please sign in to comment.