From 70ce8d53e2aedabbf1d5326222fcef580def7fff Mon Sep 17 00:00:00 2001 From: Ashot Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 11 Jun 2024 02:04:02 +0000 Subject: [PATCH] Add: Harley Seal AVX-512 implementation This commit adds the optimized Harley Seal kernel from the `WojciechMula/sse-popcount` library to the benchmarking suite to investigate optimization opportunities on Intel Sapphire Rapids and AMD Genoa chips. --- cpp/bench.cxx | 185 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 180 insertions(+), 5 deletions(-) diff --git a/cpp/bench.cxx b/cpp/bench.cxx index 1df1280f..1271a231 100644 --- a/cpp/bench.cxx +++ b/cpp/bench.cxx @@ -148,9 +148,7 @@ void measure(bm::State& state, metric_at metric, metric_at baseline) { // The actual benchmarking loop. std::size_t iterations = 0; for (auto _ : state) - bm::DoNotOptimize((results_contender[iterations & (pairs_count - 1)] = - call_contender(pairs[iterations & (pairs_count - 1)]))), - iterations++; + bm::DoNotOptimize((call_contender(pairs[iterations & (pairs_count - 1)]))), iterations++; // Measure the mean absolute delta and relative error. double mean_delta = 0, mean_relative_error = 0; @@ -174,9 +172,9 @@ void register_(std::string name, metric_at* distance_func, metric_at* baseline_f std::size_t seconds = 10; std::size_t threads = 1; - using pair_dims_t = vectors_pair_gt; + using pair_dims_t = vectors_pair_gt; using scalar_t = typename pair_dims_t::scalar_t; - using pair_bytes_t = vectors_pair_gt; + using pair_bytes_t = vectors_pair_gt; std::string name_dims = name + "_" + std::to_string(pair_dims_t{}.dimensions()) + "d"; bm::RegisterBenchmark(name_dims.c_str(), measure, distance_func, baseline_func) @@ -223,6 +221,182 @@ void vdot_f64c_blas(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size #endif +namespace AVX512_harley_seal { + +uint8_t lookup8bit[256] = { + /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2, + /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3, + /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3, + /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4, + /* 10 */ 1, /* 11 */ 2, /* 12 */ 2, /* 13 */ 3, + /* 14 */ 2, /* 15 */ 3, /* 16 */ 3, /* 17 */ 4, + /* 18 */ 2, /* 19 */ 3, /* 1a */ 3, /* 1b */ 4, + /* 1c */ 3, /* 1d */ 4, /* 1e */ 4, /* 1f */ 5, + /* 20 */ 1, /* 21 */ 2, /* 22 */ 2, /* 23 */ 3, + /* 24 */ 2, /* 25 */ 3, /* 26 */ 3, /* 27 */ 4, + /* 28 */ 2, /* 29 */ 3, /* 2a */ 3, /* 2b */ 4, + /* 2c */ 3, /* 2d */ 4, /* 2e */ 4, /* 2f */ 5, + /* 30 */ 2, /* 31 */ 3, /* 32 */ 3, /* 33 */ 4, + /* 34 */ 3, /* 35 */ 4, /* 36 */ 4, /* 37 */ 5, + /* 38 */ 3, /* 39 */ 4, /* 3a */ 4, /* 3b */ 5, + /* 3c */ 4, /* 3d */ 5, /* 3e */ 5, /* 3f */ 6, + /* 40 */ 1, /* 41 */ 2, /* 42 */ 2, /* 43 */ 3, + /* 44 */ 2, /* 45 */ 3, /* 46 */ 3, /* 47 */ 4, + /* 48 */ 2, /* 49 */ 3, /* 4a */ 3, /* 4b */ 4, + /* 4c */ 3, /* 4d */ 4, /* 4e */ 4, /* 4f */ 5, + /* 50 */ 2, /* 51 */ 3, /* 52 */ 3, /* 53 */ 4, + /* 54 */ 3, /* 55 */ 4, /* 56 */ 4, /* 57 */ 5, + /* 58 */ 3, /* 59 */ 4, /* 5a */ 4, /* 5b */ 5, + /* 5c */ 4, /* 5d */ 5, /* 5e */ 5, /* 5f */ 6, + /* 60 */ 2, /* 61 */ 3, /* 62 */ 3, /* 63 */ 4, + /* 64 */ 3, /* 65 */ 4, /* 66 */ 4, /* 67 */ 5, + /* 68 */ 3, /* 69 */ 4, /* 6a */ 4, /* 6b */ 5, + /* 6c */ 4, /* 6d */ 5, /* 6e */ 5, /* 6f */ 6, + /* 70 */ 3, /* 71 */ 4, /* 72 */ 4, /* 73 */ 5, + /* 74 */ 4, /* 75 */ 5, /* 76 */ 5, /* 77 */ 6, + /* 78 */ 4, /* 79 */ 5, /* 7a */ 5, /* 7b */ 6, + /* 7c */ 5, /* 7d */ 6, /* 7e */ 6, /* 7f */ 7, + /* 80 */ 1, /* 81 */ 2, /* 82 */ 2, /* 83 */ 3, + /* 84 */ 2, /* 85 */ 3, /* 86 */ 3, /* 87 */ 4, + /* 88 */ 2, /* 89 */ 3, /* 8a */ 3, /* 8b */ 4, + /* 8c */ 3, /* 8d */ 4, /* 8e */ 4, /* 8f */ 5, + /* 90 */ 2, /* 91 */ 3, /* 92 */ 3, /* 93 */ 4, + /* 94 */ 3, /* 95 */ 4, /* 96 */ 4, /* 97 */ 5, + /* 98 */ 3, /* 99 */ 4, /* 9a */ 4, /* 9b */ 5, + /* 9c */ 4, /* 9d */ 5, /* 9e */ 5, /* 9f */ 6, + /* a0 */ 2, /* a1 */ 3, /* a2 */ 3, /* a3 */ 4, + /* a4 */ 3, /* a5 */ 4, /* a6 */ 4, /* a7 */ 5, + /* a8 */ 3, /* a9 */ 4, /* aa */ 4, /* ab */ 5, + /* ac */ 4, /* ad */ 5, /* ae */ 5, /* af */ 6, + /* b0 */ 3, /* b1 */ 4, /* b2 */ 4, /* b3 */ 5, + /* b4 */ 4, /* b5 */ 5, /* b6 */ 5, /* b7 */ 6, + /* b8 */ 4, /* b9 */ 5, /* ba */ 5, /* bb */ 6, + /* bc */ 5, /* bd */ 6, /* be */ 6, /* bf */ 7, + /* c0 */ 2, /* c1 */ 3, /* c2 */ 3, /* c3 */ 4, + /* c4 */ 3, /* c5 */ 4, /* c6 */ 4, /* c7 */ 5, + /* c8 */ 3, /* c9 */ 4, /* ca */ 4, /* cb */ 5, + /* cc */ 4, /* cd */ 5, /* ce */ 5, /* cf */ 6, + /* d0 */ 3, /* d1 */ 4, /* d2 */ 4, /* d3 */ 5, + /* d4 */ 4, /* d5 */ 5, /* d6 */ 5, /* d7 */ 6, + /* d8 */ 4, /* d9 */ 5, /* da */ 5, /* db */ 6, + /* dc */ 5, /* dd */ 6, /* de */ 6, /* df */ 7, + /* e0 */ 3, /* e1 */ 4, /* e2 */ 4, /* e3 */ 5, + /* e4 */ 4, /* e5 */ 5, /* e6 */ 5, /* e7 */ 6, + /* e8 */ 4, /* e9 */ 5, /* ea */ 5, /* eb */ 6, + /* ec */ 5, /* ed */ 6, /* ee */ 6, /* ef */ 7, + /* f0 */ 4, /* f1 */ 5, /* f2 */ 5, /* f3 */ 6, + /* f4 */ 5, /* f5 */ 6, /* f6 */ 6, /* f7 */ 7, + /* f8 */ 5, /* f9 */ 6, /* fa */ 6, /* fb */ 7, + /* fc */ 6, /* fd */ 7, /* fe */ 7, /* ff */ 8}; + +uint64_t lower_qword(const __m128i v) { return _mm_cvtsi128_si64(v); } + +uint64_t higher_qword(const __m128i v) { return lower_qword(_mm_srli_si128(v, 8)); } + +uint64_t simd_sum_epu64(const __m128i v) { return lower_qword(v) + higher_qword(v); } + +uint64_t simd_sum_epu64(const __m256i v) { + + return static_cast(_mm256_extract_epi64(v, 0)) + static_cast(_mm256_extract_epi64(v, 1)) + + static_cast(_mm256_extract_epi64(v, 2)) + static_cast(_mm256_extract_epi64(v, 3)); +} + +uint64_t simd_sum_epu64(const __m512i v) { + + const __m256i lo = _mm512_extracti64x4_epi64(v, 0); + const __m256i hi = _mm512_extracti64x4_epi64(v, 1); + + return simd_sum_epu64(lo) + simd_sum_epu64(hi); +} + +__m512i popcount(const __m512i v) { + const __m512i m1 = _mm512_set1_epi8(0x55); + const __m512i m2 = _mm512_set1_epi8(0x33); + const __m512i m4 = _mm512_set1_epi8(0x0F); + + const __m512i t1 = _mm512_sub_epi8(v, (_mm512_srli_epi16(v, 1) & m1)); + const __m512i t2 = _mm512_add_epi8(t1 & m2, (_mm512_srli_epi16(t1, 2) & m2)); + const __m512i t3 = _mm512_add_epi8(t2, _mm512_srli_epi16(t2, 4)) & m4; + return _mm512_sad_epu8(t3, _mm512_setzero_si512()); +} + +void CSA(__m512i& h, __m512i& l, __m512i a, __m512i b, __m512i c) { + /* + c b a | l h + ------+---- + 0 0 0 | 0 0 + 0 0 1 | 1 0 + 0 1 0 | 1 0 + 0 1 1 | 0 1 + 1 0 0 | 1 0 + 1 0 1 | 0 1 + 1 1 0 | 0 1 + 1 1 1 | 1 1 + + l - digit + h - carry + */ + + l = _mm512_ternarylogic_epi32(c, b, a, 0x96); + h = _mm512_ternarylogic_epi32(c, b, a, 0xe8); +} + +uint64_t popcnt(__m512i const* a, __m512i const* b, const uint64_t size) { + __m512i total = _mm512_setzero_si512(); + __m512i ones = _mm512_setzero_si512(); + __m512i twos = _mm512_setzero_si512(); + __m512i fours = _mm512_setzero_si512(); + __m512i eights = _mm512_setzero_si512(); + __m512i sixteens = _mm512_setzero_si512(); + __m512i twosA, twosB, foursA, foursB, eightsA, eightsB; + + const uint64_t limit = size - size % 16; + uint64_t i = 0; + + for (; i < limit; i += 16) { + CSA(twosA, ones, ones, a[i + 0] ^ b[i + 0], a[i + 1] ^ b[i + 1]); + CSA(twosB, ones, ones, a[i + 2] ^ b[i + 2], a[i + 3] ^ b[i + 3]); + CSA(foursA, twos, twos, twosA, twosB); + CSA(twosA, ones, ones, a[i + 4] ^ b[i + 4], a[i + 5] ^ b[i + 5]); + CSA(twosB, ones, ones, a[i + 6] ^ b[i + 6], a[i + 7] ^ b[i + 7]); + CSA(foursB, twos, twos, twosA, twosB); + CSA(eightsA, fours, fours, foursA, foursB); + CSA(twosA, ones, ones, a[i + 8] ^ b[i + 8], a[i + 9] ^ b[i + 9]); + CSA(twosB, ones, ones, a[i + 10] ^ b[i + 10], a[i + 11] ^ b[i + 11]); + CSA(foursA, twos, twos, twosA, twosB); + CSA(twosA, ones, ones, a[i + 12] ^ b[i + 12], a[i + 13] ^ b[i + 13]); + CSA(twosB, ones, ones, a[i + 14] ^ b[i + 14], a[i + 15] ^ b[i + 15]); + CSA(foursB, twos, twos, twosA, twosB); + CSA(eightsB, fours, fours, foursA, foursB); + CSA(sixteens, eights, eights, eightsA, eightsB); + + total = _mm512_add_epi64(total, popcount(sixteens)); + } + + total = _mm512_slli_epi64(total, 4); // * 16 + total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(eights), 3)); // += 8 * ... + total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(fours), 2)); // += 4 * ... + total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(twos), 1)); // += 2 * ... + total = _mm512_add_epi64(total, popcount(ones)); + + for (; i < size; i++) + total = _mm512_add_epi64(total, popcount(a[i] ^ b[i])); + + return simd_sum_epu64(total); +} + +} // namespace AVX512_harley_seal + +void popcnt_AVX512_harley_seal(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t size, + simsimd_distance_t* results) { + uint64_t total = AVX512_harley_seal::popcnt((const __m512i*)a, (const __m512i*)b, size / 64); + + for (size_t i = size - size % 64; i < size; i++) + total += AVX512_harley_seal::lookup8bit[a[i] ^ b[i]]; + + results[0] = total; +} + int main(int argc, char** argv) { simsimd_capability_t runtime_caps = simsimd_capabilities(); @@ -377,6 +551,7 @@ int main(int argc, char** argv) { register_("l2sq_f64_skylake", simsimd_l2sq_f64_skylake, simsimd_l2sq_f64_serial); register_("hamming_b8_ice", simsimd_hamming_b8_ice, simsimd_hamming_b8_serial); + register_("hamming_b8_icehs", popcnt_AVX512_harley_seal, simsimd_hamming_b8_serial); register_("jaccard_b8_ice", simsimd_jaccard_b8_ice, simsimd_jaccard_b8_serial); #endif