From b0efdf40ba60ac4821e2648c53c4baac1b2299ee Mon Sep 17 00:00:00 2001 From: "Wang, Zhe" Date: Tue, 27 Aug 2024 07:36:52 +0000 Subject: [PATCH] Int4 dequantize kernel (#313) * add int4 dequantzie kernel * Sync ipex * xetla int4 dequantize kernel remove sw barrier --------- Co-authored-by: Ding, Yi1 --- include/common/core/math_general.hpp | 8 +- include/common/utils/memory_descriptor.hpp | 24 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 3 +- .../kernel/int4_dequantize/api.hpp | 51 +++ .../kernel/int4_dequantize/config.hpp | 50 +++ .../int4_dequantize/int4_dequantize.hpp | 24 ++ .../int4_dequantize_xe_impl.hpp | 285 +++++++++++++ include/experimental/kernel/kernel.hpp | 1 + include/subgroup/tile/impl/load_xe.hpp | 34 +- include/subgroup/tile/impl/payload_xe.hpp | 195 ++++++--- .../subgroup/tile/impl/tile_op_functor.hpp | 9 +- tests/integration/CMakeLists.txt | 1 + .../int4_dequantize/CMakeLists.txt | 6 + tests/integration/int4_dequantize/main.cpp | 394 ++++++++++++++++++ 14 files changed, 986 insertions(+), 99 deletions(-) create mode 100644 include/experimental/kernel/int4_dequantize/api.hpp create mode 100644 include/experimental/kernel/int4_dequantize/config.hpp create mode 100644 include/experimental/kernel/int4_dequantize/int4_dequantize.hpp create mode 100644 include/experimental/kernel/int4_dequantize/int4_dequantize_xe_impl.hpp create mode 100644 tests/integration/int4_dequantize/CMakeLists.txt create mode 100644 tests/integration/int4_dequantize/main.cpp diff --git a/include/common/core/math_general.hpp b/include/common/core/math_general.hpp index 54f4e1a2f..013ac017f 100644 --- a/include/common/core/math_general.hpp +++ b/include/common/core/math_general.hpp @@ -460,7 +460,9 @@ __XETLA_API T xetla_rsqrt(T src, Sat sat = {}) { template __XETLA_API xetla_vector xetla_tanh(xetla_vector src) { static_assert( - std::is_same, float>::value, "Only support fp32! "); + (std::is_same, float>::value) || + (std::is_same, fp16>::value), + "Only support fp32 and fp16"); constexpr uint32_t flag_elems = 8 * 16; xetla_vector ret; if constexpr (SZ / flag_elems > 0) { @@ -502,7 +504,9 @@ __XETLA_API xetla_vector xetla_tanh(xetla_vector src) { template __XETLA_API T xetla_tanh(T src) { static_assert( - std::is_same, float>::value, "Only support fp32! "); + (std::is_same, float>::value) || + (std::is_same, fp16>::value), + "Only support fp32 and fp16"); T exp2x = xetla_exp(src * 2.f); T ret = (exp2x - 1.f) / (exp2x + 1.f); return (src >= 10) ? 1 : ret; diff --git a/include/common/utils/memory_descriptor.hpp b/include/common/utils/memory_descriptor.hpp index de2fc4f3c..06b9cb2db 100644 --- a/include/common/utils/memory_descriptor.hpp +++ b/include/common/utils/memory_descriptor.hpp @@ -142,20 +142,34 @@ struct mem_base_t { } }; +// Memory descriptor template < typename dtype_, mem_layout layout_, mem_space space_, - uint32_t alignment_ = 16, + uint32_t alignment_ = 8, + bool use_mask_ = false, int dim_ = 2> struct mem_desc_t {}; +// Alias of mem_desc_t for data with non-divisible shape and requires primitives +// with masks to load correctly +template < + typename dtype_, + mem_layout layout_, + mem_space space_, + uint32_t alignment_ = 16, + int dim_ = 2> +using mem_mask_desc_t = + mem_desc_t; + template < typename dtype_, mem_layout layout_, mem_space space_, - uint32_t alignment_> -struct mem_desc_t { + uint32_t alignment_, + bool use_mask_> +struct mem_desc_t { using dtype = dtype_; static constexpr mem_layout layout = layout_; static constexpr mem_space space = space_; @@ -165,11 +179,13 @@ struct mem_desc_t { static constexpr bool is_col_major = layout == mem_layout::col_major; static constexpr bool is_local = space == mem_space::local; + static constexpr bool use_mask = use_mask_; using shape_t = mem_shape_t; using coord_t = mem_coord_t; using base_t = mem_base_t; - using this_type_t = mem_desc_t; + using this_type_t = + mem_desc_t; inline mem_desc_t() = default; inline mem_desc_t(base_t base_, shape_t shape_, coord_t coord_) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index b078301eb..7e566038a 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -498,8 +498,7 @@ class gemm_t< wg_start_m = args.matA_base_desc.coord.y; wg_start_n = args.scale_base_desc.coord.x; wg_start_k = args.matA_base_desc.coord.x; - typename dequantize_t::arguments_t dequantize_args{ - wg_start_m, wg_start_n, wg_start_k}; + typename dequantize_t::arguments_t dequantize_args{wg_start_n, wg_start_k}; dequantize_t dequantize; xetla_nbarrier_t nbarrier_a; diff --git a/include/experimental/kernel/int4_dequantize/api.hpp b/include/experimental/kernel/int4_dequantize/api.hpp new file mode 100644 index 000000000..1ad1b2db1 --- /dev/null +++ b/include/experimental/kernel/int4_dequantize/api.hpp @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include + +namespace gpu::xetla::kernel { + +/// @brief +/// +/// @tparam dtype_qweight_ qweight data type. +/// @tparam dtype_scale_ scale data type. +/// @tparam dtype_zp_ zero point data +/// @tparam dtype_dequant_weight_ dequant_weight data type. +/// @tparam mem_layout_dequant_weight_ dequant_weight memory layout. +/// @tparam quant_info quant_mode, blocksize, qweight_layout info. +/// @tparam int4_dequantize_attr_ parallel-related attribute. +/// @tparam arch_ HW architecture. +template < + typename dtype_qweight_, + typename dtype_scale_, + typename dtype_zp_, + typename dtype_dequant_weight_, + mem_layout mem_layout_qweight_, + mem_layout mem_layout_scale_, + mem_layout mem_layout_zp_, + mem_layout mem_layout_dequant_weight_, + quant_info quant_info_, + typename int4_dequantize_attr_, + gpu_arch arch_, + typename enable = void> +struct int4_dequantize_t {}; + +} // namespace gpu::xetla::kernel diff --git a/include/experimental/kernel/int4_dequantize/config.hpp b/include/experimental/kernel/int4_dequantize/config.hpp new file mode 100644 index 000000000..819d546ee --- /dev/null +++ b/include/experimental/kernel/int4_dequantize/config.hpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include +#include + +namespace gpu::xetla::kernel { + +/// @brief Sets up attribute of the int4 dequantize. +/// +/// @tparam wg_tile_n_ Is the N-dim of KxN weight processed by one workgroup. +/// @tparam wg_tile_k_ Is the K-dim of KxN weight processed by one workgroup. +/// @tparam sg_tile_n_ Is the N-dim of KxN weight processed by one subgroup. +/// @tparam sg_tile_k_ Is the K-dim of KxN weight processed by one subgroup. +/// @tparam load_block_size_ Is the size of block when load x dimenstion. +/// kernels have spills. +template < + uint32_t wg_tile_n_, + uint32_t wg_tile_k_, + uint32_t sg_tile_n_, + uint32_t sg_tile_k_, + uint32_t k_stride_> +struct int4_dequantize_attr_t { + static constexpr uint32_t wg_tile_n = wg_tile_n_; + static constexpr uint32_t wg_tile_k = wg_tile_k_; + static constexpr uint32_t sg_tile_n = sg_tile_n_; + static constexpr uint32_t sg_tile_k = sg_tile_k_; + static constexpr uint32_t k_stride = k_stride_; +}; + +} // namespace gpu::xetla::kernel diff --git a/include/experimental/kernel/int4_dequantize/int4_dequantize.hpp b/include/experimental/kernel/int4_dequantize/int4_dequantize.hpp new file mode 100644 index 000000000..b09e2e7fa --- /dev/null +++ b/include/experimental/kernel/int4_dequantize/int4_dequantize.hpp @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include +#include diff --git a/include/experimental/kernel/int4_dequantize/int4_dequantize_xe_impl.hpp b/include/experimental/kernel/int4_dequantize/int4_dequantize_xe_impl.hpp new file mode 100644 index 000000000..2ff0523fa --- /dev/null +++ b/include/experimental/kernel/int4_dequantize/int4_dequantize_xe_impl.hpp @@ -0,0 +1,285 @@ +/******************************************************************************* + * Copyright (c) 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include + +namespace gpu::xetla::kernel { +template < + typename dtype_qweight_, + typename dtype_scale_, + typename dtype_zp_, + typename dtype_dequant_weight_, + mem_layout mem_layout_qweight_, + mem_layout mem_layout_scale_, + mem_layout mem_layout_zp_, + mem_layout mem_layout_dequant_weight_, + quant_info quant_info_, + typename int4_dequantize_attr_, + gpu_arch arch_> +struct int4_dequantize_t< + dtype_qweight_, + dtype_scale_, + dtype_zp_, + dtype_dequant_weight_, + mem_layout_qweight_, + mem_layout_scale_, + mem_layout_zp_, + mem_layout_dequant_weight_, + quant_info_, + int4_dequantize_attr_, + arch_> { + static_assert( + mem_layout_qweight_ == mem_layout::col_major, + "only support col_major qweight now."); + static_assert( + mem_layout_scale_ == mem_layout::col_major, + "only support col_major scale now."); + static_assert( + mem_layout_zp_ == mem_layout::row_major, + "only support row_major zp now."); + static_assert( + mem_layout_dequant_weight_ == mem_layout::row_major, + "only support row_major dequant_weight now."); + + static constexpr uint32_t dequant_s = quant_info_.dequant_s; + static constexpr uint32_t pack_ratio = sizeof(dtype_qweight_) * 2; + static constexpr uint32_t wg_tile_n = int4_dequantize_attr_::wg_tile_n; + static constexpr uint32_t wg_tile_k = int4_dequantize_attr_::wg_tile_k; + static constexpr uint32_t sg_tile_n = int4_dequantize_attr_::sg_tile_n; + static constexpr uint32_t sg_tile_k = int4_dequantize_attr_::sg_tile_k; + static constexpr uint32_t k_stride = int4_dequantize_attr_::k_stride; + + static_assert( + wg_tile_n % sg_tile_n == 0, + "wg_tile_n must be multiple of sg_tile_n"); + static_assert( + wg_tile_k % sg_tile_k == 0, + "wg_tile_k must be multiple of sg_tile_k"); + static_assert( + sg_tile_k % k_stride == 0, + "sg_tile_k must be multiple of k_stride"); + + using mem_desc_qweight_t = mem_desc_t< + dtype_qweight_, + mem_layout_qweight_, + mem_space::global, + 64 / sizeof(dtype_qweight_)>; + using mem_desc_scale_t = mem_desc_t< + dtype_scale_, + mem_layout_scale_, + mem_space::global, + 64 / sizeof(dtype_scale_)>; + using mem_desc_zp_t = mem_desc_t< + dtype_zp_, + mem_layout_zp_, + mem_space::global, + 64 / sizeof(dtype_zp_)>; + using mem_desc_dequant_weight_t = mem_desc_t< + dtype_dequant_weight_, + mem_layout_dequant_weight_, + mem_space::global, + 64 / sizeof(dtype_dequant_weight_)>; + + struct arguments_t { + uint32_t matrix_k; + uint32_t matrix_n; + dtype_qweight_* qweight_base; + dtype_scale_* scale_base; + dtype_zp_* zp_base; + dtype_dequant_weight_* dequant_weight_base; + uint32_t qweight_ld; + uint32_t dequant_weight_ld; + uint32_t scale_ld; + uint32_t zp_ld; + }; + + static cl::sycl::range<3> get_local_range() { + uint32_t local_range_k = (wg_tile_k + sg_tile_k - 1) / sg_tile_k; + uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n; + XETLA_PRINTF("Local range: {%d, %d, %d}", 1, local_range_k, local_range_n); + return cl::sycl::range<3>{1, local_range_k, local_range_n}; + }; + + static cl::sycl::range<3> get_group_range( + uint32_t matrix_k, + uint32_t matrix_n) { + uint32_t group_range_k = (matrix_k + wg_tile_k - 1) / wg_tile_k; + uint32_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n; + XETLA_PRINTF("Group range: {%d, %d, %d}", 1, group_range_k, group_range_n); + return cl::sycl::range<3>{1, group_range_k, group_range_n}; + }; + + static cl::sycl::nd_range<3> get_nd_range(arguments_t& args) { + cl::sycl::range<3> local_range = get_local_range(); + cl::sycl::range<3> group_range = + get_group_range(args.matrix_k, args.matrix_n); + return cl::sycl::nd_range<3>{group_range * local_range, local_range}; + }; + + using mat_qweight_tile_desc_t = subgroup::tile_desc_t< + sg_tile_n, // always N-dim + k_stride / pack_ratio, // always K-dim + sg_tile_n, // will be y-tile-dim in col-major qweight. + k_stride / pack_ratio, // will be x-tile-dim in col-major qweight. + reg_layout::tiled>; + + using mat_dequant_weight_tile_desc_t = subgroup:: + tile_desc_t; + + static constexpr uint32_t block_size_y_scale = + (k_stride + dequant_s - 1) / dequant_s; + + using scale_tile_desc_t = subgroup::tile_desc_t< + sg_tile_n, + block_size_y_scale, + sg_tile_n, + block_size_y_scale, + reg_layout::transpose_tiled>; + using zp_tile_desc_t = subgroup::tile_desc_t< + (sg_tile_n + pack_ratio - 1) / pack_ratio, + block_size_y_scale, + (sg_tile_n + pack_ratio - 1) / pack_ratio, + block_size_y_scale>; + + using mat_qweight_t = + subgroup::tile_t; + using mat_dequant_weight_t = + subgroup::tile_t; + using scale_t = subgroup::tile_t; + using zp_t = subgroup::tile_t; + + // block-wise load, will trade block_size_y as bytes per row block with + // col-major weight. + using mat_qweight_payload_t = subgroup::mem_payload_t< + mem_desc_qweight_t, + mat_qweight_tile_desc_t, + subgroup::msg_type_v, + arch_>; + using mat_dequant_weight_payload_t = subgroup::mem_payload_t< + mem_desc_dequant_weight_t, + mat_dequant_weight_tile_desc_t, + subgroup:: + msg_type_v, + arch_>; + using scale_payload_t = subgroup::mem_payload_t< + mem_desc_scale_t, + scale_tile_desc_t, + subgroup::msg_type_v, + arch_>; + using zp_payload_t = subgroup::mem_payload_t< + mem_desc_zp_t, + zp_tile_desc_t, + subgroup::msg_type_v, + arch_>; + using dequantize_t = subgroup::dequant_int4_weight_t< + mat_dequant_weight_t, + mat_qweight_t, + scale_t, + zp_t, + dequant_s, + quant_info_.quant_mode>; + + static constexpr uint32_t quant_factor_update_freq = + (k_stride < dequant_s) ? dequant_s / k_stride : 1; + __XETLA_API static void call( + sycl::nd_item<3>& item, + const arguments_t& args) { + int wg_id_n = item.get_group(2); + int wg_id_k = item.get_group(1); + int sg_id_n = item.get_local_id(2); + int sg_id_k = item.get_local_id(1); + int start_k = wg_id_k * wg_tile_k + sg_id_k * sg_tile_k; + int start_n = wg_id_n * wg_tile_n + sg_id_n * sg_tile_n; + int start_x_scale = start_n; + int start_y_scale = start_k / dequant_s; + int start_x_zp = start_n / pack_ratio; + int start_y_zp = start_k / dequant_s; + + mem_desc_qweight_t mem_desc_qweight( + args.qweight_base, + {start_n + sg_tile_n, // compressed KxN weight width(N) + start_k + sg_tile_k, // compressed KxN weight height(K) + args.qweight_ld / pack_ratio}, // compressed weight pitch + {start_n, + int(start_k / + pack_ratio)}); // compressed KxN weight offset_x, offset_y + mem_desc_dequant_weight_t mem_desc_dequant_weight( + args.dequant_weight_base, + {start_n + sg_tile_n, start_k + sg_tile_k, args.dequant_weight_ld}, + {start_n, start_k}); + uint32_t scale_size_y = ((args.matrix_k + dequant_s - 1) / dequant_s); + mem_desc_scale_t mem_desc_scale( + args.scale_base, + {args.matrix_n, scale_size_y, args.scale_ld}, + {start_x_scale, start_y_scale}); + mem_desc_zp_t mem_desc_zp( + args.zp_base, + {(args.matrix_n + pack_ratio - 1) / pack_ratio, + (args.matrix_k + dequant_s - 1) / dequant_s, + args.zp_ld / pack_ratio}, + {start_x_zp, start_y_zp}); + uint32_t k_dim_loop = sg_tile_k / k_stride; + + mat_qweight_t mat_qweight; + mat_dequant_weight_t mat_dequant_weight; + scale_t scale; + zp_t zp; + + mat_qweight_payload_t mat_qweight_payload(mem_desc_qweight); + mat_dequant_weight_payload_t mat_dequant_weight_payload( + mem_desc_dequant_weight); + scale_payload_t scale_payload(mem_desc_scale); + zp_payload_t zp_payload(mem_desc_zp); + typename dequantize_t::arguments_t dequantize_args(start_n, start_k); + dequantize_t dequantize; + int tile_k_idx = (start_k + k_stride - 1) / k_stride; +#pragma unroll + for (uint32_t i = 0; i < k_dim_loop; i++) { + subgroup::tile_load( + mat_qweight, mat_qweight_payload); + subgroup::tile_load( + scale, scale_payload); + if constexpr (quant_info_.quant_mode == quant_mode::I4_ASYM) { + subgroup::tile_load( + zp, zp_payload); + } + tile_k_idx++; + mat_qweight_payload.template update_tdesc( + mat_qweight_t::tile_size_y); + + if (tile_k_idx % quant_factor_update_freq == 0) { + scale_payload.template update_tdesc( + scale_t::tile_size_y); + if constexpr (quant_info_.quant_mode == quant_mode::I4_ASYM) { + zp_payload.template update_tdesc( + zp_t::tile_size_y); + } + } + dequantize(mat_dequant_weight, mat_qweight, scale, zp, dequantize_args); + tile_transpose(mat_dequant_weight); + subgroup::tile_store(mat_dequant_weight, mat_dequant_weight_payload); + mat_dequant_weight_payload.template update_tdesc( + mat_dequant_weight_t::tile_size_y); + } + }; +}; +} // namespace gpu::xetla::kernel diff --git a/include/experimental/kernel/kernel.hpp b/include/experimental/kernel/kernel.hpp index 00aea2dbb..fedc6b271 100644 --- a/include/experimental/kernel/kernel.hpp +++ b/include/experimental/kernel/kernel.hpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 01b1ac65c..c793e2acb 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -455,6 +455,9 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t load_elems = num_channel * payload_t::vector_size; constexpr uint32_t pack_factor = payload_t::pack_factor; const xetla_vector reg_zeros(0); + constexpr uint32_t block_height = payload_t::mem_transpose + ? tile_desc::block_size_x + : tile_desc::block_size_y; auto channel_offset = payload.channel_offset + payload.base_offset; #pragma unroll @@ -466,9 +469,7 @@ tile_load(tile_t& tile, payload_t& payload) { auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); #pragma unroll - for (uint32_t sub_block_offset = 0; sub_block_offset < - (payload_t::mem_transpose ? tile_desc::block_size_x - : tile_desc::block_size_y); + for (uint32_t sub_block_offset = 0; sub_block_offset < block_height; sub_block_offset += num_channel) { xetla_vector reg_tmp = 0; uint32_t address_offset = payload_t::mem_transpose @@ -477,20 +478,19 @@ tile_load(tile_t& tile, payload_t& payload) { : offset_x * sizeof(dtype) + (offset_y + sub_block_offset) * payload.pitch_in_bytes; xetla_mask mask = 1; - if constexpr (num_channel > 1) { - // For SDP load, need pred + if constexpr (payload_t::use_mask) { + // For SDP load, need mask const uint32_t sub_block_offset_x = payload.base_x + offset_x + (payload_t::mem_transpose ? sub_block_offset : 0); const uint32_t sub_block_offset_y = payload.base_y + offset_y + (payload_t::mem_transpose ? 0 : sub_block_offset); - const auto offset_ch_dim = - payload_t::trans ? sub_block_offset_x : sub_block_offset_y; - const auto size_ch_dim = payload_t::trans ? payload.width_in_elems - : payload.height_in_elems; + const auto offset_ch_dim = payload_t::mem_transpose + ? sub_block_offset_x + : sub_block_offset_y; - mask = offset_ch_dim + num_channel > size_ch_dim + mask = offset_ch_dim + num_channel > payload.height_in_elems ? (xetla_vector_gen(offset_ch_dim, 1) < - size_ch_dim) + payload.height_in_elems) : 1; reg_tmp = xetla_load_global< load_dtype, @@ -577,6 +577,9 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t load_elems = payload_t::mem_transpose ? tile_desc::block_size_y : tile_desc::block_size_x; + constexpr uint32_t block_height = payload_t::mem_transpose + ? tile_desc::block_size_x + : tile_desc::block_size_y; #pragma unroll for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { @@ -587,18 +590,15 @@ tile_load(tile_t& tile, payload_t& payload) { auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); #pragma unroll - for (uint32_t sub_block_y = 0; - sub_block_y < (payload_t::mem_transpose ? tile_desc::block_size_x - : tile_desc::block_size_y); - sub_block_y += 1) { + for (uint32_t sub_block_y = 0; sub_block_y < block_height; + sub_block_y++) { uint32_t address_offset = payload_t::mem_transpose ? (offset_x + sub_block_y) * payload.pitch_in_bytes + offset_y * sizeof(dtype) : offset_x * sizeof(dtype) + (offset_y + sub_block_y) * payload.pitch_in_bytes; - reg_sub.xetla_select( - sub_block_y * tile_desc::block_size_x) = + reg_sub.xetla_select(sub_block_y * load_elems) = xetla_load_global( (dtype*)payload.base_ptr, payload.base_offset + address_offset); } diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 2316ef23f..697aba49e 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -38,16 +38,17 @@ template < typename tile_desc_, mem_layout mem_layout_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + bool use_mask_> struct mem_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_, msg_type::block_2d, arch_tag_, std::enable_if_t>> { using tile_desc = tile_desc_; using mem_desc_t = - mem_desc_t; + mem_desc_t; using dtype = dtype_; static constexpr msg_type message_type = msg_type::block_2d; static constexpr mem_space memory_space = mem_space::global; @@ -400,15 +401,25 @@ template < typename tile_desc_, gpu_arch arch_tag_, uint32_t alignment_, - mem_layout memory_layout_> + mem_layout memory_layout_, + bool use_mask_> struct mem_payload_t< - mem_desc_t, + mem_desc_t< + dtype_, + memory_layout_, + mem_space::global, + alignment_, + use_mask_>, tile_desc_, msg_type::block_1d, arch_tag_, std::enable_if_t>> { - using mem_desc_t = - mem_desc_t; + using mem_desc_t = mem_desc_t< + dtype_, + memory_layout_, + mem_space::global, + alignment_, + use_mask_>; using dtype = native_type_t; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; @@ -441,8 +452,8 @@ struct mem_payload_t< inline mem_payload_t(mem_desc_t& mem_tdesc) { pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); - width_in_elems = mem_tdesc.shape.x; - height_in_elems = mem_tdesc.shape.y; + width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x; + height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y; payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes + mem_tdesc.shape.y * sizeof(dtype) : (mem_tdesc.shape.y - 1) * pitch_in_bytes + @@ -480,8 +491,8 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; - width_in_elems = mem_tdesc.shape.x; - height_in_elems = mem_tdesc.shape.y; + width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x; + height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y; payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes + mem_tdesc.shape.y * sizeof(dtype) : (mem_tdesc.shape.y - 1) * pitch_in_bytes + @@ -554,15 +565,25 @@ template < typename dtype_, typename tile_desc_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + bool use_mask_> struct mem_payload_t< - mem_desc_t, + mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::global, + alignment_, + use_mask_>, tile_desc_, msg_type::atomic_add, arch_tag_, std::enable_if_t= sizeof(uint16_t)>> { - using mem_desc_t = - mem_desc_t; + using mem_desc_t = mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::global, + alignment_, + use_mask_>; using dtype = dtype_; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; @@ -736,15 +757,25 @@ template < typename dtype_, typename tile_desc_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + bool use_mask_> struct mem_payload_t< - mem_desc_t, + mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::local, + alignment_, + use_mask_>, tile_desc_, msg_type::block_1d, arch_tag_, std::enable_if_t>> { - using mem_desc_t = - mem_desc_t; + using mem_desc_t = mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::local, + alignment_, + use_mask_>; using dtype = dtype_; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::local; @@ -867,16 +898,17 @@ template < typename tile_desc_, mem_layout mem_layout_, uint32_t alignment_, + bool use_mask_, gpu_arch arch_tag_> struct mem_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_, msg_type::unaligned_2d, arch_tag_, std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = - mem_desc_t; + mem_desc_t; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; @@ -949,8 +981,8 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); base_x = mem_tdesc.coord.x; base_y = mem_tdesc.coord.y; - width_in_elems = mem_tdesc.shape.x; - height_in_elems = mem_tdesc.shape.y; + width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x; + height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y; base_offset = mem_transpose ? base_x * pitch_in_bytes + base_y * sizeof(dtype) : base_y * pitch_in_bytes + base_x * sizeof(dtype); @@ -995,8 +1027,8 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); base_x = mem_tdesc.coord.x; base_y = mem_tdesc.coord.y; - width_in_elems = mem_tdesc.shape.x; - height_in_elems = mem_tdesc.shape.y; + width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x; + height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y; base_offset = mem_transpose ? base_x * pitch_in_bytes + base_y * sizeof(dtype) : base_y * pitch_in_bytes + base_x * sizeof(dtype); @@ -1094,22 +1126,24 @@ template < typename tile_desc_, mem_layout mem_layout_, uint32_t alignment_, + bool use_mask_, gpu_arch arch_tag_> struct mem_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_, msg_type::block_2d, arch_tag_, std::enable_if_t>> { using dtype = native_type_t; using mem_desc_t = - mem_desc_t; + mem_desc_t; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; static constexpr msg_type message_type = msg_type::block_2d; static constexpr uint32_t alignment_in_bytes = mem_desc_t::alignment_in_bytes; static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr bool use_mask = mem_desc_t::use_mask; private: static constexpr uint32_t tile_size_x = tile_desc::tile_size_x; @@ -1140,13 +1174,13 @@ struct mem_payload_t< alignment_in_bytes); // using mem_dtype = uint32_t; - using mem_dtype = typename std::conditional< + using mem_dtype = typename std::conditional_t< (block_per_row_bytes % sizeof(uint64_t) == 0), uint64_t, - typename std::conditional< + typename std::conditional_t< (block_per_row_bytes % sizeof(uint32_t) == 0), uint32_t, - dtype>::type>::type; + dtype>>; static constexpr uint32_t pack_factor = sizeof(mem_dtype) / sizeof(dtype); static constexpr uint32_t vector_size = @@ -1191,8 +1225,8 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); base_x = mem_tdesc.coord.x; base_y = mem_tdesc.coord.y; - width_in_elems = mem_tdesc.shape.x; - height_in_elems = mem_tdesc.shape.y; + width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x; + height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y; base_offset = mem_transpose ? base_x * pitch_in_bytes + base_y * sizeof(dtype) : base_y * pitch_in_bytes + base_x * sizeof(dtype); @@ -1230,8 +1264,8 @@ struct mem_payload_t< base_x = mem_tdesc.coord.x; base_y = mem_tdesc.coord.y; - width_in_elems = mem_tdesc.shape.x; - height_in_elems = mem_tdesc.shape.y; + width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x; + height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y; base_offset = mem_transpose ? base_x * pitch_in_bytes + base_y * sizeof(dtype) : base_y * pitch_in_bytes + base_x * sizeof(dtype); @@ -1318,15 +1352,25 @@ template < typename dtype_, typename tile_desc_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + bool use_mask_> struct mem_payload_t< - mem_desc_t, + mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::local, + alignment_, + use_mask_>, tile_desc_, msg_type::scatter, arch_tag_, std::enable_if_t>> { - using mem_desc_t = - mem_desc_t; + using mem_desc_t = mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::local, + alignment_, + use_mask_>; using dtype = dtype_; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::local; @@ -1492,9 +1536,15 @@ template < uint32_t block_size_x_, uint32_t block_size_y_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + bool use_mask_> struct mem_payload_t< - mem_desc_t, + mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::local, + alignment_, + use_mask_>, tile_desc_t< tile_size_x_, tile_size_y_, @@ -1504,8 +1554,12 @@ struct mem_payload_t< msg_type::scatter, arch_tag_, std::enable_if_t>> { - using mem_desc_t = - mem_desc_t; + using mem_desc_t = mem_desc_t< + dtype_, + mem_layout::row_major, + mem_space::local, + alignment_, + use_mask_>; using dtype = dtype_; using tile_desc = tile_desc_t< tile_size_x_, @@ -1642,11 +1696,12 @@ template < uint32_t block_size_y_, mem_layout mem_layout_, uint32_t alignment_, + bool use_mask_, uint32_t num_coop_sg_, reg_layout reg_layout_, gpu_arch arch_tag_> struct prefetch_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_t< tile_size_x_, tile_size_y_, @@ -1662,7 +1717,7 @@ struct prefetch_payload_t< mem_layout_ == mem_layout::col_major))>> { using dtype = native_type_t; using mem_desc_t = - mem_desc_t; + mem_desc_t; using tile_desc = tile_desc_t< tile_size_x_, tile_size_y_, @@ -1697,31 +1752,34 @@ struct prefetch_payload_t< static constexpr bool trans = (mem_transpose ^ reg_transpose) && !(std::is_same_v || std::is_same_v); - using prefetch_dtype = typename std::conditional< + using prefetch_dtype = typename std::conditional_t< (alignment_in_bytes % (sizeof(uint64_t)) == 0), uint64_t, - typename std::conditional< + typename std::conditional_t< (alignment_in_bytes % (sizeof(uint32_t)) == 0), uint32_t, - dtype>::type>::type; + dtype>>; static constexpr uint32_t pack_factor = sizeof(prefetch_dtype) / sizeof(dtype); - static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype); - static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype); - static constexpr uint32_t simd_channel = - ((tile_bytes % max_store_bytes) == 0 && - (block_bytes % max_store_bytes) == 0) - ? 32 - : 16; - static constexpr uint32_t num_channel = mem_transpose - ? (simd_channel >= block_size_x) ? block_size_x : simd_channel - : (simd_channel >= block_size_y) ? block_size_y - : simd_channel; + static constexpr uint32_t vector_size = + ((mem_transpose ? block_size_y : block_size_x) + pack_factor - 1) / + pack_factor; + + // for pvc, we can use simd16 or simd32 + static constexpr uint32_t max_bytes = + load_store_attr_t::max_prefetch_vec_len; + // std::min(load_store_attr::max_load_vec_len, block_bytes); + + static constexpr uint32_t max_channel = + max_bytes / (vector_size * sizeof(prefetch_dtype)); - static constexpr uint32_t vector_size = mem_transpose - ? (block_size_y + pack_factor - 1) / pack_factor - : (block_size_x + pack_factor - 1) / pack_factor; + static constexpr uint32_t select_channel(const uint32_t channel) { + return channel >= 32 ? 32 : channel >= 16 ? 16 : channel >= 8 ? 8 : 1; + } + + static constexpr uint32_t num_channel = select_channel( + std::min(mem_transpose ? block_size_x : block_size_y, max_channel)); static constexpr uint32_t mem_tile_size_w = mem_transpose ? tile_size_y : tile_size_x; @@ -1888,11 +1946,12 @@ template < uint32_t block_size_y_, mem_layout mem_layout_, uint32_t alignment_, + bool use_mask_, uint32_t num_coop_sg_, reg_layout reg_layout_, gpu_arch arch_tag_> struct prefetch_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_t< tile_size_x_, tile_size_y_, @@ -1906,7 +1965,7 @@ struct prefetch_payload_t< ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = - mem_desc_t; + mem_desc_t; using tile_desc = tile_desc_t< tile_size_x_, tile_size_y_, @@ -2190,11 +2249,12 @@ template < uint32_t block_size_y_, mem_layout mem_layout_, uint32_t alignment_, + bool use_mask_, uint32_t num_coop_sg_, reg_layout reg_layout_, gpu_arch arch_tag_> struct prefetch_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_t< tile_size_x_, tile_size_y_, @@ -2208,7 +2268,7 @@ struct prefetch_payload_t< ((tile_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> { using dtype = dtype_; using mem_desc_t = - mem_desc_t; + mem_desc_t; // CL aligned, so we can use uint64_t using prefetch_dtype = uint64_t; static constexpr msg_type message_type = msg_type::block_1d; @@ -2326,18 +2386,19 @@ template < typename dtype_, typename tile_desc_, mem_layout mem_layout_, + bool use_mask_, uint32_t alignment_, uint32_t num_coop_sg_, gpu_arch arch_tag_> struct prefetch_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_, num_coop_sg_, arch_tag_, std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = - mem_desc_t; + mem_desc_t; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::local; static constexpr mem_layout memory_layout = mem_layout_; diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 6f2ffdc37..ab1f0038e 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -64,13 +64,8 @@ struct dequant_int4_weight_t { uint32_t wg_start_n; uint32_t wg_start_k; inline arguments_t() = default; - inline arguments_t( - uint32_t wg_start_m_, - uint32_t wg_start_n_, - uint32_t wg_start_k_) - : wg_start_m(wg_start_m_), - wg_start_n(wg_start_n_), - wg_start_k(wg_start_k_) {} + inline arguments_t(uint32_t wg_start_n_, uint32_t wg_start_k_) + : wg_start_n(wg_start_n_), wg_start_k(wg_start_k_) {} }; __XETLA_API KERNEL_FUNC void operator()( matB_acc_t& matB_acc, diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 47a93c980..c70f6e887 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -32,3 +32,4 @@ add_subdirectory(softmax) add_subdirectory(fmha) add_subdirectory(col_major_shuf) add_subdirectory(mlp) +add_subdirectory(int4_dequantize) diff --git a/tests/integration/int4_dequantize/CMakeLists.txt b/tests/integration/int4_dequantize/CMakeLists.txt new file mode 100644 index 000000000..af0970745 --- /dev/null +++ b/tests/integration/int4_dequantize/CMakeLists.txt @@ -0,0 +1,6 @@ +get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME) +string(REPLACE " " "_" ProjectId ${ProjectId}) + +FILE(GLOB src main.cpp) +add_integration_test(${ProjectId} ${src}) + diff --git a/tests/integration/int4_dequantize/main.cpp b/tests/integration/int4_dequantize/main.cpp new file mode 100644 index 000000000..a0ccbd329 --- /dev/null +++ b/tests/integration/int4_dequantize/main.cpp @@ -0,0 +1,394 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include +#include "xetla.hpp" +// #define UT_DEBUG +using namespace gpu::xetla; +using namespace gpu::xetla::group; +// The number of times the kernel is executed +#ifdef UT_DEBUG +constexpr int ITER = 1; +#else +constexpr int ITER = 200; +#endif +constexpr size_t UNDEFINED_DATA_SIZE = 512; + +class test_col_major_1 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_k = 128; + static constexpr size_t wg_n = 16; + static constexpr size_t sg_k = 128; + static constexpr size_t sg_n = 16; + static constexpr size_t k_stride = 32; + static constexpr size_t dequant_s = 128; + // static constexpr quant_mode quant_mode = quant_mode::I4_ASYM; + static constexpr quant_mode quant_mode = quant_mode::I4_SYM; + + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr gpu_arch arch = gpu_arch::XeHpg; + using data_type_b = int4x8; + using data_type_c = fp16; +}; + +template < + quant_mode quant_mode = quant_mode::I4_SYM, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::vector convert_int4( + data_type_b data_b, + data_type_scale scale, + data_type_zero_pt zero_pt) { + std::vector dequant_fp16(sizeof(data_type_b) * 2); + + int8_t zero_pt_i8 = zero_pt & 0xf; + for (uint32_t i = 0; i < dequant_fp16.size(); i++) { + int8_t dequant_8bit = data_b & 0xf; + if constexpr (quant_mode == quant_mode::I4_SYM) { + dequant_fp16[i] = scale * (dequant_8bit - 8); + } else { + dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); + } + data_b = data_b >> 4; + } + return dequant_fp16; +} + +template < + size_t dequant_s, + mem_layout layout_b = mem_layout::col_major, + quant_mode quant_mode = quant_mode::I4_SYM, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::vector dequantize_weight( + size_t matrix_k, + size_t matrix_n, + data_type_b* b, + data_type_scale* scale, + data_type_zero_pt* zero_pt) { + std::vector b_out(matrix_k * matrix_n, 0); + constexpr size_t pack_radio = 2 * sizeof(data_type_b); + size_t width = layout_b == mem_layout::row_major ? matrix_n / pack_radio + : matrix_k / pack_radio; + size_t height = layout_b == mem_layout::row_major ? matrix_k : matrix_n; + size_t step = layout_b == mem_layout::row_major ? 1 : dequant_s / pack_radio; + + for (uint32_t i = 0; i < height; i++) { + for (uint32_t j = 0; j < width; j += step) { + int start_b_in = i * width + j; + int start_scale_in = start_b_in / step; + int start_zero_pt_in = + (j / step) * (matrix_n / pack_radio) + i / pack_radio; + int start_out = + layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; + for (uint32_t jj = 0; jj < step; jj++) { + std::vector dequant_fp16 = convert_int4( + b[start_b_in + jj], + scale[start_scale_in], + zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio))); + for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { + b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; + } + } + } + } +#ifdef UT_DEBUG + // for (uint32_t i = 0; i < matrix_n; i++) { + // for (uint32_t j = 0; j < matrix_k; j++) { + // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + // } + // std::cout << std::endl; + // } +#endif + return b_out; +} + +template +int int4_dequantize_result_validate(T* gold, T* out, size_t k, size_t n) { + int err_num = 0; + for (uint32_t i = 0; i < k; i++) { + for (uint32_t j = 0; j < n; j++) { + if (gold[i * n + j] != out[i * n + j]) { + if (err_num < 10) + std::cout << i * n + j << " " << gold[i * n + j] << " " + << out[i * n + j] << std::endl; + err_num++; + } + } + } + if (err_num == 0) { + std::cout << "Test Passed!!!" << std::endl; + } + return err_num; +} + +template +void dequantize_run(int iter) { + using namespace gpu; + // Accept incoming parameters + constexpr size_t matrix_n = Test::mat_n; + constexpr size_t matrix_k = Test::mat_k; + + constexpr size_t wg_tile_n = Test::wg_n; + constexpr size_t wg_tile_k = Test::wg_k; + constexpr size_t sg_tile_n = Test::sg_n; + constexpr size_t sg_tile_k = Test::sg_k; + constexpr size_t k_stride = Test::k_stride; + constexpr size_t dequant_s = std::min(Test::dequant_s, matrix_k); + constexpr quant_mode quant_mode = Test::quant_mode; + using data_type_b = typename Test::data_type_b; + using data_type_c = typename Test::data_type_c; + using data_type_zero_pt = data_type_b; + using data_type_scale = fp16; + + constexpr mem_layout layout_b = Test::layout_b; + + constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b)); + + constexpr size_t size_scale_k = matrix_k / dequant_s; + constexpr size_t size_scale_n = matrix_n; + constexpr size_t size_scale = size_scale_k * size_scale_n; + + constexpr size_t size_zero_pt_k = matrix_k / dequant_s; + constexpr size_t size_zero_pt_n = matrix_n; + constexpr size_t size_zero_pt = + size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b)); + + constexpr size_t size_c = matrix_k * matrix_n; + + uint32_t ldb = layout_b == mem_layout::row_major ? matrix_n : matrix_k; + uint32_t ldc = matrix_n; + uint32_t ld_scale = + layout_b == mem_layout::row_major ? size_scale_n : size_scale_k; + uint32_t ld_zero_pt = size_zero_pt_n; + + // Turn on the enable_profiling property to facilitate subsequent profiling + sycl::property_list properties{ + sycl::property::queue::enable_profiling(), + sycl::property::queue::in_order()}; + auto queue = sycl::queue(properties); + auto context = queue.get_info(); + auto device = queue.get_info(); + + std::cout << "Running on " << device.get_info() << "\n"; + + using int4_dequantize_attr = gpu::xetla::kernel::int4_dequantize_attr_t< + wg_tile_n, + wg_tile_k, + sg_tile_n, + sg_tile_k, + k_stride>; + static constexpr quant_info q_info{quant_mode, Test::dequant_s, layout_b}; + using int4_dequantize_kernel = gpu::xetla::kernel::int4_dequantize_t< + data_type_b, + data_type_scale, + data_type_zero_pt, + data_type_c, + layout_b, + layout_b, + mem_layout::row_major, + mem_layout::row_major, + q_info, + int4_dequantize_attr, + Test::arch>; + + // Define and initialize the data required for the calculation + auto* B_h = static_cast(malloc_host( + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b), context)); + auto* C_h = static_cast( + malloc_host(size_c * sizeof(data_type_c), context)); + auto* scale_h = static_cast(malloc_host( + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale), context)); + auto* zero_pt_h = static_cast(malloc_host( + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt), + context)); + + auto* B_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b), + device, + context)); + auto* C_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context)); + auto* scale_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale), + device, + context)); + auto* zero_pt_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt), + device, + context)); + + for (unsigned i = 0; i < size_b + UNDEFINED_DATA_SIZE; ++i) { + if constexpr (std::is_same_v) { + B_h[i] = random_uint8(); +#ifdef UT_DEBUG + B_h[i] = 0x77; +#endif + } else if constexpr (std::is_same_v) { + B_h[i] = random_uint32(); +#ifdef UT_DEBUG + B_h[i] = i < size_b / 2 ? 0x77777777 : 0x66666666; +#endif + } + } + + for (unsigned i = 0; i < size_scale; ++i) { + scale_h[i] = random_float() + 1.f; +#ifdef UT_DEBUG + scale_h[i] = 1.f; +#endif + } + for (unsigned i = size_scale; i < size_scale + UNDEFINED_DATA_SIZE; ++i) { + scale_h[i] = INFINITY; + } + for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) { + if constexpr (std::is_same_v) { + zero_pt_h[i] = random_uint8(); +#ifdef UT_DEBUG + zero_pt_h[i] = 0x12 << i; +#endif + } else if constexpr (std::is_same_v) { + zero_pt_h[i] = random_uint32(); +#ifdef UT_DEBUG + zero_pt_h[i] = 0x33333333; +#endif + } + } + + for (unsigned i = 0; i < size_c; ++i) { + C_h[i] = random_float(); + } + + queue + .memcpy( + (void*)B_d, + (void*)B_h, + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b)) + .wait(); + queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait(); + queue + .memcpy( + (void*)scale_d, + (void*)scale_h, + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale)) + .wait(); + queue + .memcpy( + (void*)zero_pt_d, + (void*)zero_pt_h, + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt)) + .wait(); + + typename int4_dequantize_kernel::arguments_t args( + matrix_k, + matrix_n, + B_d, + scale_d, + zero_pt_d, + C_d, + ldb, + ldc, + ld_scale, + ld_zero_pt); + cl::sycl::nd_range<3> nd_range = int4_dequantize_kernel::get_nd_range(args); + + size_t bytes = matrix_n * matrix_k / 2 + + matrix_k * matrix_n * sizeof(data_type_c) + + size_scale * sizeof(data_type_scale); + if (Test::quant_mode == quant_mode::I4_ASYM) + bytes += size_zero_pt * sizeof(data_type_zero_pt); + profiling_helper prof("int4_dequantize kernel bandwidth", bytes, "GB/s"); +#ifdef UT_DEBUG + int constexpr warm = 0; +#else + int constexpr warm = 100; +#endif + try { + for (int i = 0; i < iter + warm; i++) { + if (i >= warm) + prof.cpu_start(); + auto e_esimd = queue.submit([&](handler& cgh) { + cgh.parallel_for(nd_range, [=](nd_item<3> item) SYCL_ESIMD_KERNEL { + // allocate slm and nbarrier resource + int4_dequantize_kernel::call(item, args); + }); + }); + if (i >= warm) { + e_esimd.wait(); + prof.cpu_end(); + prof.add_gpu_event(e_esimd); + } + } + } catch (cl::sycl::exception const& e) { + std::cout << "SYCL exception caught: " << e.what() << '\n'; + FAIL(); + } + + // performance + prof.print_profiling_result(profiling_selector::GPU); + // check result + std::vector dequantize_b = + dequantize_weight( + matrix_k, matrix_n, B_h, scale_h, zero_pt_h); + std::vector trans_dq_b; + trans_dq_b.resize(matrix_n * matrix_k); + // transpose dq b + for (size_t i = 0; i < matrix_n; i++) { + for (size_t j = 0; j < matrix_k; j++) { + trans_dq_b[j * matrix_n + i] = dequantize_b[i * matrix_k + j]; + } + } + + queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); + + ASSERT_EQ( + 0, + (int4_dequantize_result_validate( + trans_dq_b.data(), C_h, Test::mat_k, Test::mat_n))); + + free(B_h, context); + free(C_h, context); + free(scale_h, context); + free(zero_pt_h, context); + free(B_d, context); + free(C_d, context); + free(scale_d, context); + free(zero_pt_d, context); +} + +template +class dequantize_test : public ::testing::Test {}; +TYPED_TEST_SUITE_P(dequantize_test); + +TYPED_TEST_P(dequantize_test, esimd) { + dequantize_run(ITER); +} + +REGISTER_TYPED_TEST_SUITE_P(dequantize_test, esimd); +using tests = ::testing::Types; + +INSTANTIATE_TYPED_TEST_SUITE_P(dequantize_test_suite, dequantize_test, tests);