From 8ed999d0f2dd82fa41a0c812ae3e9694574290cc Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Thu, 6 Jun 2024 15:59:34 +0800 Subject: [PATCH] support llama shapes, add new UT case, update new api of dpcpp --- bestla/bestla/bestla_prologue_b.h | 3 +- bestla/bestla/sycl/sycl_device.h | 12 +- bestla/bestla/sycl/sycl_gemm.h | 6 +- bestla/bestla/sycl/sycl_prologue_b.h | 223 ++++++++++++++++++--------- bestla/bestla/sycl/sycl_wrapper.h | 4 +- bestla/bestla/ut/sycl_benchmark.cpp | 72 ++++----- bestla/bestla/ut/sycl_gemm.cpp | 46 +++--- bestla/bestla/ut/sycl_misc.cpp | 4 +- 8 files changed, 224 insertions(+), 146 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 2926699d7..fa05c7019 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -126,8 +126,7 @@ class WeightKBlockNInteger { return tmp; } - AUTOCALL void convertTransStorage(StorageWeight& srcstor, StorageWeight& dststor, - parallel::IThreading* threading) { + AUTOCALL void convertTransStorage(StorageWeight& srcstor, StorageWeight& dststor, parallel::IThreading* threading) { auto s8buf = utils::amalloc((size_t)srcstor.mK * srcstor.mN); auto s8transbuf = utils::amalloc((size_t)srcstor.mKPad * srcstor.mNPad); unpackWeight(srcstor.mN, srcstor.mK, &srcstor, s8buf, srcstor.mN, threading); diff --git a/bestla/bestla/sycl/sycl_device.h b/bestla/bestla/sycl/sycl_device.h index 048d84bcf..eca3adb7d 100644 --- a/bestla/bestla/sycl/sycl_device.h +++ b/bestla/bestla/sycl/sycl_device.h @@ -67,16 +67,16 @@ class SyclDevice { void print() { std::cout << "Running on device: " << mQueue.get_device().get_info() << "\n"; if (is_gpu(mQueue.get_device())) { - std::cout << "EU count:" << mQueue.get_device().get_info() << "\n"; + std::cout << "EU count:" << mQueue.get_device().get_info() << "\n"; std::cout << "EU count per subslice:" - << mQueue.get_device().get_info() << "\n"; - std::cout << "EU SIMD width:" << mQueue.get_device().get_info() + << mQueue.get_device().get_info() << "\n"; + std::cout << "EU SIMD width:" << mQueue.get_device().get_info() << "\n"; std::cout << "HW threads per EU:" - << mQueue.get_device().get_info() << "\n"; - std::cout << "GPU slices:" << mQueue.get_device().get_info() << "\n"; + << mQueue.get_device().get_info() << "\n"; + std::cout << "GPU slices:" << mQueue.get_device().get_info() << "\n"; std::cout << "Subslice per slice:" - << mQueue.get_device().get_info() << "\n"; + << mQueue.get_device().get_info() << "\n"; } std::cout << "Global Memory size: " << getGlobalMemSizeGB() << "\n"; } diff --git a/bestla/bestla/sycl/sycl_gemm.h b/bestla/bestla/sycl/sycl_gemm.h index 4bf50ed91..e98f01aa0 100644 --- a/bestla/bestla/sycl/sycl_gemm.h +++ b/bestla/bestla/sycl/sycl_gemm.h @@ -64,9 +64,9 @@ class SGemmCoreSharedB { using SLM_B_Acc = sycl::local_accessor; - using AType = typename TA; - using BType = typename TB; - using CType = typename TC; + using AType = TA; + using BType = TB; + using CType = TC; static auto constexpr NTILE = WgNEle; static auto constexpr MTILE = WgMEle; static auto constexpr KTILE = TileK; diff --git a/bestla/bestla/sycl/sycl_prologue_b.h b/bestla/bestla/sycl/sycl_prologue_b.h index 6d864de5f..947007466 100644 --- a/bestla/bestla/sycl/sycl_prologue_b.h +++ b/bestla/bestla/sycl/sycl_prologue_b.h @@ -224,7 +224,6 @@ class WeightS4Trans { auto sptr = S_d + sg_k / blocksize + g_n * ldb; auto bptr = B_d + (sg_k + g_n * ldbn) / 2; auto dbptr = outptr + sg_k + g_n * k; - float tmp[TileK]; int constexpr Unroll = 4; #pragma unroll for (int ik = 0; ik < TileK; ik += Unroll) { @@ -369,85 +368,163 @@ class WeightS4Trans { auto B = paramB.B; auto B_scale = paramB.scale; int ldb = paramB.ldb; + int constexpr Unroll = 2; int constexpr SgSize = 16; - int constexpr TileK = 32; - int constexpr GroupK = SgSize * TileK; sycl::range<1> group{SgSize}; sycl::range<1> problem{n * SgSize}; - - auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl::nd_range<1>(problem, group), - [=](sycl::nd_item<1> it) [[cl::reqd_work_group_size( - 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { - int g_idx = it.get_group(0); - auto sg = it.get_sub_group(); - int sg_id = sg.get_local_id()[0]; - int g_n = g_idx; - auto sptr = B_scale + g_n * ldb; - auto bptr = B + g_n * k / 2; - auto aptr = A; - auto cptr = C + g_n; - if constexpr (std::is_same_v) { - sycl::half2 tmpAcc = {0.f, 0.f}; - int constexpr Unroll = 2; - for (int i = 0; i < k; i += GroupK * Unroll) { + if (k % (SgSize * 32 * Unroll) == 0) { + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size( + 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - CType scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; - sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), - static_cast((tmps8[ikk / 2] >> 4) - 8)}; - tmpAcc += tmpA * tmpB * scale; - } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; - } - } - sycl::half2 sum = {0.f, 0.f}; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum[0] + sum[1]; - } - } else { - CType tmpAcc = 0.f; - int constexpr Unroll = 2; - for (int i = 0; i < k; i += GroupK * Unroll) { + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + CType tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - CType scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - tmpAcc += - CType(aptr[sg_id * TileK + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; - tmpAcc += - CType(aptr[sg_id * TileK + ikk + 1]) * static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; - } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; - } - } - float sum = 0.f; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum; - } - } - }); - }); - return ev; + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += CType(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + }); + }); + return ev; + } else { + int constexpr TileK = 8; + int constexpr GroupK = SgSize * TileK; + auto ev = q->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size( + 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + CType tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += CType(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + }); + }); + return ev; + } } }; } // namespace sycl_prologue_b diff --git a/bestla/bestla/sycl/sycl_wrapper.h b/bestla/bestla/sycl/sycl_wrapper.h index 1ed4b8212..b5ecadbd8 100644 --- a/bestla/bestla/sycl/sycl_wrapper.h +++ b/bestla/bestla/sycl/sycl_wrapper.h @@ -61,7 +61,7 @@ class Launcher { sycl::local_accessor slm_b(sycl::range(GemmCore::SLM_B_Size), cgh); cgh.parallel_for( sycl::nd_range<2>(problem, group), - [=](sycl::nd_item<2> it) [[cl::reqd_work_group_size( + [=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size( 1, GemmCore::WgM, GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { sycl_utils::nd_item_helper helper(it); @@ -153,7 +153,7 @@ class LauncherWOQ { sycl::local_accessor slm_b(sycl::range(GemmCore::SLM_B_Size), cgh); cgh.parallel_for( sycl::nd_range<2>(problem, group), - [=](sycl::nd_item<2> it) [[cl::reqd_work_group_size( + [=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size( 1, GemmCore::WgM, GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { sycl_utils::nd_item_helper helper(it); diff --git a/bestla/bestla/ut/sycl_benchmark.cpp b/bestla/bestla/ut/sycl_benchmark.cpp index bd6266ab3..fa970a740 100644 --- a/bestla/bestla/ut/sycl_benchmark.cpp +++ b/bestla/bestla/ut/sycl_benchmark.cpp @@ -47,9 +47,9 @@ class Benchmark_Fp32Fp32 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -122,9 +122,9 @@ class Benchmark_Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, {{A, k}, {B, n}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -197,9 +197,9 @@ class Benchmark_S4Fp32Fp32 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, n}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, n}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -224,9 +224,9 @@ class Benchmark_S4Fp32Fp32 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncherT::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, blks}, {C, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncherT::compute(q, m, n, k, 128, {{A, k}, {B, B_scale, blks}, {C, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -258,9 +258,9 @@ class Benchmark_S4Fp32Fp32 { int constexpr GroupK = SgSize * TileK; sycl::range<1> group{SgSize}; sycl::range<1> problem{n * SgSize}; - auto e_esimd = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -352,9 +352,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {B_d, S_d, n}, {C_d, n}}); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {B_d, S_d, n}, {C_d, n}}); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -381,9 +381,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -410,9 +410,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = KernelTLauncher::compute({m, n, k, blocksize, {A_d, k}, {B_d, S_d, blks}, {C_d, n}}, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -438,9 +438,9 @@ class Benchmark_S4Fp16Fp16 { tm.start(); while (tm.stop() < timems) { for (size_t i = 0; i < batch; i++) { - auto e_esimd = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= timems) { break; } @@ -524,9 +524,9 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = + auto ev = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - log.add(event_helper::execute_time(e_esimd) * 1000); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } @@ -576,9 +576,9 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = + auto ev = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - log.add(event_helper::execute_time(e_esimd) * 1000); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } @@ -627,8 +627,8 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - log.add(event_helper::execute_time(e_esimd) * 1000); + auto ev = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } @@ -673,10 +673,10 @@ class Benchmark_DequantS4 { tm.start(); while (tm.stop() < TestMs) { for (size_t i = 0; i < 1; i++) { - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {dB.data(), dS.data(), n}, + auto ev = ProB::dequant_s4(n, k, blocksize, {dB.data(), dS.data(), n}, dequantB.data(), q); - e_esimd.wait(); - log.add(event_helper::execute_time(e_esimd) * 1000); + ev.wait(); + log.add(event_helper::execute_time(ev) * 1000); if (tm.stop() >= TestMs) { break; } diff --git a/bestla/bestla/ut/sycl_gemm.cpp b/bestla/bestla/ut/sycl_gemm.cpp index 64fd731e7..92c0f578e 100644 --- a/bestla/bestla/ut/sycl_gemm.cpp +++ b/bestla/bestla/ut/sycl_gemm.cpp @@ -40,8 +40,8 @@ class UT_SyclSGemm { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 4).wait(); buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -83,8 +83,8 @@ class UT_SyclHGemm { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, {{A_d, k}, {B_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 2).wait(); buffer_error(ref.data(), matC.data(), ref.size(), utils::fp16(0.2f)); } @@ -97,6 +97,7 @@ class UT_SyclS4SGemm { public: UT_SyclS4SGemm() { UT_START(); + utT(6, 4096, 11008, 128); ut(6, 32000, 4096, 128); utT(6, 32000, 4096, 128); ut(300, 1024, 1024, 32); @@ -148,8 +149,8 @@ class UT_SyclS4SGemm { auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); utils::GemmProblem gp(1, m, n, k); - auto e_esimd = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 4).wait(); buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -189,8 +190,8 @@ class UT_SyclS4SGemm { auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); utils::GemmProblem gp(1, m, n, k); - auto e_esimd = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 4).wait(); buffer_error(ref.data(), matC.data(), ref.size(), 0.001f); } @@ -249,8 +250,8 @@ class UT_SyclS4HGemm { auto Bs8_d = dBs8.data(); auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); - auto e_esimd = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, n}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 2).wait(); buffer_error(ref.data(), matC.data(), ref.size(), utils::fp16(0.2f)); } @@ -287,8 +288,8 @@ class UT_SyclS4HGemm { auto Bs8_d = dBs8.data(); auto B_scale_d = dB_scale.data(); auto C_d = dC.data(); - auto e_esimd = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); - e_esimd.wait(); + auto ev = KernelTLauncher::compute(q, m, n, k, blocksize, {{A_d, k}, {Bs8_d, B_scale_d, blks}, {C_d, n}}); + ev.wait(); q->memcpy(matC.data(), C_d, matC.size() * 2).wait(); buffer_error(ref.data(), matC.data(), ref.size(), utils::fp16(0.2f)); } @@ -331,8 +332,8 @@ class UT_SyclInt4Dequant { auto S_d = dS.data(); auto B_d = dB.data(); auto DB_d = dequantB.data(); - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, n}, DB_d, q); - e_esimd.wait(); + auto ev = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, n}, DB_d, q); + ev.wait(); q->memcpy(dequant.data(), DB_d, dequant.size() * 4).wait(); buffer_error(ref.data(), dequant.data(), dequant.size(), 0.001f); } @@ -363,15 +364,15 @@ class UT_SyclInt4Dequant { auto S_d = dS.data(); auto B_d = dB.data(); auto DB_d = dequantB.data(); - auto e_esimd = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - e_esimd.wait(); + auto ev = ProB::dequant_s4(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); + ev.wait(); q->memcpy(dequant.data(), DB_d, dequant.size() * 4).wait(); buffer_error(ref.data(), dequant.data(), dequant.size(), 0.001f); avector refNT(k * n); kernel::wrapper::Transpose2D::forward(ref.data(), refNT.data(), n, k, k, n); - e_esimd = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); - e_esimd.wait(); + ev = ProB::dequant_s4_trans(n, k, blocksize, {B_d, S_d, blks}, DB_d, q); + ev.wait(); q->memcpy(dequant.data(), DB_d, dequant.size() * 4).wait(); buffer_error(refNT.data(), dequant.data(), dequant.size(), 0.001f); } @@ -384,6 +385,7 @@ class UT_SyclS4Gemv { public: UT_SyclS4Gemv() { UT_START(); + ut_T(1024, 11008, 32); ut_T(1024, 1024, 32); ut_half(1024, 1024, 32); } @@ -432,8 +434,8 @@ class UT_SyclS4Gemv { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); + auto ev = ProBTransT::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); + ev.wait(); q->memcpy(C.data(), C_d, C.size() * 4).wait(); buffer_error(refC.data(), C.data(), C.size(), 0.001f); } @@ -473,9 +475,9 @@ class UT_SyclS4Gemv { auto A_d = dA.data(); auto B_d = dB.data(); auto C_d = dC.data(); - auto e_esimd = sycl_prologue_b::WeightS4Trans::gemv(A_d, {B_d, S_d, blks}, C_d, + auto ev = sycl_prologue_b::WeightS4Trans::gemv(A_d, {B_d, S_d, blks}, C_d, n, k, blocksize, q); - e_esimd.wait(); + ev.wait(); q->memcpy(C.data(), C_d, C.size() * 2).wait(); buffer_error(refC.data(), C.data(), C.size(), utils::fp16(0.1f)); } diff --git a/bestla/bestla/ut/sycl_misc.cpp b/bestla/bestla/ut/sycl_misc.cpp index e346f2212..21e655fa0 100644 --- a/bestla/bestla/ut/sycl_misc.cpp +++ b/bestla/bestla/ut/sycl_misc.cpp @@ -157,9 +157,9 @@ class UT_CompFp32 { sycl_stor.assign(dbuf.data()); sycl_stor.fromHost(transtor, q); int blks = updiv(k, blocksize); - auto e_esimd = ProBTransT::gemv(dA.data(), {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dC.data(), n, + auto ev = ProBTransT::gemv(dA.data(), {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dC.data(), n, k, blocksize, q); - e_esimd.wait(); + ev.wait(); q->memcpy(matC.data(), dC.data(), matC.size() * 4).wait(); auto err = get_ut_err(qtype);