Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear_int4_kernel for XPU #1130

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
67 changes: 67 additions & 0 deletions src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

#include <ATen/core/op_registration/adaption.h>
#include <ATen/div_rtn.h>
#include <ATen/native/TensorIterator.h>
#include <torch/library.h>

#include <ATen/native/xpu/sycl/LinearInt4.h>
#include <comm/xpu_aten.h>

namespace at::native {
Tensor& linear_int4_xpu(
const Tensor& input,
const Tensor& weight,
int qGroupSize,
const Tensor& weight_scale_zero_point) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
TORCH_CHECK(
input.dtype() == kBFloat16 || input.dtype() == kHalf ||
input.dtype() == kFloat,
__func__,
" : expect input to be either 32-bit or 16-bit float tensor.");

TORCH_CHECK(
weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor.");
TORCH_CHECK(
weight.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(
weight.size(1) == K / 2,
__func__,
" : expect B.size(1) to be K/2, got ",
weight.size(1));

TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256,
__func__,
": expect qGroupSize to be 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
weight_scale_zero_point.dim() == 3 &&
weight_scale_zero_point.size(1) == N &&
weight_scale_zero_point.size(2) == 2,
__func__,
": expect weight_scale_zero_point to be 3d tensor with sizes [:, ",
N,
", 2]");

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device, input, "xpu::linear_int4", "input");
c10::impl::check_and_update_common_device(
common_device, weight, "xpu::linear_int4", "weight");
c10::impl::check_and_update_common_device(
common_device,
weight_scale_zero_point,
"xpu::linear_int4",
"weight_scale_zero_point");
Tensor output = at::empty({M, N}, input.options());

at::native::xpu::linear_int4_kernel(
input, weight, qGroupSize, weight_scale_zero_point, output);
return output;
}
} // namespace at::native
167 changes: 167 additions & 0 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#include <ATen/native/xpu/sycl/LinearInt4.h>
#include <comm/SYCLContext.h>

namespace at::native::xpu {

template <typename scalar_t = at::Half, int block_size = 16>
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
LinearInt4KernelFunctor(
const scalar_t* A,
const uint32_t* B,
scalar_t* C,
const scalar_t* B_scale,
const scalar_t* B_zero_point,
int m,
int n,
int k,
int lda,
int ldb,
int ldc)
: A(A),
B(B),
C(C),
B_scale(B_scale),
B_zero_point(B_zero_point),
m(m),
n(n),
k(k),
lda(lda),
ldb(ldb),
ldc(ldc) {}
void sycl_ker_config_convention(sycl::handler& cgh) {
// local_scan_ = sycl_local_acc_t<T>(N_, cgh);
}

void operator()(sycl::nd_item<1> item) const {
int constexpr Unroll = 2;
int constexpr SgSize = 16;
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
int constexpr blocksize = 16;

int g_idx = item.get_group(0);
auto sg = item.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<scalar_t, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it safe to use half as acc type?

usually, the acc type for both float16 and bfloat16 are float32

ref: https://github.com/pytorch/pytorch/blob/795f28ac552eb61d02ea02fd64637ba814133bd8/aten/src/ATen/native/cuda/int4mm.cu#L727

uint8_t tmps8[TileK / 2];
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
scalar_t scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
Comment on lines +60 to +65

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do vectorized load and shift with sycl? i don't know.
if not, i guess this is best perf that we can get so far. this line should be the major bottlenecks.

}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
scalar_t tmpAcc = 0.f;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to be VERY careful about the acc type. Slight difference between CUDA may lead to accuracy errors that are very very difficult to debug in a finla e2e model, especially in LLM

int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
scalar_t scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) *
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) *
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
}

private:
const scalar_t* A;
const uint32_t* B;
scalar_t* C;
const scalar_t* B_scale;
const scalar_t* B_zero_point;
int m;
int n;
int k;
int lda;
int ldb;
int ldc;
};

void linear_int4_kernel(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the naming for this func linear_int4_kernel may better follow current cpu and cuda impl
would int4pack_mm_sycl be a better name?

try to blend in the style of pytorch, if someone from Meta or Nvidia wrote this, what would they do?

const Tensor& input,
const Tensor& weight,
int qGroupSize,
const Tensor& weight_scale_zero_point,
Tensor& output) {
auto& sycl_queue = at::xpu::getCurrentSYCLQueue();
int64_t m = input.size(0);
int64_t n = input.size(1);
int64_t k = output.size(1);
int constexpr Unroll = 2;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused variable for Unroll.

and is there routine for gramma formating in this project:
constepxr int or int constexpr?

int constexpr SgSize = 16;
sycl::range<1> local_range{SgSize};
sycl::range<1> global_range{static_cast<size_t>(n) * SgSize};
int lda = k;
int ldb = n;
int ldc = n;
Comment on lines +139 to +141

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically you need to add a bunch of checks here, including:

  • whether the input(s) are contiguous
  • dtype checks
  • shape checks

you can take example from cpu counterpart: https://github.com/pytorch/pytorch/blob/795f28ac552eb61d02ea02fd64637ba814133bd8/aten/src/ATen/native/LinearAlgebra.cpp#L3457

one assumption that you can get

  int lda = k;
  int ldb = n;
  int ldc = n;

is that they are contiguous, which may not always be true.

i know this won't cause any trouble 99% of the times without any checks. But this is going to save you tons of trouble if it is actually used in the wild world. if we don't think ahead, the troubles that we buried here and there would finally be disaster.

if (input.scalar_type() == at::kHalf) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the input are not float16, report an runtime error is a doable approach. Or a better rountine is to do the following:

// TODO: this kernels supports only xxx dtypes for now, we will support yyy, zzz later on.
TORCH_CHECK(input.scalar_type() == kHalf, "linear_int4_kernel: expect input to be Half tensors!")

using scalar_t = at::Half;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not LIKE a pytorch coding stype, aten uses AT_DISPATCH_xxx to do this job, I assume you plan to do both float16 and blfoat16? in that case, the one you need is AT_DISPATCH_REDUCED_FLOATING_TYPES

// const auto scalar_t = input.scalar_type();
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8

scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);
Comment on lines +156 to +161

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think having lda, ldb, ldc in LinearInt4KernelFunctor would be useless, normally the input(s) should be contiguous, (while you can just report runtime error if they are not). SO, ldx would never be some unexpected values rather, n, k since you are not dealing with strided tensors.


sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
}
}

} // namespace at::native::xpu
13 changes: 13 additions & 0 deletions src/ATen/native/xpu/sycl/LinearInt4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once
#include <comm/xpu_aten.h>

namespace at::native::xpu {

TORCH_XPU_API void linear_int4_kernel(
const Tensor& input,
const Tensor& weight,
int qGroupSize,
const Tensor& weight_scale_zero_point,
Tensor& output);

} // namespace at::native::xpu
Loading
Loading