Skip to content

Commit

Permalink
Fix test_poly_mul to be compatible with NumPy 2. The issue is recasti…
Browse files Browse the repository at this point in the history
…ng float64 to int32. jnp.int32() seems to be rounding down overflows to signed_int32.MAX and signed_int32.MIN.

PiperOrigin-RevId: 675184853
  • Loading branch information
code-perspective authored and copybara-github committed Sep 16, 2024
1 parent 19969ee commit d5f3f09
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions jaxite/jaxite_lib/matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def test_get_cyclic_matrix(self):
cyclic_matrix = matrix_utils.toeplitz(inp)
self.assertEqual(cyclic_matrix.tolist(), [[1, 9, 2], [2, 1, 9], [9, 2, 1]])

def cast_float64_to_int32(self, x):
i32max = np.iinfo(np.int32).max
i32min = np.iinfo(np.int32).min
if x > i32max:
return jnp.int32(i32min + x % (i32max + 1))
elif x < i32min:
return jnp.int32((x % i32min) + i32max + 1)
else:
return jnp.int32(x)

def _np_polymul(self, poly1, poly2, q_mod=None):
# poly_mod represents the polynomial to divide by: x^N + 1, N = len(a)
poly_mod = jnp.zeros(len(poly1) + 1, jnp.uint32)
Expand All @@ -79,7 +89,10 @@ def _np_polymul(self, poly1, poly2, q_mod=None):
'constant',
constant_values=(0, 0),
)
return jnp.array(list(reversed(np_pad)), dtype=int)
result = jnp.array([
self.cast_float64_to_int32(x) for x in list(reversed(np_pad))
])
return result

@hypothesis.given(
vectors(10),
Expand All @@ -93,7 +106,7 @@ def test_poly_mul(self, poly1, poly2, impl):
jnp.array(poly1, dtype=int),
jnp.array(poly2, dtype=int),
)
np.testing.assert_array_equal(expected, actual)
np.testing.assert_array_equal(actual, expected)

@parameterized.named_parameters(
dict(testcase_name='no_mul', degree=0, expected=[0, 1, 2, 3]),
Expand Down

0 comments on commit d5f3f09

Please sign in to comment.