Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[XETLA] Add dpas attr, refine mma, load, store attr (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
JianpingChen066 authored May 10, 2024
1 parent 4929d80 commit b13e02f
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 191 deletions.
122 changes: 84 additions & 38 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ namespace gpu::xetla {
/// @{

template <msg_type message_type, gpu_arch arch_tag>
struct load_store_attr_t {};
struct load_store_attr_t {
static constexpr bool has_hw_block_2d = false;
};

template <>
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
static constexpr bool has_hw_block_2d = true;
static constexpr uint32_t max_load_height_in_elem = 32;
static constexpr uint32_t max_load_width_in_bytes = 64;
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
Expand All @@ -53,6 +56,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
template <msg_type message_type, gpu_arch arg_tag>
struct client_load_store_attr_base_t {
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
static constexpr bool has_hw_block_2d = false;
static constexpr uint32_t max_load_height_in_elem = 32;
static constexpr uint32_t max_load_width_in_bytes = 64;
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
Expand Down Expand Up @@ -83,74 +87,116 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
msg_type::block_2d,
gpu_arch::XeLpg> {};

template <gpu_arch arch_tag>
inline constexpr bool arch_has_2d_load_store =
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;

template <gpu_arch arch_tag>
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
static constexpr uint32_t max_load_vec_len = 32;
static constexpr uint32_t max_store_vec_len = 32;
static constexpr uint32_t max_prefetch_vec_len = 32;
};

template <>
struct load_store_attr_t<msg_type::block_1d, gpu_arch::XeHpc> {
static constexpr uint32_t max_load_vec_len = 64;
static constexpr uint32_t max_store_vec_len = 64;
static constexpr uint32_t max_prefetch_vec_len = 64;
};

template <gpu_arch arch_tag>
struct mma_attr_t {};
struct dpas_attr_base_t {
static constexpr bool has_xmx = true;
static constexpr uint32_t systolic_depth = 8;
static constexpr uint32_t rcount_max = 8;
static constexpr uint32_t op_per_channel_bits = 32;
static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3);
static constexpr uint32_t op_per_channel_max = 8;
};

template <gpu_arch arch_tag>
struct client_mma_atr_base_t {
static constexpr uint32_t mma_m_in_elem = 8;
static constexpr uint32_t mma_n_in_elem = 8;
static constexpr uint32_t mma_k_in_bytes = 32;
struct dpas_attr_t {
static constexpr bool has_xmx = false;
};

template <>
struct mma_attr_t<gpu_arch::XeHpc> {
static constexpr uint32_t mma_m_in_elem = 8;
static constexpr uint32_t mma_n_in_elem = 16;
static constexpr uint32_t mma_k_in_bytes = 32;
struct dpas_attr_t<gpu_arch::XeHpc> : public dpas_attr_base_t {
static constexpr uint32_t n_fixed_limit = 16;
};

template <>
struct mma_attr_t<gpu_arch::XeHpg>
: public client_mma_atr_base_t<gpu_arch::XeHpg> {};
struct dpas_attr_t<gpu_arch::XeHpg> : public dpas_attr_base_t {
static constexpr uint32_t n_fixed_limit = 8;
};

template <grf_mode grf_num_mode, gpu_arch arch_tag>
struct register_attr_t {};
template <gpu_arch arch_tag>
inline constexpr bool arch_has_xmx = dpas_attr_t<arch_tag>::has_xmx;

template <grf_mode grf_num_mode, gpu_arch arch_tag>
struct client_register_attr_base_t {
static constexpr uint32_t acc_reg_in_bytes =
(grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64;
static constexpr uint32_t grf_in_bytes =
(grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64;
static constexpr uint32_t reg_in_bytes = 64;
template <gpu_arch arch_tag>
struct fpu_attr_t {
static constexpr bool has_fpu = true;
};

template <gpu_arch arch_tag>
inline constexpr bool arch_has_fpu = fpu_attr_t<arch_tag>::has_fpu;

template <grf_mode grf_num_mode>
struct register_attr_t<grf_num_mode, gpu_arch::XeHpc> {
static constexpr uint32_t acc_reg_in_bytes =
(grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64;
static constexpr uint32_t grf_in_bytes =
(grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64;
struct register_nums_t {
static constexpr uint32_t register_nums =
(grf_num_mode == grf_mode::normal) ? 128 : 256;
static constexpr uint32_t acc_register_nums =
(grf_num_mode == grf_mode::normal) ? 4 : 8;
};

template <gpu_arch arch_tag>
struct register_bytes_t {
static constexpr uint32_t reg_in_bytes = 64;
};

template <grf_mode grf_num_mode>
struct register_attr_t<grf_num_mode, gpu_arch::XeHpg>
: public client_register_attr_base_t<grf_num_mode, gpu_arch::XeHpg> {};
template <grf_mode grf_num_mode, gpu_arch arch_tag>
struct register_attr_t {
static constexpr uint32_t reg_in_bytes =
register_bytes_t<arch_tag>::reg_in_bytes;
static constexpr uint32_t register_nums =
register_nums_t<grf_num_mode>::register_nums;
static constexpr uint32_t acc_register_nums =
register_nums_t<grf_num_mode>::acc_register_nums;
static constexpr uint32_t acc_reg_in_bytes = acc_register_nums * reg_in_bytes;
static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes;
};

template <grf_mode grf_num_mode>
struct register_attr_t<grf_num_mode, gpu_arch::XeLpg>
: public client_register_attr_base_t<grf_num_mode, gpu_arch::XeLpg> {};
template <gpu_arch arch_tag, uint32_t m, class enable = void>
struct mma_attr_t {};

template <gpu_arch arch_tag, uint32_t m>
struct mma_attr_t<arch_tag, m, std::enable_if_t<arch_has_xmx<arch_tag>>> {
using dpas_attr = dpas_attr_t<arch_tag>;
static constexpr uint32_t mma_m_in_elem =
(m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit;
static constexpr uint32_t mma_k_in_bytes =
dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes;
};

template <gpu_arch arch_tag, uint32_t m>
struct mma_attr_t<arch_tag, m, std::enable_if_t<!arch_has_xmx<arch_tag>>> {
static constexpr uint32_t mma_m_in_elem = (m > 8) ? 8 : m;
static constexpr uint32_t mma_n_in_elem = 16;
static constexpr uint32_t mma_k_in_bytes = 32;
};

template <gpu_arch arch_tag>
struct arch_attr_t {};

template <gpu_arch arch_tag>
struct client_arch_attr_base_t {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpg>;
using load_store_attr = load_store_attr_t<message_type, arch_tag>;

template <grf_mode grf_num_mode = grf_mode::double_grf>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpg>;
template <grf_mode grf_num_mode = grf_mode::normal>
using register_attr = register_attr_t<grf_num_mode, arch_tag>;

using mma_attr = mma_attr_t<gpu_arch::XeHpg>;
using dpas_attr = dpas_attr_t<arch_tag>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t local_mem_size = 64 * 1024;
Expand All @@ -164,7 +210,7 @@ struct arch_attr_t<gpu_arch::XeHpc> {
template <grf_mode grf_num_mode = grf_mode::double_grf>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;

using mma_attr = mma_attr_t<gpu_arch::XeHpc>;
using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t local_mem_size = 128 * 1024;
Expand Down
6 changes: 0 additions & 6 deletions include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@

namespace gpu::xetla {
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
inline constexpr bool arch_has_xmx(gpu_arch arch) {
return arch >= gpu_arch::XeHpg;
}
inline constexpr bool arch_has_2d_load_store(gpu_arch arch) {
return arch >= gpu_arch::XeHpc;
}

enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

Expand Down
16 changes: 7 additions & 9 deletions include/common/utils/limitation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class block_2d {
ret = ((block_width * block_height * element_size) <= (32 * bytes_per_grf));
XETLA_ASSERT(
ret,
"2D Block Loads upto 32 GRFs are can be read but is %u:%u",
"2D Block Loads upto 32 * %u bytes are can be read but is %u:%u",
bytes_per_grf,
block_width,
block_height);
if (!ret) {
Expand Down Expand Up @@ -318,7 +319,7 @@ class block_2d {
static constexpr auto element_size = sizeof(T);
static constexpr uint32_t max_24bit = 16 * 1024 * 1024; // 2 ^ 24
static constexpr auto bytes_per_grf =
register_attr_t<grf_mode::double_grf, gpu_arch::XeHpc>::reg_in_bytes;
register_attr_t<grf_mode::double_grf, arch_tag>::reg_in_bytes;

static inline bool check_base_address(uint64_t base) {
bool ret = ((base % 64) == 0);
Expand Down Expand Up @@ -746,11 +747,8 @@ struct check_store {
} // namespace subgroup

namespace group {
template <gpu_arch arch = gpu_arch::XeHpc, class enable = void>
struct gemm {};

template <gpu_arch arch>
struct gemm<arch, std::enable_if_t<(arch <= gpu_arch::XeHpc)>> {
template <gpu_arch arch = gpu_arch::XeHpc>
struct gemm {
struct default_fpu {
template <
typename dtype_a,
Expand Down Expand Up @@ -802,7 +800,7 @@ struct gemm<arch, std::enable_if_t<(arch <= gpu_arch::XeHpc)>> {
int block_size_y_b>
struct check_tile_size_default {
static constexpr uint32_t reg_in_bytes =
register_attr_t<grf_mode::double_grf, gpu_arch::XeHpc>::reg_in_bytes;
register_attr_t<grf_mode::double_grf, arch>::reg_in_bytes;
static constexpr uint32_t simd_len = reg_in_bytes / sizeof(dtype_mma);

static_assert(
Expand Down Expand Up @@ -878,7 +876,7 @@ struct gemm<arch, std::enable_if_t<(arch <= gpu_arch::XeHpc)>> {
int block_size_x_b,
int block_size_y_b>
struct check_tile_size_default {
using mma_attr = mma_attr_t<gpu_arch::XeHpc>;
using mma_attr = mma_attr_t<arch, block_size_y_a>;
static constexpr int32_t mma_m = mma_attr::mma_m_in_elem;
static constexpr int32_t mma_n = mma_attr::mma_n_in_elem;
static constexpr int32_t mma_k =
Expand Down
25 changes: 10 additions & 15 deletions include/experimental/group/gemm/compute_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,33 +62,28 @@ struct compute_policy_int4_dequantize<
arch_tag_,
std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> {
using compute_attr = compute_attr_;
using dtype_mma_acc = typename compute_attr::dtype_acc;
using dtype_mma_a = typename compute_attr::dtype_a;
using dtype_mma_b = typename compute_attr::dtype_b;

using perf_tuning_knob = perf_tuning_knob_;
static constexpr int k_stride = perf_tuning_knob::k_stride;
static constexpr int stages = perf_tuning_knob::stages;
static constexpr int sync_freq = perf_tuning_knob::sync_freq;
static constexpr int k_stride = perf_tuning_knob::k_stride;
static constexpr mma_engine mma_engine = mma_engine_;
static constexpr gpu_arch arch_tag = arch_tag_;

static_assert(
!(mma_engine == mma_engine::xmx && arch_tag == gpu_arch::XeLpg),
"XeLpg does not support xmx");

using dtype_mma_acc = typename compute_attr::dtype_acc;
using dtype_mma_a = typename compute_attr::dtype_a;
using dtype_mma_b = typename compute_attr::dtype_b;

static constexpr uint32_t block_bytes_x_a = 32;
static constexpr uint32_t block_size_y_a = 16;

static constexpr bool is_int4_matB_policy = true;

static constexpr uint32_t block_size_x_b = (mma_engine == mma_engine::xmx)
? arch_attr_t<arch_tag>::mma_attr::mma_n_in_elem
: 32;
static constexpr uint32_t block_bytes_y_b = 32;
static_assert(
block_bytes_x_a == block_bytes_y_b,
"mat_a x need to match with mat_b y");
static constexpr uint32_t block_size_y_a = 16;
using mma_attr = mma_attr_t<arch_tag_, block_size_y_a>;
static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes;
static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem;
static constexpr uint32_t block_bytes_y_b = block_bytes_x_a;

static constexpr uint32_t dequant_s = dequant_s_;
static_assert(
Expand Down
Loading

0 comments on commit b13e02f

Please sign in to comment.