diff --git a/include/simsimd/curved.h b/include/simsimd/curved.h index 18b605e7..52684ad0 100644 --- a/include/simsimd/curved.h +++ b/include/simsimd/curved.h @@ -97,10 +97,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const* a, simsi * properly vectorized by recent compilers. */ SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_bilinear_f16c_haswell(simsimd_f16c_t const* a, simsimd_f16c_t const* b, simsimd_f16c_t const* c, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result); -SIMSIMD_PUBLIC void simsimd_bilinear_bf16c_haswell(simsimd_bf16c_t const* a, simsimd_bf16c_t const* b, simsimd_bf16c_t const* c, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result); /* SIMD-powered backends for various generations of AVX512 CPUs. @@ -128,14 +126,14 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim simsimd_size_t n, simsimd_distance_t *result) { \ simsimd_##accumulator_type##_t sum = 0; \ for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t partial = 0; \ + simsimd_##accumulator_type##_t cb_j = 0; \ simsimd_##accumulator_type##_t a_i = load_and_convert(a + i); \ for (simsimd_size_t j = 0; j != n; ++j) { \ simsimd_##accumulator_type##_t b_j = load_and_convert(b + j); \ simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \ - partial += c_ij * b_j; \ + cb_j += c_ij * b_j; \ } \ - sum += a_i * partial; \ + sum += a_i * cb_j; \ } \ *result = (simsimd_distance_t)sum; \ } @@ -147,8 +145,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim simsimd_##accumulator_type##_t sum_real = 0; \ simsimd_##accumulator_type##_t sum_imag = 0; \ for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t partial_real = 0; \ - simsimd_##accumulator_type##_t partial_imag = 0; \ + simsimd_##accumulator_type##_t cb_j_real = 0; \ + simsimd_##accumulator_type##_t cb_j_imag = 0; \ simsimd_##accumulator_type##_t a_i_real = load_and_convert(&(a_pairs + i)->real); \ simsimd_##accumulator_type##_t a_i_imag = load_and_convert(&(a_pairs + i)->imag); \ for (simsimd_size_t j = 0; j != n; ++j) { \ @@ -157,16 +155,12 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim simsimd_##accumulator_type##_t c_ij_real = load_and_convert(&(c_pairs + i * n + j)->real); \ simsimd_##accumulator_type##_t c_ij_imag = load_and_convert(&(c_pairs + i * n + j)->imag); \ /* Complex multiplication: (c_ij * b_j) */ \ - simsimd_##accumulator_type##_t prod_real = c_ij_real * b_j_real - c_ij_imag * b_j_imag; \ - simsimd_##accumulator_type##_t prod_imag = c_ij_real * b_j_imag + c_ij_imag * b_j_real; \ - partial_real += prod_real; \ - partial_imag += prod_imag; \ + cb_j_real += c_ij_real * b_j_real - c_ij_imag * b_j_imag; \ + cb_j_imag += c_ij_real * b_j_imag + c_ij_imag * b_j_real; \ } \ - /* Complex multiplication: (a_i * partial) */ \ - simsimd_##accumulator_type##_t final_real = a_i_real * partial_real - a_i_imag * partial_imag; \ - simsimd_##accumulator_type##_t final_imag = a_i_real * partial_imag + a_i_imag * partial_real; \ - sum_real += final_real; \ - sum_imag += final_imag; \ + /* Complex multiplication: (a_i * cb_j) */ \ + sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag; \ + sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real; \ } \ results[0] = (simsimd_distance_t)sum_real; \ results[1] = (simsimd_distance_t)sum_imag; \ @@ -178,14 +172,14 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim simsimd_size_t n, simsimd_distance_t *result) { \ simsimd_##accumulator_type##_t sum = 0; \ for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t partial = 0; \ + simsimd_##accumulator_type##_t cdiff_j = 0; \ simsimd_##accumulator_type##_t diff_i = load_and_convert(a + i) - load_and_convert(b + i); \ for (simsimd_size_t j = 0; j != n; ++j) { \ simsimd_##accumulator_type##_t diff_j = load_and_convert(a + j) - load_and_convert(b + j); \ simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \ - partial += c_ij * diff_j; \ + cdiff_j += c_ij * diff_j; \ } \ - sum += diff_i * partial; \ + sum += diff_i * cdiff_j; \ } \ *result = (simsimd_distance_t)SIMSIMD_SQRT(sum); \ } @@ -229,13 +223,13 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const *a, simsimd_f3 float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { float32x4_t a_vec = vdupq_n_f32(a[i]); - float32x4_t partial_sum_vec = vdupq_n_f32(0); + float32x4_t cb_j_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 4 <= n; j += 4) { float32x4_t b_vec = vld1q_f32(b + j); float32x4_t c_vec = vld1q_f32(c + i * n + j); - partial_sum_vec = vmlaq_f32(partial_sum_vec, b_vec, c_vec); + cb_j_vec = vmlaq_f32(cb_j_vec, b_vec, c_vec); } - sum_vec = vmlaq_f32(sum_vec, a_vec, partial_sum_vec); + sum_vec = vmlaq_f32(sum_vec, a_vec, cb_j_vec); } // Handle the tail of every row @@ -245,9 +239,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const *a, simsimd_f3 if (tail_length) { for (simsimd_size_t i = 0; i != n; ++i) { simsimd_f32_t a_i = a[i]; - simsimd_f32_t partial_sum = 0; - for (simsimd_size_t j = tail_start; j != n; ++j) partial_sum += b[j] * c[i * n + j]; - sum += a[i] * partial_sum; + simsimd_f32_t cb_j = 0; + for (simsimd_size_t j = tail_start; j != n; ++j) cb_j += b[j] * c[i * n + j]; + sum += a[i] * cb_j; } } @@ -262,14 +256,14 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const *a, simsimd float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { float32x4_t diff_i_vec = vdupq_n_f32(a[i] - b[i]); - float32x4_t partial_sum_vec = vdupq_n_f32(0); + float32x4_t cdiff_j_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 4 <= n; j += 4) { float32x4_t diff_j_vec = vsubq_f32(vld1q_f32(a + j), vld1q_f32(b + j)); float32x4_t c_vec = vld1q_f32(c + i * n + j); - partial_sum_vec = vmlaq_f32(partial_sum_vec, diff_j_vec, c_vec); + cdiff_j_vec = vmlaq_f32(cdiff_j_vec, diff_j_vec, c_vec); } - sum_vec = vmlaq_f32(sum_vec, diff_i_vec, partial_sum_vec); + sum_vec = vmlaq_f32(sum_vec, diff_i_vec, cdiff_j_vec); } // Handle the tail of every row @@ -279,12 +273,12 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const *a, simsimd if (tail_length) { for (simsimd_size_t i = 0; i != n; ++i) { simsimd_f32_t diff_i = a[i] - b[i]; - simsimd_f32_t partial_sum = 0; + simsimd_f32_t cdiff_j = 0; for (simsimd_size_t j = tail_start; j != n; ++j) { simsimd_f32_t diff_j = a[j] - b[j]; - partial_sum += diff_j * c[i * n + j]; + cdiff_j += diff_j * c[i * n + j]; } - sum += diff_i * partial_sum; + sum += diff_i * cdiff_j; } } @@ -306,13 +300,13 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_neon(simsimd_f16_t const *a, simsimd_f1 for (simsimd_size_t i = 0; i != n; ++i) { // MSVC doesn't recognize `vdup_n_f16` as a valid intrinsic float32x4_t a_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const *)(a + i)))); - float32x4_t partial_sum_vec = vdupq_n_f32(0); + float32x4_t cb_j_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 4 <= n; j += 4) { float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(b + j))); float32x4_t c_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(c + i * n + j))); - partial_sum_vec = vmlaq_f32(partial_sum_vec, b_vec, c_vec); + cb_j_vec = vmlaq_f32(cb_j_vec, b_vec, c_vec); } - sum_vec = vmlaq_f32(sum_vec, a_vec, partial_sum_vec); + sum_vec = vmlaq_f32(sum_vec, a_vec, cb_j_vec); } // Handle the tail of every row @@ -324,8 +318,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_neon(simsimd_f16_t const *a, simsimd_f1 simsimd_f32_t a_i = vaddvq_f32(vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a + i, 1))); float32x4_t b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b + tail_start, tail_length)); float32x4_t c_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(c + i * n + tail_start, tail_length)); - simsimd_f32_t partial_sum = vaddvq_f32(vmulq_f32(b_vec, c_vec)); - sum += a_i * partial_sum; + simsimd_f32_t cb_j = vaddvq_f32(vmulq_f32(b_vec, c_vec)); + sum += a_i * cb_j; } } @@ -344,15 +338,15 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const *a, simsimd float32x4_t a_i_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const *)(a + i)))); float32x4_t b_i_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const *)(b + i)))); float32x4_t diff_i_vec = vsubq_f32(a_i_vec, b_i_vec); - float32x4_t partial_sum_vec = vdupq_n_f32(0); + float32x4_t cdiff_j_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 4 <= n; j += 4) { float32x4_t a_j_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(a + j))); float32x4_t b_j_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(b + j))); float32x4_t diff_j_vec = vsubq_f32(a_j_vec, b_j_vec); float32x4_t c_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(c + i * n + j))); - partial_sum_vec = vmlaq_f32(partial_sum_vec, diff_j_vec, c_vec); + cdiff_j_vec = vmlaq_f32(cdiff_j_vec, diff_j_vec, c_vec); } - sum_vec = vmlaq_f32(sum_vec, diff_i_vec, partial_sum_vec); + sum_vec = vmlaq_f32(sum_vec, diff_i_vec, cdiff_j_vec); } // Handle the tail of every row @@ -368,8 +362,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const *a, simsimd float32x4_t b_j_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b + tail_start, tail_length)); float32x4_t diff_j_vec = vsubq_f32(a_j_vec, b_j_vec); float32x4_t c_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(c + i * n + tail_start, tail_length)); - simsimd_f32_t partial_sum = vaddvq_f32(vmulq_f32(diff_j_vec, c_vec)); - sum += diff_i * partial_sum; + simsimd_f32_t cdiff_j = vaddvq_f32(vmulq_f32(diff_j_vec, c_vec)); + sum += diff_i * cdiff_j; } } @@ -390,13 +384,13 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_ float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { float32x4_t a_vec = vdupq_n_f32(simsimd_bf16_to_f32(a + i)); - float32x4_t partial_sum_vec = vdupq_n_f32(0); + float32x4_t cb_j_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(b + j)); bfloat16x8_t c_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(c + i * n + j)); - partial_sum_vec = vbfdotq_f32(partial_sum_vec, b_vec, c_vec); + cb_j_vec = vbfdotq_f32(cb_j_vec, b_vec, c_vec); } - sum_vec = vmlaq_f32(sum_vec, a_vec, partial_sum_vec); + sum_vec = vmlaq_f32(sum_vec, a_vec, cb_j_vec); } // Handle the tail of every row @@ -408,8 +402,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_ simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i); bfloat16x8_t b_vec = _simsimd_partial_load_bf16x8_neon(b + tail_start, tail_length); bfloat16x8_t c_vec = _simsimd_partial_load_bf16x8_neon(c + i * n + tail_start, tail_length); - simsimd_f32_t partial_sum = vaddvq_f32(vbfdotq_f32(vdupq_n_f32(0), b_vec, c_vec)); - sum += a_i * partial_sum; + simsimd_f32_t cb_j = vaddvq_f32(vbfdotq_f32(vdupq_n_f32(0), b_vec, c_vec)); + sum += a_i * cb_j; } } @@ -428,7 +422,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i); simsimd_f32_t b_i = simsimd_bf16_to_f32(b + i); float32x4_t diff_i_vec = vdupq_n_f32(a_i - b_i); - float32x4_t partial_sum_vec = vdupq_n_f32(0); + float32x4_t cdiff_j_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { bfloat16x8_t a_j_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(a + j)); bfloat16x8_t b_j_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(b + j)); @@ -445,9 +439,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi bfloat16x8_t diff_j_vec = vcombine_bf16(vcvt_bf16_f32(diff_j_vec_low), vcvt_bf16_f32(diff_j_vec_high)); bfloat16x8_t c_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(c + i * n + j)); - partial_sum_vec = vbfdotq_f32(partial_sum_vec, diff_j_vec, c_vec); + cdiff_j_vec = vbfdotq_f32(cdiff_j_vec, diff_j_vec, c_vec); } - sum_vec = vmlaq_f32(sum_vec, diff_i_vec, partial_sum_vec); + sum_vec = vmlaq_f32(sum_vec, diff_i_vec, cdiff_j_vec); } // Handle the tail of every row @@ -472,8 +466,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi bfloat16x8_t diff_j_vec = vcombine_bf16(vcvt_bf16_f32(diff_j_vec_low), vcvt_bf16_f32(diff_j_vec_high)); bfloat16x8_t c_vec = _simsimd_partial_load_bf16x8_neon(c + i * n + tail_start, tail_length); - simsimd_f32_t partial_sum = vaddvq_f32(vbfdotq_f32(vdupq_n_f32(0), diff_j_vec, c_vec)); - sum += diff_i * partial_sum; + simsimd_f32_t cdiff_j = vaddvq_f32(vbfdotq_f32(vdupq_n_f32(0), diff_j_vec, c_vec)); + sum += diff_i * cdiff_j; } } @@ -497,13 +491,13 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const *a, simsimd __m256 sum_vec = _mm256_setzero_ps(); for (simsimd_size_t i = 0; i != n; ++i) { __m256 a_vec = _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))); - __m256 partial_sum_vec = _mm256_setzero_ps(); + __m256 cb_j_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(b + j))); __m256 c_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); - partial_sum_vec = _mm256_fmadd_ps(b_vec, c_vec, partial_sum_vec); + cb_j_vec = _mm256_fmadd_ps(b_vec, c_vec, cb_j_vec); } - sum_vec = _mm256_fmadd_ps(a_vec, partial_sum_vec, sum_vec); + sum_vec = _mm256_fmadd_ps(a_vec, cb_j_vec, sum_vec); } // Handle the tail of every row @@ -515,18 +509,14 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const *a, simsimd simsimd_f32_t a_i = _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i)))); __m256 b_vec = _simsimd_partial_load_f16x8_haswell(b + tail_start, tail_length); __m256 c_vec = _simsimd_partial_load_f16x8_haswell(c + i * n + tail_start, tail_length); - simsimd_f32_t partial_sum = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_vec, c_vec)); - sum += a_i * partial_sum; + simsimd_f32_t cb_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_vec, c_vec)); + sum += a_i * cb_j; } } *result = sum; } -SIMSIMD_PUBLIC void simsimd_bilinear_f16c_haswell(simsimd_f16c_t const *a, simsimd_f16c_t const *b, - simsimd_f16c_t const *c, simsimd_size_t n, - simsimd_distance_t *results); - SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, simsimd_size_t n, simsimd_distance_t *result) { @@ -535,15 +525,15 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const *a, sims __m256 diff_i_vec = _mm256_sub_ps( // _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), // _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i)))); - __m256 partial_sum_vec = _mm256_setzero_ps(); + __m256 cdiff_j_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { __m256 diff_j_vec = _mm256_sub_ps( // _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(a + j))), _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(b + j)))); __m256 c_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); - partial_sum_vec = _mm256_fmadd_ps(diff_j_vec, c_vec, partial_sum_vec); + cdiff_j_vec = _mm256_fmadd_ps(diff_j_vec, c_vec, cdiff_j_vec); } - sum_vec = _mm256_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); + sum_vec = _mm256_fmadd_ps(diff_i_vec, cdiff_j_vec, sum_vec); } // Handle the tail of every row @@ -559,8 +549,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const *a, sims _simsimd_partial_load_f16x8_haswell(a + tail_start, tail_length), _simsimd_partial_load_f16x8_haswell(b + tail_start, tail_length)); __m256 c_vec = _simsimd_partial_load_f16x8_haswell(c + i * n + tail_start, tail_length); - simsimd_f32_t partial_sum = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(diff_j_vec, c_vec)); - sum += diff_i * partial_sum; + simsimd_f32_t cdiff_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(diff_j_vec, c_vec)); + sum += diff_i * cdiff_j; } } @@ -574,13 +564,13 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi for (simsimd_size_t i = 0; i != n; ++i) { // The `simsimd_bf16_to_f32` is cheaper than `_simsimd_bf16x8_to_f32x8_haswell` __m256 a_vec = _mm256_set1_ps(simsimd_bf16_to_f32(a + i)); - __m256 partial_sum_vec = _mm256_setzero_ps(); + __m256 cb_j_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { __m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(b + j))); __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); - partial_sum_vec = _mm256_fmadd_ps(b_vec, c_vec, partial_sum_vec); + cb_j_vec = _mm256_fmadd_ps(b_vec, c_vec, cb_j_vec); } - sum_vec = _mm256_fmadd_ps(a_vec, partial_sum_vec, sum_vec); + sum_vec = _mm256_fmadd_ps(a_vec, cb_j_vec, sum_vec); } // Handle the tail of every row @@ -594,18 +584,14 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi _simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length)); __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell( // _simsimd_partial_load_bf16x8_haswell(c + i * n + tail_start, tail_length)); - simsimd_f32_t partial_sum = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_vec, c_vec)); - sum += a_i * partial_sum; + simsimd_f32_t cb_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_vec, c_vec)); + sum += a_i * cb_j; } } *result = sum; } -SIMSIMD_PUBLIC void simsimd_bilinear_bf16c_haswell(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b, - simsimd_bf16c_t const *c, simsimd_size_t n, - simsimd_distance_t *result) {} - SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, simsimd_size_t n, simsimd_distance_t *result) { @@ -614,15 +600,15 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si __m256 diff_i_vec = _mm256_sub_ps( // _mm256_set1_ps(simsimd_bf16_to_f32(a + i)), // _mm256_set1_ps(simsimd_bf16_to_f32(b + i))); - __m256 partial_sum_vec = _mm256_setzero_ps(); + __m256 cdiff_j_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { __m256 diff_j_vec = _mm256_sub_ps( // _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(a + j))), // _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(b + j)))); __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); - partial_sum_vec = _mm256_fmadd_ps(diff_j_vec, c_vec, partial_sum_vec); + cdiff_j_vec = _mm256_fmadd_ps(diff_j_vec, c_vec, cdiff_j_vec); } - sum_vec = _mm256_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); + sum_vec = _mm256_fmadd_ps(diff_i_vec, cdiff_j_vec, sum_vec); } // Handle the tail of every row @@ -637,8 +623,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si _simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length))); __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell( _simsimd_partial_load_bf16x8_haswell(c + i * n + tail_start, tail_length)); - simsimd_f32_t partial_sum = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(diff_j_vec, c_vec)); - sum += diff_i * partial_sum; + simsimd_f32_t cdiff_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(diff_j_vec, c_vec)); + sum += diff_i * cdiff_j; } } @@ -660,43 +646,43 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_skylake(simsimd_f32_t const *a, simsimd // On modern x86 CPUs we have enough register space to load fairly large matrices with up to 16 cells // per row and 16 rows at a time, still keeping enough register space for temporaries. if (n <= 16) { - // The goal of this optimization is to avoid horizontal accumulation of the partial sums + // The goal of this optimization is to avoid horizontal accumulation of the cb_j sums // until the very end of the computation. simsimd_size_t const row_length = n % 16; __mmask16 const row_mask = (__mmask16)_bzhi_u32(0xFFFF, row_length); __m512 const b_vec = _mm512_maskz_loadu_ps(row_mask, b); - __m512 partial_sum1 = _mm512_setzero_ps(); - __m512 partial_sum2 = _mm512_setzero_ps(); - __m512 partial_sum3 = _mm512_setzero_ps(); - __m512 partial_sum4 = _mm512_setzero_ps(); + __m512 cb_j1 = _mm512_setzero_ps(); + __m512 cb_j2 = _mm512_setzero_ps(); + __m512 cb_j3 = _mm512_setzero_ps(); + __m512 cb_j4 = _mm512_setzero_ps(); // clang-format off - if (n > 0) partial_sum1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 0), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[0])), partial_sum1); - if (n > 1) partial_sum2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 1), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[1])), partial_sum2); - if (n > 2) partial_sum3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 2), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[2])), partial_sum3); - if (n > 3) partial_sum4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 3), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[3])), partial_sum4); - - if (n > 4) partial_sum1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 4), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[4])), partial_sum1); - if (n > 5) partial_sum2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 5), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[5])), partial_sum2); - if (n > 6) partial_sum3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 6), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[6])), partial_sum3); - if (n > 7) partial_sum4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 7), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[7])), partial_sum4); - - if (n > 8) partial_sum1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 8), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[8])), partial_sum1); - if (n > 9) partial_sum2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 9), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[9])), partial_sum2); - if (n > 10) partial_sum3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 10), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[10])), partial_sum3); - if (n > 11) partial_sum4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 11), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[11])), partial_sum4); - - if (n > 12) partial_sum1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 12), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[12])), partial_sum1); - if (n > 13) partial_sum2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 13), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[13])), partial_sum2); - if (n > 14) partial_sum3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 14), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[14])), partial_sum3); - if (n > 15) partial_sum4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 15), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[15])), partial_sum4); + if (n > 0) cb_j1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 0), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[0])), cb_j1); + if (n > 1) cb_j2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 1), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[1])), cb_j2); + if (n > 2) cb_j3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 2), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[2])), cb_j3); + if (n > 3) cb_j4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 3), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[3])), cb_j4); + + if (n > 4) cb_j1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 4), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[4])), cb_j1); + if (n > 5) cb_j2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 5), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[5])), cb_j2); + if (n > 6) cb_j3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 6), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[6])), cb_j3); + if (n > 7) cb_j4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 7), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[7])), cb_j4); + + if (n > 8) cb_j1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 8), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[8])), cb_j1); + if (n > 9) cb_j2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 9), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[9])), cb_j2); + if (n > 10) cb_j3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 10), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[10])), cb_j3); + if (n > 11) cb_j4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 11), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[11])), cb_j4); + + if (n > 12) cb_j1 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 12), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[12])), cb_j1); + if (n > 13) cb_j2 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 13), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[13])), cb_j2); + if (n > 14) cb_j3 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 14), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[14])), cb_j3); + if (n > 15) cb_j4 = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(row_mask, c + n * 15), _mm512_mul_ps(b_vec, _mm512_set1_ps(a[15])), cb_j4); // clang-format on - // Combine partial sums - __m512 sum_vec = _mm512_add_ps( // - _mm512_add_ps(partial_sum1, partial_sum2), // - _mm512_add_ps(partial_sum3, partial_sum4)); + // Combine cb_j sums + __m512 sum_vec = _mm512_add_ps( // + _mm512_add_ps(cb_j1, cb_j2), // + _mm512_add_ps(cb_j3, cb_j4)); *result = _mm512_reduce_add_ps(sum_vec); return; } @@ -709,7 +695,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_skylake(simsimd_f32_t const *a, simsimd for (simsimd_size_t i = 0; i != n; ++i) { __m512 a_vec = _mm512_set1_ps(a[i]); - __m512 partial_sum_vec = _mm512_setzero_ps(); + __m512 cb_j_vec = _mm512_setzero_ps(); __m512 b_vec, c_vec; simsimd_size_t j = 0; @@ -722,10 +708,10 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_skylake(simsimd_f32_t const *a, simsimd b_vec = _mm512_maskz_loadu_ps(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_ps(tail_mask, c + i * n + tail_start); } - partial_sum_vec = _mm512_fmadd_ps(b_vec, c_vec, partial_sum_vec); + cb_j_vec = _mm512_fmadd_ps(b_vec, c_vec, cb_j_vec); j += 16; if (j < n) goto simsimd_bilinear_f32_skylake_cycle; - sum_vec = _mm512_fmadd_ps(a_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_ps(a_vec, cb_j_vec, sum_vec); } *result = _mm512_reduce_add_ps(sum_vec); @@ -741,7 +727,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const *a, sims for (simsimd_size_t i = 0; i != n; ++i) { __m512 diff_i_vec = _mm512_set1_ps(a[i] - b[i]); - __m512 partial_sum_vec = _mm512_setzero_ps(), partial_sum_bot_vec = _mm512_setzero_ps(); + __m512 cdiff_j_vec = _mm512_setzero_ps(), cdiff_j_bot_vec = _mm512_setzero_ps(); __m512 a_j_vec, b_j_vec, diff_j_vec, c_vec; simsimd_size_t j = 0; @@ -758,10 +744,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const *a, sims c_vec = _mm512_maskz_loadu_ps(tail_mask, c + i * n + tail_start); } diff_j_vec = _mm512_sub_ps(a_j_vec, b_j_vec); - partial_sum_vec = _mm512_fmadd_ps(diff_j_vec, c_vec, partial_sum_vec); + cdiff_j_vec = _mm512_fmadd_ps(diff_j_vec, c_vec, cdiff_j_vec); j += 16; if (j < n) goto simsimd_bilinear_f32_skylake_cycle; - sum_vec = _mm512_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_ps(diff_i_vec, cdiff_j_vec, sum_vec); } *result = _simsimd_sqrt_f64_haswell(_mm512_reduce_add_ps(sum_vec)); @@ -769,7 +755,61 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const *a, sims SIMSIMD_PUBLIC void simsimd_bilinear_f32c_skylake(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_f32c_t const *c, simsimd_size_t n, - simsimd_distance_t *results) {} + simsimd_distance_t *results) { + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i const sign_flip_vec = _mm512_set1_epi64(0x8000000000000000); + + // Default case for arbitrary size `n` + simsimd_size_t const tail_length = n % 8; + simsimd_size_t const tail_start = n - tail_length; + __mmask16 const tail_mask = (__mmask16)_bzhi_u32(0xFFFF, tail_length * 2); + simsimd_f32_t sum_real = 0; + simsimd_f32_t sum_imag = 0; + + for (simsimd_size_t i = 0; i != n; ++i) { + simsimd_f32_t const a_i_real = a[i].real; + simsimd_f32_t const a_i_imag = a[i].imag; + __m512 cb_j_real_vec = _mm512_setzero_ps(); + __m512 cb_j_imag_vec = _mm512_setzero_ps(); + __m512 b_vec, c_vec; + simsimd_size_t j = 0; + + simsimd_bilinear_f32c_skylake_cycle: + if (j + 8 <= n) { + b_vec = _mm512_loadu_ps((simsimd_f32_t const *)(b + j)); + c_vec = _mm512_loadu_ps((simsimd_f32_t const *)(c + i * n + j)); + } + else { + b_vec = _mm512_maskz_loadu_ps(tail_mask, (simsimd_f32_t const *)(b + tail_start)); + c_vec = _mm512_maskz_loadu_ps(tail_mask, (simsimd_f32_t const *)(c + i * n + tail_start)); + } + // The real part of the product: b.real * c.real - b.imag * c.imag. + // The subtraction will be performed later with a sign flip. + cb_j_real_vec = _mm512_fmadd_ps(c_vec, b_vec, cb_j_real_vec); + // The imaginary part of the product: b.real * c.imag + b.imag * c.real. + // Swap the imaginary and real parts of `c` before multiplication: + c_vec = _mm512_permute_ps(c_vec, 0xB1); //? Swap adjacent entries within each pair + cb_j_imag_vec = _mm512_fmadd_ps(c_vec, b_vec, cb_j_imag_vec); + j += 8; + if (j < n) goto simsimd_bilinear_f32c_skylake_cycle; + // Flip the sign bit in every second scalar before accumulation: + cb_j_real_vec = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(cb_j_real_vec), sign_flip_vec)); + // Horizontal sums are the expensive part of the computation: + simsimd_f32_t const cb_j_real = _mm512_reduce_add_ps(cb_j_real_vec); + simsimd_f32_t const cb_j_imag = _mm512_reduce_add_ps(cb_j_imag_vec); + sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag; + sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real; + } + + // Reduce horizontal sums: + results[0] = sum_real; + results[1] = sum_imag; +} SIMSIMD_PUBLIC void simsimd_bilinear_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n, simsimd_distance_t *result) { @@ -777,33 +817,33 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f64_skylake(simsimd_f64_t const *a, simsimd // On modern x86 CPUs we have enough register space to load fairly large matrices with up to 16 cells // per row and 8 rows at a time, still keeping enough register space for temporaries. if (n <= 8) { - // The goal of this optimization is to avoid horizontal accumulation of the partial sums + // The goal of this optimization is to avoid horizontal accumulation of the cb_j sums // until the very end of the computation. simsimd_size_t const row_length = n % 8; __mmask8 const row_mask = (__mmask8)_bzhi_u32(0xFFFF, row_length); __m512d const b_vec = _mm512_maskz_loadu_pd(row_mask, b); - __m512d partial_sum1 = _mm512_setzero_pd(); - __m512d partial_sum2 = _mm512_setzero_pd(); - __m512d partial_sum3 = _mm512_setzero_pd(); - __m512d partial_sum4 = _mm512_setzero_pd(); + __m512d cb_j1 = _mm512_setzero_pd(); + __m512d cb_j2 = _mm512_setzero_pd(); + __m512d cb_j3 = _mm512_setzero_pd(); + __m512d cb_j4 = _mm512_setzero_pd(); // clang-format off - if (n > 0) partial_sum1 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 0), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[0])), partial_sum1); - if (n > 1) partial_sum2 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 1), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[1])), partial_sum2); - if (n > 2) partial_sum3 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 2), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[2])), partial_sum3); - if (n > 3) partial_sum4 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 3), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[3])), partial_sum4); - - if (n > 4) partial_sum1 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 4), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[4])), partial_sum1); - if (n > 5) partial_sum2 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 5), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[5])), partial_sum2); - if (n > 6) partial_sum3 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 6), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[6])), partial_sum3); - if (n > 7) partial_sum4 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 7), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[7])), partial_sum4); + if (n > 0) cb_j1 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 0), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[0])), cb_j1); + if (n > 1) cb_j2 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 1), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[1])), cb_j2); + if (n > 2) cb_j3 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 2), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[2])), cb_j3); + if (n > 3) cb_j4 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 3), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[3])), cb_j4); + + if (n > 4) cb_j1 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 4), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[4])), cb_j1); + if (n > 5) cb_j2 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 5), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[5])), cb_j2); + if (n > 6) cb_j3 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 6), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[6])), cb_j3); + if (n > 7) cb_j4 = _mm512_fmadd_pd(_mm512_maskz_loadu_pd(row_mask, c + n * 7), _mm512_mul_pd(b_vec, _mm512_set1_pd(a[7])), cb_j4); // clang-format on - // Combine partial sums - __m512d sum_vec = _mm512_add_pd( // - _mm512_add_pd(partial_sum1, partial_sum2), // - _mm512_add_pd(partial_sum3, partial_sum4)); + // Combine cb_j sums + __m512d sum_vec = _mm512_add_pd( // + _mm512_add_pd(cb_j1, cb_j2), // + _mm512_add_pd(cb_j3, cb_j4)); *result = _mm512_reduce_add_pd(sum_vec); return; } @@ -816,7 +856,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f64_skylake(simsimd_f64_t const *a, simsimd for (simsimd_size_t i = 0; i != n; ++i) { __m512d a_vec = _mm512_set1_pd(a[i]); - __m512d partial_sum_vec = _mm512_setzero_pd(); + __m512d cb_j_vec = _mm512_setzero_pd(); __m512d b_vec, c_vec; simsimd_size_t j = 0; @@ -829,15 +869,51 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f64_skylake(simsimd_f64_t const *a, simsimd b_vec = _mm512_maskz_loadu_pd(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start); } - partial_sum_vec = _mm512_fmadd_pd(b_vec, c_vec, partial_sum_vec); + cb_j_vec = _mm512_fmadd_pd(b_vec, c_vec, cb_j_vec); j += 8; if (j < n) goto simsimd_bilinear_f64_skylake_cycle; - sum_vec = _mm512_fmadd_pd(a_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_pd(a_vec, cb_j_vec, sum_vec); } *result = _mm512_reduce_add_pd(sum_vec); } +SIMSIMD_PUBLIC void simsimd_mahalanobis_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, + simsimd_f64_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { + simsimd_size_t const tail_length = n % 8; + simsimd_size_t const tail_start = n - tail_length; + __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFF, tail_length); + __m512d sum_vec = _mm512_setzero_pd(); + + for (simsimd_size_t i = 0; i != n; ++i) { + __m512d diff_i_vec = _mm512_set1_pd(a[i] - b[i]); + __m512d cdiff_j_vec = _mm512_setzero_pd(); + __m512d a_j_vec, b_j_vec, diff_j_vec, c_vec; + simsimd_size_t j = 0; + + // The nested loop is cleaner to implement with a `goto` in this case: + simsimd_bilinear_f64_skylake_cycle: + if (j + 8 <= n) { + a_j_vec = _mm512_loadu_pd(a + j); + b_j_vec = _mm512_loadu_pd(b + j); + c_vec = _mm512_loadu_pd(c + i * n + j); + } + else { + a_j_vec = _mm512_maskz_loadu_pd(tail_mask, a + tail_start); + b_j_vec = _mm512_maskz_loadu_pd(tail_mask, b + tail_start); + c_vec = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start); + } + diff_j_vec = _mm512_sub_pd(a_j_vec, b_j_vec); + cdiff_j_vec = _mm512_fmadd_pd(diff_j_vec, c_vec, cdiff_j_vec); + j += 8; + if (j < n) goto simsimd_bilinear_f64_skylake_cycle; + sum_vec = _mm512_fmadd_pd(diff_i_vec, cdiff_j_vec, sum_vec); + } + + *result = _simsimd_sqrt_f64_haswell(_mm512_reduce_add_pd(sum_vec)); +} + SIMSIMD_PUBLIC void simsimd_bilinear_f64c_skylake(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_f64c_t const *c, simsimd_size_t n, simsimd_distance_t *results) { @@ -851,29 +927,23 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f64c_skylake(simsimd_f64c_t const *a, simsi 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, // 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 // ); - __m512i const swap_adjacent_vec = _mm512_set_epi8( // - 55, 54, 53, 52, 51, 50, 49, 48, 63, 62, 61, 60, 59, 58, 57, 56, // 4th 128-bit lane - 39, 38, 37, 36, 35, 34, 33, 32, 47, 46, 45, 44, 43, 42, 41, 40, // 3rd 128-bit lane - 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24, // 2nd 128-bit lane - 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 // 1st 128-bit lane - ); // Default case for arbitrary size `n` simsimd_size_t const tail_length = n % 4; simsimd_size_t const tail_start = n - tail_length; __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFF, tail_length * 2); - __m512d sum_real_vec = _mm512_setzero_pd(); - __m512d sum_imag_vec = _mm512_setzero_pd(); + simsimd_f64_t sum_real = 0; + simsimd_f64_t sum_imag = 0; for (simsimd_size_t i = 0; i != n; ++i) { - __m512d a_real_vec = _mm512_set1_pd(a[i].real); - __m512d a_imag_vec = _mm512_set1_pd(a[i].imag); - __m512d partial_sum_real_vec = _mm512_setzero_pd(); - __m512d partial_sum_imag_vec = _mm512_setzero_pd(); + simsimd_f64_t const a_i_real = a[i].real; + simsimd_f64_t const a_i_imag = a[i].imag; + __m512d cb_j_real_vec = _mm512_setzero_pd(); + __m512d cb_j_imag_vec = _mm512_setzero_pd(); __m512d b_vec, c_vec; simsimd_size_t j = 0; - simsimd_bilinear_f64_skylake_cycle: + simsimd_bilinear_f64c_skylake_cycle: if (j + 4 <= n) { b_vec = _mm512_loadu_pd((simsimd_f64_t const *)(b + j)); c_vec = _mm512_loadu_pd((simsimd_f64_t const *)(c + i * n + j)); @@ -882,58 +952,27 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f64c_skylake(simsimd_f64c_t const *a, simsi b_vec = _mm512_maskz_loadu_pd(tail_mask, (simsimd_f64_t const *)(b + tail_start)); c_vec = _mm512_maskz_loadu_pd(tail_mask, (simsimd_f64_t const *)(c + i * n + tail_start)); } - partial_sum_real_vec = _mm512_fmadd_pd(c_vec, b_vec, partial_sum_real_vec); - partial_sum_imag_vec = _mm512_fmadd_pd( // - _mm512_castsi512_pd(_mm512_shuffle_epi8(_mm512_castpd_si512(c_vec), swap_adjacent_vec)), b_vec, - partial_sum_imag_vec); + // The real part of the product: b.real * c.real - b.imag * c.imag. + // The subtraction will be performed later with a sign flip. + cb_j_real_vec = _mm512_fmadd_pd(c_vec, b_vec, cb_j_real_vec); + // The imaginary part of the product: b.real * c.imag + b.imag * c.real. + // Swap the imaginary and real parts of `c` before multiplication: + c_vec = _mm512_permute_pd(c_vec, 0x55); //? Same as 0b01010101. + cb_j_imag_vec = _mm512_fmadd_pd(c_vec, b_vec, cb_j_imag_vec); j += 4; - if (j < n) goto simsimd_bilinear_f64_skylake_cycle; + if (j < n) goto simsimd_bilinear_f64c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: - partial_sum_real_vec = - _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(partial_sum_real_vec), sign_flip_vec)); - - sum_real_vec = _mm512_fmadd_pd(a_real_vec, partial_sum_real_vec, sum_real_vec); - sum_imag_vec = _mm512_fmadd_pd(a_imag_vec, partial_sum_imag_vec, sum_imag_vec); - } - - results[0] = _mm512_reduce_add_pd(sum_real_vec); - results[1] = _mm512_reduce_add_pd(sum_imag_vec); -} - -SIMSIMD_PUBLIC void simsimd_mahalanobis_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, - simsimd_f64_t const *c, simsimd_size_t n, - simsimd_distance_t *result) { - simsimd_size_t const tail_length = n % 8; - simsimd_size_t const tail_start = n - tail_length; - __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFF, tail_length); - __m512d sum_vec = _mm512_setzero_pd(); - - for (simsimd_size_t i = 0; i != n; ++i) { - __m512d diff_i_vec = _mm512_set1_pd(a[i] - b[i]); - __m512d partial_sum_vec = _mm512_setzero_pd(), partial_sum_bot_vec = _mm512_setzero_pd(); - __m512d a_j_vec, b_j_vec, diff_j_vec, c_vec; - simsimd_size_t j = 0; - - // The nested loop is cleaner to implement with a `goto` in this case: - simsimd_bilinear_f64_skylake_cycle: - if (j + 8 <= n) { - a_j_vec = _mm512_loadu_pd(a + j); - b_j_vec = _mm512_loadu_pd(b + j); - c_vec = _mm512_loadu_pd(c + i * n + j); - } - else { - a_j_vec = _mm512_maskz_loadu_pd(tail_mask, a + tail_start); - b_j_vec = _mm512_maskz_loadu_pd(tail_mask, b + tail_start); - c_vec = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start); - } - diff_j_vec = _mm512_sub_pd(a_j_vec, b_j_vec); - partial_sum_vec = _mm512_fmadd_pd(diff_j_vec, c_vec, partial_sum_vec); - j += 8; - if (j < n) goto simsimd_bilinear_f64_skylake_cycle; - sum_vec = _mm512_fmadd_pd(diff_i_vec, partial_sum_vec, sum_vec); + cb_j_real_vec = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(cb_j_real_vec), sign_flip_vec)); + // Horizontal sums are the expensive part of the computation: + simsimd_f64_t const cb_j_real = _mm512_reduce_add_pd(cb_j_real_vec); + simsimd_f64_t const cb_j_imag = _mm512_reduce_add_pd(cb_j_imag_vec); + sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag; + sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real; } - *result = _simsimd_sqrt_f64_haswell(_mm512_reduce_add_pd(sum_vec)); + // Reduce horizontal sums: + results[0] = sum_real; + results[1] = sum_imag; } #pragma clang attribute pop @@ -955,7 +994,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const *a, simsimd for (simsimd_size_t i = 0; i != n; ++i) { __m512 a_vec = _mm512_set1_ps(simsimd_bf16_to_f32(a + i)); - __m512 partial_sum_vec = _mm512_setzero_ps(); + __m512 cb_j_vec = _mm512_setzero_ps(); __m512i b_vec, c_vec; simsimd_size_t j = 0; @@ -968,10 +1007,10 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const *a, simsimd b_vec = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); } - partial_sum_vec = _mm512_dpbf16_ps(partial_sum_vec, (__m512bh)(b_vec), (__m512bh)(c_vec)); + cb_j_vec = _mm512_dpbf16_ps(cb_j_vec, (__m512bh)(b_vec), (__m512bh)(c_vec)); j += 32; if (j < n) goto simsimd_bilinear_bf16_genoa_cycle; - sum_vec = _mm512_fmadd_ps(a_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_ps(a_vec, cb_j_vec, sum_vec); } *result = _mm512_reduce_add_ps(sum_vec); @@ -987,7 +1026,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, sims for (simsimd_size_t i = 0; i != n; ++i) { __m512 diff_i_vec = _mm512_set1_ps(simsimd_bf16_to_f32(a + i) - simsimd_bf16_to_f32(b + i)); - __m512 partial_sum_vec = _mm512_setzero_ps(); + __m512 cdiff_j_vec = _mm512_setzero_ps(); __m512i a_j_vec, b_j_vec, diff_j_vec, c_vec; simsimd_size_t j = 0; @@ -1004,10 +1043,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, sims c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); } diff_j_vec = _simsimd_substract_bf16x32_genoa(a_j_vec, b_j_vec); - partial_sum_vec = _mm512_dpbf16_ps(partial_sum_vec, (__m512bh)(diff_j_vec), (__m512bh)(c_vec)); + cdiff_j_vec = _mm512_dpbf16_ps(cdiff_j_vec, (__m512bh)(diff_j_vec), (__m512bh)(c_vec)); j += 32; if (j < n) goto simsimd_mahalanobis_bf16_genoa_cycle; - sum_vec = _mm512_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_ps(diff_i_vec, cdiff_j_vec, sum_vec); } *result = _simsimd_sqrt_f32_haswell(_mm512_reduce_add_ps(sum_vec)); @@ -1015,7 +1054,66 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, sims SIMSIMD_PUBLIC void simsimd_bilinear_bf16c_genoa(simsimd_bf16c_t const *a, simsimd_bf16c_t const *b, simsimd_bf16c_t const *c, simsimd_size_t n, - simsimd_distance_t *results) {} + simsimd_distance_t *results) { + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i const sign_flip_vec = _mm512_set1_epi32(0x80000000); + __m512i const swap_adjacent_vec = _mm512_set_epi8( // + 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane + 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane + 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane + ); + + // Default case for arbitrary size `n` + simsimd_size_t const tail_length = n % 16; + simsimd_size_t const tail_start = n - tail_length; + __mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length * 2); + simsimd_f64_t sum_real = 0; + simsimd_f64_t sum_imag = 0; + + for (simsimd_size_t i = 0; i != n; ++i) { + simsimd_f32_t const a_i_real = a[i].real; + simsimd_f32_t const a_i_imag = a[i].imag; + __m512 cb_j_real_vec = _mm512_setzero_ps(); + __m512 cb_j_imag_vec = _mm512_setzero_ps(); + __m512i b_vec, c_vec; + simsimd_size_t j = 0; + + simsimd_bilinear_bf16c_skylake_cycle: + if (j + 16 <= n) { + b_vec = _mm512_loadu_epi16((simsimd_i16_t const *)(b + j)); + c_vec = _mm512_loadu_epi16((simsimd_i16_t const *)(c + i * n + j)); + } + else { + b_vec = _mm512_maskz_loadu_epi16(tail_mask, (simsimd_i16_t const *)(b + tail_start)); + c_vec = _mm512_maskz_loadu_epi16(tail_mask, (simsimd_i16_t const *)(c + i * n + tail_start)); + } + cb_j_real_vec = _mm512_dpbf16_ps( // + cb_j_real_vec, // + (__m512bh)(_mm512_xor_si512(c_vec, sign_flip_vec)), // + (__m512bh)b_vec); + cb_j_imag_vec = _mm512_dpbf16_ps( // + cb_j_imag_vec, // + (__m512bh)(_mm512_shuffle_epi8(c_vec, swap_adjacent_vec)), // + (__m512bh)b_vec); + j += 16; + if (j < n) goto simsimd_bilinear_bf16c_skylake_cycle; + // Horizontal sums are the expensive part of the computation: + simsimd_f64_t const cb_j_real = _simsimd_reduce_f32x16_skylake(cb_j_real_vec); + simsimd_f64_t const cb_j_imag = _simsimd_reduce_f32x16_skylake(cb_j_imag_vec); + sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag; + sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real; + } + + // Reduce horizontal sums: + results[0] = sum_real; + results[1] = sum_imag; +} #pragma clang attribute pop #pragma GCC pop_options @@ -1037,7 +1135,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const *a, simsim for (simsimd_size_t i = 0; i != n; ++i) { __m512h a_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const *)(a + i))); - __m512h partial_sum_vec = _mm512_setzero_ph(); + __m512h cb_j_vec = _mm512_setzero_ph(); __m512i b_vec, c_vec; simsimd_size_t j = 0; @@ -1050,10 +1148,10 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const *a, simsim b_vec = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); } - partial_sum_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(b_vec), _mm512_castsi512_ph(c_vec), partial_sum_vec); + cb_j_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(b_vec), _mm512_castsi512_ph(c_vec), cb_j_vec); j += 32; if (j < n) goto simsimd_bilinear_f16_sapphire_cycle; - sum_vec = _mm512_fmadd_ph(a_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_ph(a_vec, cb_j_vec, sum_vec); } *result = _mm512_reduce_add_ph(sum_vec); @@ -1071,7 +1169,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const *a, sim __m512h a_i_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const *)(a + i))); __m512h b_i_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const *)(b + i))); __m512h diff_i_vec = _mm512_sub_ph(a_i_vec, b_i_vec); - __m512h partial_sum_vec = _mm512_setzero_ph(); + __m512h cdiff_j_vec = _mm512_setzero_ph(); __m512h diff_j_vec; __m512i a_j_vec, b_j_vec, c_vec; simsimd_size_t j = 0; @@ -1089,10 +1187,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const *a, sim c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); } diff_j_vec = _mm512_sub_ph(_mm512_castsi512_ph(a_j_vec), _mm512_castsi512_ph(b_j_vec)); - partial_sum_vec = _mm512_fmadd_ph(diff_j_vec, _mm512_castsi512_ph(c_vec), partial_sum_vec); + cdiff_j_vec = _mm512_fmadd_ph(diff_j_vec, _mm512_castsi512_ph(c_vec), cdiff_j_vec); j += 32; if (j < n) goto simsimd_mahalanobis_f16_sapphire_cycle; - sum_vec = _mm512_fmadd_ph(diff_i_vec, partial_sum_vec, sum_vec); + sum_vec = _mm512_fmadd_ph(diff_i_vec, cdiff_j_vec, sum_vec); } *result = _simsimd_sqrt_f32_haswell(_mm512_reduce_add_ph(sum_vec)); @@ -1100,7 +1198,64 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const *a, sim SIMSIMD_PUBLIC void simsimd_bilinear_f16c_sapphire(simsimd_f16c_t const *a, simsimd_f16c_t const *b, simsimd_f16c_t const *c, simsimd_size_t n, - simsimd_distance_t *results) {} + simsimd_distance_t *results) { + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i const sign_flip_vec = _mm512_set1_epi32(0x80000000); + __m512i const swap_adjacent_vec = _mm512_set_epi8( // + 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane + 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane + 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane + ); + + // Default case for arbitrary size `n` + simsimd_size_t const tail_length = n % 16; + simsimd_size_t const tail_start = n - tail_length; + __mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length * 2); + simsimd_f32_t sum_real = 0; + simsimd_f32_t sum_imag = 0; + + for (simsimd_size_t i = 0; i != n; ++i) { + simsimd_f32_t const a_i_real = a[i].real; + simsimd_f32_t const a_i_imag = a[i].imag; + __m512h cb_j_real_vec = _mm512_setzero_ph(); + __m512h cb_j_imag_vec = _mm512_setzero_ph(); + __m512i b_vec, c_vec; + simsimd_size_t j = 0; + + simsimd_bilinear_f16c_skylake_cycle: + if (j + 16 <= n) { + b_vec = _mm512_loadu_epi16((simsimd_i16_t const *)(b + j)); + c_vec = _mm512_loadu_epi16((simsimd_i16_t const *)(c + i * n + j)); + } + else { + b_vec = _mm512_maskz_loadu_epi16(tail_mask, (simsimd_i16_t const *)(b + tail_start)); + c_vec = _mm512_maskz_loadu_epi16(tail_mask, (simsimd_i16_t const *)(c + i * n + tail_start)); + } + cb_j_real_vec = _mm512_fmadd_ph( // + _mm512_castsi512_ph(_mm512_xor_si512(c_vec, sign_flip_vec)), // + _mm512_castsi512_ph(b_vec), cb_j_real_vec); + cb_j_imag_vec = _mm512_fmadd_ph( // + _mm512_castsi512_ph(_mm512_shuffle_epi8(c_vec, swap_adjacent_vec)), // + _mm512_castsi512_ph(b_vec), cb_j_imag_vec); + j += 16; + if (j < n) goto simsimd_bilinear_f16c_skylake_cycle; + // Horizontal sums are the expensive part of the computation: + simsimd_f32_t const cb_j_real = _mm512_reduce_add_ph(cb_j_real_vec); + simsimd_f32_t const cb_j_imag = _mm512_reduce_add_ph(cb_j_imag_vec); + sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag; + sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real; + } + + // Reduce horizontal sums: + results[0] = sum_real; + results[1] = sum_imag; +} #pragma clang attribute pop #pragma GCC pop_options diff --git a/include/simsimd/dot.h b/include/simsimd/dot.h index f90b758e..1033a6d7 100644 --- a/include/simsimd/dot.h +++ b/include/simsimd/dot.h @@ -1438,7 +1438,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64c_t const *a_pairs, sims a_pairs += 4, b_pairs += 4, count_pairs -= 4; } ab_real_vec = _mm512_fmadd_pd(b_vec, a_vec, ab_real_vec); - b_vec = _mm512_permute_pd(b_vec, 0xAA); //? Same as 0b10101010. + b_vec = _mm512_permute_pd(b_vec, 0x55); //? Same as 0b01010101. ab_imag_vec = _mm512_fmadd_pd(b_vec, a_vec, ab_imag_vec); if (count_pairs) goto simsimd_dot_f64c_skylake_cycle; @@ -1478,7 +1478,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64c_t const *a_pairs, sim a_pairs += 4, b_pairs += 4, count_pairs -= 4; } ab_real_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_real_vec); - b_vec = _mm512_permute_pd(b_vec, 0xAA); //? Same as 0b10101010. + b_vec = _mm512_permute_pd(b_vec, 0x55); //? Same as 0b01010101. ab_imag_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_imag_vec); if (count_pairs) goto simsimd_vdot_f64c_skylake_cycle; @@ -1670,7 +1670,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16c_t const *a_pairs, sim b_vec = _mm512_loadu_epi16(b_pairs); a_pairs += 16, b_pairs += 16, count_pairs -= 16; } - // TODO: Consider using `_mm512_fmaddsub` + // TODO: Consider using `_mm512_fmaddsub` and `_mm512_fcmadd_pch` ab_real_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_xor_si512(b_vec, sign_flip_vec)), _mm512_castsi512_ph(a_vec), ab_real_vec); ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), @@ -1714,6 +1714,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16c_t const *a_pairs, si b_vec = _mm512_loadu_epi16(b_pairs); a_pairs += 16, b_pairs += 16, count_pairs -= 16; } + // TODO: Consider using `_mm512_fmaddsub` and `_mm512_fcmadd_pch` ab_real_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_vec), _mm512_castsi512_ph(b_vec), ab_real_vec); a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); diff --git a/include/simsimd/simsimd.h b/include/simsimd/simsimd.h index 351f55c5..d8bae75f 100644 --- a/include/simsimd/simsimd.h +++ b/include/simsimd/simsimd.h @@ -2175,8 +2175,6 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16c(simsimd_f16c_t const *a, simsimd_f16c_ simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SAPPHIRE simsimd_bilinear_f16c_sapphire(a, b, c, n, d); -#elif SIMSIMD_TARGET_HASWELL - simsimd_bilinear_f16c_haswell(a, b, c, n, d); #elif SIMSIMD_TARGET_NEON simsimd_bilinear_f16c_neon(a, b, c, n, d); #else @@ -2187,8 +2185,6 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16c(simsimd_bf16c_t const *a, simsimd_bf1 simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_bilinear_bf16c_genoa(a, b, c, n, d); -#elif SIMSIMD_TARGET_HASWELL - simsimd_bilinear_bf16c_haswell(a, b, c, n, d); #elif SIMSIMD_TARGET_NEON simsimd_bilinear_bf16c_neon(a, b, c, n, d); #else diff --git a/scripts/bench.cxx b/scripts/bench.cxx index 6e359206..3c6f3f24 100644 --- a/scripts/bench.cxx +++ b/scripts/bench.cxx @@ -327,14 +327,14 @@ void measure_curved(bm::State &state, metric_at metric, metric_at baseline, std: using vector_t = typename pair_at::vector_t; auto call_baseline = [&](pair_t const &pair, vector_t const &tensor) -> double { - simsimd_distance_t results[2] = {signaling_distance, signaling_distance}; + simsimd_distance_t results[2] = {signaling_distance, 0}; baseline(pair.a.data(), pair.b.data(), tensor.data(), pair.a.size(), &results[0]); - return results[1] != signaling_distance ? results[0] + results[1] : results[0]; + return results[0] + results[1]; }; auto call_contender = [&](pair_t const &pair, vector_t const &tensor) -> double { - simsimd_distance_t results[2] = {signaling_distance, signaling_distance}; + simsimd_distance_t results[2] = {signaling_distance, 0}; metric(pair.a.data(), pair.b.data(), tensor.data(), pair.a.size(), &results[0]); - return results[1] != signaling_distance ? results[0] + results[1] : results[0]; + return results[0] + results[1]; }; // Let's average the distance results over many pairs. @@ -962,12 +962,12 @@ int main(int argc, char **argv) { dense_("cos_bf16_genoa", simsimd_cos_bf16_genoa, simsimd_cos_bf16_accurate); dense_("l2sq_bf16_genoa", simsimd_l2sq_bf16_genoa, simsimd_l2sq_bf16_accurate); dense_("l2_bf16_genoa", simsimd_l2_bf16_genoa, simsimd_l2_bf16_accurate); - dense_("dot_bf16c_genoa", simsimd_dot_bf16c_genoa, simsimd_dot_bf16c_accurate); dense_("vdot_bf16c_genoa", simsimd_vdot_bf16c_genoa, simsimd_vdot_bf16c_accurate); curved_("bilinear_bf16_genoa", simsimd_bilinear_bf16_genoa, simsimd_bilinear_bf16_accurate); curved_("mahalanobis_bf16_genoa", simsimd_mahalanobis_bf16_genoa, simsimd_mahalanobis_bf16_accurate); + curved_("bilinear_bf16c_genoa", simsimd_bilinear_bf16c_genoa, simsimd_bilinear_bf16c_accurate); #endif #if SIMSIMD_TARGET_SAPPHIRE @@ -985,6 +985,10 @@ int main(int argc, char **argv) { fma_("wsum_u8_sapphire", simsimd_wsum_u8_sapphire, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial); fma_("fma_i8_sapphire", simsimd_fma_i8_sapphire, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); fma_("wsum_i8_sapphire", simsimd_wsum_i8_sapphire, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); + + curved_("bilinear_f16_sapphire", simsimd_bilinear_f16_sapphire, simsimd_bilinear_f16_accurate); + curved_("mahalanobis_f16_sapphire", simsimd_mahalanobis_f16_sapphire, simsimd_mahalanobis_f16_accurate); + curved_("bilinear_f16c_sapphire", simsimd_bilinear_f16c_sapphire, simsimd_bilinear_f16c_accurate); #endif #if SIMSIMD_TARGET_ICE