Skip to content

Commit

Permalink
Fix: Skip set intersections in Aarch64 emulator
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Sep 6, 2024
1 parent db2b793 commit b816617
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
44 changes: 22 additions & 22 deletions include/simsimd/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,18 @@ SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u16, size) // simsimd_intersect_u16_accu
SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accurate

#define SIMSIMD_MAKE_INTERSECT_GALLOPING(name, input_type, accumulator_type) \
SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* b, \
simsimd_size_t start, simsimd_size_t b_length, \
SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* array, \
simsimd_size_t start, simsimd_size_t length, \
simsimd_##input_type##_t val) { \
simsimd_size_t low = start; \
simsimd_size_t high = start + 1; \
while (high < b_length && b[high] < val) { \
while (high < length && array[high] < val) { \
low = high; \
high = (2 * high < b_length) ? 2 * high : b_length; \
high = (2 * high < length) ? 2 * high : length; \
} \
while (low < high) { \
simsimd_size_t mid = low + (high - low) / 2; \
if (b[mid] < val) { \
if (array[mid] < val) { \
low = mid + 1; \
} else { \
high = mid; \
Expand All @@ -112,31 +112,31 @@ SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accu
} \
\
SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \
simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, simsimd_size_t a_length, \
simsimd_size_t b_length, simsimd_distance_t* result) { \
/* Swap arrays if necessary, as we want "b" to be larger than "a" */ \
if (a_length > b_length) { \
simsimd_##input_type##_t const* temp = a; \
a = b; \
b = temp; \
simsimd_size_t temp_length = a_length; \
a_length = b_length; \
b_length = temp_length; \
simsimd_##input_type##_t const* shorter, simsimd_##input_type##_t const* longer, \
simsimd_size_t shorter_length, simsimd_size_t longer_length, simsimd_distance_t* result) { \
/* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \
if (longer_length < shorter_length) { \
simsimd_##input_type##_t const* temp = shorter; \
shorter = longer; \
longer = temp; \
simsimd_size_t temp_length = shorter_length; \
shorter_length = longer_length; \
longer_length = temp_length; \
} \
\
/* Use accurate implementation if galloping is not beneficial */ \
if (b_length < 64 * a_length) { \
simsimd_intersect_##input_type##_accurate(a, b, a_length, b_length, result); \
/* Use the accurate implementation if galloping is not beneficial */ \
if (longer_length < 64 * shorter_length) { \
simsimd_intersect_##input_type##_accurate(shorter, longer, shorter_length, longer_length, result); \
return; \
} \
\
/* Perform galloping, shrinking the target range */ \
simsimd_##accumulator_type##_t intersection = 0; \
simsimd_size_t j = 0; \
for (simsimd_size_t i = 0; i < a_length; ++i) { \
simsimd_##input_type##_t ai = a[i]; \
j = simsimd_galloping_search_##input_type(b, j, b_length, ai); \
if (j < b_length && b[j] == ai) { \
for (simsimd_size_t i = 0; i < shorter_length; ++i) { \
simsimd_##input_type##_t shorter_i = shorter[i]; \
j = simsimd_galloping_search_##input_type(longer, j, longer_length, shorter_i); \
if (j < longer_length && longer[j] == shorter_i) { \
intersection++; \
} \
} \
Expand Down
22 changes: 15 additions & 7 deletions python/test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import platform

import pytest
import simsimd as simd

Expand Down Expand Up @@ -479,16 +481,22 @@ def test_dot_complex_explicit(ndim):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(300)
@pytest.mark.repeat(100)
@pytest.mark.parametrize("dtype", ["uint16", "uint32"])
@pytest.mark.parametrize("length_bound", [10, 25, 1000])
def test_intersect(dtype, length_bound):
@pytest.mark.parametrize("first_length_bound", [10, 100, 1000])
@pytest.mark.parametrize("second_length_bound", [10, 100, 1000])
def test_intersect(dtype, first_length_bound, second_length_bound):
"""Compares the simd.intersect() function with numpy.intersect1d."""

if is_running_under_qemu() and (platform.machine() == "aarch64" or platform.machine() == "arm64"):
pytest.skip("In QEMU `aarch64` emulation on `x86_64` the `intersect` function is not reliable")

np.random.seed()
a_length = np.random.randint(1, length_bound)
b_length = np.random.randint(1, length_bound)
a = np.random.randint(length_bound * 2, size=a_length, dtype=dtype)
b = np.random.randint(length_bound * 2, size=b_length, dtype=dtype)

a_length = np.random.randint(1, first_length_bound)
b_length = np.random.randint(1, second_length_bound)
a = np.random.randint(first_length_bound * 2, size=a_length, dtype=dtype)
b = np.random.randint(second_length_bound * 2, size=b_length, dtype=dtype)

# Remove duplicates, converting into sorted arrays
a = np.unique(a)
Expand Down

0 comments on commit b816617

Please sign in to comment.