-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Harley Seal AVX-512 implementations #138
base: main-dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,9 +148,7 @@ void measure(bm::State& state, metric_at metric, metric_at baseline) { | |
// The actual benchmarking loop. | ||
std::size_t iterations = 0; | ||
for (auto _ : state) | ||
bm::DoNotOptimize((results_contender[iterations & (pairs_count - 1)] = | ||
call_contender(pairs[iterations & (pairs_count - 1)]))), | ||
iterations++; | ||
bm::DoNotOptimize((call_contender(pairs[iterations & (pairs_count - 1)]))), iterations++; | ||
|
||
// Measure the mean absolute delta and relative error. | ||
double mean_delta = 0, mean_relative_error = 0; | ||
|
@@ -174,9 +172,9 @@ void register_(std::string name, metric_at* distance_func, metric_at* baseline_f | |
std::size_t seconds = 10; | ||
std::size_t threads = 1; | ||
|
||
using pair_dims_t = vectors_pair_gt<datatype_ak, 1536>; | ||
using pair_dims_t = vectors_pair_gt<datatype_ak, 4096>; | ||
using scalar_t = typename pair_dims_t::scalar_t; | ||
using pair_bytes_t = vectors_pair_gt<datatype_ak, 1536 / sizeof(scalar_t)>; | ||
using pair_bytes_t = vectors_pair_gt<datatype_ak, 4096 / sizeof(scalar_t)>; | ||
|
||
std::string name_dims = name + "_" + std::to_string(pair_dims_t{}.dimensions()) + "d"; | ||
bm::RegisterBenchmark(name_dims.c_str(), measure<pair_dims_t, metric_at*>, distance_func, baseline_func) | ||
|
@@ -223,6 +221,182 @@ void vdot_f64c_blas(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size | |
|
||
#endif | ||
|
||
namespace AVX512_harley_seal { | ||
|
||
uint8_t lookup8bit[256] = { | ||
/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2, | ||
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3, | ||
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3, | ||
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4, | ||
/* 10 */ 1, /* 11 */ 2, /* 12 */ 2, /* 13 */ 3, | ||
/* 14 */ 2, /* 15 */ 3, /* 16 */ 3, /* 17 */ 4, | ||
/* 18 */ 2, /* 19 */ 3, /* 1a */ 3, /* 1b */ 4, | ||
/* 1c */ 3, /* 1d */ 4, /* 1e */ 4, /* 1f */ 5, | ||
/* 20 */ 1, /* 21 */ 2, /* 22 */ 2, /* 23 */ 3, | ||
/* 24 */ 2, /* 25 */ 3, /* 26 */ 3, /* 27 */ 4, | ||
/* 28 */ 2, /* 29 */ 3, /* 2a */ 3, /* 2b */ 4, | ||
/* 2c */ 3, /* 2d */ 4, /* 2e */ 4, /* 2f */ 5, | ||
/* 30 */ 2, /* 31 */ 3, /* 32 */ 3, /* 33 */ 4, | ||
/* 34 */ 3, /* 35 */ 4, /* 36 */ 4, /* 37 */ 5, | ||
/* 38 */ 3, /* 39 */ 4, /* 3a */ 4, /* 3b */ 5, | ||
/* 3c */ 4, /* 3d */ 5, /* 3e */ 5, /* 3f */ 6, | ||
/* 40 */ 1, /* 41 */ 2, /* 42 */ 2, /* 43 */ 3, | ||
/* 44 */ 2, /* 45 */ 3, /* 46 */ 3, /* 47 */ 4, | ||
/* 48 */ 2, /* 49 */ 3, /* 4a */ 3, /* 4b */ 4, | ||
/* 4c */ 3, /* 4d */ 4, /* 4e */ 4, /* 4f */ 5, | ||
/* 50 */ 2, /* 51 */ 3, /* 52 */ 3, /* 53 */ 4, | ||
/* 54 */ 3, /* 55 */ 4, /* 56 */ 4, /* 57 */ 5, | ||
/* 58 */ 3, /* 59 */ 4, /* 5a */ 4, /* 5b */ 5, | ||
/* 5c */ 4, /* 5d */ 5, /* 5e */ 5, /* 5f */ 6, | ||
/* 60 */ 2, /* 61 */ 3, /* 62 */ 3, /* 63 */ 4, | ||
/* 64 */ 3, /* 65 */ 4, /* 66 */ 4, /* 67 */ 5, | ||
/* 68 */ 3, /* 69 */ 4, /* 6a */ 4, /* 6b */ 5, | ||
/* 6c */ 4, /* 6d */ 5, /* 6e */ 5, /* 6f */ 6, | ||
/* 70 */ 3, /* 71 */ 4, /* 72 */ 4, /* 73 */ 5, | ||
/* 74 */ 4, /* 75 */ 5, /* 76 */ 5, /* 77 */ 6, | ||
/* 78 */ 4, /* 79 */ 5, /* 7a */ 5, /* 7b */ 6, | ||
/* 7c */ 5, /* 7d */ 6, /* 7e */ 6, /* 7f */ 7, | ||
/* 80 */ 1, /* 81 */ 2, /* 82 */ 2, /* 83 */ 3, | ||
/* 84 */ 2, /* 85 */ 3, /* 86 */ 3, /* 87 */ 4, | ||
/* 88 */ 2, /* 89 */ 3, /* 8a */ 3, /* 8b */ 4, | ||
/* 8c */ 3, /* 8d */ 4, /* 8e */ 4, /* 8f */ 5, | ||
/* 90 */ 2, /* 91 */ 3, /* 92 */ 3, /* 93 */ 4, | ||
/* 94 */ 3, /* 95 */ 4, /* 96 */ 4, /* 97 */ 5, | ||
/* 98 */ 3, /* 99 */ 4, /* 9a */ 4, /* 9b */ 5, | ||
/* 9c */ 4, /* 9d */ 5, /* 9e */ 5, /* 9f */ 6, | ||
/* a0 */ 2, /* a1 */ 3, /* a2 */ 3, /* a3 */ 4, | ||
/* a4 */ 3, /* a5 */ 4, /* a6 */ 4, /* a7 */ 5, | ||
/* a8 */ 3, /* a9 */ 4, /* aa */ 4, /* ab */ 5, | ||
/* ac */ 4, /* ad */ 5, /* ae */ 5, /* af */ 6, | ||
/* b0 */ 3, /* b1 */ 4, /* b2 */ 4, /* b3 */ 5, | ||
/* b4 */ 4, /* b5 */ 5, /* b6 */ 5, /* b7 */ 6, | ||
/* b8 */ 4, /* b9 */ 5, /* ba */ 5, /* bb */ 6, | ||
/* bc */ 5, /* bd */ 6, /* be */ 6, /* bf */ 7, | ||
/* c0 */ 2, /* c1 */ 3, /* c2 */ 3, /* c3 */ 4, | ||
/* c4 */ 3, /* c5 */ 4, /* c6 */ 4, /* c7 */ 5, | ||
/* c8 */ 3, /* c9 */ 4, /* ca */ 4, /* cb */ 5, | ||
/* cc */ 4, /* cd */ 5, /* ce */ 5, /* cf */ 6, | ||
/* d0 */ 3, /* d1 */ 4, /* d2 */ 4, /* d3 */ 5, | ||
/* d4 */ 4, /* d5 */ 5, /* d6 */ 5, /* d7 */ 6, | ||
/* d8 */ 4, /* d9 */ 5, /* da */ 5, /* db */ 6, | ||
/* dc */ 5, /* dd */ 6, /* de */ 6, /* df */ 7, | ||
/* e0 */ 3, /* e1 */ 4, /* e2 */ 4, /* e3 */ 5, | ||
/* e4 */ 4, /* e5 */ 5, /* e6 */ 5, /* e7 */ 6, | ||
/* e8 */ 4, /* e9 */ 5, /* ea */ 5, /* eb */ 6, | ||
/* ec */ 5, /* ed */ 6, /* ee */ 6, /* ef */ 7, | ||
/* f0 */ 4, /* f1 */ 5, /* f2 */ 5, /* f3 */ 6, | ||
/* f4 */ 5, /* f5 */ 6, /* f6 */ 6, /* f7 */ 7, | ||
/* f8 */ 5, /* f9 */ 6, /* fa */ 6, /* fb */ 7, | ||
/* fc */ 6, /* fd */ 7, /* fe */ 7, /* ff */ 8}; | ||
|
||
uint64_t lower_qword(const __m128i v) { return _mm_cvtsi128_si64(v); } | ||
|
||
uint64_t higher_qword(const __m128i v) { return lower_qword(_mm_srli_si128(v, 8)); } | ||
|
||
uint64_t simd_sum_epu64(const __m128i v) { return lower_qword(v) + higher_qword(v); } | ||
|
||
uint64_t simd_sum_epu64(const __m256i v) { | ||
Comment on lines
+292
to
+298
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think modern compilers might do this without asking in some cases, but using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The changes I've suggested so far are just low hanging fruit though. Have you used profiling tools to find which lines of code each approach is spending the most time in? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most time is spent in the main loop computing CSAs. Sadly, I can't access hardware performance counters on those machines. |
||
|
||
return static_cast<uint64_t>(_mm256_extract_epi64(v, 0)) + static_cast<uint64_t>(_mm256_extract_epi64(v, 1)) + | ||
static_cast<uint64_t>(_mm256_extract_epi64(v, 2)) + static_cast<uint64_t>(_mm256_extract_epi64(v, 3)); | ||
} | ||
|
||
uint64_t simd_sum_epu64(const __m512i v) { | ||
|
||
const __m256i lo = _mm512_extracti64x4_epi64(v, 0); | ||
const __m256i hi = _mm512_extracti64x4_epi64(v, 1); | ||
|
||
return simd_sum_epu64(lo) + simd_sum_epu64(hi); | ||
} | ||
|
||
__m512i popcount(const __m512i v) { | ||
const __m512i m1 = _mm512_set1_epi8(0x55); | ||
const __m512i m2 = _mm512_set1_epi8(0x33); | ||
const __m512i m4 = _mm512_set1_epi8(0x0F); | ||
|
||
const __m512i t1 = _mm512_sub_epi8(v, (_mm512_srli_epi16(v, 1) & m1)); | ||
const __m512i t2 = _mm512_add_epi8(t1 & m2, (_mm512_srli_epi16(t1, 2) & m2)); | ||
const __m512i t3 = _mm512_add_epi8(t2, _mm512_srli_epi16(t2, 4)) & m4; | ||
return _mm512_sad_epu8(t3, _mm512_setzero_si512()); | ||
} | ||
|
||
void CSA(__m512i& h, __m512i& l, __m512i a, __m512i b, __m512i c) { | ||
/* | ||
c b a | l h | ||
------+---- | ||
0 0 0 | 0 0 | ||
0 0 1 | 1 0 | ||
0 1 0 | 1 0 | ||
0 1 1 | 0 1 | ||
1 0 0 | 1 0 | ||
1 0 1 | 0 1 | ||
1 1 0 | 0 1 | ||
1 1 1 | 1 1 | ||
|
||
l - digit | ||
h - carry | ||
*/ | ||
|
||
l = _mm512_ternarylogic_epi32(c, b, a, 0x96); | ||
h = _mm512_ternarylogic_epi32(c, b, a, 0xe8); | ||
} | ||
|
||
uint64_t popcnt(__m512i const* a, __m512i const* b, const uint64_t size) { | ||
__m512i total = _mm512_setzero_si512(); | ||
__m512i ones = _mm512_setzero_si512(); | ||
__m512i twos = _mm512_setzero_si512(); | ||
__m512i fours = _mm512_setzero_si512(); | ||
__m512i eights = _mm512_setzero_si512(); | ||
__m512i sixteens = _mm512_setzero_si512(); | ||
__m512i twosA, twosB, foursA, foursB, eightsA, eightsB; | ||
|
||
const uint64_t limit = size - size % 16; | ||
uint64_t i = 0; | ||
|
||
for (; i < limit; i += 16) { | ||
CSA(twosA, ones, ones, a[i + 0] ^ b[i + 0], a[i + 1] ^ b[i + 1]); | ||
CSA(twosB, ones, ones, a[i + 2] ^ b[i + 2], a[i + 3] ^ b[i + 3]); | ||
CSA(foursA, twos, twos, twosA, twosB); | ||
CSA(twosA, ones, ones, a[i + 4] ^ b[i + 4], a[i + 5] ^ b[i + 5]); | ||
CSA(twosB, ones, ones, a[i + 6] ^ b[i + 6], a[i + 7] ^ b[i + 7]); | ||
CSA(foursB, twos, twos, twosA, twosB); | ||
CSA(eightsA, fours, fours, foursA, foursB); | ||
CSA(twosA, ones, ones, a[i + 8] ^ b[i + 8], a[i + 9] ^ b[i + 9]); | ||
CSA(twosB, ones, ones, a[i + 10] ^ b[i + 10], a[i + 11] ^ b[i + 11]); | ||
CSA(foursA, twos, twos, twosA, twosB); | ||
CSA(twosA, ones, ones, a[i + 12] ^ b[i + 12], a[i + 13] ^ b[i + 13]); | ||
CSA(twosB, ones, ones, a[i + 14] ^ b[i + 14], a[i + 15] ^ b[i + 15]); | ||
CSA(foursB, twos, twos, twosA, twosB); | ||
CSA(eightsB, fours, fours, foursA, foursB); | ||
CSA(sixteens, eights, eights, eightsA, eightsB); | ||
|
||
total = _mm512_add_epi64(total, popcount(sixteens)); | ||
} | ||
|
||
total = _mm512_slli_epi64(total, 4); // * 16 | ||
total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(eights), 3)); // += 8 * ... | ||
total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(fours), 2)); // += 4 * ... | ||
total = _mm512_add_epi64(total, _mm512_slli_epi64(popcount(twos), 1)); // += 2 * ... | ||
total = _mm512_add_epi64(total, popcount(ones)); | ||
|
||
for (; i < size; i++) | ||
total = _mm512_add_epi64(total, popcount(a[i] ^ b[i])); | ||
|
||
return simd_sum_epu64(total); | ||
} | ||
|
||
} // namespace AVX512_harley_seal | ||
|
||
void popcnt_AVX512_harley_seal(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t size, | ||
simsimd_distance_t* results) { | ||
uint64_t total = AVX512_harley_seal::popcnt((const __m512i*)a, (const __m512i*)b, size / 64); | ||
|
||
for (size_t i = size - size % 64; i < size; i++) | ||
total += AVX512_harley_seal::lookup8bit[a[i] ^ b[i]]; | ||
|
||
results[0] = total; | ||
} | ||
|
||
int main(int argc, char** argv) { | ||
simsimd_capability_t runtime_caps = simsimd_capabilities(); | ||
|
||
|
@@ -377,6 +551,7 @@ int main(int argc, char** argv) { | |
register_<simsimd_datatype_f64_k>("l2sq_f64_skylake", simsimd_l2sq_f64_skylake, simsimd_l2sq_f64_serial); | ||
|
||
register_<simsimd_datatype_b8_k>("hamming_b8_ice", simsimd_hamming_b8_ice, simsimd_hamming_b8_serial); | ||
register_<simsimd_datatype_b8_k>("hamming_b8_icehs", popcnt_AVX512_harley_seal, simsimd_hamming_b8_serial); | ||
register_<simsimd_datatype_b8_k>("jaccard_b8_ice", simsimd_jaccard_b8_ice, simsimd_jaccard_b8_serial); | ||
#endif | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it help to add const/constexpr? I wonder if it would encourage the table to be cached. It might also help to run a loop over it to pre-load it into cache too (although I figure prefetching would most likely get the whole table in the first access).
In my own experiments in the past, I did find the built in instructions to be faster vs LUTs, however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming the size of the inputs - the tail will never be evaluated separately. I've just copied that part of the code for completeness.