Skip to content

Commit

Permalink
Implement TPU kernel for signed toeplitz
Browse files Browse the repository at this point in the history
This CL just adds the kernel to the codebase, without incorporating it into the cryptosystem yet.

PiperOrigin-RevId: 606823780
  • Loading branch information
j2kun authored and copybara-github committed Feb 14, 2024
1 parent 39e6346 commit c25fb15
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ py_library(
),
visibility = [":internal"],
deps = [
"# copybara: jax:pallas_lib",
"# copybara: jax:pallas_tpu",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
Expand Down
56 changes: 56 additions & 0 deletions jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import functools

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp


Expand Down Expand Up @@ -137,6 +139,60 @@ def toeplitz(x: jnp.ndarray) -> jnp.ndarray:
return toeplitz(jnp.concatenate([x, r], axis=-1))


@jax.jit
def toeplitz_kernelized(x: jnp.ndarray) -> jnp.ndarray:
"""Use pltpu.roll op to implement toeplitz + sign matrix.
Note:
* Only works on TPU v5+.
* Current implementation assumes
- both input and output can fit in VMEM.
- size of input is a multiple of 128.
Args:
x: the 1D array to shift of length n
Returns:
A 2D matrix of shape (n, n), with row i containing the input rolled
rightward i times, with the lower-diagonal sign-flipped.
"""
if len(x.shape) == 1:
x = x.reshape(1, x.shape[0])
assert len(x.shape) == 2
n = x.shape[-1]
if n % 128 != 0:
raise ValueError(f"Input size {n} is not a multiple of 128")

if x.dtype != jnp.float32 and x.dtype != jnp.int32:
raise ValueError(f"Input {x.dtype} is not supported")

def _toeplitz(inp_ref, out_ref):
chunk = jnp.broadcast_to(inp_ref[...], (128, n))
chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0)
chunk_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=0
)
chunk_col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=1
)
for r in range(0, n, 128):
out_ref[pl.ds(r, 128), slice(None)] = jnp.where(
chunk_row_indices > chunk_col_indices, -chunk, chunk
)
# Because the vector registers are aligned to size 128, this roll
# operation lowers to telling the TPU to refer to a different register,
# rather than actually apply and rolling operation. Hence, the op is
# compiled away.
chunk = pltpu.roll(chunk, 128, 1)
chunk_row_indices = chunk_row_indices + 128

return pl.pallas_call(
_toeplitz,
out_shape=jax.ShapeDtypeStruct((n, n), x.dtype),
interpreted=(jax.default_backend() == "cpu"),
)(x)


@jax.named_call
@jax.jit
def toeplitz_poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
Expand Down
10 changes: 10 additions & 0 deletions jaxite/jaxite_lib/matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hypothesis
from hypothesis import strategies
import jax.numpy as jnp
from jaxite.jaxite_lib import jax_helpers
from jaxite.jaxite_lib import matrix_utils
import numpy as np

Expand Down Expand Up @@ -235,6 +236,15 @@ def test_i32_as_u8_matmul(self, lhs, rhs):
)
np.testing.assert_array_equal(expected, actual)

@hypothesis.given(vectors(512))
@hypothesis.settings(deadline=None)
def test_toeplitz_kernelized(self, poly):
if jax_helpers.get_tpu_version() >= 5:
multiplier = matrix_utils._generate_sign_matrix(len(poly))
exp = multiplier.transpose() * matrix_utils.toeplitz(poly)
actual = matrix_utils.toeplitz_kernelized(poly)
np.testing.assert_array_equal(exp, actual)

@hypothesis.given(strategies.integers(min_value=0, max_value=10), vectors(16))
@hypothesis.settings(deadline=None)
def test_scale_by_x_power_n_minus_1(self, power, poly):
Expand Down

0 comments on commit c25fb15

Please sign in to comment.