From b13e02ff1fc89991293de345fb8a113e158b7ab2 Mon Sep 17 00:00:00 2001 From: Jianping Chen Date: Fri, 10 May 2024 08:48:22 +0800 Subject: [PATCH] [XETLA] Add dpas attr, refine mma, load, store attr (#242) --- include/common/core/arch_config.hpp | 122 ++++++++++++------ include/common/core/common_types.hpp | 6 - include/common/utils/limitation.hpp | 16 +-- .../group/gemm/compute_policy.hpp | 25 ++-- include/group/gemm/compute_policy.hpp | 89 +++++-------- include/group/gemm/impl/default_fpu_xe.hpp | 2 +- include/group/gemm/impl/default_xmx_xe.hpp | 11 +- include/group/gemm/impl/unaligned_xmx_xe.hpp | 4 +- include/subgroup/tile/impl/fma_xe.hpp | 2 +- include/subgroup/tile/impl/load_xe.hpp | 23 ++-- include/subgroup/tile/impl/mma_xe.hpp | 6 +- include/subgroup/tile/impl/prefetch_xe.hpp | 42 +++--- include/subgroup/tile/impl/store_xe.hpp | 21 +-- 13 files changed, 178 insertions(+), 191 deletions(-) diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 352a87512..1b302b91a 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -27,11 +27,14 @@ namespace gpu::xetla { /// @{ template -struct load_store_attr_t {}; +struct load_store_attr_t { + static constexpr bool has_hw_block_2d = false; +}; template <> struct load_store_attr_t { /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 + static constexpr bool has_hw_block_2d = true; static constexpr uint32_t max_load_height_in_elem = 32; static constexpr uint32_t max_load_width_in_bytes = 64; static constexpr uint32_t max_trans_load_width_in_bytes = 32; @@ -53,6 +56,7 @@ struct load_store_attr_t { template struct client_load_store_attr_base_t { /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 + static constexpr bool has_hw_block_2d = false; static constexpr uint32_t max_load_height_in_elem = 32; static constexpr uint32_t max_load_width_in_bytes = 64; static constexpr uint32_t max_trans_load_width_in_bytes = 32; @@ -83,61 +87,103 @@ struct load_store_attr_t msg_type::block_2d, gpu_arch::XeLpg> {}; +template +inline constexpr bool arch_has_2d_load_store = + load_store_attr_t::has_hw_block_2d; + template struct load_store_attr_t { + static constexpr uint32_t max_load_vec_len = 32; + static constexpr uint32_t max_store_vec_len = 32; + static constexpr uint32_t max_prefetch_vec_len = 32; +}; + +template <> +struct load_store_attr_t { static constexpr uint32_t max_load_vec_len = 64; static constexpr uint32_t max_store_vec_len = 64; + static constexpr uint32_t max_prefetch_vec_len = 64; }; -template -struct mma_attr_t {}; +struct dpas_attr_base_t { + static constexpr bool has_xmx = true; + static constexpr uint32_t systolic_depth = 8; + static constexpr uint32_t rcount_max = 8; + static constexpr uint32_t op_per_channel_bits = 32; + static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3); + static constexpr uint32_t op_per_channel_max = 8; +}; template -struct client_mma_atr_base_t { - static constexpr uint32_t mma_m_in_elem = 8; - static constexpr uint32_t mma_n_in_elem = 8; - static constexpr uint32_t mma_k_in_bytes = 32; +struct dpas_attr_t { + static constexpr bool has_xmx = false; }; template <> -struct mma_attr_t { - static constexpr uint32_t mma_m_in_elem = 8; - static constexpr uint32_t mma_n_in_elem = 16; - static constexpr uint32_t mma_k_in_bytes = 32; +struct dpas_attr_t : public dpas_attr_base_t { + static constexpr uint32_t n_fixed_limit = 16; }; template <> -struct mma_attr_t - : public client_mma_atr_base_t {}; +struct dpas_attr_t : public dpas_attr_base_t { + static constexpr uint32_t n_fixed_limit = 8; +}; -template -struct register_attr_t {}; +template +inline constexpr bool arch_has_xmx = dpas_attr_t::has_xmx; -template -struct client_register_attr_base_t { - static constexpr uint32_t acc_reg_in_bytes = - (grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64; - static constexpr uint32_t grf_in_bytes = - (grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64; - static constexpr uint32_t reg_in_bytes = 64; +template +struct fpu_attr_t { + static constexpr bool has_fpu = true; }; +template +inline constexpr bool arch_has_fpu = fpu_attr_t::has_fpu; + template -struct register_attr_t { - static constexpr uint32_t acc_reg_in_bytes = - (grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64; - static constexpr uint32_t grf_in_bytes = - (grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64; +struct register_nums_t { + static constexpr uint32_t register_nums = + (grf_num_mode == grf_mode::normal) ? 128 : 256; + static constexpr uint32_t acc_register_nums = + (grf_num_mode == grf_mode::normal) ? 4 : 8; +}; + +template +struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 64; }; -template -struct register_attr_t - : public client_register_attr_base_t {}; +template +struct register_attr_t { + static constexpr uint32_t reg_in_bytes = + register_bytes_t::reg_in_bytes; + static constexpr uint32_t register_nums = + register_nums_t::register_nums; + static constexpr uint32_t acc_register_nums = + register_nums_t::acc_register_nums; + static constexpr uint32_t acc_reg_in_bytes = acc_register_nums * reg_in_bytes; + static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes; +}; -template -struct register_attr_t - : public client_register_attr_base_t {}; +template +struct mma_attr_t {}; + +template +struct mma_attr_t>> { + using dpas_attr = dpas_attr_t; + static constexpr uint32_t mma_m_in_elem = + (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m; + static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit; + static constexpr uint32_t mma_k_in_bytes = + dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes; +}; + +template +struct mma_attr_t>> { + static constexpr uint32_t mma_m_in_elem = (m > 8) ? 8 : m; + static constexpr uint32_t mma_n_in_elem = 16; + static constexpr uint32_t mma_k_in_bytes = 32; +}; template struct arch_attr_t {}; @@ -145,12 +191,12 @@ struct arch_attr_t {}; template struct client_arch_attr_base_t { template - using load_store_attr = load_store_attr_t; + using load_store_attr = load_store_attr_t; - template - using register_attr = register_attr_t; + template + using register_attr = register_attr_t; - using mma_attr = mma_attr_t; + using dpas_attr = dpas_attr_t; static constexpr uint32_t max_wg_num = 64; static constexpr uint32_t local_mem_size = 64 * 1024; @@ -164,7 +210,7 @@ struct arch_attr_t { template using register_attr = register_attr_t; - using mma_attr = mma_attr_t; + using dpas_attr = dpas_attr_t; static constexpr uint32_t max_wg_num = 64; static constexpr uint32_t local_mem_size = 128 * 1024; diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index a971fc86b..2a23a9e5e 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -22,12 +22,6 @@ namespace gpu::xetla { enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 }; -inline constexpr bool arch_has_xmx(gpu_arch arch) { - return arch >= gpu_arch::XeHpg; -} -inline constexpr bool arch_has_2d_load_store(gpu_arch arch) { - return arch >= gpu_arch::XeHpc; -} enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; diff --git a/include/common/utils/limitation.hpp b/include/common/utils/limitation.hpp index dcb3b185c..041dd4111 100644 --- a/include/common/utils/limitation.hpp +++ b/include/common/utils/limitation.hpp @@ -91,7 +91,8 @@ class block_2d { ret = ((block_width * block_height * element_size) <= (32 * bytes_per_grf)); XETLA_ASSERT( ret, - "2D Block Loads upto 32 GRFs are can be read but is %u:%u", + "2D Block Loads upto 32 * %u bytes are can be read but is %u:%u", + bytes_per_grf, block_width, block_height); if (!ret) { @@ -318,7 +319,7 @@ class block_2d { static constexpr auto element_size = sizeof(T); static constexpr uint32_t max_24bit = 16 * 1024 * 1024; // 2 ^ 24 static constexpr auto bytes_per_grf = - register_attr_t::reg_in_bytes; + register_attr_t::reg_in_bytes; static inline bool check_base_address(uint64_t base) { bool ret = ((base % 64) == 0); @@ -746,11 +747,8 @@ struct check_store { } // namespace subgroup namespace group { -template -struct gemm {}; - -template -struct gemm> { +template +struct gemm { struct default_fpu { template < typename dtype_a, @@ -802,7 +800,7 @@ struct gemm> { int block_size_y_b> struct check_tile_size_default { static constexpr uint32_t reg_in_bytes = - register_attr_t::reg_in_bytes; + register_attr_t::reg_in_bytes; static constexpr uint32_t simd_len = reg_in_bytes / sizeof(dtype_mma); static_assert( @@ -878,7 +876,7 @@ struct gemm> { int block_size_x_b, int block_size_y_b> struct check_tile_size_default { - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; static constexpr int32_t mma_n = mma_attr::mma_n_in_elem; static constexpr int32_t mma_k = diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 2a616877e..79793ebf8 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -62,10 +62,14 @@ struct compute_policy_int4_dequantize< arch_tag_, std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { using compute_attr = compute_attr_; + using dtype_mma_acc = typename compute_attr::dtype_acc; + using dtype_mma_a = typename compute_attr::dtype_a; + using dtype_mma_b = typename compute_attr::dtype_b; + using perf_tuning_knob = perf_tuning_knob_; - static constexpr int k_stride = perf_tuning_knob::k_stride; static constexpr int stages = perf_tuning_knob::stages; static constexpr int sync_freq = perf_tuning_knob::sync_freq; + static constexpr int k_stride = perf_tuning_knob::k_stride; static constexpr mma_engine mma_engine = mma_engine_; static constexpr gpu_arch arch_tag = arch_tag_; @@ -73,22 +77,13 @@ struct compute_policy_int4_dequantize< !(mma_engine == mma_engine::xmx && arch_tag == gpu_arch::XeLpg), "XeLpg does not support xmx"); - using dtype_mma_acc = typename compute_attr::dtype_acc; - using dtype_mma_a = typename compute_attr::dtype_a; - using dtype_mma_b = typename compute_attr::dtype_b; - - static constexpr uint32_t block_bytes_x_a = 32; - static constexpr uint32_t block_size_y_a = 16; - static constexpr bool is_int4_matB_policy = true; - static constexpr uint32_t block_size_x_b = (mma_engine == mma_engine::xmx) - ? arch_attr_t::mma_attr::mma_n_in_elem - : 32; - static constexpr uint32_t block_bytes_y_b = 32; - static_assert( - block_bytes_x_a == block_bytes_y_b, - "mat_a x need to match with mat_b y"); + static constexpr uint32_t block_size_y_a = 16; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; + static constexpr uint32_t block_bytes_y_b = block_bytes_x_a; static constexpr uint32_t dequant_s = dequant_s_; static_assert( diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 2d536c3e6..fe02cb758 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -46,77 +46,44 @@ struct compute_policy_default_xmx< compute_attr_, perf_tuning_knob_, arch_tag_, - std::enable_if_t> { - using compute_attr = compute_attr_; - using perf_tuning_knob = perf_tuning_knob_; - static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr int stages = perf_tuning_knob::stages; - static constexpr int sync_freq = perf_tuning_knob::sync_freq; + std::enable_if_t>> { static constexpr gpu_arch arch_tag = arch_tag_; + + using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; using dtype_mma_b = typename compute_attr::dtype_b; - static constexpr uint32_t block_bytes_x_a = 32; + using perf_tuning_knob = perf_tuning_knob_; + static constexpr int stages = perf_tuning_knob::stages; + static constexpr int sync_freq = perf_tuning_knob::sync_freq; + static constexpr int k_stride = perf_tuning_knob::k_stride; + + static constexpr uint32_t block_size_y_a = 16; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_y_a = 16; - static constexpr uint32_t block_size_x_b = - arch_attr_t::mma_attr::mma_n_in_elem; - static constexpr uint32_t block_bytes_y_b = 32; - static constexpr uint32_t block_size_y_b = - block_bytes_y_b / sizeof(dtype_mma_b); - static_assert( - block_size_x_a == block_size_y_b, - "mat_a x need to match with mat_b y"); + static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; + static constexpr uint32_t block_size_y_b = block_size_x_a; + static constexpr uint32_t block_bytes_y_b = + block_size_x_a * sizeof(dtype_mma_b); }; /// @brief Compute policy for unaligned shape and xmx engine. /// @tparam compute_attr_ Is compute-related attributes. /// @tparam perf_tuning_knob_ Is performance-related knobs. /// @tparam arch_tag_ Is the HW architecture. -template < - typename compute_attr_, - typename perf_tuning_knob_, - gpu_arch arch_tag_ = gpu_arch::XeHpc, - typename enable = void> -struct compute_policy_unaligned_xmx {}; - /// @brief Specialized for Xe architecture. template < typename compute_attr_, typename perf_tuning_knob_, gpu_arch arch_tag_> -struct compute_policy_unaligned_xmx< - compute_attr_, - perf_tuning_knob_, - arch_tag_, - std::enable_if_t> { - using compute_attr = compute_attr_; - using perf_tuning_knob = perf_tuning_knob_; - static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr int stages = perf_tuning_knob::stages; - static constexpr int sync_freq = perf_tuning_knob::sync_freq; - static constexpr gpu_arch arch_tag = arch_tag_; - using dtype_mma_acc = typename compute_attr::dtype_acc; - using dtype_mma_a = typename compute_attr::dtype_a; - using dtype_mma_b = typename compute_attr::dtype_b; - - static constexpr uint32_t block_bytes_x_a = 32; - static constexpr uint32_t block_size_x_a = - block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_y_a = 16; - - static constexpr uint32_t block_size_x_b = - arch_attr_t::mma_attr::mma_n_in_elem; - static constexpr uint32_t block_bytes_y_b = 32; - static constexpr uint32_t block_size_y_b = - block_bytes_y_b / sizeof(dtype_mma_b); - static_assert( - block_size_x_a == block_size_y_b, - "mat_a x need to match with mat_b y"); -}; +struct compute_policy_unaligned_xmx : public compute_policy_default_xmx< + compute_attr_, + perf_tuning_knob_, + arch_tag_> {}; /// @brief Compute policy for fpu engine. /// @tparam compute_attr_ Is compute-related attributes. @@ -138,23 +105,25 @@ struct compute_policy_default_fpu< compute_attr_, perf_tuning_knob_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { - using compute_attr = compute_attr_; - using perf_tuning_knob = perf_tuning_knob_; - static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr int stages = perf_tuning_knob::stages; - static constexpr int sync_freq = perf_tuning_knob::sync_freq; + std::enable_if_t>> { static constexpr gpu_arch arch_tag = arch_tag_; + using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; using dtype_mma_b = typename compute_attr::dtype_b; + using perf_tuning_knob = perf_tuning_knob_; + static constexpr int stages = perf_tuning_knob::stages; + static constexpr int sync_freq = perf_tuning_knob::sync_freq; + static constexpr int k_stride = perf_tuning_knob::k_stride; + + static constexpr uint32_t block_size_y_a = + arch_tag_ == gpu_arch::XeLpg ? 8 : 16; static constexpr uint32_t block_bytes_x_a = 32; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_y_a = - arch_tag_ == gpu_arch::XeLpg ? 8 : 16; + static constexpr uint32_t block_bytes_x_b = arch_tag_ == gpu_arch::XeLpg ? 32 : 64; static constexpr uint32_t block_size_x_b = diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index 950f016f2..34956a328 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -42,7 +42,7 @@ class gemm_t< mem_desc_a_t_, // memory attribute of matA mem_desc_b_t_, // memory attribute of matB pre_processing_t_, // pre_processing functor - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using mem_desc_a_t = mem_desc_a_t_; using mem_desc_b_t = mem_desc_b_t_; diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index 50f2db529..ee3bf3c2b 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -42,7 +42,7 @@ class gemm_t< mem_desc_a_t_, // memory attribute of matA mem_desc_b_t_, // memory attribute of matB pre_processing_t_, // pre_processing functor - std::enable_if_t> { + std::enable_if_t>> { public: using mem_desc_a_t = mem_desc_a_t_; using mem_desc_b_t = mem_desc_b_t_; @@ -72,8 +72,8 @@ class gemm_t< using dtype_mma_a = typename compute_policy::dtype_mma_a; using dtype_mma_b = typename compute_policy::dtype_mma_b; - using check_dtype = group::gemm::default_xmx:: - check_dtype_default; + using check_dtype = group::gemm::default_xmx:: + template check_dtype_default; /******** set memory attribute **********/ static constexpr mem_space mem_space_a = mem_desc_a_t::space; @@ -87,7 +87,7 @@ class gemm_t< is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; using check_memory = - group::gemm::default_xmx::check_memory_default< + group::gemm::default_xmx::template check_memory_default< mem_layout_a, mem_layout_b, mem_space_a, @@ -103,6 +103,7 @@ class gemm_t< static constexpr uint32_t tile_size_y_b = k_stride; static constexpr uint32_t tile_size_x_c = sg_tile_n; static constexpr uint32_t tile_size_y_c = sg_tile_m; + static constexpr uint32_t block_size_x_a = compute_policy::block_size_x_a; static constexpr uint32_t block_size_y_a = (compute_policy::block_size_y_a > tile_size_y_a) @@ -112,7 +113,7 @@ class gemm_t< static constexpr uint32_t block_size_y_b = compute_policy::block_size_y_b; using check_tile_size = - group::gemm::default_xmx::check_tile_size_default< + group::gemm::default_xmx::template check_tile_size_default< dtype_mma_a, tile_size_x_a, tile_size_y_a, diff --git a/include/group/gemm/impl/unaligned_xmx_xe.hpp b/include/group/gemm/impl/unaligned_xmx_xe.hpp index 9ee9a2f4c..c43899d65 100644 --- a/include/group/gemm/impl/unaligned_xmx_xe.hpp +++ b/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -43,7 +43,7 @@ class gemm_t< mem_desc_a_t_, // memory attribute of matA mem_desc_b_t_, // memory attribute of matB pre_processing_t_, // pre_processing functor - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using mem_desc_a_t = mem_desc_a_t_; using mem_desc_b_t = mem_desc_b_t_; @@ -61,7 +61,7 @@ class gemm_t< static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; using work_group_t = typename tile_shape::work_group_t; - constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + constexpr static gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index 147306cc4..c1ca0c6ff 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -37,7 +37,7 @@ struct tile_mma_t< matA_t_, mma_engine::fpu, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using matA_t = matA_t_; using matB_t = matB_t_; using matSrc_t = matAcc_src_t_; diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 62acec7b1..70ddd6d74 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -31,29 +31,24 @@ struct check_load_type { static constexpr bool is_lsc_gather = true; static constexpr bool is_global_block_2d = (payload_t::memory_space == mem_space::global && - (payload_t::message_type == msg_type::block_2d) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::block_2d)); static constexpr bool is_global_block_1d = ((payload_t::memory_space == mem_space::global) && (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1) && - (payload_t::message_type == msg_type::block_1d) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::block_1d)); static constexpr bool is_global_unaligned_2d_xe = ((payload_t::memory_space == mem_space::global) && - (payload_t::message_type == msg_type::unaligned_2d) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::unaligned_2d)); static constexpr bool is_local_scatter_xe = ((payload_t::memory_space == mem_space::local) && - (payload_t::message_type == msg_type::scatter) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::scatter)); static constexpr bool is_local_block_1d_xe = ((payload_t::memory_space == mem_space::local) && - (payload_t::message_type == msg_type::block_1d) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::block_1d)); }; } // namespace detail @@ -79,7 +74,7 @@ template < typename payload_t> __XETLA_API typename std::enable_if_t< detail::check_load_type::is_global_block_2d && - arch_has_2d_load_store(payload_t::arch_tag)> + arch_has_2d_load_store> tile_load(tile_t& tile, payload_t& payload) { using dtype = typename tile_t::dtype; using load_dtype = typename payload_t::mem_dtype; @@ -460,7 +455,7 @@ template < __XETLA_API typename std::enable_if_t< detail::check_load_type::is_global_block_2d && detail::check_load_type::is_lsc_gather && - payload_t::arch_tag <= gpu_arch::XeHpg> + !arch_has_2d_load_store> tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; @@ -569,7 +564,7 @@ template < __XETLA_API typename std::enable_if_t< detail::check_load_type::is_global_block_2d && !detail::check_load_type::is_lsc_gather && - !arch_has_2d_load_store(payload_t::arch_tag)> + !arch_has_2d_load_store> tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; @@ -617,7 +612,9 @@ tile_load(tile_t& tile, payload_t& payload) { SW_BARRIER(); vnni_convert(tile); } + } + /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into /// registers. Each block will be loaded serially by its corresponding payload. diff --git a/include/subgroup/tile/impl/mma_xe.hpp b/include/subgroup/tile/impl/mma_xe.hpp index f0808842c..7292a013f 100644 --- a/include/subgroup/tile/impl/mma_xe.hpp +++ b/include/subgroup/tile/impl/mma_xe.hpp @@ -38,7 +38,7 @@ struct tile_mma_t< matA_t_, mma_engine::xmx, arch_tag_, - std::enable_if_t> { + std::enable_if_t>> { using matA_t = matA_t_; using matB_t = matB_t_; using matSrc_t = matAcc_src_t_; @@ -48,8 +48,6 @@ struct tile_mma_t< using dtype_src = typename matSrc_t::dtype; using dtype_dst = typename matDst_t::dtype; - using mma_attr = typename arch_attr_t::mma_attr; - static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; static constexpr uint32_t a_tile_elems = matA_t::tile_elems; @@ -104,8 +102,10 @@ struct tile_mma_t< static constexpr int32_t num_block_m = matDst_t::num_block_y; static constexpr int32_t num_block_k = tile_size_k / block_size_k; + using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; static constexpr int32_t mma_k = mma_attr::mma_k_in_bytes / sizeof(uint32_t); + static_assert( tile_size_m % mma_m == 0, "tile_size_m shoud be a multiple of mma_m"); diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index c13950411..efec6b698 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -29,23 +29,19 @@ template struct check_prefetch_type { static constexpr bool is_global_2d = ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::tile_desc::tile_size_y != 1)); - static constexpr bool is_global_block_1d_xe = + static constexpr bool is_global_block_1d = ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y == 1) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::tile_desc::tile_size_y == 1)); - static constexpr bool is_global_unaligned_2d_xe = + static constexpr bool is_global_unaligned_2d = ((payload_t::memory_space == mem_space::global) && (payload_t::tile_desc::tile_size_y != 1) && - (payload_t::arch_tag -= gpu_arch::XeHpc) && (payload_t::message_type == msg_type::unaligned_2d)); - static constexpr bool is_local_xe = - ((payload_t::memory_space == mem_space::local) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + static constexpr bool is_local = + (payload_t::memory_space == mem_space::local); }; } // namespace detail @@ -66,7 +62,7 @@ template < typename payload_t> __XETLA_API typename std::enable_if_t< detail::check_prefetch_type::is_global_2d && - payload_t::arch_tag == gpu_arch::XeHpc> + arch_has_2d_load_store> tile_prefetch(payload_t& payload) { using dtype = typename payload_t::dtype; static constexpr uint32_t num_tdesc = payload_t::num_tdesc; @@ -95,7 +91,7 @@ template < typename payload_t> __XETLA_API typename std::enable_if_t< detail::check_prefetch_type::is_global_2d && - payload_t::arch_tag <= gpu_arch::XeHpg> + !arch_has_2d_load_store> tile_prefetch(payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; @@ -153,33 +149,31 @@ template < cache_hint L2 = cache_hint::cached, typename payload_t> __XETLA_API typename std::enable_if_t< - detail::check_prefetch_type::is_global_block_1d_xe> + detail::check_prefetch_type::is_global_block_1d> tile_prefetch(payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; using prefetch_dtype = typename payload_t::prefetch_dtype; constexpr uint32_t prefetch_len = tile_desc::tile_size_x / payload_t::scale_factor; - // TODO (read from arch register info) - constexpr uint32_t reg_in_bytes = - payload_t::arch_tag == gpu_arch::XeHpc ? 64 : 32; - if constexpr (prefetch_len >= reg_in_bytes) { + constexpr uint32_t max_prefetch_in_bytes = load_store_attr_t::max_prefetch_vec_len; + if constexpr (prefetch_len >= max_prefetch_in_bytes) { #pragma unroll - for (uint32_t j = 0; j < prefetch_len / reg_in_bytes; j++) { - uint32_t offset_x = j * reg_in_bytes * payload_t::scale_factor; + for (uint32_t j = 0; j < prefetch_len / max_prefetch_in_bytes; j++) { + uint32_t offset_x = j * max_prefetch_in_bytes * payload_t::scale_factor; uint32_t address_offset = offset_x * sizeof(dtype); xetla_prefetch_global< prefetch_dtype, - reg_in_bytes, + max_prefetch_in_bytes, data_size::default_size, L1, L2>(payload.base_ptr, payload.base_offset + address_offset); } } - constexpr uint32_t tail_len = prefetch_len % reg_in_bytes; + constexpr uint32_t tail_len = prefetch_len % max_prefetch_in_bytes; uint32_t tail_offset = - prefetch_len / reg_in_bytes * reg_in_bytes * payload_t::scale_factor; - detail::process_1d_tail( + prefetch_len / max_prefetch_in_bytes * max_prefetch_in_bytes * payload_t::scale_factor; + detail::process_1d_tail( payload, tail_offset); } @@ -196,7 +190,7 @@ template < cache_hint L2 = cache_hint::cached, typename payload_t> __XETLA_API typename std::enable_if_t< - detail::check_prefetch_type::is_local_xe> + detail::check_prefetch_type::is_local> tile_prefetch([[maybe_unused]] payload_t& payload) {} } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 9bc69b2c0..a84469d6e 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -30,43 +30,36 @@ template struct check_store_type { static constexpr bool is_global_block_2d = (payload_t::memory_space == mem_space::global && - (payload_t::message_type == msg_type::block_2d) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::block_2d)); static constexpr bool is_global_block_1d_xe = ((payload_t::memory_space == mem_space::global) && (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1) && - (payload_t::message_type == msg_type::block_1d) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::block_1d)); static constexpr bool is_global_unaligned_2d_xe = (payload_t::memory_space == mem_space::global && - (payload_t::message_type == msg_type::unaligned_2d) && - (payload_t::arch_tag == gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::unaligned_2d)); static constexpr bool is_global_atomic_xe = ((payload_t::memory_space == mem_space::global) && - (payload_t::message_type == msg_type::atomic_add) && - (payload_t::arch_tag <= gpu_arch::XeHpc)); + (payload_t::message_type == msg_type::atomic_add)); static constexpr bool is_local_scatter_xe = ((payload_t::memory_space == mem_space::local) && (payload_t::message_type == msg_type::scatter) && - (payload_t::arch_tag <= gpu_arch::XeHpc) && (payload_t::tile_desc::register_layout == reg_layout::tiled || payload_t::tile_desc::register_layout == reg_layout::vnni_tiled)); static constexpr bool is_local_scatter_vnni_col_xe = ((payload_t::memory_space == mem_space::local) && (payload_t::message_type == msg_type::scatter) && - (payload_t::arch_tag == gpu_arch::XeHpc) && (payload_t::tile_desc::register_layout == reg_layout::vnni_tiled_col_major)); static constexpr bool is_local_block_1d_xe = ((payload_t::memory_space == mem_space::local) && (payload_t::message_type == msg_type::block_1d) && - (payload_t::arch_tag <= gpu_arch::XeHpc) && (payload_t::tile_desc::register_layout == reg_layout::tiled)); }; @@ -91,7 +84,7 @@ template < typename payload_t> __XETLA_API typename std::enable_if_t< detail::check_store_type::is_global_block_2d && - arch_has_2d_load_store(payload_t::arch_tag)> + arch_has_2d_load_store> tile_store(tile_t& tile, payload_t& payload) { using dtype = typename tile_t::dtype; using tile_desc = typename tile_t::tile_desc; @@ -462,7 +455,7 @@ template < typename payload_t> __XETLA_API typename std::enable_if_t< detail::check_store_type::is_global_block_2d && - payload_t::arch_tag <= gpu_arch::XeHpg> + !arch_has_2d_load_store> tile_store(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; @@ -615,7 +608,7 @@ tile_store( : 1; uint64_t address_offset = offset_x * sizeof(dtype) + (sub_block_y + offset_y) * payload.pitch_in_bytes; - if constexpr (payload_t::arch_tag >= gpu_arch::XeHpc) { + if constexpr (arch_has_2d_load_store) { xetla_tatomic_store_global< dtype, payload_t::num_channel,