From b816617cfd000a8151a69045c50f7250ee1bb59c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 6 Sep 2024 03:26:57 +0000 Subject: [PATCH] Fix: Skip set intersections in Aarch64 emulator --- include/simsimd/sparse.h | 44 ++++++++++++++++++++-------------------- python/test.py | 22 +++++++++++++------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/include/simsimd/sparse.h b/include/simsimd/sparse.h index 8bf9a87f..783405d6 100644 --- a/include/simsimd/sparse.h +++ b/include/simsimd/sparse.h @@ -91,18 +91,18 @@ SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u16, size) // simsimd_intersect_u16_accu SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accurate #define SIMSIMD_MAKE_INTERSECT_GALLOPING(name, input_type, accumulator_type) \ - SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* b, \ - simsimd_size_t start, simsimd_size_t b_length, \ + SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* array, \ + simsimd_size_t start, simsimd_size_t length, \ simsimd_##input_type##_t val) { \ simsimd_size_t low = start; \ simsimd_size_t high = start + 1; \ - while (high < b_length && b[high] < val) { \ + while (high < length && array[high] < val) { \ low = high; \ - high = (2 * high < b_length) ? 2 * high : b_length; \ + high = (2 * high < length) ? 2 * high : length; \ } \ while (low < high) { \ simsimd_size_t mid = low + (high - low) / 2; \ - if (b[mid] < val) { \ + if (array[mid] < val) { \ low = mid + 1; \ } else { \ high = mid; \ @@ -112,31 +112,31 @@ SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accu } \ \ SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \ - simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, simsimd_size_t a_length, \ - simsimd_size_t b_length, simsimd_distance_t* result) { \ - /* Swap arrays if necessary, as we want "b" to be larger than "a" */ \ - if (a_length > b_length) { \ - simsimd_##input_type##_t const* temp = a; \ - a = b; \ - b = temp; \ - simsimd_size_t temp_length = a_length; \ - a_length = b_length; \ - b_length = temp_length; \ + simsimd_##input_type##_t const* shorter, simsimd_##input_type##_t const* longer, \ + simsimd_size_t shorter_length, simsimd_size_t longer_length, simsimd_distance_t* result) { \ + /* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \ + if (longer_length < shorter_length) { \ + simsimd_##input_type##_t const* temp = shorter; \ + shorter = longer; \ + longer = temp; \ + simsimd_size_t temp_length = shorter_length; \ + shorter_length = longer_length; \ + longer_length = temp_length; \ } \ \ - /* Use accurate implementation if galloping is not beneficial */ \ - if (b_length < 64 * a_length) { \ - simsimd_intersect_##input_type##_accurate(a, b, a_length, b_length, result); \ + /* Use the accurate implementation if galloping is not beneficial */ \ + if (longer_length < 64 * shorter_length) { \ + simsimd_intersect_##input_type##_accurate(shorter, longer, shorter_length, longer_length, result); \ return; \ } \ \ /* Perform galloping, shrinking the target range */ \ simsimd_##accumulator_type##_t intersection = 0; \ simsimd_size_t j = 0; \ - for (simsimd_size_t i = 0; i < a_length; ++i) { \ - simsimd_##input_type##_t ai = a[i]; \ - j = simsimd_galloping_search_##input_type(b, j, b_length, ai); \ - if (j < b_length && b[j] == ai) { \ + for (simsimd_size_t i = 0; i < shorter_length; ++i) { \ + simsimd_##input_type##_t shorter_i = shorter[i]; \ + j = simsimd_galloping_search_##input_type(longer, j, longer_length, shorter_i); \ + if (j < longer_length && longer[j] == shorter_i) { \ intersection++; \ } \ } \ diff --git a/python/test.py b/python/test.py index d1d57612..5c2718fb 100644 --- a/python/test.py +++ b/python/test.py @@ -1,4 +1,6 @@ import os +import platform + import pytest import simsimd as simd @@ -479,16 +481,22 @@ def test_dot_complex_explicit(ndim): @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") -@pytest.mark.repeat(300) +@pytest.mark.repeat(100) @pytest.mark.parametrize("dtype", ["uint16", "uint32"]) -@pytest.mark.parametrize("length_bound", [10, 25, 1000]) -def test_intersect(dtype, length_bound): +@pytest.mark.parametrize("first_length_bound", [10, 100, 1000]) +@pytest.mark.parametrize("second_length_bound", [10, 100, 1000]) +def test_intersect(dtype, first_length_bound, second_length_bound): """Compares the simd.intersect() function with numpy.intersect1d.""" + + if is_running_under_qemu() and (platform.machine() == "aarch64" or platform.machine() == "arm64"): + pytest.skip("In QEMU `aarch64` emulation on `x86_64` the `intersect` function is not reliable") + np.random.seed() - a_length = np.random.randint(1, length_bound) - b_length = np.random.randint(1, length_bound) - a = np.random.randint(length_bound * 2, size=a_length, dtype=dtype) - b = np.random.randint(length_bound * 2, size=b_length, dtype=dtype) + + a_length = np.random.randint(1, first_length_bound) + b_length = np.random.randint(1, second_length_bound) + a = np.random.randint(first_length_bound * 2, size=a_length, dtype=dtype) + b = np.random.randint(second_length_bound * 2, size=b_length, dtype=dtype) # Remove duplicates, converting into sorted arrays a = np.unique(a)