diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index de577ae64..34610d90a 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -80,6 +80,21 @@ static inline void convert_s4_s8_16_sse(int8_t* dstptr, int8_t* srcptr) { _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0); } +template +static inline void convert_s4_s8_32_avx(int8_t* dstptr, int8_t* srcptr) { + const __m128i xmm = _mm_loadu_si128((const __m128i*)srcptr); + const __m256i decmp_v = _mm256_inserti128_si256(_mm256_castsi128_si256(xmm), _mm_srli_epi16(xmm, 4), 0x01); + const __m256i lowMask = _mm256_set1_epi8(0xF); + auto dst0 = _mm256_and_si256(decmp_v, lowMask); + if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { + auto s8 = _mm256_set1_epi8(8); + dst0 = _mm256_sub_epi8(dst0, s8); + } else { + dst0 = _mm256_slli_epi32(dst0, 4); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); +} + template static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { auto xmm = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); @@ -369,9 +384,11 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t elesize = static_cast(row) * col; size_t ele16 = utils::padto_le(elesize, 16); size_t i = 0; - for (; i < ele16; i += 16) { - convert_s4_s8_16_sse(dstptr + i, reinterpret_cast(srcptr + i / 2)); - } + // for (; i < ele16; i += 16) { + // convert_s4_s8_16_sse(dstptr + i, reinterpret_cast(srcptr + i / 2)); + // } + size_t ele32 = utils::padto_le(elesize, 32); + for (; i < ele32; i += 32) convert_s4_s8_32_avx(dstptr + i, reinterpret_cast(srcptr + i / 2)); for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; dstptr[i + 0] = kernel::ref::get_s8(tmp.x); diff --git a/bestla/bestla/ut/bestla_ut.h b/bestla/bestla/ut/bestla_ut.h index b9e40f54e..93edf44b1 100644 --- a/bestla/bestla/ut/bestla_ut.h +++ b/bestla/bestla/ut/bestla_ut.h @@ -63,7 +63,8 @@ static inline int auto_batch(size_t memsize) { GetCPUDevice(); auto L3 = _cd->getL3CacheSize(); size_t constexpr Enlarge = 4; - auto batch = L3 * Enlarge / memsize; + size_t constexpr TargetMem = 1LL << 30; + auto batch = TargetMem / memsize; return batch > 1 ? batch : 2; }