-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linear_int4_kernel for XPU #1130
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the biggest question should be why we need post op fusion here? does pytorch have it with cuda?
int64_t m = input.size(0); | ||
int64_t n = input.size(1); | ||
int64_t k = output.size(1); | ||
int constexpr Unroll = 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused variable for Unroll
.
and is there routine for gramma formating in this project:
constepxr int
or int constexpr
?
int lda = k; | ||
int ldb = n; | ||
int ldc = n; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically you need to add a bunch of checks here, including:
- whether the input(s) are contiguous
- dtype checks
- shape checks
you can take example from cpu counterpart: https://github.com/pytorch/pytorch/blob/795f28ac552eb61d02ea02fd64637ba814133bd8/aten/src/ATen/native/LinearAlgebra.cpp#L3457
one assumption that you can get
int lda = k;
int ldb = n;
int ldc = n;
is that they are contiguous, which may not always be true.
i know this won't cause any trouble 99% of the times without any checks. But this is going to save you tons of trouble if it is actually used in the wild world. if we don't think ahead, the troubles that we buried here and there would finally be disaster.
int lda = k; | ||
int ldb = n; | ||
int ldc = n; | ||
if (input.scalar_type() == at::kHalf) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the input are not float16, report an runtime error is a doable approach. Or a better rountine is to do the following:
// TODO: this kernels supports only xxx dtypes for now, we will support yyy, zzz later on.
TORCH_CHECK(input.scalar_type() == kHalf, "linear_int4_kernel: expect input to be Half tensors!")
int ldb = n; | ||
int ldc = n; | ||
if (input.scalar_type() == at::kHalf) { | ||
using scalar_t = at::Half; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not LIKE a pytorch coding stype, aten uses AT_DISPATCH_xxx
to do this job, I assume you plan to do both float16 and blfoat16? in that case, the one you need is AT_DISPATCH_REDUCED_FLOATING_TYPES
if (input.scalar_type() == at::kHalf) { | ||
using scalar_t = at::Half; | ||
// const auto scalar_t = input.scalar_type(); | ||
scalar_t* input_data = input.data_ptr<scalar_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using const scalar_t* input_data = input.data_ptr<scalar_t>()';
would be better practice.
auto aptr = A; | ||
auto cptr = C + g_n; | ||
if constexpr (std::is_same_v<scalar_t, sycl::half>) { | ||
sycl::half2 tmpAcc = {0.f, 0.f}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it safe to use half as acc type?
usually, the acc type for both float16 and bfloat16 are float32
*cptr = sum[0] + sum[1]; | ||
} | ||
} else { | ||
scalar_t tmpAcc = 0.f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to be VERY careful about the acc type. Slight difference between CUDA may lead to accuracy errors that are very very difficult to debug in a finla e2e model, especially in LLM
for (int ikk = 0; ikk < TileK; ikk += 2) { | ||
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; | ||
sycl::half2 tmpB = { | ||
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8), | ||
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)}; | ||
tmpAcc += tmpA * tmpB * scale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to do vectorized load and shift with sycl? i don't know.
if not, i guess this is best perf that we can get so far. this line should be the major bottlenecks.
zero_points = torch.Tensor([8]).to(torch.int8).to("xpu") | ||
weight_ba = weight.transpose(0, 1).contiguous() | ||
|
||
out_onednn =torch._weight_int4pack_mm_with_scales_and_zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a more general question is that where are we placing _weight_int4pack_mm_with_scales_and_zeros
, pytorch does not have this right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be added in pytorch/pytorch#137566
) | ||
|
||
# check gemm + bias + gelu | ||
out_onednn_gelu = torch._weight_int4pack_mm_with_scales_and_zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where was the signature with "tanh" defined?
does pytorch has a packed int4 gemm with post op?
@liangan1 CC |
@sunjiweiswift for the perf benchmarking, please include other configs expect M=1. This would serve as a reference of final decision making. I expect that big M would have worse perf, but that's fine, we still need to know the numbers. |
Pure SYCL path for. int4 gemm
Benchmark results on PVC-1100. The remaining gaps are lack of usage of 2D load.
Besides PVC, the kernel can achieve
92.7% bandwidth usage on MTL
84.7% bandwidth usage on A750