Skip to content

Commit

Permalink
Add: FMA on Sapphire Rapids
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 19, 2024
1 parent b6fb1d3 commit 5a56a25
Show file tree
Hide file tree
Showing 3 changed files with 527 additions and 47 deletions.
27 changes: 27 additions & 0 deletions include/simsimd/dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,22 @@ SIMSIMD_INTERNAL __m256 _simsimd_bf16x8_to_f32x8_haswell(__m128i a) {
return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16));
}

SIMSIMD_INTERNAL __m128i _simsimd_f32x8_to_bf16x8_haswell(__m256 a) {
// Pack the 32-bit integers into 16-bit integers.
// This is less trivial than unpacking: https://stackoverflow.com/a/77781241/2766161
// The best approach is to shuffle within lanes first: https://stackoverflow.com/a/49723746/2766161
// Our shuffling mask will drop the low 2-bytes from every 4-byte word.
__m256i trunc_elements = _mm256_shuffle_epi8( //
_mm256_castps_si256(a), //
_mm256_set_epi8( //
-1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 11, 10, 7, 6, 3, 2, //
-1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 11, 10, 7, 6, 3, 2 //
));
__m256i ordered = _mm256_permute4x64_epi64(trunc_elements, 0x58);
__m128i result = _mm256_castsi256_si128(ordered);
return result;
}

SIMSIMD_INTERNAL __m128i _simsimd_partial_load_bf16x8_haswell(simsimd_bf16_t const* a, simsimd_size_t n) {
// In case the software emulation for `bf16` scalars is enabled, the `simsimd_bf16_to_f32`
// function will run. It is extremely slow, so even for the tail, let's combine serial
Expand Down Expand Up @@ -1256,6 +1272,17 @@ SIMSIMD_INTERNAL simsimd_f64_t _simsimd_reduce_f32x16_skylake(__m512 a) {
return _mm_cvtss_f32(_mm_hadd_ps(r, r));
}

SIMSIMD_INTERNAL __m512 _simsimd_bf16x16_to_f32x16_skylake(__m256i a) {
// Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like:
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
}

SIMSIMD_INTERNAL __m256i _simsimd_f32x16_to_bf16x16_skylake(__m512 a) {
// Add 2^15 and right shift 16 to do round-nearest
__m512i x = _mm512_srli_epi32(_mm512_add_epi32(_mm512_castps_si512(a), _mm512_set1_epi32(1 << 15)), 16);
return _mm512_cvtepi32_epi16(x);
}

SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n,
simsimd_distance_t* result) {
__m512 ab_vec = _mm512_setzero();
Expand Down
Loading

0 comments on commit 5a56a25

Please sign in to comment.