diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 8c7c56463..a4d487ff1 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -29,6 +29,7 @@ namespace gpu::xetla { template struct load_store_attr_t { static constexpr bool has_hw_block_2d = false; + static constexpr bool has_block_1d = false; }; template <> @@ -93,6 +94,7 @@ inline constexpr bool arch_has_2d_load_store = template struct load_store_attr_t { + static constexpr bool has_block_1d = true; static constexpr uint32_t max_load_vec_len = 32; static constexpr uint32_t max_store_vec_len = 32; static constexpr uint32_t max_prefetch_vec_len = 32; @@ -100,11 +102,16 @@ struct load_store_attr_t { template <> struct load_store_attr_t { + static constexpr bool has_block_1d = true; static constexpr uint32_t max_load_vec_len = 64; static constexpr uint32_t max_store_vec_len = 64; static constexpr uint32_t max_prefetch_vec_len = 64; }; +template +inline constexpr bool arch_has_1d_load_store = + load_store_attr_t::has_block_1d; + struct dpas_attr_base_t { static constexpr bool has_xmx = true; static constexpr uint32_t systolic_depth = 8; @@ -112,6 +119,7 @@ struct dpas_attr_base_t { static constexpr uint32_t op_per_channel_bits = 32; static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3); static constexpr uint32_t op_per_channel_max = 8; + static constexpr uint32_t k_in_bytes = systolic_depth * op_per_channel_bytes; }; template @@ -121,12 +129,12 @@ struct dpas_attr_t { template <> struct dpas_attr_t : public dpas_attr_base_t { - static constexpr uint32_t n_fixed_limit = 16; + static constexpr uint32_t n_in_elem = 16; }; template <> struct dpas_attr_t : public dpas_attr_base_t { - static constexpr uint32_t n_fixed_limit = 8; + static constexpr uint32_t n_in_elem = 8; }; template @@ -140,16 +148,9 @@ struct fpu_attr_t { template inline constexpr bool arch_has_fpu = fpu_attr_t::has_fpu; -template -struct register_nums_t { - static constexpr uint32_t register_nums = - (grf_num_mode == grf_mode::normal) ? 128 : 256; - static constexpr uint32_t acc_register_nums = - (grf_num_mode == grf_mode::normal) ? 4 : 8; -}; - template struct register_bytes_t; + template <> struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 64; @@ -163,6 +164,14 @@ struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 32; }; +template +struct register_nums_t { + static constexpr uint32_t register_nums = + (grf_num_mode == grf_mode::normal) ? 128 : 256; + static constexpr uint32_t acc_register_nums = + (grf_num_mode == grf_mode::normal) ? 4 : 8; +}; + template struct register_attr_t { static constexpr uint32_t reg_in_bytes = @@ -175,24 +184,48 @@ struct register_attr_t { static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes; }; -template +template < + gpu_arch arch_tag, + mma_engine engine_type, + uint32_t m, + class enable = void> struct mma_attr_t {}; template -struct mma_attr_t>> { +struct mma_attr_t< + arch_tag, + mma_engine::xmx, + m, + std::enable_if_t>> { using dpas_attr = dpas_attr_t; static constexpr uint32_t mma_m_in_elem = (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m; - static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit; - static constexpr uint32_t mma_k_in_bytes = - dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes; + static constexpr uint32_t blk_m_in_elem = 16; + + static constexpr uint32_t mma_n_in_elem = dpas_attr::n_in_elem; + [[maybe_unused]] static constexpr uint32_t blk_n_in_bytes = 64; + + static constexpr uint32_t mma_k_in_bytes = dpas_attr::k_in_bytes; + static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes; }; template -struct mma_attr_t>> { +struct mma_attr_t< + arch_tag, + mma_engine::fpu, + m, + std::enable_if_t>> { + using load_store_attr = load_store_attr_t; static constexpr uint32_t mma_m_in_elem = (m > 8) ? 8 : m; - static constexpr uint32_t mma_n_in_elem = 16; + static constexpr uint32_t blk_m_in_elem = 16; + static constexpr uint32_t mma_k_in_bytes = 32; + static constexpr uint32_t blk_k_in_bytes = + load_store_attr::max_trans_load_width_in_bytes; + + [[maybe_unused]] static constexpr uint32_t mma_n_in_elem = 16; + static constexpr uint32_t blk_n_in_bytes = + register_bytes_t::reg_in_bytes; }; template @@ -210,6 +243,8 @@ struct arch_attr_t { static constexpr uint32_t max_wg_num = 64; static constexpr uint32_t local_mem_size = 128 * 1024; + static constexpr bool has_named_barrier = true; + static constexpr bool has_atomic_add = true; }; template <> @@ -222,8 +257,11 @@ struct arch_attr_t { using dpas_attr = dpas_attr_t; - static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t max_wg_num = 32; static constexpr uint32_t local_mem_size = 64 * 1024; + + static constexpr bool has_named_barrier = false; + static constexpr bool has_atomic_add = true; }; template <> @@ -236,10 +274,20 @@ struct arch_attr_t { using dpas_attr = dpas_attr_t; - static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t max_wg_num = 32; static constexpr uint32_t local_mem_size = 64 * 1024; + static constexpr bool has_named_barrier = false; + static constexpr bool has_atomic_add = true; }; +template +inline constexpr bool arch_has_named_barrier = + arch_attr_t::has_named_barrier; + +template +inline constexpr bool arch_has_atomic_add = + arch_attr_t::has_atomic_add; + /// @} xetla_core_arch_config } // namespace gpu::xetla diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index 2a23a9e5e..07c725b28 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -23,7 +23,12 @@ namespace gpu::xetla { enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 }; +template +inline constexpr bool valid_xe_arch_tag = (arch_tag <= gpu_arch::XeHpc); + enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; enum class mem_layout : uint8_t { row_major = 0, col_major = 1 }; + +enum class mma_engine : uint8_t { xmx = 0, fpu = 1 }; } // namespace gpu::xetla diff --git a/include/common/utils/common.hpp b/include/common/utils/common.hpp index 2fd47ad89..2da30af31 100644 --- a/include/common/utils/common.hpp +++ b/include/common/utils/common.hpp @@ -51,7 +51,7 @@ constexpr uint32_t get_element_size_code() { enum class lsc_action : uint8_t { prefetch, load, store, atomic }; template -constexpr std::enable_if_t +constexpr std::enable_if_t, void> check_lsc_cache_hint() { if constexpr (Action == lsc_action::prefetch) { // https://gfxspecs.intel.com/Predator/Home/Index/53560 @@ -153,7 +153,7 @@ get_prefetch_cache_hint_code() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t get_store_cache_hint_code() { check_lsc_cache_hint(); if (L1H == cache_hint::none && L2H == cache_hint::none) { @@ -286,7 +286,6 @@ enum class store_op : uint8_t { scattered_transpose = 3, block_1d = 4 }; -enum class mma_engine : uint8_t { xmx = 0, fpu = 1 }; // enum class trans_mode : uint8_t { none = 0, transpose = 1 }; enum class memory_op : uint8_t { load = 0, store = 1 }; enum class tdesc_update_dir : uint8_t { x_dir = 0, y_dir = 1 }; diff --git a/include/common/utils/limitation.hpp b/include/common/utils/limitation.hpp index 041dd4111..46f840ed8 100644 --- a/include/common/utils/limitation.hpp +++ b/include/common/utils/limitation.hpp @@ -747,7 +747,7 @@ struct check_store { } // namespace subgroup namespace group { -template +template struct gemm { struct default_fpu { template < @@ -876,7 +876,7 @@ struct gemm { int block_size_x_b, int block_size_y_b> struct check_tile_size_default { - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; static constexpr int32_t mma_n = mma_attr::mma_n_in_elem; static constexpr int32_t mma_k = diff --git a/include/common/utils/raw_send_load_store.hpp b/include/common/utils/raw_send_load_store.hpp index 4c85a7fd7..06d55ece2 100644 --- a/include/common/utils/raw_send_load_store.hpp +++ b/include/common/utils/raw_send_load_store.hpp @@ -219,12 +219,12 @@ __XETLA_API void xetla_update_tdesc_offsety( template < typename Ty, uint32_t N, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none, - bool transpose = false, - bool transform = false, - gpu_arch arch_tag = gpu_arch::XeHpc> -__XETLA_API std::enable_if_t> + cache_hint L1H, + cache_hint L2H, + bool transpose, + bool transform, + gpu_arch arch_tag > +__XETLA_API std::enable_if_t, xetla_vector> xetla_tload_global(xetla_tdescriptor tdesc) { DEBUG_INVOKE( dbg_level::core, @@ -273,10 +273,10 @@ xetla_tload_global(xetla_tdescriptor tdesc) { template < typename Ty, uint32_t N, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none, - gpu_arch arch_tag = gpu_arch::XeHpc> -__XETLA_API std::enable_if_t + cache_hint L1H, + cache_hint L2H, + gpu_arch arch_tag> +__XETLA_API std::enable_if_t, void> xetla_tstore_global(xetla_tdescriptor tdesc, xetla_vector data) { DEBUG_INVOKE( dbg_level::core, core::block_2d::check_store(tdesc)); @@ -310,10 +310,10 @@ xetla_tstore_global(xetla_tdescriptor tdesc, xetla_vector data) { /// template < typename Ty, - cache_hint L1H = cache_hint::cached, - cache_hint L2H = cache_hint::cached, - gpu_arch arch_tag = gpu_arch::XeHpc> -__XETLA_API std::enable_if_t + cache_hint L1H, + cache_hint L2H, + gpu_arch arch_tag> +__XETLA_API std::enable_if_t, void> xetla_tprefetch_global(xetla_tdescriptor tdesc) { uint32_t msg_desc = 3; msg_desc |= 0 << 7; @@ -350,12 +350,12 @@ xetla_tprefetch_global(xetla_tdescriptor tdesc) { template < typename Ty, uint32_t N, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none, + cache_hint L1H, + cache_hint L2H, atomic_op Op, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename Toffset = uint32_t> -__XETLA_API std::enable_if_t +__XETLA_API std::enable_if_t, void> xetla_tatomic_store_global( uint64_t base_address, xetla_vector offset, diff --git a/include/common/utils/raw_send_nbarrier.hpp b/include/common/utils/raw_send_nbarrier.hpp index a1b17be02..8c40d2480 100644 --- a/include/common/utils/raw_send_nbarrier.hpp +++ b/include/common/utils/raw_send_nbarrier.hpp @@ -41,9 +41,9 @@ enum class nbarrier_role : uint8_t { /// as consumer. /// template < - uint8_t num_producers = 1, - uint8_t num_consumers = 1, - gpu_arch arch_tag = gpu_arch::XeHpc, + uint8_t num_producers, + uint8_t num_consumers, + gpu_arch arch_tag, typename enable = void> struct xetla_nbarrier_t; @@ -52,7 +52,7 @@ struct xetla_nbarrier_t< num_producers, num_consumers, arch_tag, - std::enable_if_t> { + std::enable_if_t>> { /// /// @brief Description of named barrier objection. /// Structure is defined in @@ -105,20 +105,7 @@ struct xetla_nbarrier_t< num_producers, num_consumers, arch_tag, - std::enable_if_t> { - /// - /// @brief Description of named barrier objection. - /// Structure is defined in - /// [here](https://gfxspecs.intel.com/Predator/Home/Index/57499). - /// - // xetla_vector nbar; - // uint32_t barrier_id; - - /// @param role is the role of subgroup when participating the barrier. - /// @param nbarrier_id [in] is the id of the barrier. - /// note: all subgroups participating the barrier should have the same - /// barrier_id. Here is the bspec link - /// https://gfxspecs.intel.com/Predator/Home/Index/54006 + std::enable_if_t>> { __XETLA_API void init_nbarrier(uint8_t, nbarrier_role) {} /// @brief Generic work-group split barrier. @@ -127,14 +114,10 @@ struct xetla_nbarrier_t< __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::signal>(); } - /// @brief named barrier wait within subgroup. - /// __XETLA_API void wait() { __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::wait>(); } - /// @brief named barrier signal from subgroup. - /// __XETLA_API void arrive_wait() { arrive(); wait(); diff --git a/include/experimental/group/dropout_mask_gen.hpp b/include/experimental/group/dropout_mask_gen.hpp index fce6d2366..b8bb03c54 100644 --- a/include/experimental/group/dropout_mask_gen.hpp +++ b/include/experimental/group/dropout_mask_gen.hpp @@ -39,8 +39,8 @@ template < uint32_t wg_tile_m_, uint32_t sg_tile_n_, uint32_t sg_tile_m_, - uint32_t random_simd_ = 16, - gpu_arch arch_ = gpu_arch::XeHpc> + uint32_t random_simd_, + gpu_arch arch_> struct mask_gen_t { using dtype_mask = dtype_mask_; static constexpr uint32_t wg_tile_n = wg_tile_n_; @@ -64,8 +64,7 @@ struct mask_gen_t { float dropout_prob; }; - using load_store_attr = - typename arch_attr_t::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_width_in_bytes = load_store_attr::max_store_width_in_bytes; static constexpr uint32_t max_store_width_in_elem = @@ -99,7 +98,7 @@ struct mask_gen_t { mem_desc_t, mask_out_tile_desc_t, (sg_tile_m == 1) ? msg_type::block_1d : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_>; static constexpr uint32_t tile_size = tile_size_x * tile_size_y; /// @brief diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp index 8f5d7da1e..eb13d418f 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp @@ -60,7 +60,7 @@ template < typename dtype_out_, typename dtype_acc_, typename layer_norm_attr_, - gpu_arch arch_ = gpu_arch::XeHpc> + gpu_arch arch_> struct ln_fwd_fused_op_t {}; /// @brief @@ -77,7 +77,7 @@ template < typename dtype_out_, typename dtype_acc_, typename layer_norm_attr_, - gpu_arch arch_ = gpu_arch::XeHpc> + gpu_arch arch_> struct ln_bwd_fused_op_t {}; } // namespace group diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp index 04e621447..de88736bf 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp @@ -136,6 +136,7 @@ struct ln_bwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_bwd_fused_kind fused_op_kind = ln_bwd_fused_kind::bias_dropout_resAdd_ln; using dtype_acc = dtype_acc_; @@ -164,13 +165,13 @@ struct ln_bwd_fused_op_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_in_t = subgroup::tile_t; using mask_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; dx_resAdd_out_t dx_resAdd_out; dx_resAdd_out_payload_t dx_resAdd_out_payload; mask_in_t mask_in; @@ -279,6 +280,7 @@ struct ln_bwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_bwd_fused_kind fused_op_kind = ln_bwd_fused_kind::ln_dropout_gradAdd; using dtype_acc = dtype_acc_; @@ -307,7 +309,7 @@ struct ln_bwd_fused_op_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_in_t = subgroup::tile_t; using mask_in_payload_t = subgroup::mem_payload_t< mem_desc_t, @@ -415,6 +417,7 @@ struct ln_bwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_bwd_fused_kind fused_op_kind = ln_bwd_fused_kind::ln_dropout; using dtype_acc = dtype_acc_; @@ -439,7 +442,7 @@ struct ln_bwd_fused_op_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; mask_in_t mask_in; mask_in_payload_t mask_in_payload; diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp index 3357dac5d..8e0aa8dd4 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp @@ -130,6 +130,7 @@ struct ln_fwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::bias_dropout_resAdd_ln; using dtype_acc = dtype_acc_; @@ -161,26 +162,26 @@ struct ln_fwd_fused_op_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using res_in_t = subgroup::tile_t; using res_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_in_t = subgroup::tile_t; using mask_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using bias_dropout_res_out_t = subgroup::tile_t; using bias_dropout_res_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; bias_in_t bias_in; bias_in_payload_t bias_in_payload; bias_dropout_res_out_t bias_dropout_res_out; @@ -312,6 +313,7 @@ struct ln_fwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::ln_dropout; using dtype_acc = dtype_acc_; @@ -343,7 +345,7 @@ struct ln_fwd_fused_op_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; mask_in_t mask_in; mask_in_payload_t mask_in_payload; uint32_t mask_ld; @@ -422,6 +424,7 @@ struct ln_fwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::bias_rng_dropout_resAdd_ln; using dtype_acc = dtype_acc_; @@ -449,26 +452,26 @@ struct ln_fwd_fused_op_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using res_in_t = subgroup::tile_t; using res_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_out_t = subgroup::tile_t; using mask_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using bias_dropout_res_out_t = subgroup::tile_t; using bias_dropout_res_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; bias_in_t bias_in; bias_in_payload_t bias_in_payload; @@ -611,6 +614,7 @@ struct ln_fwd_fused_op_t< dtype_acc_, layer_norm_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::ln_rng_dropout; using dtype_acc = dtype_acc_; @@ -641,7 +645,7 @@ struct ln_fwd_fused_op_t< mem_desc_t, mask_out_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; mask_out_t mask_out; mask_out_payload_t mask_out_payload; dropout_fwd_t dropout_fwd; diff --git a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp index 611a09b8c..7f14bcc06 100644 --- a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp +++ b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp @@ -91,6 +91,7 @@ struct row_reduction_fused_op_t< dtype_acc_, reduction_attr_, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr reduction_fused_kind fused_op_kind = reduction_fused_kind::bias_gelu_w_bwd; using dtype_in = dtype_in_; @@ -143,13 +144,13 @@ struct row_reduction_fused_op_t< mem_desc_t, dgelu_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using dgelu_x_out_t = subgroup::tile_t; using dgelu_x_out_payload_t = subgroup::mem_payload_t< mem_desc_t, dgelu_tile_desc_t, msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; dgelu_w_in_t dgelu_w_in; dgelu_w_in_payload_t dgelu_w_in_payload(w_load_base_desc); subgroup::tile_load(dgelu_w_in, dgelu_w_in_payload); @@ -188,6 +189,7 @@ struct row_reduction_fused_op_t< gpu_arch::XeHpc> { static constexpr reduction_fused_kind fused_op_kind = reduction_fused_kind::bias_dropout_bwd; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; using dtype_in = dtype_in_; using dtype_out = dtype_out_; using dtype_acc = dtype_acc_; @@ -238,14 +240,14 @@ struct row_reduction_fused_op_t< mem_desc_t, reduction_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using dropout_bwd_out_t = subgroup::tile_t; using dropout_bwd_out_payload_t = subgroup::mem_payload_t< mem_desc_t, reduction_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; if (dropout_prob != 0) { mask_in_t mask_in; mask_in_payload_t mask_in_payload(mask_load_base_desc); diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 8d7ffe33d..55bb1e03b 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -36,8 +36,8 @@ template < typename dtype_zero_pt_, quant_mode quant_type_, int dequant_s_, - mma_engine mma_engine_ = mma_engine::xmx, - gpu_arch arch_tag_ = gpu_arch::XeHpc, + mma_engine mma_engine_, + gpu_arch arch_tag_, typename enable = void> struct compute_policy_int4_dequantize {}; @@ -60,7 +60,7 @@ struct compute_policy_int4_dequantize< dequant_s_, mma_engine_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; @@ -74,8 +74,9 @@ struct compute_policy_int4_dequantize< static constexpr gpu_arch arch_tag = arch_tag_; static_assert( - !(mma_engine == mma_engine::xmx && arch_tag == gpu_arch::XeLpg), - "XeLpg does not support xmx"); + ((mma_engine == mma_engine::xmx && arch_has_xmx) || + (mma_engine == mma_engine::fpu && arch_has_fpu)), + "arch does not support mma_engine specified"); static constexpr bool is_int4_matB_policy = true; @@ -88,15 +89,17 @@ struct compute_policy_int4_dequantize< static constexpr quant_mode quant_type = quant_type_; static constexpr uint32_t block_size_y_a = 16; - using mma_attr = mma_attr_t; - static constexpr uint32_t block_bytes_x_a = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_bytes_x_a = (mma_engine == mma_engine::xmx) + ? mma_attr::mma_k_in_bytes + : mma_attr::blk_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_x_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_n_in_elem : 32; - static constexpr uint32_t block_bytes_y_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + + static constexpr uint32_t block_size_x_b = (mma_engine == mma_engine::xmx) + ? mma_attr::mma_n_in_elem + : mma_attr::blk_n_in_bytes / sizeof(dtype_mma_b); + static constexpr uint32_t block_bytes_y_b = block_bytes_x_a; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 528911fcf..cf3ab0173 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -285,7 +285,8 @@ class gemm_t< public: static constexpr uint32_t barrier_count = - enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + enable_periodic_sync && arch_has_named_barrier ? + barrier_count_x + barrier_count_y : 0; // current only support matA from slm static constexpr uint32_t slm_size = is_local_a ? sg_tile_m * wg_size_y * k_stride * sizeof(dtype_a) : 0; @@ -482,7 +483,7 @@ class gemm_t< if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); } - if constexpr (arch_tag >= gpu_arch::XeHpc) { + if constexpr (arch_has_named_barrier) { if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } @@ -556,7 +557,7 @@ class gemm_t< if constexpr (wg_size_x > 1) { nbarrier_a.wait(); } - if constexpr (arch_tag >= gpu_arch::XeHpc) { + if constexpr (arch_has_named_barrier) { if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } diff --git a/include/experimental/group/reduction/reduction_api.hpp b/include/experimental/group/reduction/reduction_api.hpp index 5e487cedf..8e155d591 100644 --- a/include/experimental/group/reduction/reduction_api.hpp +++ b/include/experimental/group/reduction/reduction_api.hpp @@ -45,8 +45,8 @@ template < uint32_t row_size, uint32_t wg_size_x, uint32_t wg_size_y, - uint32_t max_simd_len = 32, - gpu_arch arch_ = gpu_arch::XeHpc> + uint32_t max_simd_len, + gpu_arch arch_> struct group_row_reduce_store_t {}; } // namespace gpu::xetla::group diff --git a/include/experimental/group/reduction/row_reduce_store_xe.hpp b/include/experimental/group/reduction/row_reduce_store_xe.hpp index a38d4e5bb..f0f9671f1 100644 --- a/include/experimental/group/reduction/row_reduce_store_xe.hpp +++ b/include/experimental/group/reduction/row_reduce_store_xe.hpp @@ -38,6 +38,7 @@ struct group_row_reduce_store_t< wg_size_y, max_simd_len, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr uint32_t block_size_x = gpu::xetla::subgroup::detail::gcd::value; static_assert( @@ -62,7 +63,7 @@ struct group_row_reduce_store_t< mem_desc_t, local_st_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using local_ld_tile_desc_t = subgroup::tile_desc_t< local_tile_size_x, wg_size_y, @@ -74,7 +75,7 @@ struct group_row_reduce_store_t< mem_desc_t, local_ld_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; // If the local tile size is small, we still can use 2D block store using global_st_tile_desc_t = subgroup:: @@ -85,8 +86,8 @@ struct group_row_reduce_store_t< global_st_tile_desc_t, (local_tile_size_x * sizeof(dtype_out) > 64) ? msg_type::block_1d : msg_type::block_2d, - gpu_arch::XeHpc>; - xetla_nbarrier_t nbarrier; + arch_tag>; + xetla_nbarrier_t nbarrier; local_st_t local_st; local_st_payload_t local_st_payload; local_ld_t local_ld; @@ -165,6 +166,7 @@ struct group_row_reduce_store_t< 1, max_simd_len, gpu_arch::XeHpc> { + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr uint32_t block_size_x = gpu::xetla::subgroup::detail::gcd::value; @@ -176,7 +178,7 @@ struct group_row_reduce_store_t< global_st_tile_desc_t, (row_size * sizeof(dtype_out) > 64) ? msg_type::block_1d : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; inline void init( [[maybe_unused]] uint32_t sg_idx_ = 0, [[maybe_unused]] uint32_t sg_idy_ = 0, diff --git a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp index 4cbd1498a..4a1e7bd5f 100644 --- a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp +++ b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp @@ -57,6 +57,7 @@ struct xetla_data_transformer< using dtype_compute = dtype_compute_; using data_transformer_attr = data_transformer_attr_; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr mem_layout mem_layout_in = mem_layout_in_; static constexpr bool is_col_major_in = @@ -73,8 +74,7 @@ struct xetla_data_transformer< static constexpr uint32_t wg_size_x = (wg_tile_n + sg_tile_n - 1) / sg_tile_n; static constexpr uint32_t wg_size_y = (wg_tile_m + sg_tile_m - 1) / sg_tile_m; - using load_store_attr = typename arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; static constexpr uint32_t max_load_width_in_bytes = @@ -126,7 +126,7 @@ struct xetla_data_transformer< mem_desc_t, global_ld_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using global_st_tile_desc_t = subgroup::tile_desc_t< tile_size_x, @@ -139,7 +139,7 @@ struct xetla_data_transformer< mem_desc_t, global_st_tile_desc_t, msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using global_compute_tile_desc = subgroup::tile_desc_t< tile_size_x, tile_size_y, @@ -156,7 +156,7 @@ struct xetla_data_transformer< reduce_op::max, wg_size_x * wg_size_y, true, - gpu_arch::XeHpc>; + arch_tag>; /// @brief Arguments for gemm::run. /// User should prepare mat_in_ptr, mat_out_ptr, matrix_m, matrix_n, @@ -291,7 +291,8 @@ struct xetla_data_transformer< simd, cache_hint::uncached, cache_hint::write_back, - atomic_op::fmax>( + atomic_op::fmax, + arch_tag>( (uint64_t)args->amax_ptr, offsets * sizeof(dtype_compute), local_max, diff --git a/include/experimental/kernel/layer_norm/api.hpp b/include/experimental/kernel/layer_norm/api.hpp index d1fe6b361..5c3cfc979 100644 --- a/include/experimental/kernel/layer_norm/api.hpp +++ b/include/experimental/kernel/layer_norm/api.hpp @@ -40,8 +40,8 @@ template < typename dtype_weight_, typename dtype_acc_, typename layer_norm_attr_, - bool store_for_bwd_ = true, - gpu_arch arch_ = gpu_arch::XeHpc, + bool store_for_bwd_, + gpu_arch arch_, typename ln_fwd_fused_op_ = group::ln_fwd_fused_op_t< ln_fwd_fused_kind::none, dtype_x_, @@ -66,7 +66,7 @@ template < typename dtype_weight_, typename dtype_acc_, typename layer_norm_attr_, - gpu_arch arch_ = gpu_arch::XeHpc, + gpu_arch arch_, typename ln_bwd_fused_op_ = group::ln_bwd_fused_op_t< ln_bwd_fused_kind::none, dtype_y_, diff --git a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp index f4c2fff61..6ed6c0aa0 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp @@ -56,6 +56,7 @@ struct layer_norm_bwd_t< using layer_norm_attr = layer_norm_attr_; using ln_bwd_fused_op = ln_bwd_fused_op_; using ln_fused_op_arguments_t = typename ln_bwd_fused_op::arguments_t; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m; static constexpr uint32_t wg_tile_n = layer_norm_attr::wg_tile_n; static constexpr uint32_t sg_tile_m = layer_norm_attr::sg_tile_m; @@ -96,22 +97,22 @@ struct layer_norm_bwd_t< mem_desc_t, ln_bwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using x_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using gamma_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using dx_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using ln_group_row_reduce_store_t = group::group_row_reduce_store_t< dtype_acc, @@ -120,7 +121,7 @@ struct layer_norm_bwd_t< wg_size_x, wg_size_y, 32, - gpu_arch::XeHpc>; + arch_tag>; /// @brief /// @@ -162,7 +163,7 @@ struct layer_norm_bwd_t< reduce_op Op, uint32_t wg_size_x, uint32_t wg_size_y, - gpu_arch arch_ = gpu_arch::XeHpc> + gpu_arch arch_> struct ln_group_all_reduce_t { uint32_t itr_count; uint32_t slm_base_0; @@ -266,7 +267,7 @@ struct layer_norm_bwd_t< reduce_op::sum, wg_size_x, wg_size_y, - gpu_arch::XeHpc>; + arch_tag>; public: __XETLA_API static void call( diff --git a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp index 08ad1eaf3..413ae88fd 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp @@ -59,6 +59,7 @@ struct layer_norm_fwd_t< using layer_norm_attr = layer_norm_attr_; using ln_fwd_fused_op = ln_fwd_fused_op_; using ln_fused_op_arguments_t = typename ln_fwd_fused_op::arguments_t; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr bool store_for_bwd = store_for_bwd_; static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m; @@ -105,22 +106,22 @@ struct layer_norm_fwd_t< mem_desc_t, ln_fwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using gamma_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using beta_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using y_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; /// @brief /// @@ -202,7 +203,7 @@ struct layer_norm_fwd_t< int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n; int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m; - xetla_nbarrier_t nbarrier; + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier( sg_idy + nbarrier_base, nbarrier_role::producer_consumer); diff --git a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp index 0dac16a9d..8e26418e1 100644 --- a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp @@ -43,6 +43,7 @@ struct xetla_mha_attn_reg_fwd_t { using dtype_sfx = dtype_sfx_; using dtype_acc = dtype_acc_; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr int ThreadNum = HWThreadNum; static constexpr int max_seqlen = Max_SeqLen; static constexpr mem_space mem_space_a = mem_space::global; @@ -84,7 +85,7 @@ struct xetla_mha_attn_reg_fwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -93,7 +94,7 @@ struct xetla_mha_attn_reg_fwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -103,19 +104,19 @@ struct xetla_mha_attn_reg_fwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x384 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x512 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_32x1024 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_16x2048 = - group::pre_processing_default_t; - using pre_processing_128x64 = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64 = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -232,49 +233,49 @@ struct xetla_mha_attn_reg_fwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_128x128_t = subgroup::tile_t; @@ -292,37 +293,37 @@ struct xetla_mha_attn_reg_fwd_t { mem_desc_t, mat_128x128_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_128x256_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_64x384_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_64x512_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_32x1024_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_16x2048_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t, mat_128x64_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; /// @brief Arguments for xetla_softmax_fwd_t::run. /// User should prepare matQ_ptr, matK_ptr, matQKT_ptr, ... @@ -450,9 +451,9 @@ struct xetla_mha_attn_reg_fwd_t { int tid_x = tid_linear & ((1 << tid_x_shift) - 1); int tid_y = tid_linear >> tid_x_shift; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> first_nbarr; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> second_nbarr; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> third_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> first_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> second_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> third_nbarr; first_nbarr.init_nbarrier(31, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier(30, nbarrier_role::producer_consumer); third_nbarr.init_nbarrier(29, nbarrier_role::producer_consumer); @@ -857,7 +858,8 @@ struct xetla_mha_attn_reg_fwd_t { 16, cache_hint::none, cache_hint::none, - atomic_op::fmax>( + atomic_op::fmax, + arch_tag>( (uint64_t)args->Max_ptr, address_fmax, matElem_reg_max_local.xetla_select<16, 1>(0), @@ -1001,7 +1003,8 @@ struct xetla_mha_attn_reg_fwd_t { 16, cache_hint::none, cache_hint::none, - atomic_op::fadd>( + atomic_op::fadd, + arch_tag>( (uint64_t)args->Sum_ptr, address_fmax, matElem_reg_Sum_1.xetla_select<16, 1>(0), @@ -1546,6 +1549,7 @@ struct xetla_mha_attn_reg_bwd_t { using dtype_sfx = dtype_bwd_sfx_; using dtype_acc = dtype_bwd_acc_; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr int ThreadNum = HWThreadNum; static_assert(ThreadNum == 32); static constexpr mem_space mem_space_a = mem_space::global; @@ -1591,7 +1595,7 @@ struct xetla_mha_attn_reg_bwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -1600,7 +1604,7 @@ struct xetla_mha_attn_reg_bwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out_b_trnp_a = mem_desc_t; @@ -1609,7 +1613,7 @@ struct xetla_mha_attn_reg_bwd_t { using compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -1619,25 +1623,25 @@ struct xetla_mha_attn_reg_bwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x384 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x512 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_32x1024 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_16x2048 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x64 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_256x64 = - group::pre_processing_default_t; - using pre_processing_128x64_af = group:: - pre_processing_matA_neg_filter_t; - using pre_processing_256x64_af = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64_af = + group::pre_processing_matA_neg_filter_t; + using pre_processing_256x64_af = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -1786,42 +1790,42 @@ struct xetla_mha_attn_reg_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, @@ -1869,33 +1873,33 @@ struct xetla_mha_attn_reg_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_256x64_trnp_a_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_128x128_t = subgroup::tile_t; using matW_128x256_t = subgroup::tile_t; @@ -1908,32 +1912,32 @@ struct xetla_mha_attn_reg_bwd_t { mem_desc_t, matC_128x128_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x256_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_64x384_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_64x512_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_32x1024_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_16x2048_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; #if 0 //512 = 16x32 or 8x64 @@ -2083,15 +2087,15 @@ struct xetla_mha_attn_reg_bwd_t { int tid_y = tid_linear >> tid_x_shift; static_assert(ThreadNum == 32, "All Thread Sync"); - xetla_nbarrier_t first_nbarr; - xetla_nbarrier_t second_nbarr; + xetla_nbarrier_t first_nbarr; + xetla_nbarrier_t second_nbarr; int max_2d_nbar_id = ThreadNum >> 1; first_nbarr.init_nbarrier(max_2d_nbar_id, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier( max_2d_nbar_id + 1, nbarrier_role::producer_consumer); - xetla_nbarrier_t all_nbarr; + xetla_nbarrier_t all_nbarr; all_nbarr.init_nbarrier(ThreadNum - 1, nbarrier_role::producer_consumer); for (int transp128_loop = 0; transp128_loop < transp128_loop_num; @@ -2743,7 +2747,8 @@ struct xetla_mha_attn_reg_bwd_t { 16, cache_hint::none, cache_hint::none, - atomic_op::fadd>( + atomic_op::fadd, + arch_tag>( (uint64_t)args->matSum_ptr, address_fsum, matElem_reg_Sum_1.xetla_select<16, 1>(0), diff --git a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp index e6764ff3c..c116e437b 100644 --- a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp @@ -53,6 +53,7 @@ struct xetla_mha_core_attn_fwd_t { using dtype_sfx = dtype_sfx_; using dtype_acc = dtype_acc_; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr int ThreadNum = HWThreadNum; static constexpr int max_seqlen = Max_SeqLen; static constexpr mem_space mem_space_a = mem_space::global; @@ -90,7 +91,7 @@ struct xetla_mha_core_attn_fwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -99,7 +100,7 @@ struct xetla_mha_core_attn_fwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -109,11 +110,11 @@ struct xetla_mha_core_attn_fwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; - using pre_processing_128x64 = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64 = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -167,17 +168,17 @@ struct xetla_mha_core_attn_fwd_t { mem_desc_t, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; // 512 = 16x32 or 8x64 using matElem_tile_desc_t = gpu::xetla::subgroup::tile_desc_t< @@ -192,14 +193,14 @@ struct xetla_mha_core_attn_fwd_t { mem_desc_t, matElem_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matElem_st_t = gpu::xetla::subgroup::tile_t; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< mem_desc_t, matElem_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matElem_reg_t = gpu::xetla::subgroup::tile_t< float, gpu::xetla::subgroup::tile_desc_t<32, 16, 32, 16, reg_layout::tiled>>; @@ -311,8 +312,8 @@ struct xetla_mha_core_attn_fwd_t { blk_128x256_loop_num = 2; } - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> first_nbarr; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> second_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> first_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> second_nbarr; first_nbarr.init_nbarrier(0, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier(1, nbarrier_role::producer_consumer); @@ -869,6 +870,7 @@ struct xetla_mha_core_attn_bwd_t { using dtype_sfx = dtype_bwd_sfx_; using dtype_acc = dtype_bwd_acc_; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr int ThreadNum = HWThreadNum; static_assert(ThreadNum == 32); static constexpr mem_space mem_space_a = mem_space::global; @@ -909,7 +911,7 @@ struct xetla_mha_core_attn_bwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -918,7 +920,7 @@ struct xetla_mha_core_attn_bwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out_b_trnp_a = mem_desc_t; @@ -927,7 +929,7 @@ struct xetla_mha_core_attn_bwd_t { using compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -937,17 +939,17 @@ struct xetla_mha_core_attn_bwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x64 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_256x64 = - group::pre_processing_default_t; - using pre_processing_128x64_af = group:: - pre_processing_matA_neg_filter_t; - using pre_processing_256x64_af = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64_af = + group::pre_processing_matA_neg_filter_t; + using pre_processing_256x64_af = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -1072,49 +1074,49 @@ struct xetla_mha_core_attn_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_256x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; // 512 = 16x32 or 8x64 using matElem_tile_desc_t = gpu::xetla::subgroup::tile_desc_t< @@ -1131,12 +1133,12 @@ struct xetla_mha_core_attn_bwd_t { mem_desc_t, matElem_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< mem_desc_t, matElem_tile_desc_t, msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using matElem_reg_t = gpu::xetla::subgroup::tile_t< float, gpu::xetla::subgroup::tile_desc_t<32, 16, 32, 16, reg_layout::tiled>>; @@ -1243,15 +1245,15 @@ struct xetla_mha_core_attn_bwd_t { g_thd32_tid.init(tid_linear); static_assert(ThreadNum == 32, "All Thread Sync"); - xetla_nbarrier_t first_nbarr; - xetla_nbarrier_t second_nbarr; + xetla_nbarrier_t first_nbarr; + xetla_nbarrier_t second_nbarr; int max_2d_nbar_id = ThreadNum >> 1; first_nbarr.init_nbarrier(max_2d_nbar_id, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier( max_2d_nbar_id + 1, nbarrier_role::producer_consumer); - xetla_nbarrier_t all_nbarr; + xetla_nbarrier_t all_nbarr; all_nbarr.init_nbarrier(ThreadNum - 1, nbarrier_role::producer_consumer); for (int transp128_loop = 0; transp128_loop < transp128_loop_num; diff --git a/include/experimental/kernel/reduction/row_reduction_xe.hpp b/include/experimental/kernel/reduction/row_reduction_xe.hpp index 1d24d50b0..7e1676007 100644 --- a/include/experimental/kernel/reduction/row_reduction_xe.hpp +++ b/include/experimental/kernel/reduction/row_reduction_xe.hpp @@ -58,6 +58,7 @@ struct xetla_row_reduction_t< using fused_op_t = fused_op_t_; using fused_op_arguments_t = typename fused_op_t::arguments_t; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static constexpr uint32_t wg_tile_m = reduction_attr::wg_tile_m; static constexpr uint32_t wg_tile_n = reduction_attr::wg_tile_n; static constexpr uint32_t sg_tile_m = reduction_attr::sg_tile_m; @@ -67,8 +68,7 @@ struct xetla_row_reduction_t< static constexpr uint32_t wg_size_y = (wg_tile_m + sg_tile_m - 1) / sg_tile_m; using work_group_t = work_group_t; static constexpr bool use_dynamic_job = is_dynamic_job && (wg_size_y > 1); - using load_store_attr = typename arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; static constexpr uint32_t max_load_width_in_bytes = @@ -112,7 +112,7 @@ struct xetla_row_reduction_t< mem_desc_t, global_ld_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mat_buffer_t = subgroup::tile_t< dtype_acc, subgroup:: @@ -124,7 +124,8 @@ struct xetla_row_reduction_t< sg_tile_n, wg_size_x, wg_size_y, - max_simd_len>; + max_simd_len, + arch_tag>; /// @brief /// @@ -177,7 +178,7 @@ struct xetla_row_reduction_t< int global_start_x_in = item.get_group(2) * wg_tile_n + sg_idx * sg_tile_n; int global_start_y_in = sg_idy * sg_tile_m; - xetla_nbarrier_t nbarrier; + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier( nbarrier_base + sg_idx, nbarrier_role::producer_consumer); if constexpr (use_dynamic_job) { diff --git a/include/group/cooperative_reduction.hpp b/include/group/cooperative_reduction.hpp index b5ab96d23..66784c413 100644 --- a/include/group/cooperative_reduction.hpp +++ b/include/group/cooperative_reduction.hpp @@ -54,7 +54,7 @@ class cooperative_reduce_t< matAcc_t, num_cooperative_wg, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape = tile_shape_; @@ -125,7 +125,8 @@ class cooperative_reduce_t< public: using mat_slice_t = subgroup::tile_t; - static constexpr uint32_t barrier_count = work_group_size; + static constexpr uint32_t barrier_count = + arch_has_named_barrier ? work_group_size : 0; static constexpr uint32_t slm_size = wg_tile_size * num_cooperative_wg; uint32_t coop_id; @@ -221,7 +222,7 @@ class cooperative_reduce_t< matAcc_t, 1, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape = tile_shape_; diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index ab149396a..41fae537d 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -35,7 +35,7 @@ class epilogue_t< epilogue_policy_default, tile_shape_, mem_desc_c_t_, - std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_default; using tile_shape = tile_shape_; @@ -113,7 +113,7 @@ class epilogue_t< tile_shape_, mem_desc_c_t_, std::enable_if_t<( - (arch_tag_ <= gpu_arch::XeHpc) && (mem_desc_c_t_::dim == 4))>> { + valid_xe_arch_tag && (mem_desc_c_t_::dim == 4))>> { public: using epilogue_policy = epilogue_policy_default; using tile_shape = tile_shape_; diff --git a/include/group/epilogue/impl/quant_tile_op_xe.hpp b/include/group/epilogue/impl/quant_tile_op_xe.hpp index 2648fd77a..4cefae147 100644 --- a/include/group/epilogue/impl/quant_tile_op_xe.hpp +++ b/include/group/epilogue/impl/quant_tile_op_xe.hpp @@ -47,7 +47,7 @@ class epilogue_t< dtype_dequant_>, tile_shape_, mem_desc_c_t_, - std::enable_if_t<(arch_tag_ == gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_quant_op< dequant_op_t_, diff --git a/include/group/epilogue/impl/stream_k_op_xe.hpp b/include/group/epilogue/impl/stream_k_op_xe.hpp index ea60bee2f..b109d854f 100644 --- a/include/group/epilogue/impl/stream_k_op_xe.hpp +++ b/include/group/epilogue/impl/stream_k_op_xe.hpp @@ -34,7 +34,7 @@ template < typename epilogue_t_, typename mem_desc_d_t_, typename mem_desc_atomic_sync_t_, - gpu_arch arch_tag_ = gpu_arch::XeHpc> + gpu_arch arch_tag_> struct epilogue_stream_k_t { static constexpr gpu_arch arch_tag = arch_tag_; using epilogue_t = epilogue_t_; diff --git a/include/group/epilogue/impl/tile_op_xe.hpp b/include/group/epilogue/impl/tile_op_xe.hpp index 656cdabde..b4de45ce7 100644 --- a/include/group/epilogue/impl/tile_op_xe.hpp +++ b/include/group/epilogue/impl/tile_op_xe.hpp @@ -39,7 +39,7 @@ class epilogue_t< epilogue_policy_tile_op, tile_shape_, mem_desc_c_t_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_tile_op; using tile_op_t = typename epilogue_policy::tile_op_t; diff --git a/include/group/epilogue/impl/unaligned_xe.hpp b/include/group/epilogue/impl/unaligned_xe.hpp index b0a0eb8cc..c8ca5385a 100644 --- a/include/group/epilogue/impl/unaligned_xe.hpp +++ b/include/group/epilogue/impl/unaligned_xe.hpp @@ -35,7 +35,7 @@ class epilogue_t< epilogue_policy_unaligned, tile_shape_, mem_desc_c_t_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_unaligned; using tile_shape = tile_shape_; diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 0a0cd1c91..f089f84d6 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -33,7 +33,7 @@ namespace gpu::xetla::group { template < typename compute_attr_, typename perf_tuning_knob_, - gpu_arch arch_tag_ = gpu_arch::XeHpc, + gpu_arch arch_tag_, typename enable = void> struct compute_policy_default_xmx {}; @@ -59,8 +59,8 @@ struct compute_policy_default_xmx< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr uint32_t block_size_y_a = 16; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem; static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); @@ -92,7 +92,7 @@ struct compute_policy_unaligned_xmx : public compute_policy_default_xmx< template < typename compute_attr_, typename perf_tuning_knob_, - gpu_arch arch_tag_ = gpu_arch::XeHpc, + gpu_arch arch_tag_, typename enable = void> struct compute_policy_default_fpu {}; @@ -118,13 +118,12 @@ struct compute_policy_default_fpu< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr uint32_t block_size_y_a = - arch_tag_ == gpu_arch::XeLpg ? 8 : 16; - static constexpr uint32_t block_bytes_x_a = 32; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem; + static constexpr uint32_t block_bytes_x_a = mma_attr::blk_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_bytes_x_b = - arch_attr_t::template register_attr<>::reg_in_bytes; + static constexpr uint32_t block_bytes_x_b = mma_attr::blk_n_in_bytes; static constexpr uint32_t block_size_x_b = block_bytes_x_b / sizeof(dtype_mma_b); static constexpr uint32_t block_size_y_b = block_size_x_a; diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index add8e6790..9af4e366f 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -55,8 +55,9 @@ class gemm_t< static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; using work_group_t = typename tile_shape::work_group_t; - constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + constexpr static gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; @@ -200,10 +201,18 @@ class gemm_t< static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr bool need_local_fence = + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; public: static constexpr uint32_t barrier_count = - enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + (enable_periodic_sync && arch_has_named_barrier) ? + barrier_count_x + barrier_count_y : 0; + // current no slm path static constexpr uint32_t slm_size = 0; @@ -283,6 +292,57 @@ class gemm_t< } }; + inline void periodic_sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (enable_periodic_sync) { + if constexpr (arch_has_named_barrier) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } else { + barrier_all.init_nbarrier(nbarrier_base, nbarrier_role::producer_consumer); + } + } + } + + inline void periodic_sync_arrive(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else { + barrier_all.arrive(); + } + } + } + } + + inline void periodic_sync_wait(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else { + barrier_all.wait(); + } + } + } + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -306,18 +366,14 @@ class gemm_t< "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") - __XETLA_API static void release(uint8_t nbarrier_id = 0) { - static constexpr bool need_local_fence = - (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + __XETLA_API void release(uint8_t nbarrier_id = 0) { if constexpr (need_local_fence) { xetla_fence(); } xetla_fence(); - static constexpr uint32_t wg_size = wg_size_x * wg_size_y; if constexpr (wg_size > 1) { - xetla_nbarrier_t nbarrier; - nbarrier.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); - nbarrier.arrive_wait(); + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); } } @@ -346,13 +402,9 @@ class gemm_t< matB_payload_t matB_payload(args.matB_base_desc); matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, 0); matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, 0); - xetla_nbarrier_t nbarrier_a; - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - xetla_nbarrier_t nbarrier_b; - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); + + periodic_sync_init(sg_idx, sg_idy, nbarrier_base); + #pragma unroll for (uint32_t i = 0; i < stages; i++) { subgroup::tile_prefetch( @@ -366,18 +418,9 @@ class gemm_t< } for (uint32_t i = 0; i < args.inner_loop_count; i++) { - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.arrive(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.arrive(); - } - } - } + periodic_sync_arrive(i); SW_BARRIER(); + subgroup::tile_load( matA, matA_payload); subgroup::tile_load( @@ -402,17 +445,7 @@ class gemm_t< SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.wait(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.wait(); - } - } - } + periodic_sync_wait(i); } SW_BARRIER(); } diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index 75a0ef79c..c7e7856e7 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -57,7 +57,7 @@ class gemm_t< static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; using work_group_t = typename tile_shape::work_group_t; - constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + constexpr static gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; @@ -65,6 +65,7 @@ class gemm_t< static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; private: + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; /******** set data type **********/ using dtype_a = typename mem_desc_a_t::dtype; using dtype_b = typename mem_desc_b_t::dtype; @@ -186,10 +187,18 @@ class gemm_t< static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr bool need_local_fence = + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; public: static constexpr uint32_t barrier_count = - enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + (enable_periodic_sync && arch_has_named_barrier) + ? barrier_count_x + barrier_count_y + : 0; static constexpr uint32_t slm_size = 0; @@ -269,6 +278,58 @@ class gemm_t< } }; + inline void periodic_sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (enable_periodic_sync) { + if constexpr (arch_has_named_barrier) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } else { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + } + + inline void periodic_sync_arrive(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else { + barrier_all.arrive(); + } + } + } + } + + inline void periodic_sync_wait(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else { + barrier_all.wait(); + } + } + } + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -292,18 +353,14 @@ class gemm_t< "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") - __XETLA_API static void release(uint8_t nbarrier_id = 0) { - static constexpr bool need_local_fence = - (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + __XETLA_API void release(uint8_t nbarrier_id = 0) { if constexpr (need_local_fence) { xetla_fence(); } xetla_fence(); - static constexpr uint32_t wg_size = wg_size_x * wg_size_y; if constexpr (wg_size > 1) { - xetla_nbarrier_t nbarrier; - nbarrier.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); - nbarrier.arrive_wait(); + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); } } @@ -324,10 +381,10 @@ class gemm_t< int32_t sg_idy = g.get_id() / wg_size_x; XETLA_ASSERT( - g.get_id() < (wg_size_x * wg_size_y), + g.get_id() < wg_size, "Thread id(%d) should less than wg_size(%d)", g.get_id(), - wg_size_x * wg_size_y); + wg_size); update_sg_tile_tdesc(args, sg_idx, sg_idy); pre_processing_t pre_processing; @@ -339,13 +396,8 @@ class gemm_t< matB_payload_t matB_payload(args.matB_base_desc); matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, sg_idx); matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, sg_idy); - xetla_nbarrier_t nbarrier_a; - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - xetla_nbarrier_t nbarrier_b; - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); + + periodic_sync_init(sg_idx, sg_idy, nbarrier_base); #pragma unroll for (uint32_t i = 0; i < stages; i++) { @@ -360,17 +412,8 @@ class gemm_t< } for (uint32_t i = 0; i < args.inner_loop_count; i++) { - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.arrive(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.arrive(); - } - } - } + periodic_sync_arrive(i); + subgroup::tile_load( matB, matB_payload); subgroup::tile_load( @@ -398,18 +441,9 @@ class gemm_t< pre_processing(matA_acc, matB_acc, matA, matB); SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.wait(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.wait(); - } - } - } + periodic_sync_wait(i); } SW_BARRIER(); } diff --git a/include/group/gemm/impl/pre_processing_xe.hpp b/include/group/gemm/impl/pre_processing_xe.hpp index 053cf9aeb..588622ef6 100644 --- a/include/group/gemm/impl/pre_processing_xe.hpp +++ b/include/group/gemm/impl/pre_processing_xe.hpp @@ -32,7 +32,7 @@ template class pre_processing_default_t< tile_shape_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using tile_shape = tile_shape_; using work_group_t = typename tile_shape::work_group_t; @@ -67,7 +67,7 @@ template class pre_processing_matA_neg_filter_t< tile_shape_, arch_tag, - std::enable_if_t<(arch_tag == gpu_arch::XeHpc)>> { + std::enable_if_t>> { using tile_shape = tile_shape_; using work_group_t = typename tile_shape::work_group_t; diff --git a/include/group/gemm/impl/unaligned_xmx_xe.hpp b/include/group/gemm/impl/unaligned_xmx_xe.hpp index 19e4c89b3..1e2f134a6 100644 --- a/include/group/gemm/impl/unaligned_xmx_xe.hpp +++ b/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -59,6 +59,7 @@ class gemm_t< static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; using work_group_t = typename tile_shape::work_group_t; constexpr static gpu_arch arch_tag = arch_tag_; @@ -76,7 +77,7 @@ class gemm_t< using dtype_mma_a = typename compute_policy::dtype_mma_a; using dtype_mma_b = typename compute_policy::dtype_mma_b; - using check_dtype = group::gemm::default_xmx:: + using check_dtype = group::gemm::default_xmx:: template check_dtype_default; /******** set memory attribute **********/ @@ -91,7 +92,7 @@ class gemm_t< is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; using check_memory = - group::gemm::default_xmx::template check_memory_default< + group::gemm::default_xmx::template check_memory_default< mem_layout_a, mem_layout_b, mem_space_a, @@ -116,7 +117,7 @@ class gemm_t< static constexpr uint32_t block_size_y_b = compute_policy::block_size_y_b; using check_tile_size = - group::gemm::default_xmx::template check_tile_size_default< + group::gemm::default_xmx::template check_tile_size_default< dtype_mma_a, tile_size_x_a, tile_size_y_a, @@ -129,6 +130,9 @@ class gemm_t< /******** set tile **********/ static constexpr reg_layout reg_layout_a = reg_layout::tiled; + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, tile_size_y_a, @@ -142,7 +146,7 @@ class gemm_t< matA_t, mem_layout_a, tile_shape::wg_size_x, - gpu_arch::XeHpc>; + arch_tag>; using cooperative_tile_desc_A_t = typename cooperative_helper_A_t::co_tile_desc_t; using partial_matA_t = subgroup::tile_t; @@ -184,7 +188,7 @@ class gemm_t< matB_t, mem_layout_b, tile_shape::wg_size_y, - gpu_arch::XeHpc>; + arch_tag>; using cooperative_tile_desc_B_t = typename cooperative_helper_B_t::co_tile_desc_t; @@ -242,9 +246,11 @@ class gemm_t< static constexpr uint32_t slm_size_b = wg_size_x * tile_size_b; public: - static constexpr uint32_t barrier_count = barrier_count_x + barrier_count_y; + static constexpr uint32_t barrier_count = + arch_has_named_barrier ? barrier_count_x + barrier_count_y : 0; static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic; + static_assert(slm_size <= arch_attr_t::local_mem_size); static constexpr uint32_t slm_base_a = 0; static constexpr uint32_t slm_base_b = 0 + slm_size_a * num_cyclic; @@ -350,18 +356,16 @@ class gemm_t< "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") - __XETLA_API static void release(uint8_t nbarrier_id = 0) { + __XETLA_API void release(uint8_t nbarrier_id = 0) { static constexpr bool need_local_fence = (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); if constexpr (need_local_fence) { xetla_fence(); } xetla_fence(); - static constexpr uint32_t wg_size = wg_size_x * wg_size_y; if constexpr (wg_size > 1) { - xetla_nbarrier_t nbarrier; - nbarrier.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); - nbarrier.arrive_wait(); + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); } } @@ -426,10 +430,11 @@ class gemm_t< matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, sg_idx); matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, sg_idy); - xetla_nbarrier_t nbarrier_a; + xetla_nbarrier_t nbarrier_a; nbarrier_a.init_nbarrier( sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - xetla_nbarrier_t nbarrier_b; + + xetla_nbarrier_t nbarrier_b; nbarrier_b.init_nbarrier( sg_idx + barrier_count_y + nbarrier_base, nbarrier_role::producer_consumer); @@ -444,9 +449,11 @@ class gemm_t< matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); xetla_fence(); + nbarrier_a.arrive(); - if (arch_tag >= gpu_arch::XeHpc) + if constexpr (arch_has_named_barrier) nbarrier_b.arrive(); + #pragma unroll for (uint32_t i = 1; i < num_cyclic - 1; i++) { tile_load(partial_matA, matA_payload); @@ -469,6 +476,7 @@ class gemm_t< matA_t::tile_size_x * (num_cyclic - 1)); matB_prefetch_payload.template update_tdesc( matB_t::tile_size_y * (num_cyclic - 1)); + #pragma unroll for (uint32_t i = 0; i < stages; i++) { subgroup::tile_prefetch( @@ -496,7 +504,7 @@ class gemm_t< } nbarrier_a.wait(); - if (arch_tag >= gpu_arch::XeHpc) + if constexpr (arch_has_named_barrier) nbarrier_b.wait(); tile_load(matA, matA_local_ld_payload); @@ -525,7 +533,7 @@ class gemm_t< } nbarrier_a.arrive(); - if (arch_tag >= gpu_arch::XeHpc) + if constexpr (arch_has_named_barrier) nbarrier_b.arrive(); SW_BARRIER(); matA_acc_t matA_acc; @@ -555,7 +563,7 @@ class gemm_t< } SW_BARRIER(); nbarrier_a.wait(); - if (arch_tag >= gpu_arch::XeHpc) + if constexpr (arch_has_named_barrier) nbarrier_b.wait(); } diff --git a/include/group/global_reduction.hpp b/include/group/global_reduction.hpp index 0fe83d6a8..8289ab42a 100644 --- a/include/group/global_reduction.hpp +++ b/include/group/global_reduction.hpp @@ -64,7 +64,7 @@ class global_reduce_t< num_group_reduction, counter_size, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape_acc = tile_shape_acc_; @@ -225,7 +225,7 @@ class global_reduce_t< 1, counter_size_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape_acc = tile_shape_acc_; diff --git a/include/group/reduction/reduction_xe.hpp b/include/group/reduction/reduction_xe.hpp index ee39acc80..1f6592487 100644 --- a/include/group/reduction/reduction_xe.hpp +++ b/include/group/reduction/reduction_xe.hpp @@ -31,8 +31,9 @@ template < uint32_t N_SG, bool is_all_reduce> struct group_reduce_t { - group_reduce_t sg_reduce{}; - xetla_nbarrier_t nbarrier; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + group_reduce_t sg_reduce{}; + xetla_nbarrier_t nbarrier; uint32_t slm_base; uint32_t sg_id; using local_st_tile_desc = @@ -45,12 +46,12 @@ struct group_reduce_t { mem_desc_t, local_ld_tile_desc, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using local_st_payload_t = subgroup::mem_payload_t< mem_desc_t, local_st_tile_desc, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; inline group_reduce_t() = default; inline group_reduce_t( uint32_t sg_id_, diff --git a/include/group/softmax/impl/softmax_bwd_xe.hpp b/include/group/softmax/impl/softmax_bwd_xe.hpp index dd31f9f1e..d4b419a2b 100644 --- a/include/group/softmax/impl/softmax_bwd_xe.hpp +++ b/include/group/softmax/impl/softmax_bwd_xe.hpp @@ -55,7 +55,7 @@ class softmax_t< reduce_op::sum, wg_size_x, true, - gpu_arch::XeHpc>; + arch_tag>; public: struct arguments_t { @@ -106,7 +106,7 @@ class softmax_t< mem_desc_in_t, mat_in_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; int32_t sg_idx = g.get_id() % wg_size_x; int32_t sg_idy = g.get_id() / wg_size_x; diff --git a/include/group/softmax/impl/softmax_fwd_xe.hpp b/include/group/softmax/impl/softmax_fwd_xe.hpp index c5c175855..15632f94c 100644 --- a/include/group/softmax/impl/softmax_fwd_xe.hpp +++ b/include/group/softmax/impl/softmax_fwd_xe.hpp @@ -53,7 +53,7 @@ class softmax_t, tile_shape_> { reduce_op::max, wg_size_x, true, - gpu_arch::XeHpc>; + arch_tag>; using wg_reduce_sum_t = group_reduce_t< dtype_acc, 1, @@ -61,7 +61,7 @@ class softmax_t, tile_shape_> { reduce_op::sum, wg_size_x, true, - gpu_arch::XeHpc>; + arch_tag>; public: struct get_barrier_count { diff --git a/include/group/softmax/softmax_policy.hpp b/include/group/softmax/softmax_policy.hpp index db3123523..278e2879d 100644 --- a/include/group/softmax/softmax_policy.hpp +++ b/include/group/softmax/softmax_policy.hpp @@ -23,13 +23,10 @@ namespace gpu::xetla::group { -template +template struct softmax_policy_fwd {}; -template < - typename dtype_in, - typename dtype_acc, - gpu_arch arch_tag_ = gpu_arch::XeHpc> +template struct softmax_policy_bwd {}; } // namespace gpu::xetla::group diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 0635aaec6..189514808 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -37,7 +37,7 @@ template < mem_layout mem_layout_c, uint32_t alignment_c, typename dtype_acc, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> struct default_gemm_config_t : param_adaptor< @@ -69,7 +69,7 @@ template < mem_layout mem_layout_c, uint32_t alignment_c, typename dtype_acc, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> using default_gemm_t = typename default_gemm_config_t< dtype_a, @@ -170,7 +170,7 @@ template < typename dtype_acc, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> struct default_gemm_selector_config_t : param_adaptor< @@ -204,7 +204,7 @@ template < typename dtype_acc, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> using default_gemm_selector_t = typename default_gemm_selector_config_t< dtype_a, @@ -228,7 +228,7 @@ template < mem_space mem_space_c, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> struct default_epilogue_selector_config_t : param_adaptor< @@ -252,7 +252,7 @@ template < mem_space mem_space_c, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> using default_epilogue_selector_t = typename default_epilogue_selector_config_t< dtype_c, diff --git a/include/kernel/gemm/dispatch_policy.hpp b/include/kernel/gemm/dispatch_policy.hpp index 28eaf7840..6432082ec 100644 --- a/include/kernel/gemm/dispatch_policy.hpp +++ b/include/kernel/gemm/dispatch_policy.hpp @@ -148,7 +148,7 @@ struct dispatch_policy_kslicing { /// and performs inter-group reduction. Implementation loosely based on this /// paper - https://arxiv.org/pdf/2301.03598.pdf /// @tparam arch_tag_ Is the HW architecture. -template +template struct dispatch_policy_stream_k { static constexpr gpu_arch arch_tag = arch_tag_; diff --git a/include/kernel/gemm/gemm_preset.hpp b/include/kernel/gemm/gemm_preset.hpp index f24af3dba..585f46924 100644 --- a/include/kernel/gemm/gemm_preset.hpp +++ b/include/kernel/gemm/gemm_preset.hpp @@ -46,7 +46,7 @@ using param_performance_default = dict_t< elem_v_t, elem_v_t>; -template +template using param_runtime_default = dict_t< elem_v_t, elem_v_t, @@ -61,7 +61,8 @@ using param_runtime_default = dict_t< tune_key::group_swizzle_policy, kernel::group_swizzle_default>>; } // namespace detail -template + +template using default_param_t = dict_t<>::template update_dict_t< detail::param_dtype_bf16_bf16_bf16>:: template update_dict_t::template update_dict_t< @@ -88,7 +89,7 @@ using default_param_t = dict_t<>::template update_dict_t< param_optimizer_level>>; namespace kernel { -template +template using param_kslicing_g1l1_t = default_param_t::template update_t< elem_v_t, elem_v_t, @@ -99,7 +100,7 @@ using param_kslicing_g1l1_t = default_param_t::template update_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing>>; -template +template using param_kslicing_g2l1_t = default_param_t::template update_t< elem_v_t, elem_v_t, @@ -110,7 +111,7 @@ using param_kslicing_g2l1_t = default_param_t::template update_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing>>; -template +template using param_kslicing_g1l2_t = default_param_t::template update_t< elem_v_t, elem_v_t, @@ -124,7 +125,7 @@ using param_kslicing_g1l2_t = default_param_t::template update_t< } // namespace kernel namespace group { -template +template using param_dict1_wg_t = default_param_t::template update_t< elem_t_t, elem_t_t>, diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index cb6c5270b..2f51ccc39 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -37,7 +37,7 @@ class gemm_universal_t< dispatch_policy_default, gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index 7b74226e5..3dc37e731 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -48,7 +48,7 @@ class gemm_universal_t< num_local_kslicing_>, gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; diff --git a/include/subgroup/cooperative_load_helper.hpp b/include/subgroup/cooperative_load_helper.hpp index ab69e5bc1..79ddc69b9 100644 --- a/include/subgroup/cooperative_load_helper.hpp +++ b/include/subgroup/cooperative_load_helper.hpp @@ -46,7 +46,7 @@ class cooperative_load_helper_t< mem_layout::row_major, num_cooperative_wg, arch_tag_, - std::enable_if_t> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using matAcc_t = matAcc_t_; @@ -112,7 +112,7 @@ class cooperative_load_helper_t< mem_layout::col_major, num_cooperative_wg, arch_tag_, - std::enable_if_t> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using matAcc_t = matAcc_t_; diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 9385c700f..254750b92 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -352,8 +352,8 @@ struct get_load_block_size_auto< mem_layout::row_major, reg_layout::tiled> { private: - using load_store_attr = arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; static constexpr uint32_t max_load_width_in_bytes = @@ -389,8 +389,8 @@ struct get_store_block_size_auto< mem_layout::row_major, reg_layout::tiled> { private: - using load_store_attr = arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_height_in_elem = load_store_attr::max_store_height_in_elem; static constexpr uint32_t max_store_width_in_bytes = diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index c1ca0c6ff..9cf308bb0 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -103,8 +103,8 @@ struct tile_mma_t< static constexpr int32_t num_block_m = matDst_t::num_block_y; static constexpr int32_t num_block_k = tile_size_k / block_size_k; - static constexpr int32_t mma_m = - register_attr::acc_reg_in_bytes / (block_size_n * sizeof(dtype_dst)); + using mma_attr = mma_attr_t; + static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; template __XETLA_API static void mma_core( @@ -112,36 +112,39 @@ struct tile_mma_t< xetla_vector_ref __REF__ src, xetla_vector_ref __REF__ b_block, xetla_vector_ref __REF__ a_block) { + constexpr uint32_t blk_m_iters = blk_m / mma_m; + constexpr uint32_t tail_m = blk_m % mma_m; auto dst_blk_2d = dst.xetla_format(); auto src_blk_2d = src.xetla_format(); auto b_blk_2d = b_block.xetla_format(); + if constexpr (blk_m_iters > 0) { #pragma unroll - for (uint32_t i = 0; i < blk_m / mma_m; i++) { - xetla_vector dst_tmp; - auto dst_tmp_2d = dst_tmp.xetla_format(); + for (uint32_t i = 0; i < blk_m_iters; i++) { + xetla_vector dst_tmp; + auto dst_tmp_2d = dst_tmp.xetla_format(); #pragma unroll - for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - dst_tmp_2d.row(i_acc) = a_block[i_acc + i * mma_m] * b_blk_2d.row(0) + + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + dst_tmp_2d.row(i_acc) = a_block[i_acc + i * mma_m] * b_blk_2d.row(0) + src_blk_2d.row(i_acc + i * mma_m); - } + } #pragma unroll - for (uint32_t k = 1; k < blk_k - 1; k++) { - for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - int a_offset = k * blk_m + i_acc + i * mma_m; - dst_tmp_2d.row(i_acc) += a_block[a_offset] * b_blk_2d.row(k); + for (uint32_t k = 1; k < blk_k - 1; k++) { + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = k * blk_m + i_acc + i * mma_m; + dst_tmp_2d.row(i_acc) += a_block[a_offset] * b_blk_2d.row(k); + } } - } - for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m; - dst_blk_2d.row(i_acc + i * mma_m) = + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m; + dst_blk_2d.row(i_acc + i * mma_m) = a_block[a_offset] * b_blk_2d.row(blk_k - 1) + dst_tmp_2d.row(i_acc); + } + SW_BARRIER(); } - SW_BARRIER(); } - if constexpr ((blk_m % mma_m) != 0) { - constexpr uint32_t tail_start_m = blk_m / mma_m * mma_m; - constexpr uint32_t tail_m = blk_m % mma_m; + if constexpr (tail_m != 0) { + constexpr uint32_t tail_start_m = blk_m_iters * mma_m; xetla_vector dst_tmp; auto dst_tmp_2d = dst_tmp.xetla_format(); #pragma unroll @@ -172,20 +175,22 @@ struct tile_mma_t< matA_t& a) { { // k_blk=0 auto b_reg = b.reg.xetla_select(0); + if constexpr (tile_size_m >= block_size_m) { #pragma unroll - for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { - auto a_block = a.reg.xetla_select( + for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { + auto a_block = a.reg.xetla_select( i * num_block_k * a_block_elems); #pragma unroll - for (uint32_t j = 0; j < num_block_n; j++) { - auto b_block = + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = b_reg.xetla_select(j * b_block_elems); - auto src_block = src.reg.xetla_select( + auto src_block = src.reg.xetla_select( (i * num_block_n + j) * block_elems); - auto dst_block = dst.reg.xetla_select( + auto dst_block = dst.reg.xetla_select( (i * num_block_n + j) * block_elems); - mma_core( + mma_core( dst_block, src_block, b_block, a_block); + } } } diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 216a57d96..8e059d3b3 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -106,12 +106,11 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr bool mem_transform = payload_t::mem_transform; - using load_store_attr = typename arch_attr_t< - arch_tag>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t elems_per_CL = load_store_attr::cache_line_size_in_bytes / sizeof(dtype); static constexpr uint32_t elems_per_reg = - arch_attr_t::template register_attr<>::reg_in_bytes / + register_bytes_t::reg_in_bytes / sizeof(dtype); static constexpr int32_t max_load_block_height = load_store_attr::max_load_height_in_elem; diff --git a/include/subgroup/tile/impl/mma_xe.hpp b/include/subgroup/tile/impl/mma_xe.hpp index 371040397..37eb392a1 100644 --- a/include/subgroup/tile/impl/mma_xe.hpp +++ b/include/subgroup/tile/impl/mma_xe.hpp @@ -105,7 +105,7 @@ struct tile_mma_t< static constexpr int32_t num_block_mma_b = b_block_size_y / block_size_k; static constexpr uint32_t b_block_mma_elems = b_block_elems / num_block_mma_b; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; static constexpr int32_t mma_k = mma_attr::mma_k_in_bytes / sizeof(uint32_t); @@ -120,74 +120,78 @@ struct tile_mma_t< matA_t& a) { constexpr int32_t a_mma_elems = mma_m * a_block_size_x; constexpr int32_t c_mma_elems = mma_m * block_size_n; + constexpr uint32_t blk_m_iters = tile_size_m / block_size_m; + constexpr uint32_t tail_block_size_m = tile_size_m % block_size_m; #pragma unroll for (uint32_t j = 0; j < num_block_n; j++) { + if constexpr (blk_m_iters > 0) { #pragma unroll - for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { - auto src_block = src.reg.xetla_select( - (i * num_block_n + j) * block_elems); - auto dst_block = dst.reg.xetla_select( - (i * num_block_n + j) * block_elems); + for (uint32_t i = 0; i < blk_m_iters; i++) { + auto src_block = src.reg.xetla_select( + (i * num_block_n + j) * block_elems); + auto dst_block = dst.reg.xetla_select( + (i * num_block_n + j) * block_elems); #pragma unroll - for (uint32_t mma_i = 0; mma_i < block_size_m / mma_m; mma_i++) { - auto src_sub_blk = - src_block.xetla_select(mma_i * c_mma_elems); - auto dst_sub_blk = - dst_block.xetla_select(mma_i * c_mma_elems); - { // k=0 - auto a_block = a.reg.xetla_select( - (i * num_block_k) * a_block_elems); - auto a_sub_blk = - a_block.xetla_select(mma_i * a_mma_elems); - auto b_blk = - b.reg.xetla_select(j * b_block_elems); - auto b_sub_blk = b_blk.xetla_select(0); - dst_sub_blk = xetla_mma< - gpu::xetla::detail::mma_argument_type(), - gpu::xetla::detail::mma_argument_type(), - mma_k, - mma_m, - dtype_src, - dtype_b, - dtype_a, - c_mma_elems, - b_block_mma_elems, - a_mma_elems>(src_sub_blk, b_sub_blk, a_sub_blk); - } + for (uint32_t mma_i = 0; mma_i < block_size_m / mma_m; mma_i++) { + auto src_sub_blk = + src_block.xetla_select(mma_i * c_mma_elems); + auto dst_sub_blk = + dst_block.xetla_select(mma_i * c_mma_elems); + + { // k=0 + auto a_block = a.reg.xetla_select( + (i * num_block_k) * a_block_elems); + auto a_sub_blk = + a_block.xetla_select(mma_i * a_mma_elems); + auto b_blk = + b.reg.xetla_select(j * b_block_elems); + auto b_sub_blk = b_blk.xetla_select(0); + dst_sub_blk = xetla_mma< + gpu::xetla::detail::mma_argument_type(), + gpu::xetla::detail::mma_argument_type(), + mma_k, + mma_m, + dtype_src, + dtype_b, + dtype_a, + c_mma_elems, + b_block_mma_elems, + a_mma_elems>(src_sub_blk, b_sub_blk, a_sub_blk); + } #pragma unroll - for (uint32_t k = 1; k < num_block_k; k++) { - auto a_block = a.reg.xetla_select( - (i * num_block_k + k) * a_block_elems); - auto a_sub_blk = - a_block.xetla_select(mma_i * a_mma_elems); - int inter_k_b = k / num_block_mma_b; - int inner_k_b = k % num_block_mma_b; - auto b_blk = b.reg.xetla_select( - (j + inter_k_b * num_block_n) * b_block_elems); - auto b_sub_blk = b_blk.xetla_select( - inner_k_b * b_block_mma_elems); - dst_sub_blk = xetla_mma< - gpu::xetla::detail::mma_argument_type(), - gpu::xetla::detail::mma_argument_type(), - mma_k, - mma_m, - dtype_src, - dtype_b, - dtype_a, - c_mma_elems, - b_block_mma_elems, - a_mma_elems>(dst_sub_blk, b_sub_blk, a_sub_blk); + for (uint32_t k = 1; k < num_block_k; k++) { + auto a_block = a.reg.xetla_select( + (i * num_block_k + k) * a_block_elems); + auto a_sub_blk = + a_block.xetla_select(mma_i * a_mma_elems); + int inter_k_b = k / num_block_mma_b; + int inner_k_b = k % num_block_mma_b; + auto b_blk = b.reg.xetla_select( + (j + inter_k_b * num_block_n) * b_block_elems); + auto b_sub_blk = b_blk.xetla_select( + inner_k_b * b_block_mma_elems); + dst_sub_blk = xetla_mma< + gpu::xetla::detail::mma_argument_type(), + gpu::xetla::detail::mma_argument_type(), + mma_k, + mma_m, + dtype_src, + dtype_b, + dtype_a, + c_mma_elems, + b_block_mma_elems, + a_mma_elems>(dst_sub_blk, b_sub_blk, a_sub_blk); + } } } } - if constexpr ((tile_size_m % block_size_m) != 0) { - constexpr uint32_t tail_block_size_m = tile_size_m % block_size_m; + + if constexpr (tail_block_size_m != 0) { constexpr uint32_t tail_block_elems = block_size_n * tail_block_size_m; constexpr uint32_t a_tail_block_elems = tail_block_size_m * a_block_size_x; - constexpr uint32_t tail_m_start = - tile_size_m / block_size_m * block_size_m; + constexpr uint32_t tail_m_start = blk_m_iters * block_size_m; constexpr uint32_t tail_elems_start = tail_m_start * tile_size_n; constexpr uint32_t a_tail_elems_start = tail_m_start * a_tile_size_x; auto src_block = src.reg.xetla_select( diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index c895614e0..e990bfaf6 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -44,7 +44,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_2d, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::XeHpc)>> { + std::enable_if_t>> { using tile_desc = tile_desc_; using mem_desc_t = mem_desc_t; @@ -403,7 +403,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_1d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -526,7 +526,7 @@ struct mem_payload_t< tile_desc_, msg_type::atomic_add, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -708,7 +708,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_1d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -839,7 +839,7 @@ struct mem_payload_t< tile_desc_, msg_type::unaligned_2d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1066,7 +1066,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_2d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpg)>> { + std::enable_if_t>> { using dtype = std::conditional_t, uint8_t, dtype_>; using mem_desc_t = @@ -1285,7 +1285,7 @@ struct mem_payload_t< tile_desc_, msg_type::scatter, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -1464,7 +1464,7 @@ struct mem_payload_t< reg_layout::vnni_tiled_col_major>, msg_type::scatter, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -1617,7 +1617,7 @@ struct prefetch_payload_t< num_coop_sg_, arch_tag_, std::enable_if_t<( - arch_tag_ <= gpu_arch::XeHpg && + (!arch_has_2d_load_store) && (tile_size_y_ != 1 || block_size_y_ != 1))>> { using dtype = dtype_; using mem_desc_t = @@ -1836,7 +1836,7 @@ struct prefetch_payload_t< num_coop_sg_, arch_tag_, std::enable_if_t< - (arch_tag_ == gpu_arch::XeHpc) && + (arch_has_2d_load_store) && (tile_size_y_ != 1 || block_size_y_ != 1)>> { using dtype = dtype_; using mem_desc_t = @@ -2119,7 +2119,7 @@ struct prefetch_payload_t< tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2226,7 +2226,7 @@ struct prefetch_payload_t< tile_desc_, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; diff --git a/include/subgroup/tile/impl/quant_op_functor.hpp b/include/subgroup/tile/impl/quant_op_functor.hpp index ffb3dab54..ce10dbad7 100644 --- a/include/subgroup/tile/impl/quant_op_functor.hpp +++ b/include/subgroup/tile/impl/quant_op_functor.hpp @@ -40,7 +40,7 @@ template struct dequant_op_t< tile_op_t_, arch_tag, - std::enable_if_t<(arch_tag == gpu_arch::XeHpc)>> { + std::enable_if_t>> { // may need to add some limitations to tile_op used in dequant_op using tile_op_t = tile_op_t_; struct arguments_t { @@ -72,7 +72,7 @@ template struct quant_op_t< tile_op_t_, arch_tag, - std::enable_if_t<(arch_tag == gpu_arch::XeHpc)>> { + std::enable_if_t>> { // may need to add some limitations to tile_op used in dequant_op using tile_op_t = tile_op_t_; diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 56196da6d..324fd57de 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -1020,7 +1020,7 @@ tile_store(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t j = 0; j < store_iter_steps; j++) { uint32_t offset_x = j * max_store_vec_len * scale_factor; - auto reg_sub = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x); + auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); xetla_store_local( payload.address + address_offset, diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 644717df8..879a6f7b8 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -213,7 +213,7 @@ template struct gelu_fwd_w_op_t< dtype_out_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_out = dtype_out_; using mem_desc_w_t = mem_desc_t; @@ -345,7 +345,7 @@ template struct gelu_bwd_op_t< dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_x_t = mem_desc_t; @@ -452,7 +452,7 @@ template struct bias_add_op_t< mem_desc_bias_t_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_bias_t = mem_desc_bias_t_; using dtype_bias = typename mem_desc_bias_t::dtype; using shape_t = typename mem_desc_bias_t::shape_t; @@ -556,7 +556,7 @@ struct scale_v_offset_v_op_t< scale_dtype_, offset_dtype_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using scale_dtype = scale_dtype_; using offset_dtype = offset_dtype_; @@ -691,7 +691,7 @@ template struct scale_v_op_t< scale_dtype_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using scale_dtype = scale_dtype_; using scale_mem_desc_t = @@ -798,7 +798,7 @@ struct elemwise_reduce_op_t< reduce_kind_, dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -916,7 +916,7 @@ struct elemwise_reduce_op_t< template < reduce_op reduce_kind, typename dtype_in, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, class enable = void> struct elemwise_reduce_op_stream_k_t {}; /// @brief Is the element-wise reduce op functor, specialized for Xe @@ -926,7 +926,7 @@ struct elemwise_reduce_op_stream_k_t< reduce_kind_, dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -1038,7 +1038,7 @@ template struct dropout_op_t< dtype_mask_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1125,7 +1125,7 @@ template struct rng_dropout_op_t< dtype_mask_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1239,7 +1239,7 @@ template struct scalar_mul_op_t< dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -1278,7 +1278,7 @@ template struct linear_op_t< dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 7e7896f3e..3674ee922 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -49,16 +49,16 @@ class TestBase { static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; }; -class Test0 : public TestBase { +class Test0x : public TestBase { public: static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; + static constexpr size_t wg_n = 16; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; - static constexpr size_t sg_k = 16; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; @@ -67,6 +67,28 @@ class Test0 : public TestBase { using data_type_b = fp16; using data_type_c = fp16; using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::xmx; +}; + +class Test0f : public TestBase { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 256; + static constexpr size_t mat_k = 256; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 16; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; }; class Test1 : public TestBase { @@ -75,9 +97,9 @@ class Test1 : public TestBase { static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; + static constexpr size_t wg_n = 16; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; @@ -127,15 +149,36 @@ class Test3 : public TestBase { using data_type_acc = float; }; -class Test4 : public TestBase { +class Test4f : public TestBase { public: static constexpr size_t mat_m = 1024; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 16 * 1; - static constexpr size_t wg_n = 32 * 32; - static constexpr size_t sg_m = 16; - static constexpr size_t sg_n = 32; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class Test4x : public TestBase { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index 26a78ff91..f13956cdc 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -41,8 +41,9 @@ template < gpu_arch gpu_arch> struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0 ; //8; - static constexpr uint32_t prefetch_distance = 0 ;//256 / (sg_k * sizeof(dtype_a)); + static constexpr uint32_t periodic_sync_interval = 0; // 8; + static constexpr uint32_t prefetch_distance = + 1; // 256 / (sg_k * sizeof(dtype_a)); using compute_attr = typename std::conditional< (engine == mma_engine::fpu), diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index 400d13276..6f9d0d490 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -32,25 +32,24 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); -using tests = ::testing::Types< - Test4>; - // Test1, - // Test2, - // Test3>; - // Test4, - // Test5, - // Test6, - // Test7, - // Test8, - // Test9, - // Test10, - // Test11, - // Test12, - // Test13, - // Test14, - // Test15, - // Test16, - // Test17, - // Test18, - // Test19>; +using tests = ::testing::Types; +// Test1, +// Test2, +// Test3>; +// Test4, +// Test5, +// Test6, +// Test7, +// Test8, +// Test9, +// Test10, +// Test11, +// Test12, +// Test13, +// Test14, +// Test15, +// Test16, +// Test17, +// Test18, +// Test19>; INSTANTIATE_TYPED_TEST_SUITE_P(fp16_gemm_test_suite, fp16_gemm_test, tests); diff --git a/tools/clang-format/clang-format.hook b/tools/clang-format/clang-format.hook index 16566e46d..c84f26595 100644 --- a/tools/clang-format/clang-format.hook +++ b/tools/clang-format/clang-format.hook @@ -21,7 +21,7 @@ format_file() { if [ -f $file ]; then - clang-format-12 -i ${STYLEARG} ${1} + clang-format -i ${STYLEARG} ${1} git add ${1}