From 191de7c97ece2214c2878b30a19fc66c9370a342 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 24 Nov 2024 13:09:19 +0000 Subject: [PATCH] Improve: Use complex types for `dense.h` --- c/lib.c | 107 ++-- include/simsimd/dot.h | 1054 ++++++++++++++++++------------------- include/simsimd/simsimd.h | 36 +- python/lib.c | 47 +- scripts/test.c | 22 +- scripts/test.py | 70 +-- 6 files changed, 675 insertions(+), 661 deletions(-) diff --git a/c/lib.c b/c/lib.c index 42872632..815f8b29 100644 --- a/c/lib.c +++ b/c/lib.c @@ -61,9 +61,10 @@ extern "C" { // If no metric is found, it returns NaN. We can obtain NaN by dividing 0.0 by 0.0, but that annoys // the MSVC compiler. Instead we can directly write-in the signaling NaN (0x7FF0000000000001) // or the qNaN (0x7FF8000000000000). -#define SIMSIMD_DECLARATION_DENSE(name, extension, type) \ - SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const *a, simsimd_##type##_t const *b, \ - simsimd_size_t n, simsimd_distance_t *results) { \ +#define SIMSIMD_DECLARATION_DENSE(name, extension) \ + SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##extension##_t const *a, \ + simsimd_##extension##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *results) { \ static simsimd_metric_dense_punned_t metric = 0; \ if (metric == 0) { \ simsimd_capability_t used_capability; \ @@ -143,54 +144,54 @@ extern "C" { } // Dot products -SIMSIMD_DECLARATION_DENSE(dot, i8, i8) -SIMSIMD_DECLARATION_DENSE(dot, u8, u8) -SIMSIMD_DECLARATION_DENSE(dot, f16, f16) -SIMSIMD_DECLARATION_DENSE(dot, bf16, bf16) -SIMSIMD_DECLARATION_DENSE(dot, f32, f32) -SIMSIMD_DECLARATION_DENSE(dot, f64, f64) -SIMSIMD_DECLARATION_DENSE(dot, f16c, f16) -SIMSIMD_DECLARATION_DENSE(dot, bf16c, bf16) -SIMSIMD_DECLARATION_DENSE(dot, f32c, f32) -SIMSIMD_DECLARATION_DENSE(dot, f64c, f64) -SIMSIMD_DECLARATION_DENSE(vdot, f16c, f16) -SIMSIMD_DECLARATION_DENSE(vdot, bf16c, bf16) -SIMSIMD_DECLARATION_DENSE(vdot, f32c, f32) -SIMSIMD_DECLARATION_DENSE(vdot, f64c, f64) +SIMSIMD_DECLARATION_DENSE(dot, i8) +SIMSIMD_DECLARATION_DENSE(dot, u8) +SIMSIMD_DECLARATION_DENSE(dot, f16) +SIMSIMD_DECLARATION_DENSE(dot, bf16) +SIMSIMD_DECLARATION_DENSE(dot, f32) +SIMSIMD_DECLARATION_DENSE(dot, f64) +SIMSIMD_DECLARATION_DENSE(dot, f16c) +SIMSIMD_DECLARATION_DENSE(dot, bf16c) +SIMSIMD_DECLARATION_DENSE(dot, f32c) +SIMSIMD_DECLARATION_DENSE(dot, f64c) +SIMSIMD_DECLARATION_DENSE(vdot, f16c) +SIMSIMD_DECLARATION_DENSE(vdot, bf16c) +SIMSIMD_DECLARATION_DENSE(vdot, f32c) +SIMSIMD_DECLARATION_DENSE(vdot, f64c) // Spatial distances -SIMSIMD_DECLARATION_DENSE(cos, i8, i8) -SIMSIMD_DECLARATION_DENSE(cos, u8, u8) -SIMSIMD_DECLARATION_DENSE(cos, f16, f16) -SIMSIMD_DECLARATION_DENSE(cos, bf16, bf16) -SIMSIMD_DECLARATION_DENSE(cos, f32, f32) -SIMSIMD_DECLARATION_DENSE(cos, f64, f64) -SIMSIMD_DECLARATION_DENSE(l2sq, i8, i8) -SIMSIMD_DECLARATION_DENSE(l2sq, u8, u8) -SIMSIMD_DECLARATION_DENSE(l2sq, f16, f16) -SIMSIMD_DECLARATION_DENSE(l2sq, bf16, bf16) -SIMSIMD_DECLARATION_DENSE(l2sq, f32, f32) -SIMSIMD_DECLARATION_DENSE(l2sq, f64, f64) -SIMSIMD_DECLARATION_DENSE(l2, i8, i8) -SIMSIMD_DECLARATION_DENSE(l2, u8, u8) -SIMSIMD_DECLARATION_DENSE(l2, f16, f16) -SIMSIMD_DECLARATION_DENSE(l2, bf16, bf16) -SIMSIMD_DECLARATION_DENSE(l2, f32, f32) -SIMSIMD_DECLARATION_DENSE(l2, f64, f64) +SIMSIMD_DECLARATION_DENSE(cos, i8) +SIMSIMD_DECLARATION_DENSE(cos, u8) +SIMSIMD_DECLARATION_DENSE(cos, f16) +SIMSIMD_DECLARATION_DENSE(cos, bf16) +SIMSIMD_DECLARATION_DENSE(cos, f32) +SIMSIMD_DECLARATION_DENSE(cos, f64) +SIMSIMD_DECLARATION_DENSE(l2sq, i8) +SIMSIMD_DECLARATION_DENSE(l2sq, u8) +SIMSIMD_DECLARATION_DENSE(l2sq, f16) +SIMSIMD_DECLARATION_DENSE(l2sq, bf16) +SIMSIMD_DECLARATION_DENSE(l2sq, f32) +SIMSIMD_DECLARATION_DENSE(l2sq, f64) +SIMSIMD_DECLARATION_DENSE(l2, i8) +SIMSIMD_DECLARATION_DENSE(l2, u8) +SIMSIMD_DECLARATION_DENSE(l2, f16) +SIMSIMD_DECLARATION_DENSE(l2, bf16) +SIMSIMD_DECLARATION_DENSE(l2, f32) +SIMSIMD_DECLARATION_DENSE(l2, f64) // Binary distances -SIMSIMD_DECLARATION_DENSE(hamming, b8, b8) -SIMSIMD_DECLARATION_DENSE(jaccard, b8, b8) +SIMSIMD_DECLARATION_DENSE(hamming, b8) +SIMSIMD_DECLARATION_DENSE(jaccard, b8) // Probability distributions -SIMSIMD_DECLARATION_DENSE(kl, f16, f16) -SIMSIMD_DECLARATION_DENSE(kl, bf16, bf16) -SIMSIMD_DECLARATION_DENSE(kl, f32, f32) -SIMSIMD_DECLARATION_DENSE(kl, f64, f64) -SIMSIMD_DECLARATION_DENSE(js, f16, f16) -SIMSIMD_DECLARATION_DENSE(js, bf16, bf16) -SIMSIMD_DECLARATION_DENSE(js, f32, f32) -SIMSIMD_DECLARATION_DENSE(js, f64, f64) +SIMSIMD_DECLARATION_DENSE(kl, f16) +SIMSIMD_DECLARATION_DENSE(kl, bf16) +SIMSIMD_DECLARATION_DENSE(kl, f32) +SIMSIMD_DECLARATION_DENSE(kl, f64) +SIMSIMD_DECLARATION_DENSE(js, f16) +SIMSIMD_DECLARATION_DENSE(js, bf16) +SIMSIMD_DECLARATION_DENSE(js, f32) +SIMSIMD_DECLARATION_DENSE(js, f64) // Sparse sets SIMSIMD_DECLARATION_SPARSE(intersect, u16, u16) @@ -259,14 +260,14 @@ SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) { simsimd_dot_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); simsimd_dot_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); - simsimd_dot_f16c((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); - simsimd_dot_bf16c((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); - simsimd_dot_f32c((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); - simsimd_dot_f64c((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); - simsimd_vdot_f16c((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); - simsimd_vdot_bf16c((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); - simsimd_vdot_f32c((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); - simsimd_vdot_f64c((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + simsimd_dot_f16c((simsimd_f16c_t *)x, (simsimd_f16c_t *)x, 0, dummy_results); + simsimd_dot_bf16c((simsimd_bf16c_t *)x, (simsimd_bf16c_t *)x, 0, dummy_results); + simsimd_dot_f32c((simsimd_f32c_t *)x, (simsimd_f32c_t *)x, 0, dummy_results); + simsimd_dot_f64c((simsimd_f64c_t *)x, (simsimd_f64c_t *)x, 0, dummy_results); + simsimd_vdot_f16c((simsimd_f16c_t *)x, (simsimd_f16c_t *)x, 0, dummy_results); + simsimd_vdot_bf16c((simsimd_bf16c_t *)x, (simsimd_bf16c_t *)x, 0, dummy_results); + simsimd_vdot_f32c((simsimd_f32c_t *)x, (simsimd_f32c_t *)x, 0, dummy_results); + simsimd_vdot_f64c((simsimd_f64c_t *)x, (simsimd_f64c_t *)x, 0, dummy_results); simsimd_cos_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results); simsimd_cos_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results); diff --git a/include/simsimd/dot.h b/include/simsimd/dot.h index 21445b50..d2449dcf 100644 --- a/include/simsimd/dot.h +++ b/include/simsimd/dot.h @@ -39,20 +39,20 @@ extern "C" { * For double-precision computation check out the "*_accurate" variants of those "*_serial" functions. */ SIMSIMD_PUBLIC void simsimd_dot_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f64c_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f64c_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f64c_serial(simsimd_f64c_t const* a, simsimd_f64c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f64c_serial(simsimd_f64c_t const* a, simsimd_f64c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f32c_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f32c_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f32c_serial(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_serial(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f16c_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f16c_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f16c_serial(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_serial(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_bf16_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_bf16c_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_serial(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_serial(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_u8_serial(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); @@ -61,50 +61,50 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_serial(simsimd_u8_t const* a, simsimd_u8_t co * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. */ SIMSIMD_PUBLIC void simsimd_dot_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f32c_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f32c_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f32c_accurate(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_accurate(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f16c_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f16c_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f16c_accurate(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_accurate(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_bf16c_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_accurate(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_accurate(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); /* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all * server CPUs produced before 2023. */ SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); /* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes. * Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs. */ SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64c_t const* a, simsimd_f64c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64c_t const* a, simsimd_f64c_t const* b, simsimd_size_t n, simsimd_distance_t* results); /* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words. * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420. @@ -113,12 +113,12 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t * properly vectorized by recent compilers. */ SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); -SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); @@ -132,23 +132,23 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t c * Sapphire Rapids added tiled matrix operations, but we are most interested in the new mixed-precision FMA instructions. */ SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64c_t const* a, simsimd_f64c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64c_t const* a, simsimd_f64c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32c_t const* a, simsimd_f32c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); // clang-format on @@ -166,70 +166,70 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t co *result = ab; \ } -#define SIMSIMD_MAKE_COMPLEX_DOT(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_dot_##input_type##c_##name(simsimd_##input_type##_t const *a, \ - simsimd_##input_type##_t const *b, simsimd_size_t n, \ - simsimd_distance_t *results) { \ - simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ - for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ - simsimd_##accumulator_type##_t ar = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t br = load_and_convert(b + i); \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i + 1); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i + 1); \ - ab_real += ar * br - ai * bi; \ - ab_imag += ar * bi + ai * br; \ - } \ - results[0] = ab_real; \ - results[1] = ab_imag; \ +#define SIMSIMD_MAKE_COMPLEX_DOT(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_dot_##input_type##_##name(simsimd_##input_type##_t const *a_pairs, \ + simsimd_##input_type##_t const *b_pairs, \ + simsimd_size_t count_pairs, simsimd_distance_t *results) { \ + simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ + for (simsimd_size_t i = 0; i != count_pairs; ++i) { \ + simsimd_##accumulator_type##_t ar = load_and_convert(&(a_pairs + i)->real); \ + simsimd_##accumulator_type##_t br = load_and_convert(&(b_pairs + i)->real); \ + simsimd_##accumulator_type##_t ai = load_and_convert(&(a_pairs + i)->imag); \ + simsimd_##accumulator_type##_t bi = load_and_convert(&(b_pairs + i)->imag); \ + ab_real += ar * br - ai * bi; \ + ab_imag += ar * bi + ai * br; \ + } \ + results[0] = ab_real; \ + results[1] = ab_imag; \ } -#define SIMSIMD_MAKE_COMPLEX_VDOT(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_vdot_##input_type##c_##name(simsimd_##input_type##_t const *a, \ - simsimd_##input_type##_t const *b, simsimd_size_t n, \ - simsimd_distance_t *results) { \ - simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ - for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ - simsimd_##accumulator_type##_t ar = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t br = load_and_convert(b + i); \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i + 1); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i + 1); \ - ab_real += ar * br + ai * bi; \ - ab_imag += ar * bi - ai * br; \ - } \ - results[0] = ab_real; \ - results[1] = ab_imag; \ +#define SIMSIMD_MAKE_COMPLEX_VDOT(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_vdot_##input_type##_##name(simsimd_##input_type##_t const *a_pairs, \ + simsimd_##input_type##_t const *b_pairs, \ + simsimd_size_t count_pairs, simsimd_distance_t *results) { \ + simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ + for (simsimd_size_t i = 0; i != count_pairs; ++i) { \ + simsimd_##accumulator_type##_t ar = load_and_convert(&(a_pairs + i)->real); \ + simsimd_##accumulator_type##_t br = load_and_convert(&(b_pairs + i)->real); \ + simsimd_##accumulator_type##_t ai = load_and_convert(&(a_pairs + i)->imag); \ + simsimd_##accumulator_type##_t bi = load_and_convert(&(b_pairs + i)->imag); \ + ab_real += ar * br + ai * bi; \ + ab_imag += ar * bi - ai * br; \ + } \ + results[0] = ab_real; \ + results[1] = ab_imag; \ } -SIMSIMD_MAKE_DOT(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f64_serial -SIMSIMD_MAKE_COMPLEX_DOT(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f64c_serial -SIMSIMD_MAKE_COMPLEX_VDOT(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_vdot_f64c_serial +SIMSIMD_MAKE_DOT(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f64_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, f64c, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f64c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, f64c, f64, SIMSIMD_DEREFERENCE) // simsimd_vdot_f64c_serial -SIMSIMD_MAKE_DOT(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_dot_f32_serial -SIMSIMD_MAKE_COMPLEX_DOT(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_dot_f32c_serial -SIMSIMD_MAKE_COMPLEX_VDOT(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_vdot_f32c_serial +SIMSIMD_MAKE_DOT(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_dot_f32_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, f32c, f32, SIMSIMD_DEREFERENCE) // simsimd_dot_f32c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, f32c, f32, SIMSIMD_DEREFERENCE) // simsimd_vdot_f32c_serial -SIMSIMD_MAKE_DOT(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_dot_f16_serial -SIMSIMD_MAKE_COMPLEX_DOT(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_dot_f16c_serial -SIMSIMD_MAKE_COMPLEX_VDOT(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_vdot_f16c_serial +SIMSIMD_MAKE_DOT(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_dot_f16_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, f16c, f32, SIMSIMD_F16_TO_F32) // simsimd_dot_f16c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, f16c, f32, SIMSIMD_F16_TO_F32) // simsimd_vdot_f16c_serial -SIMSIMD_MAKE_DOT(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16_serial -SIMSIMD_MAKE_COMPLEX_DOT(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16c_serial -SIMSIMD_MAKE_COMPLEX_VDOT(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_vdot_bf16c_serial +SIMSIMD_MAKE_DOT(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, bf16c, f32, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, bf16c, f32, SIMSIMD_BF16_TO_F32) // simsimd_vdot_bf16c_serial SIMSIMD_MAKE_DOT(serial, i8, i64, SIMSIMD_DEREFERENCE) // simsimd_dot_i8_serial SIMSIMD_MAKE_DOT(serial, u8, i64, SIMSIMD_DEREFERENCE) // simsimd_dot_u8_serial -SIMSIMD_MAKE_DOT(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f32_accurate -SIMSIMD_MAKE_COMPLEX_DOT(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f32c_accurate -SIMSIMD_MAKE_COMPLEX_VDOT(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_vdot_f32c_accurate +SIMSIMD_MAKE_DOT(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f32_accurate +SIMSIMD_MAKE_COMPLEX_DOT(accurate, f32c, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f32c_accurate +SIMSIMD_MAKE_COMPLEX_VDOT(accurate, f32c, f64, SIMSIMD_DEREFERENCE) // simsimd_vdot_f32c_accurate -SIMSIMD_MAKE_DOT(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_dot_f16_accurate -SIMSIMD_MAKE_COMPLEX_DOT(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_dot_f16c_accurate -SIMSIMD_MAKE_COMPLEX_VDOT(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_vdot_f16c_accurate +SIMSIMD_MAKE_DOT(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_dot_f16_accurate +SIMSIMD_MAKE_COMPLEX_DOT(accurate, f16c, f64, SIMSIMD_F16_TO_F32) // simsimd_dot_f16c_accurate +SIMSIMD_MAKE_COMPLEX_VDOT(accurate, f16c, f64, SIMSIMD_F16_TO_F32) // simsimd_vdot_f16c_accurate -SIMSIMD_MAKE_DOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16_accurate -SIMSIMD_MAKE_COMPLEX_DOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16c_accurate -SIMSIMD_MAKE_COMPLEX_VDOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_vdot_bf16c_accurate +SIMSIMD_MAKE_DOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16_accurate +SIMSIMD_MAKE_COMPLEX_DOT(accurate, bf16c, f64, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16c_accurate +SIMSIMD_MAKE_COMPLEX_VDOT(accurate, bf16c, f64, SIMSIMD_BF16_TO_F32) // simsimd_vdot_bf16c_accurate #if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON @@ -237,40 +237,40 @@ SIMSIMD_MAKE_COMPLEX_VDOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_v #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SIMSIMD_INTERNAL float32x4_t _simsimd_partial_load_f32x4_neon(simsimd_f32_t const *a, simsimd_size_t n) { +SIMSIMD_INTERNAL float32x4_t _simsimd_partial_load_f32x4_neon(simsimd_f32_t const *x, simsimd_size_t n) { union { float32x4_t vec; simsimd_f32_t scalars[4]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < n; ++i) result.scalars[i] = x[i]; for (; i < 4; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const *a_scalars, simsimd_f32_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { float32x4_t ab_vec = vdupq_n_f32(0); - simsimd_size_t i = 0; - for (; i + 4 <= n; i += 4) { - float32x4_t a_vec = vld1q_f32(a + i); - float32x4_t b_vec = vld1q_f32(b + i); + simsimd_size_t idx_scalars = 0; + for (; idx_scalars + 4 <= count_scalars; idx_scalars += 4) { + float32x4_t a_vec = vld1q_f32(a_scalars + idx_scalars); + float32x4_t b_vec = vld1q_f32(b_scalars + idx_scalars); ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); } simsimd_f32_t ab = vaddvq_f32(ab_vec); - for (; i < n; ++i) ab += a[i] * b[i]; + for (; idx_scalars < count_scalars; ++idx_scalars) ab += a_scalars[idx_scalars] * b_scalars[idx_scalars]; *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const *a_pairs, simsimd_f32_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); - simsimd_size_t i = 0; - for (; i + 8 <= n; i += 8) { + simsimd_size_t idx_pairs = 0; + for (; idx_pairs + 4 <= count_pairs; idx_pairs += 4) { // Unpack the input arrays into real and imaginary parts: - float32x4x2_t a_vec = vld2q_f32(a + i); - float32x4x2_t b_vec = vld2q_f32(b + i); + float32x4x2_t a_vec = vld2q_f32(a_pairs + idx_pairs); + float32x4x2_t b_vec = vld2q_f32(b_pairs + idx_pairs); float32x4_t a_real_vec = a_vec.val[0]; float32x4_t a_imag_vec = a_vec.val[1]; float32x4_t b_real_vec = b_vec.val[0]; @@ -288,8 +288,9 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); // Handle the tail: - for (; i + 2 <= n; i += 2) { - simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + for (; idx_pairs != count_pairs; ++idx_pairs) { + simsimd_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs]; + simsimd_f32_t ar = a_pair.real, ai = a_pair.imag, br = b_pair.real, bi = b_pair.imag; ab_real += ar * br - ai * bi; ab_imag += ar * bi + ai * br; } @@ -297,15 +298,15 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t results[1] = ab_imag; } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const *a_pairs, simsimd_f32_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); - simsimd_size_t i = 0; - for (; i + 8 <= n; i += 8) { + simsimd_size_t idx_pairs = 0; + for (; idx_pairs + 4 <= count_pairs; idx_pairs += 4) { // Unpack the input arrays into real and imaginary parts: - float32x4x2_t a_vec = vld2q_f32(a + i); - float32x4x2_t b_vec = vld2q_f32(b + i); + float32x4x2_t a_vec = vld2q_f32(a_pairs + idx_pairs); + float32x4x2_t b_vec = vld2q_f32(b_pairs + idx_pairs); float32x4_t a_real_vec = a_vec.val[0]; float32x4_t a_imag_vec = a_vec.val[1]; float32x4_t b_real_vec = b_vec.val[0]; @@ -323,8 +324,9 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); // Handle the tail: - for (; i + 2 <= n; i += 2) { - simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + for (; idx_pairs != count_pairs; ++idx_pairs) { + simsimd_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs]; + simsimd_f32_t ar = a_pair.real, ai = a_pair.imag, br = b_pair.real, bi = b_pair.imag; ab_real += ar * br + ai * bi; ab_imag += ar * bi - ai * br; } @@ -341,51 +343,51 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t #pragma GCC target("arch=armv8.2-a+dotprod") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const *a_scalars, simsimd_i8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { int32x4_t ab_vec = vdupq_n_s32(0); - simsimd_size_t i = 0; + simsimd_size_t idx_scalars = 0; // If the 128-bit `vdot_s32` intrinsic is unavailable, we can use the 64-bit `vdot_s32`. - // for (simsimd_size_t i = 0; i != n; i += 8) { - // int16x8_t a_vec = vmovl_s8(vld1_s8(a + i)); - // int16x8_t b_vec = vmovl_s8(vld1_s8(b + i)); + // for (simsimd_size_t idx_scalars = 0; idx_scalars != n; idx_scalars += 8) { + // int16x8_t a_vec = vmovl_s8(vld1_s8(a_scalars + idx_scalars)); + // int16x8_t b_vec = vmovl_s8(vld1_s8(b_scalars + idx_scalars)); // int16x8_t ab_part_vec = vmulq_s16(a_vec, b_vec); // ab_vec = vaddq_s32(ab_vec, vaddq_s32(vmovl_s16(vget_high_s16(ab_part_vec)), // // vmovl_s16(vget_low_s16(ab_part_vec)))); // } - for (; i + 16 <= n; i += 16) { - int8x16_t a_vec = vld1q_s8(a + i); - int8x16_t b_vec = vld1q_s8(b + i); + for (; idx_scalars + 16 <= count_scalars; idx_scalars += 16) { + int8x16_t a_vec = vld1q_s8(a_scalars + idx_scalars); + int8x16_t b_vec = vld1q_s8(b_scalars + idx_scalars); ab_vec = vdotq_s32(ab_vec, a_vec, b_vec); } // Take care of the tail: simsimd_i32_t ab = vaddvq_s32(ab_vec); - for (; i < n; ++i) { - simsimd_i32_t ai = a[i], bi = b[i]; + for (; idx_scalars < count_scalars; ++idx_scalars) { + simsimd_i32_t ai = a[idx_scalars], bi = b[idx_scalars]; ab += ai * bi; } *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const *a_scalars, simsimd_u8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { uint32x4_t ab_vec = vdupq_n_u32(0); - simsimd_size_t i = 0; - for (; i + 16 <= n; i += 16) { - uint8x16_t a_vec = vld1q_u8(a + i); - uint8x16_t b_vec = vld1q_u8(b + i); + simsimd_size_t idx_scalars = 0; + for (; idx_scalars + 16 <= count_scalars; idx_scalars += 16) { + uint8x16_t a_vec = vld1q_u8(a_scalars + idx_scalars); + uint8x16_t b_vec = vld1q_u8(b_scalars + idx_scalars); ab_vec = vdotq_u32(ab_vec, a_vec, b_vec); } // Take care of the tail: simsimd_u32_t ab = vaddvq_u32(ab_vec); - for (; i < n; ++i) { - simsimd_u32_t ai = a[i], bi = b[i]; + for (; idx_scalars < count_scalars; ++idx_scalars) { + simsimd_u32_t ai = a_scalars[idx_scalars], bi = b_scalars[idx_scalars]; ab += ai * bi; } @@ -401,7 +403,7 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const *a, simsimd_u8_t cons #pragma GCC target("arch=armv8.2-a+simd+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) -SIMSIMD_INTERNAL float16x4_t _simsimd_partial_load_f16x4_neon(simsimd_f16_t const *a, simsimd_size_t n) { +SIMSIMD_INTERNAL float16x4_t _simsimd_partial_load_f16x4_neon(simsimd_f16_t const *x, simsimd_size_t n) { // In case the software emulation for `f16` scalars is enabled, the `simsimd_f16_to_f32` // function will run. It is extremely slow, so even for the tail, let's combine serial // loads and stores with vectorized math. @@ -410,35 +412,35 @@ SIMSIMD_INTERNAL float16x4_t _simsimd_partial_load_f16x4_neon(simsimd_f16_t cons simsimd_f16_t scalars[4]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < n; ++i) result.scalars[i] = x[i]; for (; i < 4; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - float32x4_t ab_vec = vdupq_n_f32(0); +SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const *a_scalars, simsimd_f16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { float32x4_t a_vec, b_vec; + float32x4_t ab_vec = vdupq_n_f32(0); simsimd_size_t i = 0; simsimd_dot_f16_neon_cycle: - if (n < 4) { - a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a, n)); - b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b, n)); - n = 0; + if (count_scalars < 4) { + a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a_scalars, count_scalars)); + b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b_scalars, count_scalars)); + count_scalars = 0; } else { - a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a)); - b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b)); - a += 4, b += 4, n -= 4; + a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a_scalars)); + b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b_scalars)); + a_scalars += 4, b_scalars += 4, count_scalars -= 4; } ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); - if (n) goto simsimd_dot_f16_neon_cycle; + if (count_scalars) goto simsimd_dot_f16_neon_cycle; *result = vaddvq_f32(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { // A nicer approach is to use `f16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -446,12 +448,12 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); - while (n >= 8) { + while (count_pairs >= 4) { // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short *)a); - int16x4x2_t b_vec = vld2_s16((short *)b); + int16x4x2_t a_vec = vld2_s16((short *)a_pairs); + int16x4x2_t b_vec = vld2_s16((short *)b_pairs); float32x4_t a_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[0])); @@ -463,17 +465,17 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec); - n -= 8, a += 8, b += 8; + count_pairs -= 4, a_pairs += 4, b_pairs += 4; } // Reduce horizontal sums and aggregate with the tail: - simsimd_dot_f16c_serial(a, b, n, results); + simsimd_dot_f16c_serial(a_pairs, b_pairs, count_pairs, results); results[0] += vaddvq_f32(ab_real_vec); results[1] += vaddvq_f32(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { // A nicer approach is to use `f16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -481,12 +483,12 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); - while (n >= 8) { + while (count_pairs >= 4) { // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short *)a); - int16x4x2_t b_vec = vld2_s16((short *)b); + int16x4x2_t a_vec = vld2_s16((short *)a_pairs); + int16x4x2_t b_vec = vld2_s16((short *)b_pairs); float32x4_t a_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[0])); @@ -498,11 +500,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = vfmsq_f32(ab_imag_vec, a_imag_vec, b_real_vec); - n -= 8, a += 8, b += 8; + count_pairs -= 4, a_pairs += 4, b_pairs += 4; } // Reduce horizontal sums and aggregate with the tail: - simsimd_vdot_f16c_serial(a, b, n, results); + simsimd_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, results); results[0] += vaddvq_f32(ab_real_vec); results[1] += vaddvq_f32(ab_imag_vec); } @@ -516,41 +518,41 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t #pragma GCC target("arch=armv8.6-a+simd+bf16") #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function) -SIMSIMD_INTERNAL bfloat16x8_t _simsimd_partial_load_bf16x8_neon(simsimd_bf16_t const *a, simsimd_size_t n) { +SIMSIMD_INTERNAL bfloat16x8_t _simsimd_partial_load_bf16x8_neon(simsimd_bf16_t const *x, simsimd_size_t n) { union { bfloat16x8_t vec; simsimd_bf16_t scalars[8]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < n; ++i) result.scalars[i] = x[i]; for (; i < 8; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - float32x4_t ab_vec = vdupq_n_f32(0); +SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const *a_scalars, simsimd_bf16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { bfloat16x8_t a_vec, b_vec; + float32x4_t ab_vec = vdupq_n_f32(0); simsimd_dot_bf16_neon_cycle: - if (n < 8) { - a_vec = _simsimd_partial_load_bf16x8_neon(a, n); - b_vec = _simsimd_partial_load_bf16x8_neon(b, n); - n = 0; + if (count_scalars < 4) { + a_vec = _simsimd_partial_load_bf16x8_neon(a_scalars, count_scalars); + b_vec = _simsimd_partial_load_bf16x8_neon(b_scalars, count_scalars); + count_scalars = 0; } else { - a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)a); - b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)b); - a += 8, b += 8, n -= 8; + a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)a_scalars); + b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)b_scalars); + a += 4, b += 4, count_scalars -= 4; } ab_vec = vbfdotq_f32(ab_vec, a_vec, b_vec); - if (n) goto simsimd_dot_bf16_neon_cycle; + if (count_scalars) goto simsimd_dot_bf16_neon_cycle; *result = vaddvq_f32(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16c_t const *a_pairs, simsimd_bf16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { // A nicer approach is to use `bf16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -558,12 +560,12 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf16 float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); - while (n >= 8) { + while (count_pairs >= 4) { // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short *)a); - int16x4x2_t b_vec = vld2_s16((short *)b); + int16x4x2_t a_vec = vld2_s16((short *)a_pairs); + int16x4x2_t b_vec = vld2_s16((short *)b_pairs); float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0])); @@ -575,17 +577,17 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf16 ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec); - n -= 8, a += 8, b += 8; + count_pairs -= 4, a_pairs += 4, b_pairs += 4; } // Reduce horizontal sums and aggregate with the tail: - simsimd_dot_bf16c_serial(a, b, n, results); + simsimd_dot_bf16c_serial(a_pairs, b_pairs, count_pairs, results); results[0] += vaddvq_f32(ab_real_vec); results[1] += vaddvq_f32(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16c_t const *a_pairs, simsimd_bf16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { // A nicer approach is to use `bf16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -593,7 +595,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf1 float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); - while (n >= 8) { + while (count_pairs >= 4) { // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards. @@ -610,11 +612,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf1 ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = vfmsq_f32(ab_imag_vec, a_imag_vec, b_real_vec); - n -= 8, a += 8, b += 8; + count_pairs -= 4, a_pairs += 4, b_pairs += 4; } // Reduce horizontal sums and aggregate with the tail: - simsimd_vdot_bf16c_serial(a, b, n, results); + simsimd_vdot_bf16c_serial(a_pairs, b_pairs, count_pairs, results); results[0] += vaddvq_f32(ab_real_vec); results[1] += vaddvq_f32(ab_imag_vec); } @@ -629,29 +631,29 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf1 #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const *a_scalars, simsimd_f32_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { + simsimd_size_t idx_scalars = 0; svfloat32_t ab_vec = svdup_f32(0.f); do { - svbool_t pg_vec = svwhilelt_b32((unsigned int)i, (unsigned int)n); - svfloat32_t a_vec = svld1_f32(pg_vec, a + i); - svfloat32_t b_vec = svld1_f32(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b32((unsigned int)idx_scalars, (unsigned int)count_scalars); + svfloat32_t a_vec = svld1_f32(pg_vec, a_scalars + idx_scalars); + svfloat32_t b_vec = svld1_f32(pg_vec, b_scalars + idx_scalars); ab_vec = svmla_f32_x(pg_vec, ab_vec, a_vec, b_vec); - i += svcntw(); - } while (i < n); + idx_scalars += svcntw(); + } while (idx_scalars < count_scalars); *result = svaddv_f32(svptrue_b32(), ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32c_t const *a_pairs, simsimd_f32c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + simsimd_size_t idx_pairs = 0; svfloat32_t ab_real_vec = svdup_f32(0.f); svfloat32_t ab_imag_vec = svdup_f32(0.f); do { - svbool_t pg_vec = svwhilelt_b32((unsigned int)i / 2, (unsigned int)n / 2); - svfloat32x2_t a_vec = svld2_f32(pg_vec, a + i); - svfloat32x2_t b_vec = svld2_f32(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b32((unsigned int)idx_pairs, (unsigned int)count_pairs); + svfloat32x2_t a_vec = svld2_f32(pg_vec, (simsimd_f32_t const *)(a_pairs + idx_pairs)); + svfloat32x2_t b_vec = svld2_f32(pg_vec, (simsimd_f32_t const *)(b_pairs + idx_pairs)); svfloat32_t a_real_vec = svget2_f32(a_vec, 0); svfloat32_t a_imag_vec = svget2_f32(a_vec, 1); svfloat32_t b_real_vec = svget2_f32(b_vec, 0); @@ -660,21 +662,21 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const *a, simsimd_f32_t c ab_real_vec = svmls_f32_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); ab_imag_vec = svmla_f32_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = svmla_f32_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); - i += svcntw() * 2; - } while (i < n); + idx_pairs += svcntw(); + } while (idx_pairs < count_pairs); results[0] = svaddv_f32(svptrue_b32(), ab_real_vec); results[1] = svaddv_f32(svptrue_b32(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32c_t const *a_pairs, simsimd_f32c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + simsimd_size_t idx_pairs = 0; svfloat32_t ab_real_vec = svdup_f32(0.f); svfloat32_t ab_imag_vec = svdup_f32(0.f); do { - svbool_t pg_vec = svwhilelt_b32((unsigned int)i / 2, (unsigned int)n / 2); - svfloat32x2_t a_vec = svld2_f32(pg_vec, a + i); - svfloat32x2_t b_vec = svld2_f32(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b32((unsigned int)idx_pairs, (unsigned int)count_pairs); + svfloat32x2_t a_vec = svld2_f32(pg_vec, (simsimd_f32_t const *)(a_pairs + idx_pairs)); + svfloat32x2_t b_vec = svld2_f32(pg_vec, (simsimd_f32_t const *)(b_pairs + idx_pairs)); svfloat32_t a_real_vec = svget2_f32(a_vec, 0); svfloat32_t a_imag_vec = svget2_f32(a_vec, 1); svfloat32_t b_real_vec = svget2_f32(b_vec, 0); @@ -683,35 +685,35 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const *a, simsimd_f32_t ab_real_vec = svmla_f32_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); ab_imag_vec = svmla_f32_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = svmls_f32_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); - i += svcntw() * 2; - } while (i < n); + idx_pairs += svcntw(); + } while (idx_pairs < count_pairs); results[0] = svaddv_f32(svptrue_b32(), ab_real_vec); results[1] = svaddv_f32(svptrue_b32(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const *a_scalars, simsimd_f64_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { + simsimd_size_t idx_scalars = 0; svfloat64_t ab_vec = svdup_f64(0.); do { - svbool_t pg_vec = svwhilelt_b64((unsigned int)i, (unsigned int)n); - svfloat64_t a_vec = svld1_f64(pg_vec, a + i); - svfloat64_t b_vec = svld1_f64(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b64((unsigned int)idx_scalars, (unsigned int)count_scalars); + svfloat64_t a_vec = svld1_f64(pg_vec, a_scalars + idx_scalars); + svfloat64_t b_vec = svld1_f64(pg_vec, b_scalars + idx_scalars); ab_vec = svmla_f64_x(pg_vec, ab_vec, a_vec, b_vec); - i += svcntd(); - } while (i < n); + idx_scalars += svcntd(); + } while (idx_scalars < count_scalars); *result = svaddv_f64(svptrue_b32(), ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64c_t const *a_pairs, simsimd_f64c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + simsimd_size_t idx_pairs = 0; svfloat64_t ab_real_vec = svdup_f64(0.); svfloat64_t ab_imag_vec = svdup_f64(0.); do { - svbool_t pg_vec = svwhilelt_b64((unsigned int)i / 2, (unsigned int)n / 2); - svfloat64x2_t a_vec = svld2_f64(pg_vec, a + i); - svfloat64x2_t b_vec = svld2_f64(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b64((unsigned int)idx_pairs, (unsigned int)count_pairs); + svfloat64x2_t a_vec = svld2_f64(pg_vec, (simsimd_f64_t const *)(a_pairs + idx_pairs)); + svfloat64x2_t b_vec = svld2_f64(pg_vec, (simsimd_f64_t const *)(b_pairs + idx_pairs)); svfloat64_t a_real_vec = svget2_f64(a_vec, 0); svfloat64_t a_imag_vec = svget2_f64(a_vec, 1); svfloat64_t b_real_vec = svget2_f64(b_vec, 0); @@ -720,21 +722,21 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t c ab_real_vec = svmls_f64_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); ab_imag_vec = svmla_f64_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = svmla_f64_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); - i += svcntd() * 2; - } while (i < n); + idx_pairs += svcntd(); + } while (idx_pairs < count_pairs); results[0] = svaddv_f64(svptrue_b64(), ab_real_vec); results[1] = svaddv_f64(svptrue_b64(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64c_t const *a_pairs, simsimd_f64c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + simsimd_size_t idx_pairs = 0; svfloat64_t ab_real_vec = svdup_f64(0.); svfloat64_t ab_imag_vec = svdup_f64(0.); do { - svbool_t pg_vec = svwhilelt_b64((unsigned int)i / 2, (unsigned int)n / 2); - svfloat64x2_t a_vec = svld2_f64(pg_vec, a + i); - svfloat64x2_t b_vec = svld2_f64(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b64((unsigned int)idx_pairs, (unsigned int)count_pairs); + svfloat64x2_t a_vec = svld2_f64(pg_vec, (simsimd_f64_t const *)(a_pairs + idx_pairs)); + svfloat64x2_t b_vec = svld2_f64(pg_vec, (simsimd_f64_t const *)(b_pairs + idx_pairs)); svfloat64_t a_real_vec = svget2_f64(a_vec, 0); svfloat64_t a_imag_vec = svget2_f64(a_vec, 1); svfloat64_t b_real_vec = svget2_f64(b_vec, 0); @@ -743,8 +745,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t ab_real_vec = svmla_f64_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); ab_imag_vec = svmla_f64_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = svmls_f64_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); - i += svcntd() * 2; - } while (i < n); + idx_pairs += svcntd(); + } while (idx_pairs < count_pairs); results[0] = svaddv_f64(svptrue_b64(), ab_real_vec); results[1] = svaddv_f64(svptrue_b64(), ab_imag_vec); } @@ -756,32 +758,30 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t #pragma GCC target("arch=armv8.2-a+sve+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const *a_enum, simsimd_f16_t const *b_enum, simsimd_size_t n, - simsimd_distance_t *result) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const *a_scalars, simsimd_f16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { + simsimd_size_t idx_scalars = 0; svfloat16_t ab_vec = svdup_f16(0); - simsimd_f16_for_arm_simd_t const *a = (simsimd_f16_for_arm_simd_t const *)(a_enum); - simsimd_f16_for_arm_simd_t const *b = (simsimd_f16_for_arm_simd_t const *)(b_enum); do { - svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); - svfloat16_t a_vec = svld1_f16(pg_vec, a + i); - svfloat16_t b_vec = svld1_f16(pg_vec, b + i); + svbool_t pg_vec = svwhilelt_b16((unsigned int)idx_scalars, (unsigned int)count_scalars); + svfloat16_t a_vec = svld1_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)(a_scalars + idx_scalars)); + svfloat16_t b_vec = svld1_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)(b_scalars + idx_scalars)); ab_vec = svmla_f16_x(pg_vec, ab_vec, a_vec, b_vec); - i += svcnth(); - } while (i < n); + idx_scalars += svcnth(); + } while (idx_scalars < count_scalars); simsimd_f16_for_arm_simd_t ab = svaddv_f16(svptrue_b16(), ab_vec); *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + simsimd_size_t idx_scalars = 0; svfloat16_t ab_real_vec = svdup_f16(0); svfloat16_t ab_imag_vec = svdup_f16(0); do { - svbool_t pg_vec = svwhilelt_b16((unsigned int)i / 2, (unsigned int)n / 2); - svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)a + i); - svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)b + i); + svbool_t pg_vec = svwhilelt_b32((unsigned int)idx_pairs, (unsigned int)count_pairs); + svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)(a_pairs + idx_pairs)); + svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)(b_pairs + idx_pairs)); svfloat16_t a_real_vec = svget2_f16(a_vec, 0); svfloat16_t a_imag_vec = svget2_f16(a_vec, 1); svfloat16_t b_real_vec = svget2_f16(b_vec, 0); @@ -790,21 +790,21 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const *a, simsimd_f16_t c ab_real_vec = svmls_f16_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); ab_imag_vec = svmla_f16_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = svmla_f16_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); - i += svcnth() * 2; - } while (i < n); + idx_pairs += svcnth(); + } while (idx_pairs < count_pairs); results[0] = svaddv_f16(svptrue_b16(), ab_real_vec); results[1] = svaddv_f16(svptrue_b16(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { - simsimd_size_t i = 0; +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + simsimd_size_t idx_scalars = 0; svfloat16_t ab_real_vec = svdup_f16(0); svfloat16_t ab_imag_vec = svdup_f16(0); do { - svbool_t pg_vec = svwhilelt_b16((unsigned int)i / 2, (unsigned int)n / 2); - svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)a + i); - svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)b + i); + svbool_t pg_vec = svwhilelt_b32((unsigned int)idx_pairs, (unsigned int)count_pairs); + svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)(a_pairs + idx_pairs)); + svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)(b_pairs + idx_pairs)); svfloat16_t a_real_vec = svget2_f16(a_vec, 0); svfloat16_t a_imag_vec = svget2_f16(a_vec, 1); svfloat16_t b_real_vec = svget2_f16(b_vec, 0); @@ -813,8 +813,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const *a, simsimd_f16_t ab_real_vec = svmla_f16_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); ab_imag_vec = svmla_f16_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); ab_imag_vec = svmls_f16_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); - i += svcnth() * 2; - } while (i < n); + idx_pairs += svcnth(); + } while (idx_pairs < count_pairs); results[0] = svaddv_f16(svptrue_b16(), ab_real_vec); results[1] = svaddv_f16(svptrue_b16(), ab_imag_vec); } @@ -867,23 +867,23 @@ SIMSIMD_INTERNAL simsimd_i32_t _simsimd_reduce_i32x8_haswell(__m256i vec) { return _mm_cvtsi128_si32(sum); } -SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const *a_scalars, simsimd_f32_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *results) { __m256 ab_vec = _mm256_setzero_ps(); - simsimd_size_t i = 0; - for (; i + 8 <= n; i += 8) { - __m256 a_vec = _mm256_loadu_ps(a + i); - __m256 b_vec = _mm256_loadu_ps(b + i); + simsimd_size_t idx_scalars = 0; + for (; idx_scalars + 8 <= count_scalars; idx_scalars += 8) { + __m256 a_vec = _mm256_loadu_ps(a_scalars + idx_scalars); + __m256 b_vec = _mm256_loadu_ps(b_scalars + idx_scalars); ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); } simsimd_f64_t ab = _simsimd_reduce_f32x8_haswell(ab_vec); - for (; i < n; ++i) ab += a[i] * b[i]; + for (; idx_scalars < count_scalars; ++idx_scalars) ab += a_scalars[idx_scalars] * b_scalars[idx_scalars]; *results = ab; } -SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32c_t const *a_pairs, simsimd_f32c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { // The naive approach would be to use FMA and FMS instructions on different parts of the vectors. // Prior to that we would need to shuffle the input vectors to separate real and imaginary parts. @@ -891,10 +891,10 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32 // __m128 ab_real_vec = _mm_setzero_ps(); // __m128 ab_imag_vec = _mm_setzero_ps(); // __m256i permute_vec = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - // simsimd_size_t i = 0; - // for (; i + 8 <= n; i += 8) { - // __m256 a_vec = _mm256_loadu_ps(a + i); - // __m256 b_vec = _mm256_loadu_ps(b + i); + // simsimd_size_t idx_pairs = 0; + // for (; idx_pairs + 4 <= count_pairs; idx_pairs += 4) { + // __m256 a_vec = _mm256_loadu_ps((simsimd_f32_t const *)(a_pairs + idx_pairs)); + // __m256 b_vec = _mm256_loadu_ps((simsimd_f32_t const *)(b_pairs + idx_pairs)); // __m256 a_shuffled = _mm256_permutevar8x32_ps(a_vec, permute_vec); // __m256 b_shuffled = _mm256_permutevar8x32_ps(b_vec, permute_vec); // __m128 a_real_vec = _mm256_extractf128_ps(a_shuffled, 0); @@ -926,10 +926,11 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32 3, 2, 1, 0, // Points to the first f32 in 128-bit lane 7, 6, 5, 4 // Points to the second f32 in 128-bit lane ); - simsimd_size_t i = 0; - for (; i + 8 <= n; i += 8) { - __m256 a_vec = _mm256_loadu_ps(a + i); - __m256 b_vec = _mm256_loadu_ps(b + i); + + simsimd_size_t idx_pairs = 0; + for (; idx_pairs + 4 <= count_pairs; idx_pairs += 4) { + __m256 a_vec = _mm256_loadu_ps((simsimd_f32_t const *)(a_pairs + idx_pairs)); + __m256 b_vec = _mm256_loadu_ps((simsimd_f32_t const *)(b_pairs + idx_pairs)); __m256 b_swapped_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); ab_imag_vec = _mm256_fmadd_ps(a_vec, b_swapped_vec, ab_imag_vec); @@ -943,8 +944,9 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32 simsimd_distance_t ab_imag = _simsimd_reduce_f32x8_haswell(ab_imag_vec); // Handle the tail: - for (; i + 2 <= n; i += 2) { - simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + for (; idx_pairs != count_pairs; ++idx_pairs) { + simsimd_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs]; + simsimd_f32_t ar = a_pair.real, ai = a_pair.imag, br = b_pair.real, bi = b_pair.imag; ab_real += ar * br - ai * bi; ab_imag += ar * bi + ai * br; } @@ -952,8 +954,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32 results[1] = ab_imag; } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32c_t const *a_pairs, simsimd_f32c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { __m256 ab_real_vec = _mm256_setzero_ps(); __m256 ab_imag_vec = _mm256_setzero_ps(); @@ -968,10 +970,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const *a, simsimd_f3 3, 2, 1, 0, // Points to the first f32 in 128-bit lane 7, 6, 5, 4 // Points to the second f32 in 128-bit lane ); - simsimd_size_t i = 0; - for (; i + 8 <= n; i += 8) { - __m256 a_vec = _mm256_loadu_ps(a + i); - __m256 b_vec = _mm256_loadu_ps(b + i); + + simsimd_size_t idx_pairs = 0; + for (; idx_pairs + 4 <= count_pairs; idx_pairs += 4) { + __m256 a_vec = _mm256_loadu_ps((simsimd_f32_t const *)(a_pairs + idx_pairs)); + __m256 b_vec = _mm256_loadu_ps((simsimd_f32_t const *)(b_pairs + idx_pairs)); ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); b_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_imag_vec); @@ -985,8 +988,9 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const *a, simsimd_f3 simsimd_distance_t ab_imag = _simsimd_reduce_f32x8_haswell(ab_imag_vec); // Handle the tail: - for (; i + 2 <= n; i += 2) { - simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + for (; idx_pairs != count_pairs; ++idx_pairs) { + simsimd_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs]; + simsimd_f32_t ar = a_pair.real, ai = a_pair.imag, br = b_pair.real, bi = b_pair.imag; ab_real += ar * br + ai * bi; ab_imag += ar * bi - ai * br; } @@ -1008,21 +1012,21 @@ SIMSIMD_INTERNAL __m256 _simsimd_partial_load_f16x8_haswell(simsimd_f16_t const return _mm256_cvtph_ps(result.vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const *a_scalars, simsimd_f16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m256 a_vec, b_vec; __m256 ab_vec = _mm256_setzero_ps(); simsimd_dot_f16_haswell_cycle: - if (n < 8) { - a_vec = _simsimd_partial_load_f16x8_haswell(a, n); - b_vec = _simsimd_partial_load_f16x8_haswell(b, n); - n = 0; + if (count_scalars < 8) { + a_vec = _simsimd_partial_load_f16x8_haswell(a_scalars, count_scalars); + b_vec = _simsimd_partial_load_f16x8_haswell(b_scalars, count_scalars); + count_scalars = 0; } else { - a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); - b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); - n -= 8, a += 8, b += 8; + a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a_scalars)); + b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b_scalars)); + count_scalars -= 8, a_scalars += 8, b_scalars += 8; } // We can silence the NaNs using blends: // @@ -1031,13 +1035,14 @@ SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const *a, simsimd_f16_ // ab_vec = _mm256_blendv_ps(_mm256_fmadd_ps(a_vec, b_vec, ab_vec), ab_vec, _mm256_or_ps(a_is_nan, b_is_nan)); // ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); - if (n) goto simsimd_dot_f16_haswell_cycle; + if (count_scalars) goto simsimd_dot_f16_haswell_cycle; *result = _simsimd_reduce_f32x8_haswell(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + // Ideally the implementation would load 256 bits worth of vector data at a time, // shuffle those within a register, split in halfs, and only then upcast. // That way, we are stepping through 32x 16-bit vector components at a time, or 16 dimensions. @@ -1059,28 +1064,34 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const *a, simsimd_f16 7, 6, 5, 4 // Points to the second f32 in 128-bit lane ); - while (n >= 8) { - __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); - __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); + while (count_pairs >= 4) { + __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a_pairs)); + __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b_pairs)); __m256 b_swapped_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); ab_imag_vec = _mm256_fmadd_ps(a_vec, b_swapped_vec, ab_imag_vec); - - n -= 8, a += 8, b += 8; + count_pairs -= 4, a_pairs += 4, b_pairs += 4; } // Flip the sign bit in every second scalar before accumulation: ab_real_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(ab_real_vec), sign_flip_vec)); // Reduce horizontal sums and aggregate with the tail: - simsimd_dot_f16c_serial(a, b, n, results); + simsimd_dot_f16c_serial(a_pairs, b_pairs, count_pairs, results); results[0] += _simsimd_reduce_f32x8_haswell(ab_real_vec); results[1] += _simsimd_reduce_f32x8_haswell(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + // Ideally the implementation would load 256 bits worth of vector data at a time, + // shuffle those within a register, split in halfs, and only then upcast. + // That way, we are stepping through 32x 16-bit vector components at a time, or 16 dimensions. + // Sadly, shuffling 16-bit entries in a YMM register is hard to implement efficiently. + // + // Simpler approach is to load 128 bits at a time, upcast, and then shuffle. + // This mostly replicates the `simsimd_vdot_f32c_haswell`. __m256 ab_real_vec = _mm256_setzero_ps(); __m256 ab_imag_vec = _mm256_setzero_ps(); __m256i sign_flip_vec = _mm256_set1_epi64x(0x8000000000000000); @@ -1095,27 +1106,26 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const *a, simsimd_f1 7, 6, 5, 4 // Points to the second f32 in 128-bit lane ); - while (n >= 8) { - __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); - __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); + while (count_pairs >= 4) { + __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a_pairs)); + __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b_pairs)); ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); b_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_imag_vec); - - n -= 8, a += 8, b += 8; + count_pairs -= 4, a_pairs += 4, b_pairs += 4; } // Flip the sign bit in every second scalar before accumulation: ab_imag_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(ab_imag_vec), sign_flip_vec)); // Reduce horizontal sums and aggregate with the tail: - simsimd_dot_f16c_serial(a, b, n, results); + simsimd_dot_f16c_serial(a_pairs, b_pairs, count_pairs, results); results[0] += _simsimd_reduce_f32x8_haswell(ab_real_vec); results[1] += _simsimd_reduce_f32x8_haswell(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const *a_scalars, simsimd_i8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m256i ab_i32_low_vec = _mm256_setzero_si256(); __m256i ab_i32_high_vec = _mm256_setzero_si256(); @@ -1132,10 +1142,10 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t c // The problem with this approach, however, is the `-128` value in the second vector. // Flipping it's sign will do nothing, and the result will be incorrect. // This can easily lead to noticeable numerical errors in the final result. - simsimd_size_t i = 0; - for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); + simsimd_size_t idx_scalars = 0; + for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) { + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a_scalars + idx_scalars)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b_scalars + idx_scalars)); // Upcast `int8` to `int16` __m256i a_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_i8_vec, 0)); @@ -1152,12 +1162,12 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t c int ab = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); // Take care of the tail: - for (; i < n; ++i) ab += (int)(a[i]) * b[i]; + for (; idx_scalars < count_scalars; ++idx_scalars) ab += (int)(a_scalars[idx_scalars]) * b_scalars[idx_scalars]; *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const *a_scalars, simsimd_u8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m256i ab_i32_low_vec = _mm256_setzero_si256(); __m256i ab_i32_high_vec = _mm256_setzero_si256(); @@ -1165,10 +1175,10 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t c // AVX2 has no instructions for unsigned 8-bit integer dot-products, // but it has a weird instruction for mixed signed-unsigned 8-bit dot-product. - simsimd_size_t i = 0; - for (; i + 32 <= n; i += 32) { - __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); - __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); + simsimd_size_t idx_scalars = 0; + for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) { + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const *)(a_scalars + idx_scalars)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const *)(b_scalars + idx_scalars)); // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking // instructions instead of extracts, as they are much faster and more efficient. @@ -1186,22 +1196,22 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t c int ab = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); // Take care of the tail: - for (; i < n; ++i) ab += (int)(a[i]) * b[i]; + for (; idx_scalars < count_scalars; ++idx_scalars) ab += (int)(a_scalars[idx_scalars]) * b_scalars[idx_scalars]; *result = ab; } -SIMSIMD_INTERNAL __m256 _simsimd_bf16x8_to_f32x8_haswell(__m128i a) { +SIMSIMD_INTERNAL __m256 _simsimd_bf16x8_to_f32x8_haswell(__m128i x) { // Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like: - return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16)); + return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(x), 16)); } -SIMSIMD_INTERNAL __m128i _simsimd_f32x8_to_bf16x8_haswell(__m256 a) { +SIMSIMD_INTERNAL __m128i _simsimd_f32x8_to_bf16x8_haswell(__m256 x) { // Pack the 32-bit integers into 16-bit integers. // This is less trivial than unpacking: https://stackoverflow.com/a/77781241/2766161 // The best approach is to shuffle within lanes first: https://stackoverflow.com/a/49723746/2766161 // Our shuffling mask will drop the low 2-bytes from every 4-byte word. __m256i trunc_elements = _mm256_shuffle_epi8( // - _mm256_castps_si256(a), // + _mm256_castps_si256(x), // _mm256_set_epi8( // -1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 11, 10, 7, 6, 3, 2, // -1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 11, 10, 7, 6, 3, 2 // @@ -1225,24 +1235,24 @@ SIMSIMD_INTERNAL __m128i _simsimd_partial_load_bf16x8_haswell(simsimd_bf16_t con return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const *a_scalars, simsimd_bf16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m128i a_vec, b_vec; __m256 ab_vec = _mm256_setzero_ps(); simsimd_dot_bf16_haswell_cycle: - if (n < 8) { - a_vec = _simsimd_partial_load_bf16x8_haswell(a, n); - b_vec = _simsimd_partial_load_bf16x8_haswell(b, n); - n = 0; + if (count_scalars < 8) { + a_vec = _simsimd_partial_load_bf16x8_haswell(a_scalars, count_scalars); + b_vec = _simsimd_partial_load_bf16x8_haswell(b_scalars, count_scalars); + count_scalars = 0; } else { - a_vec = _mm_lddqu_si128((__m128i const *)a); - b_vec = _mm_lddqu_si128((__m128i const *)b); - a += 8, b += 8, n -= 8; + a_vec = _mm_lddqu_si128((__m128i const *)a_scalars); + b_vec = _mm_lddqu_si128((__m128i const *)b_scalars); + a_scalars += 8, b_scalars += 8, count_scalars -= 8; } ab_vec = _mm256_fmadd_ps(_simsimd_bf16x8_to_f32x8_haswell(a_vec), _simsimd_bf16x8_to_f32x8_haswell(b_vec), ab_vec); - if (n) goto simsimd_dot_bf16_haswell_cycle; + if (count_scalars) goto simsimd_dot_bf16_haswell_cycle; *result = _simsimd_reduce_f32x8_haswell(ab_vec); } @@ -1274,58 +1284,57 @@ SIMSIMD_INTERNAL __m256i _simsimd_f32x16_to_bf16x16_skylake(__m512 a) { return _mm512_cvtepi32_epi16(x); } -SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - __m512 ab_vec = _mm512_setzero(); +SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const *a_scalars, simsimd_f32_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m512 a_vec, b_vec; + __m512 ab_vec = _mm512_setzero(); simsimd_dot_f32_skylake_cycle: - if (n < 16) { - __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_ps(mask, a); - b_vec = _mm512_maskz_loadu_ps(mask, b); - n = 0; + if (count_scalars < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, count_scalars); + a_vec = _mm512_maskz_loadu_ps(mask, a_scalars); + b_vec = _mm512_maskz_loadu_ps(mask, b_scalars); + count_scalars = 0; } else { - a_vec = _mm512_loadu_ps(a); - b_vec = _mm512_loadu_ps(b); - a += 16, b += 16, n -= 16; + a_vec = _mm512_loadu_ps(a_scalars); + b_vec = _mm512_loadu_ps(b_scalars); + a_scalars += 16, b_scalars += 16, count_scalars -= 16; } ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec); - if (n) goto simsimd_dot_f32_skylake_cycle; + if (count_scalars) goto simsimd_dot_f32_skylake_cycle; *result = _simsimd_reduce_f32x16_skylake(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - __m512d ab_vec = _mm512_setzero_pd(); +SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const *a_scalars, simsimd_f64_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m512d a_vec, b_vec; + __m512d ab_vec = _mm512_setzero_pd(); simsimd_dot_f64_skylake_cycle: - if (n < 8) { - __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_pd(mask, a); - b_vec = _mm512_maskz_loadu_pd(mask, b); - n = 0; + if (count_scalars < 8) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_scalars); + a_vec = _mm512_maskz_loadu_pd(mask, a_scalars); + b_vec = _mm512_maskz_loadu_pd(mask, b_scalars); + count_scalars = 0; } else { - a_vec = _mm512_loadu_pd(a); - b_vec = _mm512_loadu_pd(b); - a += 8, b += 8, n -= 8; + a_vec = _mm512_loadu_pd(a_scalars); + b_vec = _mm512_loadu_pd(b_scalars); + a_scalars += 8, b_scalars += 8, count_scalars -= 8; } ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec); - if (n) goto simsimd_dot_f64_skylake_cycle; + if (count_scalars) goto simsimd_dot_f64_skylake_cycle; *result = _mm512_reduce_add_pd(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32c_t const *a_pairs, simsimd_f32c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512 a_vec, b_vec; __m512 ab_real_vec = _mm512_setzero(); __m512 ab_imag_vec = _mm512_setzero(); - __m512 a_vec; - __m512 b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1340,21 +1349,21 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const *a, simsimd_f32 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 // 1st 128-bit lane ); simsimd_dot_f32c_skylake_cycle: - if (n < 16) { - __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_ps(mask, a); - b_vec = _mm512_maskz_loadu_ps(mask, b); - n = 0; + if (count_pairs < 8) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_ps(mask, a_pairs); + b_vec = _mm512_maskz_loadu_ps(mask, b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_ps(a); - b_vec = _mm512_loadu_ps(b); - a += 16, b += 16, n -= 16; + a_vec = _mm512_loadu_ps(a_pairs); + b_vec = _mm512_loadu_ps(b_pairs); + a_pairs += 8, b_pairs += 8, count_pairs -= 8; } ab_real_vec = _mm512_fmadd_ps(b_vec, a_vec, ab_real_vec); ab_imag_vec = _mm512_fmadd_ps( _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b_vec), swap_adjacent_vec)), a_vec, ab_imag_vec); - if (n) goto simsimd_dot_f32c_skylake_cycle; + if (count_pairs) goto simsimd_dot_f32c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_real_vec = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(ab_real_vec), sign_flip_vec)); @@ -1364,12 +1373,11 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const *a, simsimd_f32 results[1] = _simsimd_reduce_f32x16_skylake(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32c_t const *a_pairs, simsimd_f32c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512 a_vec, b_vec; __m512 ab_real_vec = _mm512_setzero(); __m512 ab_imag_vec = _mm512_setzero(); - __m512 a_vec; - __m512 b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1384,21 +1392,21 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const *a, simsimd_f3 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 // 1st 128-bit lane ); simsimd_vdot_f32c_skylake_cycle: - if (n < 16) { - __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_ps(mask, a); - b_vec = _mm512_maskz_loadu_ps(mask, b); - n = 0; + if (count_pairs < 8) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_ps(mask, (simsimd_f32_t const *)a_pairs); + b_vec = _mm512_maskz_loadu_ps(mask, (simsimd_f32_t const *)b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_ps(a); - b_vec = _mm512_loadu_ps(b); - a += 16, b += 16, n -= 16; + a_vec = _mm512_loadu_ps((simsimd_f32_t const *)a_pairs); + b_vec = _mm512_loadu_ps((simsimd_f32_t const *)b_pairs); + a_pairs += 8, b_pairs += 8, count_pairs -= 8; } ab_real_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_real_vec); b_vec = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_imag_vec); - if (n) goto simsimd_vdot_f32c_skylake_cycle; + if (count_pairs) goto simsimd_vdot_f32c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_imag_vec = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(ab_imag_vec), sign_flip_vec)); @@ -1408,12 +1416,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const *a, simsimd_f3 results[1] = _simsimd_reduce_f32x16_skylake(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64c_t const *a_pairs, simsimd_f64c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512d a_vec, b_vec; __m512d ab_real_vec = _mm512_setzero_pd(); __m512d ab_imag_vec = _mm512_setzero_pd(); - __m512d a_vec; - __m512d b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1431,21 +1438,21 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const *a, simsimd_f64 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 // 1st 128-bit lane ); simsimd_dot_f64c_skylake_cycle: - if (n < 8) { - __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_pd(mask, a); - b_vec = _mm512_maskz_loadu_pd(mask, b); - n = 0; + if (count_pairs < 4) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_pd(mask, a_pairs); + b_vec = _mm512_maskz_loadu_pd(mask, b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_pd(a); - b_vec = _mm512_loadu_pd(b); - a += 8, b += 8, n -= 8; + a_vec = _mm512_loadu_pd(a_pairs); + b_vec = _mm512_loadu_pd(b_pairs); + a_pairs += 4, b_pairs += 4, count_pairs -= 4; } ab_real_vec = _mm512_fmadd_pd(b_vec, a_vec, ab_real_vec); ab_imag_vec = _mm512_fmadd_pd( _mm512_castsi512_pd(_mm512_shuffle_epi8(_mm512_castpd_si512(b_vec), swap_adjacent_vec)), a_vec, ab_imag_vec); - if (n) goto simsimd_dot_f64c_skylake_cycle; + if (count_pairs) goto simsimd_dot_f64c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_real_vec = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(ab_real_vec), sign_flip_vec)); @@ -1455,12 +1462,11 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const *a, simsimd_f64 results[1] = _mm512_reduce_add_pd(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64c_t const *a_pairs, simsimd_f64c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512d a_vec, b_vec; __m512d ab_real_vec = _mm512_setzero_pd(); __m512d ab_imag_vec = _mm512_setzero_pd(); - __m512d a_vec; - __m512d b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1478,21 +1484,21 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const *a, simsimd_f6 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 // 1st 128-bit lane ); simsimd_vdot_f64c_skylake_cycle: - if (n < 8) { - __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_pd(mask, a); - b_vec = _mm512_maskz_loadu_pd(mask, b); - n = 0; + if (count_pairs < 4) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_pd(mask, (simsimd_f32_t const *)a_pairs); + b_vec = _mm512_maskz_loadu_pd(mask, (simsimd_f32_t const *)b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_pd(a); - b_vec = _mm512_loadu_pd(b); - a += 8, b += 8, n -= 8; + a_vec = _mm512_loadu_pd((simsimd_f32_t const *)a_pairs); + b_vec = _mm512_loadu_pd((simsimd_f32_t const *)b_pairs); + a_pairs += 4, b_pairs += 4, count_pairs -= 4; } ab_real_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_real_vec); b_vec = _mm512_castsi512_pd(_mm512_shuffle_epi8(_mm512_castpd_si512(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_imag_vec); - if (n) goto simsimd_vdot_f64c_skylake_cycle; + if (count_pairs) goto simsimd_vdot_f64c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_imag_vec = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(ab_imag_vec), sign_flip_vec)); @@ -1512,35 +1518,34 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const *a, simsimd_f6 #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - __m512 ab_vec = _mm512_setzero_ps(); +SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const *a_scalars, simsimd_bf16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m512i a_i16_vec, b_i16_vec; + __m512 ab_vec = _mm512_setzero_ps(); simsimd_dot_bf16_genoa_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); - b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); - n = 0; + if (count_scalars < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a_scalars); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b_scalars); + count_scalars = 0; } else { - a_i16_vec = _mm512_loadu_epi16(a); - b_i16_vec = _mm512_loadu_epi16(b); - a += 32, b += 32, n -= 32; + a_i16_vec = _mm512_loadu_epi16(a_scalars); + b_i16_vec = _mm512_loadu_epi16(b_scalars); + a_scalars += 32, b_scalars += 32, count_scalars -= 32; } ab_vec = _mm512_dpbf16_ps(ab_vec, (__m512bh)(a_i16_vec), (__m512bh)(b_i16_vec)); - if (n) goto simsimd_dot_bf16_genoa_cycle; + if (count_scalars) goto simsimd_dot_bf16_genoa_cycle; *result = _simsimd_reduce_f32x16_skylake(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16c_t const *a_pairs, simsimd_bf16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512i a_vec, b_vec; __m512 ab_real_vec = _mm512_setzero_ps(); __m512 ab_imag_vec = _mm512_setzero_ps(); - __m512i a_vec; - __m512i b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1556,33 +1561,32 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf1 ); simsimd_dot_bf16c_genoa_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_epi16(mask, a); - b_vec = _mm512_maskz_loadu_epi16(mask, b); - n = 0; + if (count_pairs < 16) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_epi16(mask, (simsimd_i16_t const *)a_pairs); + b_vec = _mm512_maskz_loadu_epi16(mask, (simsimd_i16_t const *)b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_epi16(a); - b_vec = _mm512_loadu_epi16(b); - a += 32, b += 32, n -= 32; + a_vec = _mm512_loadu_epi16((simsimd_i16_t const *)a_pairs); + b_vec = _mm512_loadu_epi16((simsimd_i16_t const *)b_pairs); + a_pairs += 16, b_pairs += 16, count_pairs -= 16; } ab_real_vec = _mm512_dpbf16_ps(ab_real_vec, (__m512bh)(_mm512_xor_si512(b_vec, sign_flip_vec)), (__m512bh)(a_vec)); ab_imag_vec = _mm512_dpbf16_ps(ab_imag_vec, (__m512bh)(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), (__m512bh)(a_vec)); - if (n) goto simsimd_dot_bf16c_genoa_cycle; + if (count_pairs) goto simsimd_dot_bf16c_genoa_cycle; // Reduce horizontal sums: results[0] = _simsimd_reduce_f32x16_skylake(ab_real_vec); results[1] = _simsimd_reduce_f32x16_skylake(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16c_t const *a_pairs, simsimd_bf16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512i a_vec, b_vec; __m512 ab_real_vec = _mm512_setzero_ps(); __m512 ab_imag_vec = _mm512_setzero_ps(); - __m512i a_vec; - __m512i b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1598,22 +1602,22 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf ); simsimd_dot_bf16c_genoa_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_epi16(mask, a); - b_vec = _mm512_maskz_loadu_epi16(mask, b); - n = 0; + if (count_pairs < 16) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_epi16(mask, (simsimd_i16_t const *)a_pairs); + b_vec = _mm512_maskz_loadu_epi16(mask, (simsimd_i16_t const *)b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_epi16(a); - b_vec = _mm512_loadu_epi16(b); - a += 32, b += 32, n -= 32; + a_vec = _mm512_loadu_epi16((simsimd_i16_t const *)a_pairs); + b_vec = _mm512_loadu_epi16((simsimd_i16_t const *)b_pairs); + a_pairs += 16, b_pairs += 16, count_pairs -= 16; } ab_real_vec = _mm512_dpbf16_ps(ab_real_vec, (__m512bh)(a_vec), (__m512bh)(b_vec)); a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); ab_imag_vec = _mm512_dpbf16_ps(ab_imag_vec, (__m512bh)(a_vec), (__m512bh)(b_vec)); - if (n) goto simsimd_dot_bf16c_genoa_cycle; + if (count_pairs) goto simsimd_dot_bf16c_genoa_cycle; // Reduce horizontal sums: results[0] = _simsimd_reduce_f32x16_skylake(ab_real_vec); @@ -1630,35 +1634,34 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - __m512h ab_vec = _mm512_setzero_ph(); +SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const *a_scalars, simsimd_f16_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m512i a_i16_vec, b_i16_vec; + __m512h ab_vec = _mm512_setzero_ph(); simsimd_dot_f16_sapphire_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); - b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); - n = 0; + if (count_scalars < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a_scalars); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b_scalars); + count_scalars = 0; } else { - a_i16_vec = _mm512_loadu_epi16(a); - b_i16_vec = _mm512_loadu_epi16(b); - a += 32, b += 32, n -= 32; + a_i16_vec = _mm512_loadu_epi16(a_scalars); + b_i16_vec = _mm512_loadu_epi16(b_scalars); + a_scalars += 32, b_scalars += 32, count_scalars -= 32; } ab_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec), ab_vec); - if (n) goto simsimd_dot_f16_sapphire_cycle; + if (count_scalars) goto simsimd_dot_f16_sapphire_cycle; *result = _mm512_reduce_add_ph(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512i a_vec, b_vec; __m512h ab_real_vec = _mm512_setzero_ph(); __m512h ab_imag_vec = _mm512_setzero_ph(); - __m512i a_vec; - __m512i b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1674,23 +1677,23 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f1 ); simsimd_dot_f16c_sapphire_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_epi16(mask, a); - b_vec = _mm512_maskz_loadu_epi16(mask, b); - n = 0; + if (count_pairs < 16) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_epi16(mask, a_pairs); + b_vec = _mm512_maskz_loadu_epi16(mask, b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_epi16(a); - b_vec = _mm512_loadu_epi16(b); - a += 32, b += 32, n -= 32; + a_vec = _mm512_loadu_epi16(a_pairs); + b_vec = _mm512_loadu_epi16(b_pairs); + a_pairs += 16, b_pairs += 16, count_pairs -= 16; } // TODO: Consider using `_mm512_fmaddsub` ab_real_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_xor_si512(b_vec, sign_flip_vec)), _mm512_castsi512_ph(a_vec), ab_real_vec); ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), _mm512_castsi512_ph(a_vec), ab_imag_vec); - if (n) goto simsimd_dot_f16c_sapphire_cycle; + if (count_pairs) goto simsimd_dot_f16c_sapphire_cycle; // Reduce horizontal sums: // TODO: Optimize this with tree-like reductions @@ -1698,12 +1701,11 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f1 results[1] = _mm512_reduce_add_ph(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, - simsimd_distance_t *results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16c_t const *a_pairs, simsimd_f16c_t const *b_pairs, + simsimd_size_t count_pairs, simsimd_distance_t *results) { + __m512i a_vec, b_vec; __m512h ab_real_vec = _mm512_setzero_ph(); __m512h ab_imag_vec = _mm512_setzero_ph(); - __m512i a_vec; - __m512i b_vec; // We take into account, that FMS is the same as FMA with a negative multiplier. // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. @@ -1719,22 +1721,22 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f ); simsimd_dot_f16c_sapphire_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_vec = _mm512_maskz_loadu_epi16(mask, a); - b_vec = _mm512_maskz_loadu_epi16(mask, b); - n = 0; + if (count_pairs < 16) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_pairs * 2); + a_vec = _mm512_maskz_loadu_epi16(mask, a_pairs); + b_vec = _mm512_maskz_loadu_epi16(mask, b_pairs); + count_pairs = 0; } else { - a_vec = _mm512_loadu_epi16(a); - b_vec = _mm512_loadu_epi16(b); - a += 32, b += 32, n -= 32; + a_vec = _mm512_loadu_epi16(a_pairs); + b_vec = _mm512_loadu_epi16(b_pairs); + a_pairs += 16, b_pairs += 16, count_pairs -= 16; } ab_real_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_vec), _mm512_castsi512_ph(b_vec), ab_real_vec); a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_vec), _mm512_castsi512_ph(b_vec), ab_imag_vec); - if (n) goto simsimd_dot_f16c_sapphire_cycle; + if (count_pairs) goto simsimd_dot_f16c_sapphire_cycle; // Reduce horizontal sums: results[0] = _mm512_reduce_add_ph(ab_real_vec); @@ -1751,53 +1753,53 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { - __m512i ab_i32_vec = _mm512_setzero_si512(); +SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const *a_scalars, simsimd_i8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m512i a_i16_vec, b_i16_vec; + __m512i ab_i32_vec = _mm512_setzero_si512(); simsimd_dot_i8_ice_cycle: - if (n < 32) { - __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); - n = 0; + if (count_scalars < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars); + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a_scalars)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b_scalars)); + count_scalars = 0; } else { - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)b)); - a += 32, b += 32, n -= 32; + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)a_scalars)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)b_scalars)); + a_scalars += 32, b_scalars += 32, count_scalars -= 32; } // Unfortunately we can't use the `_mm512_dpbusd_epi32` intrinsics here either, // as it's asymmetric with respect to the sign of the input arguments: - // Signed(ZeroExtend16(a.byte[4*j]) * SignExtend16(b.byte[4*j])) + // Signed(ZeroExtend16(a_scalars.byte[4*j]) * SignExtend16(b_scalars.byte[4*j])) // So we have to use the `_mm512_dpwssd_epi32` intrinsics instead, upcasting // to 16-bit beforehand. ab_i32_vec = _mm512_dpwssd_epi32(ab_i32_vec, a_i16_vec, b_i16_vec); - if (n) goto simsimd_dot_i8_ice_cycle; + if (count_scalars) goto simsimd_dot_i8_ice_cycle; *result = _mm512_reduce_add_epi32(ab_i32_vec); } -SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const *a_scalars, simsimd_u8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { + __m512i a_u8_vec, b_u8_vec; + __m512i a_i16_low_vec, a_i16_high_vec, b_i16_low_vec, b_i16_high_vec; __m512i ab_i32_low_vec = _mm512_setzero_si512(); __m512i ab_i32_high_vec = _mm512_setzero_si512(); __m512i const zeros_vec = _mm512_setzero_si512(); - __m512i a_i16_low_vec, a_i16_high_vec, b_i16_low_vec, b_i16_high_vec; - __m512i a_u8_vec, b_u8_vec; simsimd_dot_u8_ice_cycle: - if (n < 64) { - __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); - a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); - b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); - n = 0; + if (count_scalars < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars); + a_u8_vec = _mm512_maskz_loadu_epi8(mask, a_scalars); + b_u8_vec = _mm512_maskz_loadu_epi8(mask, b_scalars); + count_scalars = 0; } else { - a_u8_vec = _mm512_loadu_si512(a); - b_u8_vec = _mm512_loadu_si512(b); - a += 64, b += 64, n -= 64; + a_u8_vec = _mm512_loadu_si512(a_scalars); + b_u8_vec = _mm512_loadu_si512(b_scalars); + a_scalars += 64, b_scalars += 64, count_scalars -= 64; } // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking @@ -1813,7 +1815,7 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const // to 16-bit beforehand. ab_i32_low_vec = _mm512_dpwssd_epi32(ab_i32_low_vec, a_i16_low_vec, b_i16_low_vec); ab_i32_high_vec = _mm512_dpwssd_epi32(ab_i32_high_vec, a_i16_high_vec, b_i16_high_vec); - if (n) goto simsimd_dot_u8_ice_cycle; + if (count_scalars) goto simsimd_dot_u8_ice_cycle; *result = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); } @@ -1827,15 +1829,14 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const #pragma GCC target("avx2", "bmi2", "avx2vnni") #pragma clang attribute push(__attribute__((target("avx2,bmi2,avx2vnni"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, - simsimd_distance_t *result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const *a_scalars, simsimd_i8_t const *b_scalars, + simsimd_size_t count_scalars, simsimd_distance_t *result) { __m256i ab_i32_vec = _mm256_setzero_si256(); - - simsimd_size_t i = 0; - for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); + simsimd_size_t idx_scalars = 0; + for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) { + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a_scalars + idx_scalars)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b_scalars + idx_scalars)); ab_i32_vec = _mm256_dpbssds_epi32(ab_i32_vec, a_i8_vec, b_i8_vec); } @@ -1843,8 +1844,7 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const *a, simsimd_i8_t co int ab = _simsimd_reduce_i32x8_haswell(ab_i32_vec); // Take care of the tail: - for (; i < n; ++i) ab += (int)(a[i]) * b[i]; - + for (; idx_scalars < count_scalars; ++idx_scalars) ab += (int)(a_scalars[idx_scalars]) * b_scalars[idx_scalars]; *result = ab; } diff --git a/include/simsimd/simsimd.h b/include/simsimd/simsimd.h index 2eb4f805..94a8cbb8 100644 --- a/include/simsimd/simsimd.h +++ b/include/simsimd/simsimd.h @@ -981,6 +981,7 @@ SIMSIMD_INTERNAL void _simsimd_find_kernel_punned_f64c(simsimd_capability_t v, s if (v & simsimd_cap_serial_k) switch (k) { case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_serial, *c = simsimd_cap_serial_k; return; case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f64c_serial, *c = simsimd_cap_serial_k; return; default: break; } } @@ -1019,6 +1020,7 @@ SIMSIMD_INTERNAL void _simsimd_find_kernel_punned_f32c(simsimd_capability_t v, s if (v & simsimd_cap_serial_k) switch (k) { case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_serial, *c = simsimd_cap_serial_k; return; case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32c_serial, *c = simsimd_cap_serial_k; return; default: break; } } @@ -1057,6 +1059,7 @@ SIMSIMD_INTERNAL void _simsimd_find_kernel_punned_f16c(simsimd_capability_t v, s if (v & simsimd_cap_serial_k) switch (k) { case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_serial, *c = simsimd_cap_serial_k; return; case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16c_serial, *c = simsimd_cap_serial_k; return; default: break; } } @@ -1081,6 +1084,7 @@ SIMSIMD_INTERNAL void _simsimd_find_kernel_punned_bf16c(simsimd_capability_t v, if (v & simsimd_cap_serial_k) switch (k) { case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_serial, *c = simsimd_cap_serial_k; return; case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_bf16c_serial, *c = simsimd_cap_serial_k; return; default: break; } } @@ -1314,21 +1318,21 @@ SIMSIMD_DYNAMIC void simsimd_dot_f32(simsimd_f32_t const *a, simsimd_f32_t const simsimd_distance_t *d); SIMSIMD_DYNAMIC void simsimd_dot_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_dot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_dot_f16c(simsimd_f16c_t const *a, simsimd_f16c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_dot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_dot_bf16c(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_dot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_dot_f32c(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_dot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_dot_f64c(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_vdot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_vdot_f16c(simsimd_f16c_t const *a, simsimd_f16c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_vdot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_vdot_bf16c(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_vdot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_vdot_f32c(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_size_t n, simsimd_distance_t *d); -SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, +SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_size_t n, simsimd_distance_t *d); /* Spatial distances @@ -1562,7 +1566,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f64(simsimd_f64_t const *a, simsimd_f64_t const simsimd_dot_f64_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16c_t const *a, simsimd_f16c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE_F16 simsimd_dot_f16c_sve(a, b, n, d); @@ -1576,7 +1580,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16_t const *a, simsimd_f16_t const simsimd_dot_f16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_dot_bf16c_genoa(a, b, n, d); @@ -1586,7 +1590,7 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t co simsimd_dot_bf16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_dot_f32c_sve(a, b, n, d); @@ -1600,7 +1604,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32_t const *a, simsimd_f32_t const simsimd_dot_f32c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_dot_f64c_sve(a, b, n, d); @@ -1610,7 +1614,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64_t const *a, simsimd_f64_t const simsimd_dot_f64c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16c_t const *a, simsimd_f16c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_vdot_f16c_sve(a, b, n, d); @@ -1624,7 +1628,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16_t const *a, simsimd_f16_t cons simsimd_vdot_f16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_vdot_bf16c_genoa(a, b, n, d); @@ -1634,7 +1638,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t c simsimd_vdot_bf16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_vdot_f32c_sve(a, b, n, d); @@ -1648,7 +1652,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32_t const *a, simsimd_f32_t cons simsimd_vdot_f32c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, +SIMSIMD_PUBLIC void simsimd_vdot_f64c(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_vdot_f64c_sve(a, b, n, d); diff --git a/python/lib.c b/python/lib.c index 2aca7ee5..7cbaabbe 100644 --- a/python/lib.c +++ b/python/lib.c @@ -540,8 +540,6 @@ int parse_tensor(PyObject *tensor, Py_buffer *buffer, TensorArgument *parsed) { return 0; } - // We handle complex numbers differently - if (is_complex(parsed->datatype)) parsed->dimensions *= 2; return 1; } @@ -729,17 +727,12 @@ static PyObject *implement_dense_metric( // // If the distance is computed between two vectors, rather than matrices, return a scalar int const dtype_is_complex = is_complex(dtype); if (a_parsed.rank == 1 && b_parsed.rank == 1) { - // For complex numbers we are going to use `PyComplex_FromDoubles`. - if (dtype_is_complex) { - simsimd_distance_t distances[2]; - metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, distances); - return_obj = PyComplex_FromDoubles(distances[0], distances[1]); - } - else { - simsimd_distance_t distance; - metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, &distance); - return_obj = PyFloat_FromDouble(distance); - } + simsimd_distance_t distances[2]; + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, distances); + return_obj = // + dtype_is_complex // + ? PyComplex_FromDoubles(distances[0], distances[1]) + : PyFloat_FromDouble(distances[0]); goto cleanup; } @@ -931,9 +924,14 @@ static PyObject *implement_curved_metric( // goto cleanup; } - simsimd_distance_t distance; - metric(a_parsed.start, b_parsed.start, c_parsed.start, a_parsed.dimensions, &distance); - return_obj = PyFloat_FromDouble(distance); + // If the distance is computed between two vectors, rather than matrices, return a scalar + int const dtype_is_complex = is_complex(dtype); + simsimd_distance_t distances[2]; + metric(a_parsed.start, b_parsed.start, c_parsed.start, a_parsed.dimensions, &distances[0]); + return_obj = // + dtype_is_complex // + ? PyComplex_FromDoubles(distances[0], distances[1]) + : PyFloat_FromDouble(distances[0]); cleanup: PyBuffer_Release(&a_buffer); @@ -1083,17 +1081,12 @@ static PyObject *implement_cdist( // // If the distance is computed between two vectors, rather than matrices, return a scalar int const dtype_is_complex = is_complex(dtype); if (a_parsed.rank == 1 && b_parsed.rank == 1) { - // For complex numbers we are going to use `PyComplex_FromDoubles`. - if (dtype_is_complex) { - simsimd_distance_t distances[2]; - metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, distances); - return_obj = PyComplex_FromDoubles(distances[0], distances[1]); - } - else { - simsimd_distance_t distance; - metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, &distance); - return_obj = PyFloat_FromDouble(distance); - } + simsimd_distance_t distances[2]; + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, distances); + return_obj = // + dtype_is_complex // + ? PyComplex_FromDoubles(distances[0], distances[1]) + : PyFloat_FromDouble(distances[0]); goto cleanup; } diff --git a/scripts/test.c b/scripts/test.c index b8302cf5..4d590c74 100644 --- a/scripts/test.c +++ b/scripts/test.c @@ -90,6 +90,12 @@ void test_distance_from_itself(void) { simsimd_f32_t f32s[1536]; simsimd_f16_t f16s[1536]; simsimd_bf16_t bf16s[1536]; + + simsimd_f64c_t f64cs[768]; + simsimd_f32c_t f32cs[768]; + simsimd_f16c_t f16cs[768]; + simsimd_bf16c_t bf16cs[768]; + simsimd_i8_t i8s[1536]; simsimd_u8_t u8s[1536]; simsimd_b8_t b8s[1536 / 8]; // 8 bits per word @@ -120,16 +126,16 @@ void test_distance_from_itself(void) { simsimd_dot_f64(f64s, f64s, 1536, &distance[0]); // Complex inner product between two vectors - simsimd_dot_f16c(f16s, f16s, 1536, &distance[0]); - simsimd_dot_bf16c(bf16s, bf16s, 1536, &distance[0]); - simsimd_dot_f32c(f32s, f32s, 1536, &distance[0]); - simsimd_dot_f64c(f64s, f64s, 1536, &distance[0]); + simsimd_dot_bf16c(bf16cs, bf16cs, 768, &distance[0]); + simsimd_dot_f16c(f16cs, f16cs, 768, &distance[0]); + simsimd_dot_f32c(f32cs, f32cs, 768, &distance[0]); + simsimd_dot_f64c(f64cs, f64cs, 768, &distance[0]); // Complex conjugate inner product between two vectors - simsimd_vdot_f16c(f16s, f16s, 1536, &distance[0]); - simsimd_vdot_bf16c(bf16s, bf16s, 1536, &distance[0]); - simsimd_vdot_f32c(f32s, f32s, 1536, &distance[0]); - simsimd_vdot_f64c(f64s, f64s, 1536, &distance[0]); + simsimd_vdot_bf16c(bf16cs, bf16cs, 768, &distance[0]); + simsimd_vdot_f16c(f16cs, f16cs, 768, &distance[0]); + simsimd_vdot_f32c(f32cs, f32cs, 768, &distance[0]); + simsimd_vdot_f64c(f64cs, f64cs, 768, &distance[0]); // Hamming distance between two vectors simsimd_hamming_b8(b8s, b8s, 1536 / 8, &distance[0]); diff --git a/scripts/test.py b/scripts/test.py index 985dc67c..c1ca39a8 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -672,6 +672,40 @@ def test_curved(ndim, dtypes, metric, capability, stats_fixture): collect_errors(metric, ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) +@pytest.mark.skipif(is_running_under_qemu(), reason="Complex math in QEMU fails") +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +@pytest.mark.repeat(50) +@pytest.mark.parametrize("ndim", [11, 97]) +@pytest.mark.parametrize("dtype", ["complex128", "complex64"]) +@pytest.mark.parametrize("capability", possible_capabilities) +def test_curved_complex(ndim, dtype, capability, stats_fixture): + """Compares various SIMD kernels (like Bilinear Forms and Mahalanobis distances) for curved spaces + with their NumPy or baseline counterparts, testing accuracy for complex IEEE standard floating-point types.""" + + # Let's generate some uniform complex numbers + np.random.seed() + a = (np.random.randn(ndim) + 1.0j * np.random.randn(ndim)).astype(dtype) + b = (np.random.randn(ndim) + 1.0j * np.random.randn(ndim)).astype(dtype) + c = (np.random.randn(ndim, ndim) + 1.0j * np.random.randn(ndim, ndim)).astype(dtype) + + keep_one_capability(capability) + baseline_kernel, simd_kernel = name_to_kernels("bilinear") + accurate_dt, accurate = profile( + baseline_kernel, + a.astype(np.complex128), + b.astype(np.complex128), + c.astype(np.complex128), + ) + expected_dt, expected = profile(baseline_kernel, a, b, c) + result_dt, result = profile(simd_kernel, a, b, c) + result = np.array(result) + + np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + collect_errors( + "bilinear", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture + ) + + @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.repeat(50) @pytest.mark.parametrize("ndim", [11, 97, 1536]) @@ -968,15 +1002,14 @@ def test_overflow_i8(ndim, metric, capability): @pytest.mark.skipif(is_running_under_qemu(), reason="Complex math in QEMU fails") @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.repeat(50) -@pytest.mark.parametrize("ndim", [22, 66, 1536]) -@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("ndim", [11, 97, 1536]) +@pytest.mark.parametrize("dtype", ["complex128", "complex64"]) @pytest.mark.parametrize("capability", possible_capabilities) def test_dot_complex(ndim, dtype, capability, stats_fixture): """Compares the simd.dot() and simd.vdot() against NumPy for complex numbers.""" np.random.seed() - dtype_view = np.complex64 if dtype == "float32" else np.complex128 - a = np.random.randn(ndim).astype(dtype=dtype).view(dtype_view) - b = np.random.randn(ndim).astype(dtype=dtype).view(dtype_view) + a = (np.random.randn(ndim) + 1.0j * np.random.randn(ndim)).astype(dtype) + b = (np.random.randn(ndim) + 1.0j * np.random.randn(ndim)).astype(dtype) keep_one_capability(capability) accurate_dt, accurate = profile(np.dot, a.astype(np.complex128), b.astype(np.complex128)) @@ -986,7 +1019,7 @@ def test_dot_complex(ndim, dtype, capability, stats_fixture): np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors( - "dot", ndim, dtype + "c", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture + "dot", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture ) accurate_dt, accurate = profile(np.vdot, a.astype(np.complex128), b.astype(np.complex128)) @@ -996,33 +1029,10 @@ def test_dot_complex(ndim, dtype, capability, stats_fixture): np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors( - "vdot", ndim, dtype + "c", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture + "vdot", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture ) -@pytest.mark.skipif(is_running_under_qemu(), reason="Complex math in QEMU fails") -@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") -@pytest.mark.repeat(50) -@pytest.mark.parametrize("ndim", [22, 66, 1536]) -@pytest.mark.parametrize("capability", possible_capabilities) -def test_dot_complex_explicit(ndim, capability): - """Compares the simd.dot() and simd.vdot() against NumPy for complex numbers.""" - np.random.seed() - a = np.random.randn(ndim).astype(dtype=np.float32) - b = np.random.randn(ndim).astype(dtype=np.float32) - - keep_one_capability(capability) - expected = np.dot(a.view(np.complex64), b.view(np.complex64)) - result = simd.dot(a, b, "complex64") - - np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) - - expected = np.vdot(a.view(np.complex64), b.view(np.complex64)) - result = simd.vdot(a, b, "complex64") - - np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) - - @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.repeat(100) @pytest.mark.parametrize("dtype", ["uint16", "uint32"])