Skip to content

Commit

Permalink
Speed up sample extraction
Browse files Browse the repository at this point in the history
The prior code was using bad jax practices, mainly in how it set up the extraction_indices and used that as an index to the array, which lowered to a loop on TPU. This shaves off about 0.5ms from bootstrap.

PiperOrigin-RevId: 607031190
  • Loading branch information
j2kun authored and copybara-github committed Feb 14, 2024
1 parent a3fd4a9 commit 94a56b5
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,27 +526,22 @@ def jit_sample_extract(
An LWE encryption of the constant term of the input polynomial.
"""
k = jnp.shape(rlwe_ciphertext)[0] - 1 # rlwe_dimension

# generates an array like [1, -1, -1, -1, ..., -1] of length poly_deg
# then tiles it vertically rlwe_dimension times.
extraction_coefficients = jnp.tile(
jnp.concatenate([
jnp.ones(1, dtype=jnp.uint32),
-1 * jnp.ones(poly_deg - 1, dtype=jnp.uint32),
]),
(k, 1),
ones = jnp.ones(poly_deg, dtype=jnp.int32)
indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=ones.shape, dimension=0
)
signed_ones = jnp.where(indices > 0, -ones, ones)
extraction_coefficients = jnp.broadcast_to(signed_ones, (k, poly_deg))

# generates an array like [0, poly_deg-1, poly_deg-2, ..., 1] of len poly_deg,
extraction_indices = jnp.flip(jnp.roll(jnp.arange(poly_deg), poly_deg - 1))

extracted_sample = (
extraction_coefficients * rlwe_ciphertext[:-1, extraction_indices]
# extracts the rlwe_ciphertext into a matrix, accessing the last axis by
# indices via [0, poly_deg-1, poly_deg-2, ..., 1]
extracted_values = jnp.flip(
jnp.roll(rlwe_ciphertext[:-1, :], -1, axis=-1), axis=-1
)

extracted_sample = extraction_coefficients * extracted_values
b_term_constant_coeff = rlwe_ciphertext[-1][0]

return jnp.append(
extracted_sample.reshape((poly_deg * k,)),
extracted_sample.flatten(),
jnp.array([b_term_constant_coeff], dtype=jnp.uint32),
)

0 comments on commit 94a56b5

Please sign in to comment.