diff --git a/jaxite/jaxite_lib/bootstrap_test.py b/jaxite/jaxite_lib/bootstrap_test.py index 868c949..a048eb1 100644 --- a/jaxite/jaxite_lib/bootstrap_test.py +++ b/jaxite/jaxite_lib/bootstrap_test.py @@ -200,6 +200,7 @@ def test_3_bit_bootstrap_larger_lwe_dimension( rlwe_rng=rng, ) + @absltest.skip("b/325287870") def test_6_bit_bootstrap(self, log_ai_bound: int, seed: int): message_bits = 6 padding_bits = 1 diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index 6925398..36b019a 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -6,6 +6,7 @@ from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp +from jaxite.jaxite_lib import jax_helpers @jax.jit @@ -197,9 +198,16 @@ def _toeplitz(inp_ref, out_ref): @jax.jit def toeplitz_poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: """Computes a poly multiplication mod (X^N + 1) where N = len(a).""" - multiplier = _generate_sign_matrix(len(a)) - left_matrix = multiplier * toeplitz(a).transpose() - return i32_as_u8_matmul(b, left_matrix.transpose()) + tpu_version = jax_helpers.get_tpu_version() + n = a.shape[-1] + if n % 128 == 0 and tpu_version >= 5: + toeplitzed = toeplitz_kernelized(a.astype(jnp.int32)) + return i32_as_u8_matmul(b, toeplitzed) + else: + # This branch is non-optimized, does not lower well on most platforms. + multiplier = _generate_sign_matrix(len(a)) + left_matrix = multiplier.transpose() * toeplitz(a) + return i32_as_u8_matmul(b, left_matrix) @jax.named_call