From d5f3f09cb6c00b973c049f6b427b636fa3140c69 Mon Sep 17 00:00:00 2001 From: Shruthi Gorantala Date: Mon, 16 Sep 2024 09:45:18 -0700 Subject: [PATCH] Fix test_poly_mul to be compatible with NumPy 2. The issue is recasting float64 to int32. jnp.int32() seems to be rounding down overflows to signed_int32.MAX and signed_int32.MIN. PiperOrigin-RevId: 675184853 --- jaxite/jaxite_lib/matrix_utils_test.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/jaxite/jaxite_lib/matrix_utils_test.py b/jaxite/jaxite_lib/matrix_utils_test.py index c6ec149..c897829 100644 --- a/jaxite/jaxite_lib/matrix_utils_test.py +++ b/jaxite/jaxite_lib/matrix_utils_test.py @@ -62,6 +62,16 @@ def test_get_cyclic_matrix(self): cyclic_matrix = matrix_utils.toeplitz(inp) self.assertEqual(cyclic_matrix.tolist(), [[1, 9, 2], [2, 1, 9], [9, 2, 1]]) + def cast_float64_to_int32(self, x): + i32max = np.iinfo(np.int32).max + i32min = np.iinfo(np.int32).min + if x > i32max: + return jnp.int32(i32min + x % (i32max + 1)) + elif x < i32min: + return jnp.int32((x % i32min) + i32max + 1) + else: + return jnp.int32(x) + def _np_polymul(self, poly1, poly2, q_mod=None): # poly_mod represents the polynomial to divide by: x^N + 1, N = len(a) poly_mod = jnp.zeros(len(poly1) + 1, jnp.uint32) @@ -79,7 +89,10 @@ def _np_polymul(self, poly1, poly2, q_mod=None): 'constant', constant_values=(0, 0), ) - return jnp.array(list(reversed(np_pad)), dtype=int) + result = jnp.array([ + self.cast_float64_to_int32(x) for x in list(reversed(np_pad)) + ]) + return result @hypothesis.given( vectors(10), @@ -93,7 +106,7 @@ def test_poly_mul(self, poly1, poly2, impl): jnp.array(poly1, dtype=int), jnp.array(poly2, dtype=int), ) - np.testing.assert_array_equal(expected, actual) + np.testing.assert_array_equal(actual, expected) @parameterized.named_parameters( dict(testcase_name='no_mul', degree=0, expected=[0, 1, 2, 3]),