Skip to content

Commit

Permalink
Fix matrix_utils_test to be compatible with NumPy 2. Update dtype to …
Browse files Browse the repository at this point in the history
…jnp.int32 where applicable to avoid overflow errors.

PiperOrigin-RevId: 673207362
  • Loading branch information
code-perspective authored and copybara-github committed Sep 11, 2024
1 parent cd303ce commit 19969ee
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion jaxite/jaxite_lib/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def x_power_n_minus_1(n: jnp.uint32, poly_mod_deg: jnp.uint32) -> jnp.ndarray:
degree = n % (2 * poly_mod_deg)
flip = degree // poly_mod_deg
reduced_degree = degree % poly_mod_deg
zeros = jnp.zeros(poly_mod_deg, dtype=jnp.uint32)
zeros = jnp.zeros(poly_mod_deg, dtype=jnp.int32)
return zeros.at[reduced_degree].set((-1) ** flip) - zeros.at[0].set(1)


Expand Down
8 changes: 4 additions & 4 deletions jaxite/jaxite_lib/matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,22 @@ def test_monomial_mul_32_bit_modulus(self, degree, expected):
np.testing.assert_array_equal(expected, mono_mul_output)

def test_x_power_n_minus_1(self):
expected = jnp.array([-1, 0, 1, 0], dtype=jnp.uint32)
expected = jnp.array([-1, 0, 1, 0], dtype=jnp.int32)
actual = matrix_utils.x_power_n_minus_1(n=2, poly_mod_deg=4)
np.testing.assert_array_equal(expected, actual)

def test_x_power_n_minus_1_zero(self):
expected = jnp.array([0, 0, 0, 0], dtype=jnp.uint32)
expected = jnp.array([0, 0, 0, 0], dtype=jnp.int32)
actual = matrix_utils.x_power_n_minus_1(n=0, poly_mod_deg=4)
np.testing.assert_array_equal(expected, actual)

def test_x_power_n_minus_1_reduced_degree_with_sign_flip(self):
expected = jnp.array([-1, 0, -1, 0], dtype=jnp.uint32)
expected = jnp.array([-1, 0, -1, 0], dtype=jnp.int32)
actual = matrix_utils.x_power_n_minus_1(n=6, poly_mod_deg=4)
np.testing.assert_array_equal(expected, actual)

def test_x_power_n_minus_1_reduced_degree_without_sign_flip(self):
expected = jnp.array([-1, 0, 1, 0], dtype=jnp.uint32)
expected = jnp.array([-1, 0, 1, 0], dtype=jnp.int32)
actual = matrix_utils.x_power_n_minus_1(n=10, poly_mod_deg=4)
np.testing.assert_array_equal(expected, actual)

Expand Down

0 comments on commit 19969ee

Please sign in to comment.