Skip to content

Commit

Permalink
Add: sz_look_up_transform_neon
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 12, 2024
1 parent be6c93b commit 3898481
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 19 deletions.
1 change: 1 addition & 0 deletions c/lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ static void sz_dispatch_table_init(void) {
impl->copy = sz_copy_neon;
impl->move = sz_move_neon;
impl->fill = sz_fill_neon;
impl->look_up_transform = sz_look_up_transform_neon;

impl->find = sz_find_neon;
impl->rfind = sz_rfind_neon;
Expand Down
84 changes: 72 additions & 12 deletions include/stringzilla/stringzilla.h
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,8 @@ SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length)
SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_fill */
SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value);
/** @copydoc sz_look_up_transform */
SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target);
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_rfind_byte */
Expand Down Expand Up @@ -5780,13 +5782,11 @@ SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t

SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
sz_u128_vec_t a_vec, b_vec;

while (length >= 16) {
for (; length >= 16; a += 16, b += 16, length -= 16) {
a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a);
b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b);
uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16);
if (vmaxvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match
a += 16, b += 16, length -= 16;
}

// Handle remaining bytes
Expand All @@ -5795,19 +5795,27 @@ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
}

SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
sz_u128_vec_t src_vec;

while (length >= 16) {
src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source);
vst1q_u8((sz_u8_t *)target, src_vec.u8x16);
target += 16, source += 16, length -= 16;
}

// Handle remaining bytes
// In most cases the `source` and the `target` are not aligned, but we should
// at least make sure that writes don't touch many cache lines.
// NEON has an instruction to load and write 64 bytes at once.
//
// sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
// sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
// for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source;
// length -= head_length;
// for (; length >= 64; target += 64, source += 64, length -= 64)
// vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source));
// for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source;
//
// Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time:
for (; length >= 16; target += 16, source += 16, length -= 16)
vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source));
if (length) sz_copy_serial(target, source, length);
}

SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// When moving small buffers, using a small buffer on stack as a temporary storage is faster.

if (target < source || target >= source + length) {
// Non-overlapping, proceed forward
sz_copy_neon(target, source, length);
Expand Down Expand Up @@ -5843,6 +5851,56 @@ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
if (length) sz_fill_serial(target, length, value);
}

SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) {

// If the input is tiny (especially smaller than the look-up table itself), we may end up paying
// more for organizing the SIMD registers and changing the CPU state, than for the actual computation.
if (length <= 128) {
sz_look_up_transform_serial(source, length, lut, target);
return;
}

sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less.

// We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers.
// According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput.
uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec;
lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0));
lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64));
lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128));
lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192));

sz_u128_vec_t source_vec;
// If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or
// `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`.
sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec;
sz_u128_vec_t blended_0_to_255_vec;

// Process the head with serial code
for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source];

// Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction
// to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position
// within each 64-byte range of the table.
// Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/
length -= head_length;
length -= tail_length;
for (; length >= 16; source += 16, target += 16, length -= 16) {
source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source);
lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16);
lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40)));
lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80)));
lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0)));
blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16),
vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16));
vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16);
}

// Process the tail with serial code
for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source];
}

SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
sz_u64_t matches;
sz_u128_vec_t h_vec, n_vec, matches_vec;
Expand Down Expand Up @@ -6276,6 +6334,8 @@ SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr
sz_look_up_transform_avx512(source, length, lut, target);
#elif SZ_USE_X86_AVX2
sz_look_up_transform_avx2(source, length, lut, target);
#elif SZ_USE_ARM_NEON
sz_look_up_transform_neon(source, length, lut, target);
#else
sz_look_up_transform_serial(source, length, lut, target);
#endif
Expand Down
19 changes: 12 additions & 7 deletions scripts/bench_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ tracked_unary_functions_t copy_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o
{"sz_copy_avx2" + suffix, wrap_sz(sz_copy_avx2)},
#endif
#if SZ_USE_ARM_SVE
{"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve), true},
{"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve)},
#endif
#if SZ_USE_ARM_NEON
{"sz_copy_neon" + suffix, wrap_sz(sz_copy_neon)},
Expand Down Expand Up @@ -116,7 +116,7 @@ tracked_unary_functions_t fill_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o
{"sz_fill_avx2", wrap_sz(sz_fill_avx2)},
#endif
#if SZ_USE_ARM_SVE
{"sz_fill_sve", wrap_sz(sz_fill_sve), true},
{"sz_fill_sve", wrap_sz(sz_fill_sve)},
#endif
#if SZ_USE_ARM_NEON
{"sz_fill_neon", wrap_sz(sz_fill_neon)},
Expand Down Expand Up @@ -197,6 +197,9 @@ tracked_unary_functions_t transform_functions() {
#endif
#if SZ_USE_X86_AVX2
{"sz_look_up_transform_avx2", wrap_sz(sz_look_up_transform_avx2)},
#endif
#if SZ_USE_ARM_NEON
{"sz_look_up_transform_neon", wrap_sz(sz_look_up_transform_neon)},
#endif
};
return result;
Expand All @@ -223,15 +226,17 @@ void bench_memory(std::vector<std::string_view> const &slices, sz_cptr_t dataset
sz_ptr_t output_buffer_ptr) {

if (slices.size() == 0) return;
(void)dataset_start_ptr;
(void)output_buffer_ptr;

bench_memory(slices, copy_functions<true>(dataset_start_ptr, output_buffer_ptr));
bench_memory(slices, copy_functions<false>(dataset_start_ptr, output_buffer_ptr));
bench_memory(slices, fill_functions(dataset_start_ptr, output_buffer_ptr));
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 1));
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 8));
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, SZ_CACHE_LINE_WIDTH));
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, max_shift_length));
bench_memory(slices, transform_functions());
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 1));
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 8));
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, SZ_CACHE_LINE_WIDTH));
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, max_shift_length));
// bench_memory(slices, transform_functions());
}

int main(int argc, char const **argv) {
Expand Down

0 comments on commit 3898481

Please sign in to comment.