From c25fb15fe3901c987cd318fb69bfc4b1ec53d5ba Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 13 Feb 2024 19:24:18 -0800 Subject: [PATCH] Implement TPU kernel for signed toeplitz This CL just adds the kernel to the codebase, without incorporating it into the cryptosystem yet. PiperOrigin-RevId: 606823780 --- BUILD | 2 + jaxite/jaxite_lib/matrix_utils.py | 56 ++++++++++++++++++++++++++ jaxite/jaxite_lib/matrix_utils_test.py | 10 +++++ 3 files changed, 68 insertions(+) diff --git a/BUILD b/BUILD index 1e7ff54..8ecb56d 100644 --- a/BUILD +++ b/BUILD @@ -38,6 +38,8 @@ py_library( ), visibility = [":internal"], deps = [ + "# copybara: jax:pallas_lib", + "# copybara: jax:pallas_tpu", "@jaxite_deps_jax//:pkg", "@jaxite_deps_jaxlib//:pkg", ], diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index bfde41d..16b3626 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -3,6 +3,8 @@ import functools import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -137,6 +139,60 @@ def toeplitz(x: jnp.ndarray) -> jnp.ndarray: return toeplitz(jnp.concatenate([x, r], axis=-1)) +@jax.jit +def toeplitz_kernelized(x: jnp.ndarray) -> jnp.ndarray: + """Use pltpu.roll op to implement toeplitz + sign matrix. + + Note: + * Only works on TPU v5+. + * Current implementation assumes + - both input and output can fit in VMEM. + - size of input is a multiple of 128. + + Args: + x: the 1D array to shift of length n + + Returns: + A 2D matrix of shape (n, n), with row i containing the input rolled + rightward i times, with the lower-diagonal sign-flipped. + """ + if len(x.shape) == 1: + x = x.reshape(1, x.shape[0]) + assert len(x.shape) == 2 + n = x.shape[-1] + if n % 128 != 0: + raise ValueError(f"Input size {n} is not a multiple of 128") + + if x.dtype != jnp.float32 and x.dtype != jnp.int32: + raise ValueError(f"Input {x.dtype} is not supported") + + def _toeplitz(inp_ref, out_ref): + chunk = jnp.broadcast_to(inp_ref[...], (128, n)) + chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0) + chunk_row_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=0 + ) + chunk_col_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=1 + ) + for r in range(0, n, 128): + out_ref[pl.ds(r, 128), slice(None)] = jnp.where( + chunk_row_indices > chunk_col_indices, -chunk, chunk + ) + # Because the vector registers are aligned to size 128, this roll + # operation lowers to telling the TPU to refer to a different register, + # rather than actually apply and rolling operation. Hence, the op is + # compiled away. + chunk = pltpu.roll(chunk, 128, 1) + chunk_row_indices = chunk_row_indices + 128 + + return pl.pallas_call( + _toeplitz, + out_shape=jax.ShapeDtypeStruct((n, n), x.dtype), + interpreted=(jax.default_backend() == "cpu"), + )(x) + + @jax.named_call @jax.jit def toeplitz_poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: diff --git a/jaxite/jaxite_lib/matrix_utils_test.py b/jaxite/jaxite_lib/matrix_utils_test.py index 7bdeedf..a068f13 100644 --- a/jaxite/jaxite_lib/matrix_utils_test.py +++ b/jaxite/jaxite_lib/matrix_utils_test.py @@ -4,6 +4,7 @@ import hypothesis from hypothesis import strategies import jax.numpy as jnp +from jaxite.jaxite_lib import jax_helpers from jaxite.jaxite_lib import matrix_utils import numpy as np @@ -235,6 +236,15 @@ def test_i32_as_u8_matmul(self, lhs, rhs): ) np.testing.assert_array_equal(expected, actual) + @hypothesis.given(vectors(512)) + @hypothesis.settings(deadline=None) + def test_toeplitz_kernelized(self, poly): + if jax_helpers.get_tpu_version() >= 5: + multiplier = matrix_utils._generate_sign_matrix(len(poly)) + exp = multiplier.transpose() * matrix_utils.toeplitz(poly) + actual = matrix_utils.toeplitz_kernelized(poly) + np.testing.assert_array_equal(exp, actual) + @hypothesis.given(strategies.integers(min_value=0, max_value=10), vectors(16)) @hypothesis.settings(deadline=None) def test_scale_by_x_power_n_minus_1(self, power, poly):