Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
support llama shapes, add new UT case, update new api of dpcpp
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jun 6, 2024
1 parent 1f07556 commit 8ed999d
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 146 deletions.
3 changes: 1 addition & 2 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ class WeightKBlockNInteger {
return tmp;
}

AUTOCALL void convertTransStorage(StorageWeight& srcstor, StorageWeight& dststor,
parallel::IThreading* threading) {
AUTOCALL void convertTransStorage(StorageWeight& srcstor, StorageWeight& dststor, parallel::IThreading* threading) {
auto s8buf = utils::amalloc<int8_t>((size_t)srcstor.mK * srcstor.mN);
auto s8transbuf = utils::amalloc<int8_t>((size_t)srcstor.mKPad * srcstor.mNPad);
unpackWeight(srcstor.mN, srcstor.mK, &srcstor, s8buf, srcstor.mN, threading);
Expand Down
12 changes: 6 additions & 6 deletions bestla/bestla/sycl/sycl_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ class SyclDevice {
void print() {
std::cout << "Running on device: " << mQueue.get_device().get_info<sycl::info::device::name>() << "\n";
if (is_gpu(mQueue.get_device())) {
std::cout << "EU count:" << mQueue.get_device().get_info<sycl::info::device::ext_intel_gpu_eu_count>() << "\n";
std::cout << "EU count:" << mQueue.get_device().get_info<sycl::ext::intel::info::device::gpu_eu_count>() << "\n";
std::cout << "EU count per subslice:"
<< mQueue.get_device().get_info<sycl::info::device::ext_intel_gpu_eu_count_per_subslice>() << "\n";
std::cout << "EU SIMD width:" << mQueue.get_device().get_info<sycl::info::device::ext_intel_gpu_eu_simd_width>()
<< mQueue.get_device().get_info<sycl::ext::intel::info::device::gpu_eu_count_per_subslice>() << "\n";
std::cout << "EU SIMD width:" << mQueue.get_device().get_info<sycl::ext::intel::info::device::gpu_eu_simd_width>()
<< "\n";
std::cout << "HW threads per EU:"
<< mQueue.get_device().get_info<sycl::info::device::ext_intel_gpu_hw_threads_per_eu>() << "\n";
std::cout << "GPU slices:" << mQueue.get_device().get_info<sycl::info::device::ext_intel_gpu_slices>() << "\n";
<< mQueue.get_device().get_info<sycl::ext::intel::info::device::gpu_hw_threads_per_eu>() << "\n";
std::cout << "GPU slices:" << mQueue.get_device().get_info<sycl::ext::intel::info::device::gpu_slices>() << "\n";
std::cout << "Subslice per slice:"
<< mQueue.get_device().get_info<sycl::info::device::ext_intel_gpu_subslices_per_slice>() << "\n";
<< mQueue.get_device().get_info<sycl::ext::intel::info::device::gpu_subslices_per_slice>() << "\n";
}
std::cout << "Global Memory size: " << getGlobalMemSizeGB() << "\n";
}
Expand Down
6 changes: 3 additions & 3 deletions bestla/bestla/sycl/sycl_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ class SGemmCoreSharedB {

using SLM_B_Acc = sycl::local_accessor<TB, 1>;

using AType = typename TA;
using BType = typename TB;
using CType = typename TC;
using AType = TA;
using BType = TB;
using CType = TC;
static auto constexpr NTILE = WgNEle;
static auto constexpr MTILE = WgMEle;
static auto constexpr KTILE = TileK;
Expand Down
223 changes: 150 additions & 73 deletions bestla/bestla/sycl/sycl_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ class WeightS4Trans {
auto sptr = S_d + sg_k / blocksize + g_n * ldb;
auto bptr = B_d + (sg_k + g_n * ldbn) / 2;
auto dbptr = outptr + sg_k + g_n * k;
float tmp[TileK];
int constexpr Unroll = 4;
#pragma unroll
for (int ik = 0; ik < TileK; ik += Unroll) {
Expand Down Expand Up @@ -369,85 +368,163 @@ class WeightS4Trans {
auto B = paramB.B;
auto B_scale = paramB.scale;
int ldb = paramB.ldb;
int constexpr Unroll = 2;
int constexpr SgSize = 16;
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
sycl::range<1> group{SgSize};
sycl::range<1> problem{n * SgSize};

auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[cl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
if (k % (SgSize * 32 * Unroll) == 0) {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 = *(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 = *(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc +=
CType(aptr[sg_id * TileK + ikk]) * static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc +=
CType(aptr[sg_id * TileK + ikk + 1]) * static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
});
return ev;
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc += CType(aptr[sg_id * TileK + ikk]) *
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) *
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
});
return ev;
} else {
int constexpr TileK = 8;
int constexpr GroupK = SgSize * TileK;
auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc += CType(aptr[sg_id * TileK + ikk]) *
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) *
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
});
return ev;
}
}
};
} // namespace sycl_prologue_b
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/sycl/sycl_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Launcher {
sycl::local_accessor<BType, 1> slm_b(sycl::range(GemmCore::SLM_B_Size), cgh);
cgh.parallel_for(
sycl::nd_range<2>(problem, group),
[=](sycl::nd_item<2> it) [[cl::reqd_work_group_size(
[=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size(
1, GemmCore::WgM,
GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] {
sycl_utils::nd_item_helper<GemmCore> helper(it);
Expand Down Expand Up @@ -153,7 +153,7 @@ class LauncherWOQ {
sycl::local_accessor<BType, 1> slm_b(sycl::range(GemmCore::SLM_B_Size), cgh);
cgh.parallel_for(
sycl::nd_range<2>(problem, group),
[=](sycl::nd_item<2> it) [[cl::reqd_work_group_size(
[=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size(
1, GemmCore::WgM,
GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] {
sycl_utils::nd_item_helper<GemmCore> helper(it);
Expand Down
Loading

0 comments on commit 8ed999d

Please sign in to comment.