Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Int4 dequantize kernel (#313)
Browse files Browse the repository at this point in the history
* add int4 dequantzie kernel

* Sync ipex

* xetla int4 dequantize kernel remove sw barrier

---------

Co-authored-by: Ding, Yi1 <yi1.ding@intel.com>
  • Loading branch information
zhewang1-intc and DDEle authored Aug 27, 2024
1 parent f7712e0 commit b0efdf4
Show file tree
Hide file tree
Showing 14 changed files with 986 additions and 99 deletions.
8 changes: 6 additions & 2 deletions include/common/core/math_general.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ __XETLA_API T xetla_rsqrt(T src, Sat sat = {}) {
template <typename T, int SZ>
__XETLA_API xetla_vector<T, SZ> xetla_tanh(xetla_vector<T, SZ> src) {
static_assert(
std::is_same<remove_const_t<T>, float>::value, "Only support fp32! ");
(std::is_same<remove_const_t<T>, float>::value) ||
(std::is_same<remove_const_t<T>, fp16>::value),
"Only support fp32 and fp16");
constexpr uint32_t flag_elems = 8 * 16;
xetla_vector<T, SZ> ret;
if constexpr (SZ / flag_elems > 0) {
Expand Down Expand Up @@ -502,7 +504,9 @@ __XETLA_API xetla_vector<T, SZ> xetla_tanh(xetla_vector<T, SZ> src) {
template <typename T>
__XETLA_API T xetla_tanh(T src) {
static_assert(
std::is_same<remove_const_t<T>, float>::value, "Only support fp32! ");
(std::is_same<remove_const_t<T>, float>::value) ||
(std::is_same<remove_const_t<T>, fp16>::value),
"Only support fp32 and fp16");
T exp2x = xetla_exp<T>(src * 2.f);
T ret = (exp2x - 1.f) / (exp2x + 1.f);
return (src >= 10) ? 1 : ret;
Expand Down
24 changes: 20 additions & 4 deletions include/common/utils/memory_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,34 @@ struct mem_base_t<dtype_, mem_space::local> {
}
};

// Memory descriptor
template <
typename dtype_,
mem_layout layout_,
mem_space space_,
uint32_t alignment_ = 16,
uint32_t alignment_ = 8,
bool use_mask_ = false,
int dim_ = 2>
struct mem_desc_t {};

// Alias of mem_desc_t for data with non-divisible shape and requires primitives
// with masks to load correctly
template <
typename dtype_,
mem_layout layout_,
mem_space space_,
uint32_t alignment_ = 16,
int dim_ = 2>
using mem_mask_desc_t =
mem_desc_t<dtype_, layout_, space_, alignment_, true, dim_>;

template <
typename dtype_,
mem_layout layout_,
mem_space space_,
uint32_t alignment_>
struct mem_desc_t<dtype_, layout_, space_, alignment_, 2> {
uint32_t alignment_,
bool use_mask_>
struct mem_desc_t<dtype_, layout_, space_, alignment_, use_mask_, 2> {
using dtype = dtype_;
static constexpr mem_layout layout = layout_;
static constexpr mem_space space = space_;
Expand All @@ -165,11 +179,13 @@ struct mem_desc_t<dtype_, layout_, space_, alignment_, 2> {

static constexpr bool is_col_major = layout == mem_layout::col_major;
static constexpr bool is_local = space == mem_space::local;
static constexpr bool use_mask = use_mask_;
using shape_t = mem_shape_t<dim>;
using coord_t = mem_coord_t<dim>;
using base_t = mem_base_t<dtype, space>;

using this_type_t = mem_desc_t<dtype, layout_, space_, alignment, 2>;
using this_type_t =
mem_desc_t<dtype, layout_, space_, alignment, use_mask_, 2>;

inline mem_desc_t() = default;
inline mem_desc_t(base_t base_, shape_t shape_, coord_t coord_)
Expand Down
3 changes: 1 addition & 2 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,7 @@ class gemm_t<
wg_start_m = args.matA_base_desc.coord.y;
wg_start_n = args.scale_base_desc.coord.x;
wg_start_k = args.matA_base_desc.coord.x;
typename dequantize_t::arguments_t dequantize_args{
wg_start_m, wg_start_n, wg_start_k};
typename dequantize_t::arguments_t dequantize_args{wg_start_n, wg_start_k};
dequantize_t dequantize;

xetla_nbarrier_t<wg_size_x, wg_size_x, arch_tag> nbarrier_a;
Expand Down
51 changes: 51 additions & 0 deletions include/experimental/kernel/int4_dequantize/api.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

/// @file
/// C++ API

#pragma once

#include <experimental/kernel/int4_dequantize/config.hpp>

namespace gpu::xetla::kernel {

/// @brief
///
/// @tparam dtype_qweight_ qweight data type.
/// @tparam dtype_scale_ scale data type.
/// @tparam dtype_zp_ zero point data
/// @tparam dtype_dequant_weight_ dequant_weight data type.
/// @tparam mem_layout_dequant_weight_ dequant_weight memory layout.
/// @tparam quant_info quant_mode, blocksize, qweight_layout info.
/// @tparam int4_dequantize_attr_ parallel-related attribute.
/// @tparam arch_ HW architecture.
template <
typename dtype_qweight_,
typename dtype_scale_,
typename dtype_zp_,
typename dtype_dequant_weight_,
mem_layout mem_layout_qweight_,
mem_layout mem_layout_scale_,
mem_layout mem_layout_zp_,
mem_layout mem_layout_dequant_weight_,
quant_info quant_info_,
typename int4_dequantize_attr_,
gpu_arch arch_,
typename enable = void>
struct int4_dequantize_t {};

} // namespace gpu::xetla::kernel
50 changes: 50 additions & 0 deletions include/experimental/kernel/int4_dequantize/config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

/// @file
/// C++ API

#pragma once

#include <common/common.hpp>
#include <group/group.hpp>
#include <subgroup/subgroup.hpp>

namespace gpu::xetla::kernel {

/// @brief Sets up attribute of the int4 dequantize.
///
/// @tparam wg_tile_n_ Is the N-dim of KxN weight processed by one workgroup.
/// @tparam wg_tile_k_ Is the K-dim of KxN weight processed by one workgroup.
/// @tparam sg_tile_n_ Is the N-dim of KxN weight processed by one subgroup.
/// @tparam sg_tile_k_ Is the K-dim of KxN weight processed by one subgroup.
/// @tparam load_block_size_ Is the size of block when load x dimenstion.
/// kernels have spills.
template <
uint32_t wg_tile_n_,
uint32_t wg_tile_k_,
uint32_t sg_tile_n_,
uint32_t sg_tile_k_,
uint32_t k_stride_>
struct int4_dequantize_attr_t {
static constexpr uint32_t wg_tile_n = wg_tile_n_;
static constexpr uint32_t wg_tile_k = wg_tile_k_;
static constexpr uint32_t sg_tile_n = sg_tile_n_;
static constexpr uint32_t sg_tile_k = sg_tile_k_;
static constexpr uint32_t k_stride = k_stride_;
};

} // namespace gpu::xetla::kernel
24 changes: 24 additions & 0 deletions include/experimental/kernel/int4_dequantize/int4_dequantize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*******************************************************************************
* Copyright (c) 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

/// @file
/// C++ API

#pragma once

#include <experimental/kernel/int4_dequantize/api.hpp>
#include <experimental/kernel/int4_dequantize/config.hpp>
#include <experimental/kernel/int4_dequantize/int4_dequantize_xe_impl.hpp>
Loading

0 comments on commit b0efdf4

Please sign in to comment.