diff --git a/python/test.py b/python/test.py index f58da743..d1d57612 100644 --- a/python/test.py +++ b/python/test.py @@ -479,15 +479,16 @@ def test_dot_complex_explicit(ndim): @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") -@pytest.mark.repeat(200) +@pytest.mark.repeat(300) @pytest.mark.parametrize("dtype", ["uint16", "uint32"]) -def test_intersect(dtype): +@pytest.mark.parametrize("length_bound", [10, 25, 1000]) +def test_intersect(dtype, length_bound): """Compares the simd.intersect() function with numpy.intersect1d.""" np.random.seed() - a_length = np.random.randint(1, 1024) - b_length = np.random.randint(1, 1024) - a = np.random.randint(2048, size=a_length, dtype=dtype) - b = np.random.randint(2048, size=b_length, dtype=dtype) + a_length = np.random.randint(1, length_bound) + b_length = np.random.randint(1, length_bound) + a = np.random.randint(length_bound * 2, size=a_length, dtype=dtype) + b = np.random.randint(length_bound * 2, size=b_length, dtype=dtype) # Remove duplicates, converting into sorted arrays a = np.unique(a) @@ -496,7 +497,7 @@ def test_intersect(dtype): expected = baseline_intersect(a, b) result = simd.intersect(a, b) - assert int(expected) == int(result) + assert int(expected) == int(result), f"Missing {np.intersect1d(a, b)} from {a} and {b}" @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")