Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
XeTLA INT4 With BF16 Support (#311)
Browse files Browse the repository at this point in the history
* int4 with bf16 support

* rename quantmode
  • Loading branch information
DDEle authored Jul 25, 2024
1 parent b733d92 commit 9500cb2
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 42 deletions.
2 changes: 1 addition & 1 deletion include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };

enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
enum class quant_mode : uint8_t { I4_ASYM = 0, I4_SYM = 1 };

struct quant_info {
quant_mode quant_mode;
Expand Down
13 changes: 13 additions & 0 deletions include/common/core/explicit_conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ xetla_cvt(xetla_vector<T_src, N> src) {
return dst;
}

/// @brief xetla explicit data conversion, bf16->fp16.
/// @tparam T_dst is the float16 data type.
/// @tparam T_src is the bfloat16 data type.
/// @tparam N is the element number in xetla_vector.
template <typename T_dst, typename T_src, int N>
__XETLA_API typename std::enable_if_t<
std::is_same<T_dst, fp16>::value && std::is_same<T_src, bf16>::value,
xetla_vector<T_dst, N>>
xetla_cvt(xetla_vector<T_src, N> src) {
xetla_vector<T_dst, N> dst = src;
return dst;
}

/// @brief xetla explicit data conversion, bf16->fp32.
/// @tparam T_dst is the bfloat16 data type.
/// @tparam T_src is the float32 data type.
Expand Down
18 changes: 6 additions & 12 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,7 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -534,8 +533,7 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<update_dir_b>(
scale_t::tile_size_y);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down Expand Up @@ -564,8 +562,7 @@ class gemm_t<
// matB, matB_payload);
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
scale, scale_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
zero_pt, zero_pt_payload);
}
Expand All @@ -579,8 +576,7 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -593,8 +589,7 @@ class gemm_t<
if (tile_k_idx % scale_addr_update_freq == 0) {
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
}
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand All @@ -608,8 +603,7 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
scale_t::tile_size_y);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class gemm_universal_t<
/// @brief GEMM arguments.
/// This is the interface for users to pass the application-related runtime
/// variables.
template <quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP>
template <quant_mode quant_mode = quant_mode::I4_SYM>
struct arguments_t {
/// @brief Is the size of the m dimension of the matrix multiplication (m x
/// k x n).
Expand Down Expand Up @@ -295,7 +295,7 @@ class gemm_universal_t<
}
};
template <>
struct arguments_t<quant_mode::S4_FULLRANGE_NO_ZP> {
struct arguments_t<quant_mode::I4_SYM> {
/// @brief Is the size of the m dimension of the matrix multiplication (m x
/// k x n).
uint32_t matrix_m;
Expand Down Expand Up @@ -526,6 +526,10 @@ class gemm_universal_t<
template <quant_mode quant_mode>
static bool can_implement(arguments_t<quant_mode>& args) {
bool implementable = true;
if (arch_tag == gpu_arch::XeLpg) {
implementable &= !std::is_same_v<dtype_a, bf16>; // XeLpg arch dosen't
// have bf16 related isa.
}
if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
if (gemm_t::msg_type_a == msg_type::block_2d) {
implementable &= kernel::block_2d<arch_tag, dtype_a>::check_tensor(
Expand Down Expand Up @@ -566,8 +570,7 @@ class gemm_universal_t<
// check for int4x2
implementable &=
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
if constexpr (
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (gemm_t::compute_policy::quant_mode != quant_mode::I4_SYM) {
implementable &= (args.zero_pt_ld % pack_ratio == 0);
}

Expand Down Expand Up @@ -664,8 +667,7 @@ class gemm_universal_t<
uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride;
uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
gemm_args_t gemm_args;
if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) {
gemm_args = gemm_args_t(
mem_desc_a,
mem_desc_b,
Expand Down
4 changes: 2 additions & 2 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct dequant_int4_weight_t {
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
offset_x_in_tile;

if constexpr (quant_mode == quant_mode::S4_ASYM) {
if constexpr (quant_mode == quant_mode::I4_ASYM) {
uint32_t zero_pt_idx =
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
offset_x_in_tile / pack_ratio;
Expand All @@ -149,7 +149,7 @@ struct dequant_int4_weight_t {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
zero_pt_i8;
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
} else if constexpr (quant_mode == quant_mode::I4_SYM) {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/gemm/int4_dequantization/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) {
compute_attr_t<data_type_acc_in, data_type_acc_in, data_type_acc>;
using perf_tuning_knob = xetla::group::
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;

static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b};

static constexpr quant_info quant_info{
quant_mode::I4_ASYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) {
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;

static constexpr quant_info quant_info{
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
quant_mode::I4_SYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down Expand Up @@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd);
INSTANTIATE_TYPED_TEST_SUITE_P(
dequantize_gemm_act_shuf_test_suite,
dequantize_gemm_act_shuf_test,
tests);
tests);
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) {
using perf_tuning_knob = xetla::group::
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
static constexpr quant_info quant_info{
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
quant_mode::I4_SYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down
45 changes: 29 additions & 16 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ constexpr int ITER = 200;
#endif
constexpr size_t UNDEFINED_DATA_SIZE = 1024;

template <typename scalar_t>
class test_col_major_1 {
public:
// Extract the parameters required by different test cases
Expand All @@ -39,18 +40,18 @@ class test_col_major_1 {
static constexpr size_t sg_n = 1;
static constexpr size_t sg_k = 512 / sg_m;
static constexpr size_t dequant_s = 128;
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;

static constexpr size_t local_kslicing = 1;
static constexpr size_t global_kslicing = 1;
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::fpu;
static constexpr gpu_arch arch = gpu_arch::XeLpg;
using data_type_a = fp16;
using data_type_a = scalar_t;
using data_type_b = int4x8;
using data_type_c = fp16;
using data_type_c = scalar_t;
};
class test_col_major_2 {
public:
Expand Down Expand Up @@ -120,7 +121,7 @@ int gemm_result_validate(
}

template <
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
quant_mode quant_mode = quant_mode::I4_SYM,
typename data_type_acc_in = fp16,
typename data_type_b,
typename data_type_scale,
Expand All @@ -134,7 +135,7 @@ std::vector<fp16> convert_int4(
int8_t zero_pt_i8 = zero_pt & 0xf;
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
int8_t dequant_8bit = data_b & 0xf;
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (quant_mode == quant_mode::I4_SYM) {
dequant_fp16[i] = scale * (dequant_8bit - 8);
} else {
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
Expand All @@ -147,7 +148,7 @@ std::vector<fp16> convert_int4(
template <
size_t dequant_s,
mem_layout layout_b = mem_layout::col_major,
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
quant_mode quant_mode = quant_mode::I4_SYM,
typename data_type_acc_in = fp16,
typename data_type_b,
typename data_type_scale,
Expand All @@ -173,11 +174,11 @@ std::vector<data_type_acc_in> dequantize_weight(
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
int start_out =
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
zp_value = zp_value >> (4 * (i % pack_radio));
for (uint32_t jj = 0; jj < step; jj++) {
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
b[start_b_in + jj],
scale[start_scale_in],
zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
b[start_b_in + jj], scale[start_scale_in], zp_value);
for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) {
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
}
Expand Down Expand Up @@ -474,7 +475,7 @@ void dequantize_gemv_run(int iter) {
// It accepts the base pointer to matrix D, and its dimensions
{bias_d, bias_add_shape}});
typename gemm_op_t::template arguments_t<compute_policy::quant_mode> gemm_arg;
if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode == quant_mode::I4_SYM) {
gemm_arg =
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
matrix_m,
Expand All @@ -491,7 +492,7 @@ void dequantize_gemv_run(int iter) {
Acc_d,
Cnt_d,
epilogue_args);
} else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
} else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
gemm_arg =
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
matrix_m,
Expand Down Expand Up @@ -551,9 +552,11 @@ void dequantize_gemv_run(int iter) {
// performance
prof.print_profiling_result(profiling_selector::GPU);
// check result
std::vector<typename Test::data_type_a> dequantize_b =
dequantize_weight<dequant_s, layout_b, compute_policy::quant_mode>(
matrix_k, matrix_n, B_h, scale_h, zero_pt_h);
std::vector<typename Test::data_type_a> dequantize_b = dequantize_weight<
dequant_s,
layout_b,
compute_policy::quant_mode,
data_type_c>(matrix_k, matrix_n, B_h, scale_h, zero_pt_h);

queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait();
ASSERT_EQ(
Expand Down Expand Up @@ -585,6 +588,12 @@ void dequantize_gemv_run(int iter) {
free(Cnt_d, context);
}

// Placeholder for void test param
template <>
void dequantize_gemv_run<void>(int) {
GTEST_SKIP();
}

template <typename T>
class dequantize_gemv_test : public ::testing::Test {};
TYPED_TEST_SUITE_P(dequantize_gemv_test);
Expand All @@ -594,7 +603,11 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {
}

REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd);
using tests = ::testing::Types<test_col_major_1>;
using tests = ::testing::Types< //
test_col_major_1<fp16>,
test_col_major_1<bf16>,
// test_col_major_2,
void>;

INSTANTIATE_TYPED_TEST_SUITE_P(
dequantize_gemv_test_suite,
Expand Down

0 comments on commit 9500cb2

Please sign in to comment.