-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linear_int4_kernel for XPU #1130
base: main
Are you sure you want to change the base?
Changes from all commits
49436fb
985dd94
f1135fd
2843d7d
f1e8faa
b88702b
13c4aef
3ca48f7
13d9772
97a552e
455c4c8
31e97cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
|
||
#include <ATen/core/op_registration/adaption.h> | ||
#include <ATen/div_rtn.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <torch/library.h> | ||
|
||
#include <ATen/native/xpu/sycl/LinearInt4.h> | ||
#include <comm/xpu_aten.h> | ||
|
||
namespace at::native { | ||
Tensor& linear_int4_xpu( | ||
const Tensor& input, | ||
const Tensor& weight, | ||
int qGroupSize, | ||
const Tensor& weight_scale_zero_point) { | ||
auto M = input.size(0); | ||
auto N = weight.size(0); | ||
auto K = input.size(1); | ||
TORCH_CHECK( | ||
input.dtype() == kBFloat16 || input.dtype() == kHalf || | ||
input.dtype() == kFloat, | ||
__func__, | ||
" : expect input to be either 32-bit or 16-bit float tensor."); | ||
|
||
TORCH_CHECK( | ||
weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor."); | ||
TORCH_CHECK( | ||
weight.is_contiguous(), __func__, " : expect B to be contiguous."); | ||
TORCH_CHECK( | ||
weight.size(1) == K / 2, | ||
__func__, | ||
" : expect B.size(1) to be K/2, got ", | ||
weight.size(1)); | ||
|
||
TORCH_CHECK( | ||
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || | ||
qGroupSize == 256, | ||
__func__, | ||
": expect qGroupSize to be 32, 64, 128 or 256, got ", | ||
qGroupSize); | ||
|
||
TORCH_CHECK( | ||
weight_scale_zero_point.dim() == 3 && | ||
weight_scale_zero_point.size(1) == N && | ||
weight_scale_zero_point.size(2) == 2, | ||
__func__, | ||
": expect weight_scale_zero_point to be 3d tensor with sizes [:, ", | ||
N, | ||
", 2]"); | ||
|
||
std::optional<Device> common_device = std::nullopt; | ||
c10::impl::check_and_update_common_device( | ||
common_device, input, "xpu::linear_int4", "input"); | ||
c10::impl::check_and_update_common_device( | ||
common_device, weight, "xpu::linear_int4", "weight"); | ||
c10::impl::check_and_update_common_device( | ||
common_device, | ||
weight_scale_zero_point, | ||
"xpu::linear_int4", | ||
"weight_scale_zero_point"); | ||
Tensor output = at::empty({M, N}, input.options()); | ||
|
||
at::native::xpu::linear_int4_kernel( | ||
input, weight, qGroupSize, weight_scale_zero_point, output); | ||
return output; | ||
} | ||
} // namespace at::native |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
#include <ATen/native/xpu/sycl/LinearInt4.h> | ||
#include <comm/SYCLContext.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t = at::Half, int block_size = 16> | ||
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { | ||
LinearInt4KernelFunctor( | ||
const scalar_t* A, | ||
const uint32_t* B, | ||
scalar_t* C, | ||
const scalar_t* B_scale, | ||
const scalar_t* B_zero_point, | ||
int m, | ||
int n, | ||
int k, | ||
int lda, | ||
int ldb, | ||
int ldc) | ||
: A(A), | ||
B(B), | ||
C(C), | ||
B_scale(B_scale), | ||
B_zero_point(B_zero_point), | ||
m(m), | ||
n(n), | ||
k(k), | ||
lda(lda), | ||
ldb(ldb), | ||
ldc(ldc) {} | ||
void sycl_ker_config_convention(sycl::handler& cgh) { | ||
// local_scan_ = sycl_local_acc_t<T>(N_, cgh); | ||
} | ||
|
||
void operator()(sycl::nd_item<1> item) const { | ||
int constexpr Unroll = 2; | ||
int constexpr SgSize = 16; | ||
int constexpr TileK = 32; | ||
int constexpr GroupK = SgSize * TileK; | ||
int constexpr blocksize = 16; | ||
|
||
int g_idx = item.get_group(0); | ||
auto sg = item.get_sub_group(); | ||
int sg_id = sg.get_local_id()[0]; | ||
int g_n = g_idx; | ||
auto sptr = B_scale + g_n * ldb; | ||
auto bptr = B + g_n * k / 2; | ||
auto aptr = A; | ||
auto cptr = C + g_n; | ||
if constexpr (std::is_same_v<scalar_t, sycl::half>) { | ||
sycl::half2 tmpAcc = {0.f, 0.f}; | ||
uint8_t tmps8[TileK / 2]; | ||
for (int i = 0; i < k; i += GroupK * Unroll) { | ||
#pragma unroll | ||
for (int iu = 0; iu < Unroll; iu++) { | ||
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 = | ||
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2); | ||
scalar_t scale = *(sptr + sg_id * TileK / blocksize); | ||
#pragma unroll | ||
for (int ikk = 0; ikk < TileK; ikk += 2) { | ||
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; | ||
sycl::half2 tmpB = { | ||
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8), | ||
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)}; | ||
tmpAcc += tmpA * tmpB * scale; | ||
Comment on lines
+60
to
+65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it possible to do vectorized load and shift with sycl? i don't know. |
||
} | ||
sptr += GroupK / blocksize; | ||
aptr += GroupK; | ||
bptr += GroupK / 2; | ||
} | ||
} | ||
sycl::half2 sum = {0.f, 0.f}; | ||
for (int i = 0; i < SgSize; i += 1) { | ||
sum += sg.shuffle(tmpAcc, i); | ||
} | ||
if (sg_id == 0) { | ||
*cptr = sum[0] + sum[1]; | ||
} | ||
} else { | ||
scalar_t tmpAcc = 0.f; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you need to be VERY careful about the acc type. Slight difference between CUDA may lead to accuracy errors that are very very difficult to debug in a finla e2e model, especially in LLM |
||
int constexpr Unroll = 2; | ||
for (int i = 0; i < k; i += GroupK * Unroll) { | ||
#pragma unroll | ||
for (int iu = 0; iu < Unroll; iu++) { | ||
uint8_t tmps8[TileK / 2]; | ||
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 = | ||
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2); | ||
scalar_t scale = *(sptr + sg_id * TileK / blocksize); | ||
#pragma unroll | ||
for (int ikk = 0; ikk < TileK; ikk += 2) { | ||
tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * | ||
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale; | ||
tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * | ||
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale; | ||
} | ||
sptr += GroupK / blocksize; | ||
aptr += GroupK; | ||
bptr += GroupK / 2; | ||
} | ||
} | ||
float sum = 0.f; | ||
for (int i = 0; i < SgSize; i += 1) { | ||
sum += sg.shuffle(tmpAcc, i); | ||
} | ||
if (sg_id == 0) { | ||
*cptr = sum; | ||
} | ||
} | ||
} | ||
|
||
private: | ||
const scalar_t* A; | ||
const uint32_t* B; | ||
scalar_t* C; | ||
const scalar_t* B_scale; | ||
const scalar_t* B_zero_point; | ||
int m; | ||
int n; | ||
int k; | ||
int lda; | ||
int ldb; | ||
int ldc; | ||
}; | ||
|
||
void linear_int4_kernel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the naming for this func try to blend in the style of pytorch, if someone from Meta or Nvidia wrote this, what would they do? |
||
const Tensor& input, | ||
const Tensor& weight, | ||
int qGroupSize, | ||
const Tensor& weight_scale_zero_point, | ||
Tensor& output) { | ||
auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); | ||
int64_t m = input.size(0); | ||
int64_t n = input.size(1); | ||
int64_t k = output.size(1); | ||
int constexpr Unroll = 2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused variable for and is there routine for gramma formating in this project: |
||
int constexpr SgSize = 16; | ||
sycl::range<1> local_range{SgSize}; | ||
sycl::range<1> global_range{static_cast<size_t>(n) * SgSize}; | ||
int lda = k; | ||
int ldb = n; | ||
int ldc = n; | ||
Comment on lines
+139
to
+141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. basically you need to add a bunch of checks here, including:
you can take example from cpu counterpart: https://github.com/pytorch/pytorch/blob/795f28ac552eb61d02ea02fd64637ba814133bd8/aten/src/ATen/native/LinearAlgebra.cpp#L3457 one assumption that you can get
is that they are contiguous, which may not always be true. i know this won't cause any trouble 99% of the times without any checks. But this is going to save you tons of trouble if it is actually used in the wild world. if we don't think ahead, the troubles that we buried here and there would finally be disaster. |
||
if (input.scalar_type() == at::kHalf) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the input are not float16, report an runtime error is a doable approach. Or a better rountine is to do the following:
|
||
using scalar_t = at::Half; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not LIKE a pytorch coding stype, aten uses |
||
// const auto scalar_t = input.scalar_type(); | ||
const scalar_t* input_data = input.data_ptr<scalar_t>(); | ||
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8 | ||
|
||
scalar_t* output_data = output.data_ptr<scalar_t>(); | ||
scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr<scalar_t>(); | ||
LinearInt4KernelFunctor<scalar_t, 16> kfn( | ||
input_data, | ||
weight_data, | ||
output_data, | ||
weight_scale_data, | ||
nullptr, | ||
m, | ||
n, | ||
k, | ||
k, | ||
n, | ||
n); | ||
Comment on lines
+156
to
+161
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think having lda, ldb, ldc in |
||
|
||
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); | ||
} | ||
} | ||
|
||
} // namespace at::native::xpu |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#pragma once | ||
#include <comm/xpu_aten.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
TORCH_XPU_API void linear_int4_kernel( | ||
const Tensor& input, | ||
const Tensor& weight, | ||
int qGroupSize, | ||
const Tensor& weight_scale_zero_point, | ||
Tensor& output); | ||
|
||
} // namespace at::native::xpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it safe to use half as acc type?
usually, the acc type for both float16 and bfloat16 are float32
ref: https://github.com/pytorch/pytorch/blob/795f28ac552eb61d02ea02fd64637ba814133bd8/aten/src/ATen/native/cuda/int4mm.cu#L727