diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index d0ccd39..0f6db53 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -23,7 +23,7 @@ jobs: with: path: | ~/.cache/bazel - key: ${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE', 'requirements.txt') }} + key: ${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE', 'requirements_dev.txt') }} - name: "Run `bazel build`" run: | diff --git a/BUILD b/BUILD index 1e7ff54..8da6d15 100644 --- a/BUILD +++ b/BUILD @@ -40,6 +40,8 @@ py_library( deps = [ "@jaxite_deps_jax//:pkg", "@jaxite_deps_jaxlib//:pkg", + # copybara: jax:pallas_lib + # copybara: jax:pallas_tpu ], ) diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index bfde41d..fb60d15 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -3,6 +3,8 @@ import functools import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -137,6 +139,60 @@ def toeplitz(x: jnp.ndarray) -> jnp.ndarray: return toeplitz(jnp.concatenate([x, r], axis=-1)) +@jax.jit +def toeplitz_kernelized(x: jnp.ndarray) -> jnp.ndarray: + """Use pltpu.roll op to implement toeplitz + sign matrix. + + Note: + * Only works on TPU v5+. + * Current implementation assumes + - both input and output can fit in VMEM. + - size of input is a multiple of 128. + + Args: + x: the 1D array to shift of length n + + Returns: + A 2D matrix of shape (n, n), with row i containing the input rolled + rightward i times, with the lower-diagonal sign-flipped. + """ + if len(x.shape) == 1: + x = x.reshape(1, x.shape[0]) + assert len(x.shape) == 2 + n = x.shape[-1] + if n % 128 != 0: + raise ValueError(f"Input size {n} is not a multiple of 128") + + if x.dtype != jnp.float32 and x.dtype != jnp.int32: + raise ValueError(f"Input {x.dtype} is not supported") + + def _toeplitz(inp_ref, out_ref): + chunk = jnp.broadcast_to(inp_ref[...], (128, n)) + chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0) + chunk_row_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=0 + ) + chunk_col_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=1 + ) + for r in range(0, n, 128): + out_ref[pl.ds(r, 128), slice(None)] = jnp.where( + chunk_row_indices > chunk_col_indices, -chunk, chunk + ) + # Because the vector registers are aligned to size 128, this roll + # operation lowers to telling the TPU to refer to a different register, + # rather than actually applying any rolling operation. Hence, the op + # produces no hardware instructions. + chunk = pltpu.roll(chunk, 128, 1) + chunk_row_indices = chunk_row_indices + 128 + + return pl.pallas_call( + _toeplitz, + out_shape=jax.ShapeDtypeStruct((n, n), x.dtype), + interpret=(jax.default_backend() == "cpu"), + )(x) + + @jax.named_call @jax.jit def toeplitz_poly_mul(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: diff --git a/jaxite/jaxite_lib/matrix_utils_test.py b/jaxite/jaxite_lib/matrix_utils_test.py index 7bdeedf..a068f13 100644 --- a/jaxite/jaxite_lib/matrix_utils_test.py +++ b/jaxite/jaxite_lib/matrix_utils_test.py @@ -4,6 +4,7 @@ import hypothesis from hypothesis import strategies import jax.numpy as jnp +from jaxite.jaxite_lib import jax_helpers from jaxite.jaxite_lib import matrix_utils import numpy as np @@ -235,6 +236,15 @@ def test_i32_as_u8_matmul(self, lhs, rhs): ) np.testing.assert_array_equal(expected, actual) + @hypothesis.given(vectors(512)) + @hypothesis.settings(deadline=None) + def test_toeplitz_kernelized(self, poly): + if jax_helpers.get_tpu_version() >= 5: + multiplier = matrix_utils._generate_sign_matrix(len(poly)) + exp = multiplier.transpose() * matrix_utils.toeplitz(poly) + actual = matrix_utils.toeplitz_kernelized(poly) + np.testing.assert_array_equal(exp, actual) + @hypothesis.given(strategies.integers(min_value=0, max_value=10), vectors(16)) @hypothesis.settings(deadline=None) def test_scale_by_x_power_n_minus_1(self, power, poly): diff --git a/requirements.txt b/requirements.txt index 2f50fad..bc489a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -jax~=0.4.13 -jaxlib~=0.4.13 +jax~=0.4.24 +jaxlib~=0.4.24 diff --git a/requirements_dev.txt b/requirements_dev.txt index 99bb2b9..6d992a8 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,7 +1,7 @@ attrs==23.1.0 hypothesis==6.79.1 -jax==0.4.13 -jaxlib==0.4.13 +jax==0.4.24 +jaxlib==0.4.24 ml-dtypes==0.2.0 numpy==1.25.1 opt-einsum==3.3.0