diff --git a/c/lib.c b/c/lib.c index 8ba01a2..e523e37 100644 --- a/c/lib.c +++ b/c/lib.c @@ -219,6 +219,7 @@ static void sz_dispatch_table_init(void) { impl->copy = sz_copy_neon; impl->move = sz_move_neon; impl->fill = sz_fill_neon; + impl->look_up_transform = sz_look_up_transform_neon; impl->find = sz_find_neon; impl->rfind = sz_rfind_neon; diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index a125ccf..299ee53 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -1251,6 +1251,8 @@ SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** @copydoc sz_look_up_transform */ +SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ @@ -5780,13 +5782,11 @@ SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { sz_u128_vec_t a_vec, b_vec; - - while (length >= 16) { + for (; length >= 16; a += 16, b += 16, length -= 16) { a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); if (vmaxvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - a += 16, b += 16, length -= 16; } // Handle remaining bytes @@ -5795,19 +5795,27 @@ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { } SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_u128_vec_t src_vec; - - while (length >= 16) { - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - target += 16, source += 16, length -= 16; - } - - // Handle remaining bytes + // In most cases the `source` and the `target` are not aligned, but we should + // at least make sure that writes don't touch many cache lines. + // NEON has an instruction to load and write 64 bytes at once. + // + // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. + // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. + // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; + // length -= head_length; + // for (; length >= 64; target += 64, source += 64, length -= 64) + // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); + // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; + // + // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: + for (; length >= 16; target += 16, source += 16, length -= 16) + vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); if (length) sz_copy_serial(target, source, length); } SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + // When moving small buffers, using a small buffer on stack as a temporary storage is faster. + if (target < source || target >= source + length) { // Non-overlapping, proceed forward sz_copy_neon(target, source, length); @@ -5843,6 +5851,56 @@ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { if (length) sz_fill_serial(target, length, value); } +SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { + + // If the input is tiny (especially smaller than the look-up table itself), we may end up paying + // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. + if (length <= 128) { + sz_look_up_transform_serial(source, length, lut, target); + return; + } + + sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. + sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. + + // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. + // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. + uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; + lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); + lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); + lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); + lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); + + sz_u128_vec_t source_vec; + // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or + // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. + sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; + sz_u128_vec_t blended_0_to_255_vec; + + // Process the head with serial code + for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; + + // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction + // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position + // within each 64-byte range of the table. + // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ + length -= head_length; + length -= tail_length; + for (; length >= 16; source += 16, target += 16, length -= 16) { + source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); + lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); + lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); + lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); + lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); + blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), + vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); + vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); + } + + // Process the tail with serial code + for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; +} + SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { sz_u64_t matches; sz_u128_vec_t h_vec, n_vec, matches_vec; @@ -6276,6 +6334,8 @@ SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr sz_look_up_transform_avx512(source, length, lut, target); #elif SZ_USE_X86_AVX2 sz_look_up_transform_avx2(source, length, lut, target); +#elif SZ_USE_ARM_NEON + sz_look_up_transform_neon(source, length, lut, target); #else sz_look_up_transform_serial(source, length, lut, target); #endif diff --git a/scripts/bench_memory.cpp b/scripts/bench_memory.cpp index b36c092..eb72fc9 100644 --- a/scripts/bench_memory.cpp +++ b/scripts/bench_memory.cpp @@ -76,7 +76,7 @@ tracked_unary_functions_t copy_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o {"sz_copy_avx2" + suffix, wrap_sz(sz_copy_avx2)}, #endif #if SZ_USE_ARM_SVE - {"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve), true}, + {"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve)}, #endif #if SZ_USE_ARM_NEON {"sz_copy_neon" + suffix, wrap_sz(sz_copy_neon)}, @@ -116,7 +116,7 @@ tracked_unary_functions_t fill_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o {"sz_fill_avx2", wrap_sz(sz_fill_avx2)}, #endif #if SZ_USE_ARM_SVE - {"sz_fill_sve", wrap_sz(sz_fill_sve), true}, + {"sz_fill_sve", wrap_sz(sz_fill_sve)}, #endif #if SZ_USE_ARM_NEON {"sz_fill_neon", wrap_sz(sz_fill_neon)}, @@ -197,6 +197,9 @@ tracked_unary_functions_t transform_functions() { #endif #if SZ_USE_X86_AVX2 {"sz_look_up_transform_avx2", wrap_sz(sz_look_up_transform_avx2)}, +#endif +#if SZ_USE_ARM_NEON + {"sz_look_up_transform_neon", wrap_sz(sz_look_up_transform_neon)}, #endif }; return result; @@ -223,15 +226,17 @@ void bench_memory(std::vector const &slices, sz_cptr_t dataset sz_ptr_t output_buffer_ptr) { if (slices.size() == 0) return; + (void)dataset_start_ptr; + (void)output_buffer_ptr; bench_memory(slices, copy_functions(dataset_start_ptr, output_buffer_ptr)); bench_memory(slices, copy_functions(dataset_start_ptr, output_buffer_ptr)); bench_memory(slices, fill_functions(dataset_start_ptr, output_buffer_ptr)); - bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 1)); - bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 8)); - bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, SZ_CACHE_LINE_WIDTH)); - bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, max_shift_length)); - bench_memory(slices, transform_functions()); + // bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 1)); + // bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 8)); + // bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, SZ_CACHE_LINE_WIDTH)); + // bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, max_shift_length)); + // bench_memory(slices, transform_functions()); } int main(int argc, char const **argv) {