From fd8f8d181f01109b0fcda9694b3b9b98386b674b Mon Sep 17 00:00:00 2001 From: "Chen, Jian Ping" Date: Mon, 29 Jul 2024 21:49:19 -0700 Subject: [PATCH] Remove hardcoded XeHpc and enhance multi target support on barrier --- examples/02_basic_gemm/basic_gemm.cpp | 6 +- examples/06_gemm_softmax/gemm_softmax.cpp | 8 +- .../softmax.hpp | 3 +- include/common/core/arch_config.hpp | 80 ++- include/common/core/common_types.hpp | 7 +- include/common/core/math_mma.hpp | 8 +- include/common/core/memory.hpp | 17 - include/common/utils/common.hpp | 11 +- include/common/utils/limitation.hpp | 8 +- include/common/utils/memory_descriptor.hpp | 2 +- include/common/utils/raw_send_load_store.hpp | 39 +- include/common/utils/raw_send_nbarrier.hpp | 27 +- .../experimental/group/dropout_mask_gen.hpp | 9 +- .../fused_op/layer_norm_fused_op_api.hpp | 37 - .../fused_op/layer_norm_fused_op_bwd_xe.hpp | 36 +- .../fused_op/layer_norm_fused_op_fwd_xe.hpp | 51 +- .../fused_op/row_reduction_fused_op_api.hpp | 20 - .../fused_op/row_reduction_fused_op_xe.hpp | 29 +- .../group/gemm/compute_policy.hpp | 21 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 8 +- .../group/reduction/reduction_api.hpp | 30 - .../group/reduction/row_reduce_store_xe.hpp | 29 +- .../kernel/data_transformer/api.hpp | 23 - .../data_transformer/data_transformer_xe.hpp | 25 +- .../experimental/kernel/layer_norm/api.hpp | 56 -- .../kernel/layer_norm/layer_norm_bwd_xe.hpp | 33 +- .../kernel/layer_norm/layer_norm_fwd_xe.hpp | 30 +- .../mha_core_attention/mha_attn_reg.hpp | 139 ++-- .../mha_core_attention/mha_core_attn.hpp | 78 +- include/experimental/kernel/reduction/api.hpp | 27 - .../kernel/reduction/row_reduction_xe.hpp | 28 +- include/group/cooperative_reduction.hpp | 7 +- include/group/epilogue/impl/default_xe.hpp | 4 +- .../group/epilogue/impl/quant_tile_op_xe.hpp | 2 +- .../group/epilogue/impl/stream_k_op_xe.hpp | 2 +- include/group/epilogue/impl/tile_op_xe.hpp | 2 +- include/group/epilogue/impl/unaligned_xe.hpp | 12 +- include/group/gemm/compute_policy.hpp | 29 +- include/group/gemm/gemm.hpp | 1 + include/group/gemm/impl/default_fpu_xe.hpp | 157 +++-- include/group/gemm/impl/default_xmx_xe.hpp | 168 +++-- include/group/gemm/impl/pre_processing_xe.hpp | 4 +- include/group/gemm/impl/selector_xe.hpp | 111 ++- include/group/gemm/impl/unaligned_fpu_xe.hpp | 666 ++++++++++++++++++ include/group/gemm/impl/unaligned_xmx_xe.hpp | 282 +++++--- include/group/global_reduction.hpp | 4 +- include/group/reduction/reduction.hpp | 1 - include/group/reduction/reduction_api.hpp | 4 +- include/group/reduction/reduction_xe.hpp | 28 +- include/group/softmax/impl/softmax_bwd_xe.hpp | 14 +- include/group/softmax/impl/softmax_fwd_xe.hpp | 10 +- include/group/softmax/softmax_policy.hpp | 7 +- include/group/tile_shape.hpp | 13 +- include/kernel/gemm/default_gemm.hpp | 54 +- include/kernel/gemm/dispatch_policy.hpp | 2 +- include/kernel/gemm/gemm_preset.hpp | 13 +- include/kernel/gemm/impl/default_xe.hpp | 2 +- include/kernel/gemm/impl/kslicing_xe.hpp | 2 +- include/subgroup/cooperative_load_helper.hpp | 4 +- include/subgroup/tile/api.hpp | 6 +- include/subgroup/tile/common.hpp | 22 +- include/subgroup/tile/impl/blk_mma.hpp | 406 +++++++++++ include/subgroup/tile/impl/fma_xe.hpp | 219 +++--- include/subgroup/tile/impl/load_xe.hpp | 65 +- include/subgroup/tile/impl/mma_xe.hpp | 118 ++-- include/subgroup/tile/impl/op_function.hpp | 1 + include/subgroup/tile/impl/payload_xe.hpp | 106 ++- .../subgroup/tile/impl/quant_op_functor.hpp | 4 +- include/subgroup/tile/impl/store_xe.hpp | 59 +- .../subgroup/tile/impl/tile_op_functor.hpp | 24 +- tests/integration/gemm/bf16/common.hpp | 222 ++++-- tests/integration/gemm/bf16/kernel_func.hpp | 18 +- tests/integration/gemm/bf16/main.cpp | 34 +- tests/integration/gemm/fp16/common.hpp | 302 ++++---- tests/integration/gemm/fp16/kernel_func.hpp | 9 +- tests/integration/gemm/fp16/main.cpp | 24 +- tests/integration/gemm/fp32/common.hpp | 11 +- tests/integration/gemm/fp32/main.cpp | 2 +- .../int4_dequantization_bias/main_client.cpp | 85 ++- tests/integration/gemm/int8/kernel_func.hpp | 4 +- .../gemm/int8_quantization/kernel_func.hpp | 4 +- .../gemm/unaligned_bf16/common.hpp | 473 +++++++++++-- .../gemm/unaligned_bf16/kernel_func.hpp | 25 +- .../integration/gemm/unaligned_bf16/main.cpp | 40 +- tests/integration/gemv/int4/main.cpp | 7 +- .../mlp/int4/int4_mlp_gate_mul_up_fwd.hpp | 4 +- tests/integration/mlp/int4/mlp.cpp | 12 +- .../softmax/softmax_bwd_kernel.hpp | 9 +- .../softmax/softmax_fwd_kernel.hpp | 14 +- tests/utils/execution.hpp | 1 - 90 files changed, 3273 insertions(+), 1587 deletions(-) create mode 100644 include/group/gemm/impl/unaligned_fpu_xe.hpp create mode 100644 include/subgroup/tile/impl/blk_mma.hpp diff --git a/examples/02_basic_gemm/basic_gemm.cpp b/examples/02_basic_gemm/basic_gemm.cpp index 8be0f6d7e..838bab930 100644 --- a/examples/02_basic_gemm/basic_gemm.cpp +++ b/examples/02_basic_gemm/basic_gemm.cpp @@ -114,8 +114,10 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { // wrap the nd_range to XeTLA range // Performance tuning setting based on different shapes - static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 1; + static constexpr uint32_t periodic_sync_interval = + (arch_tag == gpu_arch::XeHpc) ? 8 : 0; + static constexpr uint32_t prefetch_distance = + (arch_tag == gpu_arch::XeHpc) ? 3 : 1; // should larger than 8 static constexpr uint32_t k_stride = 32; diff --git a/examples/06_gemm_softmax/gemm_softmax.cpp b/examples/06_gemm_softmax/gemm_softmax.cpp index f74a421f6..ed4b4630e 100644 --- a/examples/06_gemm_softmax/gemm_softmax.cpp +++ b/examples/06_gemm_softmax/gemm_softmax.cpp @@ -200,6 +200,8 @@ void gemm_softmax_run(uint32_t iter) { static constexpr uint32_t prefetch_distance = 3; // should larger than 8 static constexpr uint32_t k_iter_num = 16; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + //static constexpr gpu_arch arch_tag = gpu_arch::XeHpg; // Step 1: define Micro-kernel's configuration using wg_shape = shape; @@ -227,7 +229,7 @@ void gemm_softmax_run(uint32_t iter) { data_type_sfx, // accumulator data type for intermediate results wg_shape, // computation tile shape k_iter_num, // elements in each iteration - gpu_arch::XeHpc, // GPU arch + arch_tag, // GPU arch tune_option>; using gemm_args_t = gemm_op_t::arguments_t; @@ -239,14 +241,14 @@ void gemm_softmax_run(uint32_t iter) { mem_space::global, // memory writing to global mem for C wg_shape, // computation tile shape k_iter_num, // elements in each iteration - gpu_arch::XeHpc, // GPU arch + arch_tag, // GPU arch tune_option>; // using experimental::group::softmax // define softmax forward op using tile_shape = typename gemm_op_t::tile_shape; using softmax_fwd_t = softmax_t< - softmax_policy_fwd, + softmax_policy_fwd, tile_shape>; using softmax_fwd_args_t = typename softmax_fwd_t::arguments_t; diff --git a/examples/08_scaled_dot_product_attention/softmax.hpp b/examples/08_scaled_dot_product_attention/softmax.hpp index 40f986a70..93dcf1157 100644 --- a/examples/08_scaled_dot_product_attention/softmax.hpp +++ b/examples/08_scaled_dot_product_attention/softmax.hpp @@ -60,7 +60,8 @@ struct xetla_softmax_fwd_t { using softmax_tile_desc_t = subgroup:: tile_desc_t; using softmax_load_t = subgroup::tile_t; - using mem_desc_in_t = mem_desc_t; + using mem_desc_in_t = + mem_desc_t; using softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_in_t, softmax_tile_desc_t, diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 5a8d37e88..67532e355 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -50,7 +50,7 @@ struct load_store_attr_t { static constexpr uint32_t special_prefetch_width_in_bytes = 64; static constexpr uint32_t cache_line_size_in_bytes = 64; - static constexpr uint32_t alignment_in_bytes = 8; + static constexpr uint32_t alignment_in_bytes = 16; }; template @@ -72,7 +72,7 @@ struct client_load_store_attr_base_t { static constexpr uint32_t special_prefetch_width_in_bytes = 64; static constexpr uint32_t cache_line_size_in_bytes = 64; - static constexpr uint32_t alignment_in_bytes = 8; + static constexpr uint32_t alignment_in_bytes = 4; }; template <> @@ -94,15 +94,21 @@ inline constexpr bool arch_has_2d_load_store = template struct load_store_attr_t { static constexpr uint32_t max_load_vec_len = 256; + static constexpr uint32_t max_aligned_load_vec_len = 256; static constexpr uint32_t max_store_vec_len = 256; + static constexpr uint32_t max_aligned_store_vec_len = 256; static constexpr uint32_t max_prefetch_vec_len = 32; + static constexpr uint32_t max_channel_num = 16; }; template <> struct load_store_attr_t { - static constexpr uint32_t max_load_vec_len = 512; - static constexpr uint32_t max_store_vec_len = 512; + static constexpr uint32_t max_load_vec_len = 256; + static constexpr uint32_t max_aligned_load_vec_len = 512; + static constexpr uint32_t max_store_vec_len = 256; + static constexpr uint32_t max_aligned_store_vec_len = 512; static constexpr uint32_t max_prefetch_vec_len = 64; + static constexpr uint32_t max_channel_num = 32; }; struct dpas_attr_base_t { @@ -112,6 +118,7 @@ struct dpas_attr_base_t { static constexpr uint32_t op_per_channel_bits = 32; static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3); static constexpr uint32_t op_per_channel_max = 8; + static constexpr uint32_t k_in_bytes = systolic_depth * op_per_channel_bytes; }; template @@ -121,12 +128,12 @@ struct dpas_attr_t { template <> struct dpas_attr_t : public dpas_attr_base_t { - static constexpr uint32_t n_fixed_limit = 16; + static constexpr uint32_t n_in_elem = 16; }; template <> struct dpas_attr_t : public dpas_attr_base_t { - static constexpr uint32_t n_fixed_limit = 8; + static constexpr uint32_t n_in_elem = 8; }; template @@ -140,9 +147,10 @@ struct fpu_attr_t { template inline constexpr bool arch_has_fpu = fpu_attr_t::has_fpu; -#define GRF grf_mode::double_grf #ifdef NORMAL_GRF #define GRF grf_mode::normal_grf +#else +#define GRF grf_mode::double_grf #endif template @@ -155,6 +163,7 @@ struct register_nums_t { template struct register_bytes_t; + template <> struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 64; @@ -180,24 +189,49 @@ struct register_attr_t { static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes; }; -template +template < + gpu_arch arch_tag, + mma_engine engine_type, + uint32_t m, + class enable = void> struct mma_attr_t {}; template -struct mma_attr_t>> { +struct mma_attr_t< + arch_tag, + mma_engine::xmx, + m, + std::enable_if_t>> { using dpas_attr = dpas_attr_t; + using load_store_attr = load_store_attr_t; static constexpr uint32_t mma_m_in_elem = (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m; - static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit; - static constexpr uint32_t mma_k_in_bytes = - dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes; + static constexpr uint32_t blk_m_in_elem = 16; + + static constexpr uint32_t mma_n_in_elem = dpas_attr::n_in_elem; + [[maybe_unused]] static constexpr uint32_t blk_n_in_bytes = + load_store_attr::max_trans_load_width_in_bytes; + + static constexpr uint32_t mma_k_in_bytes = dpas_attr::k_in_bytes; + static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes; }; template -struct mma_attr_t>> { +struct mma_attr_t< + arch_tag, + mma_engine::fpu, + m, + std::enable_if_t>> { + using load_store_attr = load_store_attr_t; static constexpr uint32_t mma_m_in_elem = (m > 8) ? 8 : m; - static constexpr uint32_t mma_n_in_elem = 16; + static constexpr uint32_t blk_m_in_elem = 16; + static constexpr uint32_t mma_k_in_bytes = 32; + static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes; + + [[maybe_unused]] static constexpr uint32_t mma_n_in_elem = 16; + static constexpr uint32_t blk_n_in_bytes = + register_bytes_t::reg_in_bytes; }; template @@ -208,13 +242,14 @@ struct arch_attr_t { template using load_store_attr = load_store_attr_t; - template + template using register_attr = register_attr_t; using dpas_attr = dpas_attr_t; static constexpr uint32_t max_wg_num = 64; static constexpr uint32_t local_mem_size = 128 * 1024; + static constexpr bool has_named_barrier = true; }; template <> @@ -222,13 +257,15 @@ struct arch_attr_t { template using load_store_attr = load_store_attr_t; - template + template using register_attr = register_attr_t; using dpas_attr = dpas_attr_t; - static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t max_wg_num = 32; static constexpr uint32_t local_mem_size = 64 * 1024; + + static constexpr bool has_named_barrier = false; }; template <> @@ -236,15 +273,20 @@ struct arch_attr_t { template using load_store_attr = load_store_attr_t; - template + template using register_attr = register_attr_t; using dpas_attr = dpas_attr_t; - static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t max_wg_num = 32; static constexpr uint32_t local_mem_size = 64 * 1024; + static constexpr bool has_named_barrier = false; }; +template +inline constexpr bool arch_has_named_barrier = + arch_attr_t::has_named_barrier; + /// @} xetla_core_arch_config } // namespace gpu::xetla diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index 30c3cc04d..44cf9f403 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -21,7 +21,12 @@ #include namespace gpu::xetla { -enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 }; +enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2, XeLast }; + +template +inline constexpr bool valid_xe_arch_tag = (arch_tag < gpu_arch::XeLast); + +enum class mma_engine : uint8_t { xmx = 0, fpu = 1 }; enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; diff --git a/include/common/core/math_mma.hpp b/include/common/core/math_mma.hpp index 7b7e7cbe4..bb3618ddc 100644 --- a/include/common/core/math_mma.hpp +++ b/include/common/core/math_mma.hpp @@ -72,9 +72,7 @@ constexpr gpu::xetla::argument_type mma_argument_type() { template constexpr __ESIMD_NS::xmx::dpas_argument_type get_argument_type() { static_assert( - arg_type == gpu::xetla::argument_type::U1 || - arg_type == gpu::xetla::argument_type::S1 || - arg_type == gpu::xetla::argument_type::U2 || + arg_type == gpu::xetla::argument_type::U2 || arg_type == gpu::xetla::argument_type::S2 || arg_type == gpu::xetla::argument_type::U4 || arg_type == gpu::xetla::argument_type::S4 || @@ -85,10 +83,6 @@ constexpr __ESIMD_NS::xmx::dpas_argument_type get_argument_type() { arg_type == gpu::xetla::argument_type::TF32, "Unsupported argument type"); switch (arg_type) { - case gpu::xetla::argument_type::U1: - return __ESIMD_NS::xmx::dpas_argument_type::u1; - case gpu::xetla::argument_type::S1: - return __ESIMD_NS::xmx::dpas_argument_type::s1; case gpu::xetla::argument_type::U2: return __ESIMD_NS::xmx::dpas_argument_type::u2; case gpu::xetla::argument_type::S2: diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index b1b031226..0d764dd1b 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -400,10 +400,6 @@ __XETLA_API xetla_vector xetla_load_global( xetla_vector offsets, xetla_mask pred = 1) { using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d:: - template check_restriction(offsets, (uint64_t)p)); return __ESIMD_ENS::lsc_gather< T, @@ -666,10 +662,6 @@ __XETLA_API xetla_vector xetla_load_local( xetla_vector offsets, xetla_mask pred = 1) { using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d:: - template check_restriction(offsets)); return __ESIMD_ENS:: lsc_slm_gather( @@ -694,11 +686,6 @@ __XETLA_API xetla_vector xetla_load_local( template __XETLA_API xetla_vector xetla_load_local(uint32_t offset) { using T = native_type_t; - // DEBUG_INVOKE( - // dbg_level::core, - // core::general_1d::template - // check_restriction( - // (uint64_t)offset)); return __ESIMD_NS::slm_block_load(offset); } @@ -729,10 +716,6 @@ __XETLA_API void xetla_store_local( xetla_vector vals, xetla_mask pred = 1) { using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d:: - template check_restriction(offsets)); __ESIMD_ENS:: lsc_slm_scatter( diff --git a/include/common/utils/common.hpp b/include/common/utils/common.hpp index 2fd47ad89..7d3b8fe15 100644 --- a/include/common/utils/common.hpp +++ b/include/common/utils/common.hpp @@ -51,7 +51,7 @@ constexpr uint32_t get_element_size_code() { enum class lsc_action : uint8_t { prefetch, load, store, atomic }; template -constexpr std::enable_if_t +constexpr std::enable_if_t, void> check_lsc_cache_hint() { if constexpr (Action == lsc_action::prefetch) { // https://gfxspecs.intel.com/Predator/Home/Index/53560 @@ -94,7 +94,7 @@ check_lsc_cache_hint() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t, uint32_t> get_load_cache_hint_code() { check_lsc_cache_hint(); if (L1H == cache_hint::none && L2H == cache_hint::none) { @@ -126,7 +126,7 @@ get_load_cache_hint_code() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t, uint32_t> get_prefetch_cache_hint_code() { check_lsc_cache_hint(); if (L2H == cache_hint::uncached) { @@ -153,7 +153,7 @@ get_prefetch_cache_hint_code() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t, uint32_t> get_store_cache_hint_code() { check_lsc_cache_hint(); if (L1H == cache_hint::none && L2H == cache_hint::none) { @@ -185,7 +185,7 @@ get_store_cache_hint_code() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t, uint32_t> get_atomic_cache_hint_code() { check_lsc_cache_hint(); if (L1H == cache_hint::none && L2H == cache_hint::none) { @@ -286,7 +286,6 @@ enum class store_op : uint8_t { scattered_transpose = 3, block_1d = 4 }; -enum class mma_engine : uint8_t { xmx = 0, fpu = 1 }; // enum class trans_mode : uint8_t { none = 0, transpose = 1 }; enum class memory_op : uint8_t { load = 0, store = 1 }; enum class tdesc_update_dir : uint8_t { x_dir = 0, y_dir = 1 }; diff --git a/include/common/utils/limitation.hpp b/include/common/utils/limitation.hpp index 041dd4111..4ec0aed1c 100644 --- a/include/common/utils/limitation.hpp +++ b/include/common/utils/limitation.hpp @@ -319,7 +319,7 @@ class block_2d { static constexpr auto element_size = sizeof(T); static constexpr uint32_t max_24bit = 16 * 1024 * 1024; // 2 ^ 24 static constexpr auto bytes_per_grf = - register_attr_t::reg_in_bytes; + register_bytes_t::reg_in_bytes; static inline bool check_base_address(uint64_t base) { bool ret = ((base % 64) == 0); @@ -747,7 +747,7 @@ struct check_store { } // namespace subgroup namespace group { -template +template struct gemm { struct default_fpu { template < @@ -800,7 +800,7 @@ struct gemm { int block_size_y_b> struct check_tile_size_default { static constexpr uint32_t reg_in_bytes = - register_attr_t::reg_in_bytes; + register_bytes_t::reg_in_bytes; static constexpr uint32_t simd_len = reg_in_bytes / sizeof(dtype_mma); static_assert( @@ -876,7 +876,7 @@ struct gemm { int block_size_x_b, int block_size_y_b> struct check_tile_size_default { - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; static constexpr int32_t mma_n = mma_attr::mma_n_in_elem; static constexpr int32_t mma_k = diff --git a/include/common/utils/memory_descriptor.hpp b/include/common/utils/memory_descriptor.hpp index 59d762565..de2fc4f3c 100644 --- a/include/common/utils/memory_descriptor.hpp +++ b/include/common/utils/memory_descriptor.hpp @@ -146,7 +146,7 @@ template < typename dtype_, mem_layout layout_, mem_space space_, - uint32_t alignment_ = 8, + uint32_t alignment_ = 16, int dim_ = 2> struct mem_desc_t {}; diff --git a/include/common/utils/raw_send_load_store.hpp b/include/common/utils/raw_send_load_store.hpp index 4c85a7fd7..b40f62ee9 100644 --- a/include/common/utils/raw_send_load_store.hpp +++ b/include/common/utils/raw_send_load_store.hpp @@ -219,13 +219,14 @@ __XETLA_API void xetla_update_tdesc_offsety( template < typename Ty, uint32_t N, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none, - bool transpose = false, - bool transform = false, - gpu_arch arch_tag = gpu_arch::XeHpc> -__XETLA_API std::enable_if_t> -xetla_tload_global(xetla_tdescriptor tdesc) { + cache_hint L1H, + cache_hint L2H, + bool transpose, + bool transform, + gpu_arch arch_tag> +__XETLA_API std:: + enable_if_t, xetla_vector> + xetla_tload_global(xetla_tdescriptor tdesc) { DEBUG_INVOKE( dbg_level::core, core::block_2d::template check_load( @@ -273,10 +274,10 @@ xetla_tload_global(xetla_tdescriptor tdesc) { template < typename Ty, uint32_t N, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none, - gpu_arch arch_tag = gpu_arch::XeHpc> -__XETLA_API std::enable_if_t + cache_hint L1H, + cache_hint L2H, + gpu_arch arch_tag> +__XETLA_API std::enable_if_t, void> xetla_tstore_global(xetla_tdescriptor tdesc, xetla_vector data) { DEBUG_INVOKE( dbg_level::core, core::block_2d::check_store(tdesc)); @@ -308,12 +309,8 @@ xetla_tstore_global(xetla_tdescriptor tdesc, xetla_vector data) { /// dimensions, block size, etc. /// @return none. /// -template < - typename Ty, - cache_hint L1H = cache_hint::cached, - cache_hint L2H = cache_hint::cached, - gpu_arch arch_tag = gpu_arch::XeHpc> -__XETLA_API std::enable_if_t +template +__XETLA_API std::enable_if_t, void> xetla_tprefetch_global(xetla_tdescriptor tdesc) { uint32_t msg_desc = 3; msg_desc |= 0 << 7; @@ -350,12 +347,12 @@ xetla_tprefetch_global(xetla_tdescriptor tdesc) { template < typename Ty, uint32_t N, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none, + cache_hint L1H, + cache_hint L2H, atomic_op Op, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename Toffset = uint32_t> -__XETLA_API std::enable_if_t +__XETLA_API std::enable_if_t, void> xetla_tatomic_store_global( uint64_t base_address, xetla_vector offset, diff --git a/include/common/utils/raw_send_nbarrier.hpp b/include/common/utils/raw_send_nbarrier.hpp index a1b17be02..8c40d2480 100644 --- a/include/common/utils/raw_send_nbarrier.hpp +++ b/include/common/utils/raw_send_nbarrier.hpp @@ -41,9 +41,9 @@ enum class nbarrier_role : uint8_t { /// as consumer. /// template < - uint8_t num_producers = 1, - uint8_t num_consumers = 1, - gpu_arch arch_tag = gpu_arch::XeHpc, + uint8_t num_producers, + uint8_t num_consumers, + gpu_arch arch_tag, typename enable = void> struct xetla_nbarrier_t; @@ -52,7 +52,7 @@ struct xetla_nbarrier_t< num_producers, num_consumers, arch_tag, - std::enable_if_t> { + std::enable_if_t>> { /// /// @brief Description of named barrier objection. /// Structure is defined in @@ -105,20 +105,7 @@ struct xetla_nbarrier_t< num_producers, num_consumers, arch_tag, - std::enable_if_t> { - /// - /// @brief Description of named barrier objection. - /// Structure is defined in - /// [here](https://gfxspecs.intel.com/Predator/Home/Index/57499). - /// - // xetla_vector nbar; - // uint32_t barrier_id; - - /// @param role is the role of subgroup when participating the barrier. - /// @param nbarrier_id [in] is the id of the barrier. - /// note: all subgroups participating the barrier should have the same - /// barrier_id. Here is the bspec link - /// https://gfxspecs.intel.com/Predator/Home/Index/54006 + std::enable_if_t>> { __XETLA_API void init_nbarrier(uint8_t, nbarrier_role) {} /// @brief Generic work-group split barrier. @@ -127,14 +114,10 @@ struct xetla_nbarrier_t< __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::signal>(); } - /// @brief named barrier wait within subgroup. - /// __XETLA_API void wait() { __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::wait>(); } - /// @brief named barrier signal from subgroup. - /// __XETLA_API void arrive_wait() { arrive(); wait(); diff --git a/include/experimental/group/dropout_mask_gen.hpp b/include/experimental/group/dropout_mask_gen.hpp index fce6d2366..b8bb03c54 100644 --- a/include/experimental/group/dropout_mask_gen.hpp +++ b/include/experimental/group/dropout_mask_gen.hpp @@ -39,8 +39,8 @@ template < uint32_t wg_tile_m_, uint32_t sg_tile_n_, uint32_t sg_tile_m_, - uint32_t random_simd_ = 16, - gpu_arch arch_ = gpu_arch::XeHpc> + uint32_t random_simd_, + gpu_arch arch_> struct mask_gen_t { using dtype_mask = dtype_mask_; static constexpr uint32_t wg_tile_n = wg_tile_n_; @@ -64,8 +64,7 @@ struct mask_gen_t { float dropout_prob; }; - using load_store_attr = - typename arch_attr_t::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_width_in_bytes = load_store_attr::max_store_width_in_bytes; static constexpr uint32_t max_store_width_in_elem = @@ -99,7 +98,7 @@ struct mask_gen_t { mem_desc_t, mask_out_tile_desc_t, (sg_tile_m == 1) ? msg_type::block_1d : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_>; static constexpr uint32_t tile_size = tile_size_x * tile_size_y; /// @brief diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp index 8f5d7da1e..5c2ba7fdb 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_api.hpp @@ -44,41 +44,4 @@ enum class ln_bwd_fused_kind : uint8_t { ln_dropout = 3, }; -namespace group { - -/// @brief -/// -/// @tparam fused_op_kind_ -/// @tparam dtype_in_ -/// @tparam dtype_out_ -/// @tparam dtype_acc_ -/// @tparam layer_norm_attr_ -/// @tparam arch_ -template < - ln_fwd_fused_kind fused_op_kind_, - typename dtype_in_, - typename dtype_out_, - typename dtype_acc_, - typename layer_norm_attr_, - gpu_arch arch_ = gpu_arch::XeHpc> -struct ln_fwd_fused_op_t {}; - -/// @brief -/// -/// @tparam fused_op_kind_ -/// @tparam dtype_in_ -/// @tparam dtype_out_ -/// @tparam dtype_acc_ -/// @tparam layer_norm_attr_ -/// @tparam arch_ -template < - ln_bwd_fused_kind fused_op_kind_, - typename dtype_in_, - typename dtype_out_, - typename dtype_acc_, - typename layer_norm_attr_, - gpu_arch arch_ = gpu_arch::XeHpc> -struct ln_bwd_fused_op_t {}; - -} // namespace group } // namespace gpu::xetla diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp index 04e621447..e6b24c952 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_bwd_xe.hpp @@ -55,14 +55,9 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> -struct ln_bwd_fused_op_t< - ln_fused_op_kind_, - dtype_in_, - dtype_out_, - dtype_acc_, - layer_norm_attr_, - gpu_arch::XeHpc> { + typename layer_norm_attr_, + gpu_arch arch_tag> +struct ln_bwd_fused_op_t { static constexpr ln_bwd_fused_kind fused_op_kind = ln_fused_op_kind_; using dtype_acc = dtype_acc_; using dtype_in = dtype_in_; @@ -128,14 +123,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_bwd_fused_op_t< ln_bwd_fused_kind::bias_dropout_resAdd_ln, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_bwd_fused_kind fused_op_kind = ln_bwd_fused_kind::bias_dropout_resAdd_ln; using dtype_acc = dtype_acc_; @@ -164,13 +160,13 @@ struct ln_bwd_fused_op_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_in_t = subgroup::tile_t; using mask_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; dx_resAdd_out_t dx_resAdd_out; dx_resAdd_out_payload_t dx_resAdd_out_payload; mask_in_t mask_in; @@ -271,14 +267,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_bwd_fused_op_t< ln_bwd_fused_kind::ln_dropout_gradAdd, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_bwd_fused_kind fused_op_kind = ln_bwd_fused_kind::ln_dropout_gradAdd; using dtype_acc = dtype_acc_; @@ -307,13 +304,13 @@ struct ln_bwd_fused_op_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_in_t = subgroup::tile_t; using mask_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; grad_in_t grad_in; grad_in_payload_t grad_in_payload; @@ -407,14 +404,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_bwd_fused_op_t< ln_bwd_fused_kind::ln_dropout, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_bwd_fused_kind fused_op_kind = ln_bwd_fused_kind::ln_dropout; using dtype_acc = dtype_acc_; @@ -439,7 +437,7 @@ struct ln_bwd_fused_op_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; mask_in_t mask_in; mask_in_payload_t mask_in_payload; diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp index 3357dac5d..db6f70c10 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp @@ -57,14 +57,9 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> -struct ln_fwd_fused_op_t< - ln_fused_op_kind_, - dtype_in_, - dtype_out_, - dtype_acc_, - layer_norm_attr_, - gpu_arch::XeHpc> { + typename layer_norm_attr_, + gpu_arch arch_tag> +struct ln_fwd_fused_op_t { static constexpr ln_fwd_fused_kind fused_op_kind = ln_fused_op_kind_; using dtype_acc = dtype_acc_; using dtype_in = dtype_in_; @@ -122,14 +117,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_fwd_fused_op_t< ln_fwd_fused_kind::bias_dropout_resAdd_ln, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::bias_dropout_resAdd_ln; using dtype_acc = dtype_acc_; @@ -161,26 +157,26 @@ struct ln_fwd_fused_op_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using res_in_t = subgroup::tile_t; using res_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_in_t = subgroup::tile_t; using mask_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using bias_dropout_res_out_t = subgroup::tile_t; using bias_dropout_res_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; bias_in_t bias_in; bias_in_payload_t bias_in_payload; bias_dropout_res_out_t bias_dropout_res_out; @@ -304,14 +300,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_fwd_fused_op_t< ln_fwd_fused_kind::ln_dropout, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::ln_dropout; using dtype_acc = dtype_acc_; @@ -343,7 +340,7 @@ struct ln_fwd_fused_op_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; mask_in_t mask_in; mask_in_payload_t mask_in_payload; uint32_t mask_ld; @@ -414,14 +411,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_fwd_fused_op_t< ln_fwd_fused_kind::bias_rng_dropout_resAdd_ln, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::bias_rng_dropout_resAdd_ln; using dtype_acc = dtype_acc_; @@ -449,26 +447,26 @@ struct ln_fwd_fused_op_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using res_in_t = subgroup::tile_t; using res_in_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using mask_out_t = subgroup::tile_t; using mask_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using bias_dropout_res_out_t = subgroup::tile_t; using bias_dropout_res_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; bias_in_t bias_in; bias_in_payload_t bias_in_payload; @@ -603,14 +601,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename layer_norm_attr_> + typename layer_norm_attr_, + gpu_arch arch_tag> struct ln_fwd_fused_op_t< ln_fwd_fused_kind::ln_rng_dropout, dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr ln_fwd_fused_kind fused_op_kind = ln_fwd_fused_kind::ln_rng_dropout; using dtype_acc = dtype_acc_; @@ -641,7 +640,7 @@ struct ln_fwd_fused_op_t< mem_desc_t, mask_out_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; mask_out_t mask_out; mask_out_payload_t mask_out_payload; dropout_fwd_t dropout_fwd; diff --git a/include/experimental/group/fused_op/row_reduction_fused_op_api.hpp b/include/experimental/group/fused_op/row_reduction_fused_op_api.hpp index e5ccb1948..c49037d71 100644 --- a/include/experimental/group/fused_op/row_reduction_fused_op_api.hpp +++ b/include/experimental/group/fused_op/row_reduction_fused_op_api.hpp @@ -31,24 +31,4 @@ enum class reduction_fused_kind : uint8_t { bias_dropout_bwd = 2 }; -namespace group { - -/// @brief Additional Ops that can be fused with row reduction processing flow. -/// -/// @tparam fused_op_kind_ Is the type of the fused op. -/// @tparam dtype_in_ Is the data type of input. -/// @tparam dtype_out_ Is the data type of output. -/// @tparam dtype_acc_ Is the accumulation data type. -/// @tparam reduction_attr_ Is the tile size for each group to do the reduction. -/// @tparam arch_ Is the HW generation. -template < - reduction_fused_kind fused_op_kind_, - typename dtype_in_, - typename dtype_out_, - typename dtype_acc_, - typename reduction_attr_, - gpu_arch arch_ = gpu_arch::XeHpc> -struct row_reduction_fused_op_t {}; - -} // namespace group } // namespace gpu::xetla diff --git a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp index 142c55e18..8cdd56b19 100644 --- a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp +++ b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp @@ -54,14 +54,9 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename reduction_attr_> -struct row_reduction_fused_op_t< - fused_op_kind_, - dtype_in_, - dtype_out_, - dtype_acc_, - reduction_attr_, - gpu_arch::XeHpc> { + typename reduction_attr_, + gpu_arch arch_tag_ = gpu_arch::XeHpc> +struct row_reduction_fused_op_t { static constexpr reduction_fused_kind fused_op_kind = fused_op_kind_; using dtype_in = dtype_in_; using dtype_out = dtype_out_; @@ -83,14 +78,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename reduction_attr_> + typename reduction_attr_, + gpu_arch arch_tag> struct row_reduction_fused_op_t< reduction_fused_kind::bias_gelu_w_bwd, dtype_in_, dtype_out_, dtype_acc_, reduction_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr reduction_fused_kind fused_op_kind = reduction_fused_kind::bias_gelu_w_bwd; using dtype_in = dtype_in_; @@ -145,13 +141,13 @@ struct row_reduction_fused_op_t< mem_desc_in_t, dgelu_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using dgelu_x_out_t = subgroup::tile_t; using dgelu_x_out_payload_t = subgroup::mem_payload_t< mem_desc_t, dgelu_tile_desc_t, msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; dgelu_w_in_t dgelu_w_in; dgelu_w_in_payload_t dgelu_w_in_payload(w_load_base_desc); subgroup::tile_load(dgelu_w_in, dgelu_w_in_payload); @@ -180,14 +176,15 @@ template < typename dtype_in_, typename dtype_out_, typename dtype_acc_, - typename reduction_attr_> + typename reduction_attr_, + gpu_arch arch_tag> struct row_reduction_fused_op_t< reduction_fused_kind::bias_dropout_bwd, dtype_in_, dtype_out_, dtype_acc_, reduction_attr_, - gpu_arch::XeHpc> { + arch_tag> { static constexpr reduction_fused_kind fused_op_kind = reduction_fused_kind::bias_dropout_bwd; using dtype_in = dtype_in_; @@ -242,7 +239,7 @@ struct row_reduction_fused_op_t< mem_desc_mask_t, reduction_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using dropout_bwd_out_t = subgroup::tile_t; using mem_desc_out_t = @@ -251,7 +248,7 @@ struct row_reduction_fused_op_t< mem_desc_out_t, reduction_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; if (dropout_prob != 0) { mask_in_t mask_in; mask_in_payload_t mask_in_payload(mask_load_base_desc); diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 59bfd8f54..ceeb1f67f 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -32,8 +32,8 @@ template < typename dtype_scale_, typename dtype_zero_pt_, quant_info quant_info_, - mma_engine mma_engine_ = mma_engine::xmx, - gpu_arch arch_tag_ = gpu_arch::XeHpc, + mma_engine mma_engine_, + gpu_arch arch_tag_, typename enable = void> struct compute_policy_int4_dequantize {}; @@ -54,7 +54,8 @@ struct compute_policy_int4_dequantize< quant_info_, mma_engine_, arch_tag_, - std::enable_if_t> { + std::enable_if_t< + (mma_engine_ == mma_engine::xmx) && arch_has_xmx>> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; @@ -67,8 +68,6 @@ struct compute_policy_int4_dequantize< static constexpr mma_engine mma_engine = mma_engine_; static constexpr gpu_arch arch_tag = arch_tag_; - static_assert(arch_has_xmx, "XeLpg does not support xmx"); - static constexpr bool is_int4_matB_policy = true; static constexpr uint32_t dequant_s = quant_info_.dequant_s; @@ -80,18 +79,14 @@ struct compute_policy_int4_dequantize< static constexpr quant_mode quant_mode = quant_info_.quant_mode; static constexpr uint32_t block_size_y_a = 16; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; - static constexpr uint32_t block_bytes_y_b = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_bytes_y_b = block_bytes_x_a; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); - - static_assert( - block_bytes_x_a == block_bytes_y_b, - "mat_a x need to match with mat_b y"); }; /// @brief Specialized for fpu engine. @@ -147,10 +142,6 @@ struct compute_policy_int4_dequantize< is_col_major_b ? reg_nums_t::register_nums : 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); - - static_assert( - block_bytes_x_a == block_bytes_y_b, - "mat_a x need to match with mat_b y"); }; } // namespace gpu::xetla::group diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 71168e2f6..06db3218e 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -328,7 +328,9 @@ class gemm_t< public: static constexpr uint32_t barrier_count = - enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + enable_periodic_sync && arch_has_named_barrier + ? barrier_count_x + barrier_count_y + : 0; // current only support matA from slm static constexpr uint32_t slm_size = is_local_a ? sg_tile_m * wg_size_y * k_stride * sizeof(dtype_a) : 0; @@ -547,7 +549,7 @@ class gemm_t< if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); } - if constexpr (arch_tag >= gpu_arch::XeHpc) { + if constexpr (arch_has_named_barrier) { if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } @@ -640,7 +642,7 @@ class gemm_t< if constexpr (wg_size_x > 1) { nbarrier_a.wait(); } - if constexpr (arch_tag >= gpu_arch::XeHpc) { + if constexpr (arch_has_named_barrier) { if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } diff --git a/include/experimental/group/reduction/reduction_api.hpp b/include/experimental/group/reduction/reduction_api.hpp index 5e487cedf..c5bfd61db 100644 --- a/include/experimental/group/reduction/reduction_api.hpp +++ b/include/experimental/group/reduction/reduction_api.hpp @@ -20,33 +20,3 @@ #pragma once #include - -namespace gpu::xetla::group { - -/// @brief This is the group row reduction(reduce_sum) + cooperative write out. -/// Use slm to exchange the data. For wg_size_y threads, at the beginning, -/// everyone will keep one row of data; Then, they compose a wg_size_y * -/// row_size 2D block in SLM; After that, each thread will load a small -/// wg_size_y * block_size block, do the local reduction and write to global -/// memory -/// @tparam dtype_acc Is the data type to do the reduction -/// @tparam dtype_out Is the data type to write out -/// @tparam row_size Is the vector size per row -/// @tparam wg_size_x Is the wg size in x direction, is the number of parallel -/// reductions in the wg. -/// @tparam wg_size_y Is the wg size in y direction, i.e. is the number of -/// threads that participate in this reduction. -/// @tparam max_simd_len Is the max SIMD for scatter load. The limitation comes -/// from the scattered load from local memory. -/// @tparam arch_ Is the HW generation. -template < - typename dtype_acc, - typename dtype_out, - uint32_t row_size, - uint32_t wg_size_x, - uint32_t wg_size_y, - uint32_t max_simd_len = 32, - gpu_arch arch_ = gpu_arch::XeHpc> -struct group_row_reduce_store_t {}; - -} // namespace gpu::xetla::group diff --git a/include/experimental/group/reduction/row_reduce_store_xe.hpp b/include/experimental/group/reduction/row_reduce_store_xe.hpp index 29a2474ef..631f7e2d0 100644 --- a/include/experimental/group/reduction/row_reduce_store_xe.hpp +++ b/include/experimental/group/reduction/row_reduce_store_xe.hpp @@ -29,15 +29,10 @@ template < uint32_t row_size, uint32_t wg_size_x, uint32_t wg_size_y, - uint32_t max_simd_len> -struct group_row_reduce_store_t< - dtype_acc, - dtype_out, - row_size, - wg_size_x, - wg_size_y, - max_simd_len, - gpu_arch::XeHpc> { + uint32_t max_simd_len, + gpu_arch arch_tag_ = gpu_arch::XeHpc> +struct group_row_reduce_store_t { + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr uint32_t block_size_x = gpu::xetla::subgroup::detail::gcd::value; static_assert( @@ -64,7 +59,7 @@ struct group_row_reduce_store_t< mem_desc_acc, local_st_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using local_ld_tile_desc_t = subgroup::tile_desc_t< local_tile_size_x, wg_size_y, @@ -78,7 +73,7 @@ struct group_row_reduce_store_t< mem_desc_ld_t, local_ld_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; // If the local tile size is small, we still can use 2D block store using global_st_tile_desc_t = subgroup:: @@ -89,8 +84,8 @@ struct group_row_reduce_store_t< global_st_tile_desc_t, (local_tile_size_x * sizeof(dtype_out) > 64) ? msg_type::block_1d : msg_type::block_2d, - gpu_arch::XeHpc>; - xetla_nbarrier_t nbarrier; + arch_tag>; + xetla_nbarrier_t nbarrier; local_st_t local_st; local_st_payload_t local_st_payload; local_ld_t local_ld; @@ -160,7 +155,8 @@ template < typename dtype_out, uint32_t row_size, uint32_t wg_size_x, - uint32_t max_simd_len> + uint32_t max_simd_len, + gpu_arch arch_tag_> struct group_row_reduce_store_t< dtype_acc, dtype_out, @@ -168,7 +164,8 @@ struct group_row_reduce_store_t< wg_size_x, 1, max_simd_len, - gpu_arch::XeHpc> { + arch_tag_> { + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr uint32_t block_size_x = gpu::xetla::subgroup::detail::gcd::value; @@ -180,7 +177,7 @@ struct group_row_reduce_store_t< global_st_tile_desc_t, (row_size * sizeof(dtype_out) > 64) ? msg_type::block_1d : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; inline void init( [[maybe_unused]] uint32_t sg_idx_ = 0, [[maybe_unused]] uint32_t sg_idy_ = 0, diff --git a/include/experimental/kernel/data_transformer/api.hpp b/include/experimental/kernel/data_transformer/api.hpp index 984149967..bb3f6c8e3 100644 --- a/include/experimental/kernel/data_transformer/api.hpp +++ b/include/experimental/kernel/data_transformer/api.hpp @@ -20,26 +20,3 @@ #pragma once #include - -namespace gpu::xetla::kernel { - -/// @brief Is the data_transformer functor. -/// -/// @tparam dtype_in_ Is the data type of input. -/// @tparam dtype_out_ Is the data type of output. -/// @tparam dtype_acc_ -/// @tparam data_transformer_config_ -/// @tparam mem_layout_in_ Indicates the input data col major or row major. -/// @tparam need_fp8_op -/// @tparam arch_ Is the HW generation. -template < - typename dtype_in_, - typename dtype_out_, - typename dtype_acc_, - typename data_transformer_config_, - mem_layout mem_layout_in_, - int need_fp8_op, - gpu_arch arch_> -struct xetla_data_transformer {}; - -} // namespace gpu::xetla::kernel diff --git a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp index 2ccf069d1..807cbb314 100644 --- a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp +++ b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp @@ -43,20 +43,15 @@ template < typename dtype_compute_, typename data_transformer_attr_, mem_layout mem_layout_in_, - int need_fp8_op> -struct xetla_data_transformer< - dtype_in_, - dtype_out_, - dtype_compute_, - data_transformer_attr_, - mem_layout_in_, - need_fp8_op, - gpu_arch::XeHpc> { + int need_fp8_op, + gpu_arch arch_tag_ = gpu_arch::XeHpc> +struct xetla_data_transformer { using dtype_in = dtype_in_; using dtype_out = dtype_out_; using dtype_compute = dtype_compute_; using data_transformer_attr = data_transformer_attr_; + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_in = mem_layout_in_; static constexpr bool is_col_major_in = @@ -73,8 +68,7 @@ struct xetla_data_transformer< static constexpr uint32_t wg_size_x = (wg_tile_n + sg_tile_n - 1) / sg_tile_n; static constexpr uint32_t wg_size_y = (wg_tile_m + sg_tile_m - 1) / sg_tile_m; - using load_store_attr = typename arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; static constexpr uint32_t max_load_width_in_bytes = @@ -127,7 +121,7 @@ struct xetla_data_transformer< mem_desc_ld_t, global_ld_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using global_st_tile_desc_t = subgroup::tile_desc_t< tile_size_x, @@ -140,7 +134,7 @@ struct xetla_data_transformer< mem_desc_t, global_st_tile_desc_t, msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using global_compute_tile_desc = subgroup::tile_desc_t< tile_size_x, tile_size_y, @@ -157,7 +151,7 @@ struct xetla_data_transformer< reduce_op::max, wg_size_x * wg_size_y, true, - gpu_arch::XeHpc>; + arch_tag>; /// @brief Arguments for gemm::run. /// User should prepare mat_in_ptr, mat_out_ptr, matrix_m, matrix_n, @@ -292,7 +286,8 @@ struct xetla_data_transformer< simd, cache_hint::uncached, cache_hint::write_back, - atomic_op::fmax>( + atomic_op::fmax, + arch_tag>( (uint64_t)args->amax_ptr, offsets * sizeof(dtype_compute), local_max, diff --git a/include/experimental/kernel/layer_norm/api.hpp b/include/experimental/kernel/layer_norm/api.hpp index d1fe6b361..bf3148ea6 100644 --- a/include/experimental/kernel/layer_norm/api.hpp +++ b/include/experimental/kernel/layer_norm/api.hpp @@ -21,59 +21,3 @@ #include #include - -namespace gpu::xetla::kernel { - -/// @brief -/// -/// @tparam dtype_x_ -/// @tparam dtype_y_ -/// @tparam dtype_weight_ -/// @tparam dtype_acc_ -/// @tparam layer_norm_attr_ -/// @tparam store_for_bwd_ -/// @tparam arch_ -/// @tparam ln_fwd_fused_op_ -template < - typename dtype_x_, - typename dtype_y_, - typename dtype_weight_, - typename dtype_acc_, - typename layer_norm_attr_, - bool store_for_bwd_ = true, - gpu_arch arch_ = gpu_arch::XeHpc, - typename ln_fwd_fused_op_ = group::ln_fwd_fused_op_t< - ln_fwd_fused_kind::none, - dtype_x_, - dtype_y_, - dtype_acc_, - layer_norm_attr_, - arch_>> -struct layer_norm_fwd_t {}; - -/// @brief -/// -/// @tparam dtype_x_ -/// @tparam dtype_y_ -/// @tparam dtype_weight_ -/// @tparam dtype_acc_ -/// @tparam layer_norm_attr_ -/// @tparam arch_ -/// @tparam ln_bwd_fused_op_ -template < - typename dtype_x_, - typename dtype_y_, - typename dtype_weight_, - typename dtype_acc_, - typename layer_norm_attr_, - gpu_arch arch_ = gpu_arch::XeHpc, - typename ln_bwd_fused_op_ = group::ln_bwd_fused_op_t< - ln_bwd_fused_kind::none, - dtype_y_, - dtype_x_, - /*in bwd, y is input, x is output*/ dtype_acc_, - layer_norm_attr_, - arch_>> -struct layer_norm_bwd_t {}; - -} // namespace gpu::xetla::kernel diff --git a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp index b67d7f997..fd49a24a3 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp @@ -40,15 +40,15 @@ template < typename dtype_weight_, typename dtype_acc_, typename layer_norm_attr_, - typename ln_bwd_fused_op_> -struct layer_norm_bwd_t< - dtype_x_, - dtype_y_, - dtype_weight_, - dtype_acc_, - layer_norm_attr_, - gpu_arch::XeHpc, - ln_bwd_fused_op_> { + gpu_arch arch_tag_, + typename ln_bwd_fused_op_ = group::ln_bwd_fused_op_t< + ln_bwd_fused_kind::none, + dtype_y_, + dtype_x_, + /*in bwd, y is input, x is output*/ dtype_acc_, + layer_norm_attr_, + arch_tag_>> +struct layer_norm_bwd_t { using dtype_x = dtype_x_; using dtype_y = dtype_y_; using dtype_weight = dtype_weight_; @@ -56,6 +56,7 @@ struct layer_norm_bwd_t< using layer_norm_attr = layer_norm_attr_; using ln_bwd_fused_op = ln_bwd_fused_op_; using ln_fused_op_arguments_t = typename ln_bwd_fused_op::arguments_t; + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m; static constexpr uint32_t wg_tile_n = layer_norm_attr::wg_tile_n; static constexpr uint32_t sg_tile_m = layer_norm_attr::sg_tile_m; @@ -98,26 +99,26 @@ struct layer_norm_bwd_t< mem_desc_y_t, ln_bwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_x_t = mem_desc_t; using x_in_payload_t = subgroup::mem_payload_t< mem_desc_x_t, ln_bwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_weight_t = mem_desc_t; using gamma_in_payload_t = subgroup::mem_payload_t< mem_desc_weight_t, ln_bwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using dx_out_payload_t = subgroup::mem_payload_t< mem_desc_t, ln_bwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; using ln_group_row_reduce_store_t = group::group_row_reduce_store_t< dtype_acc, @@ -126,7 +127,7 @@ struct layer_norm_bwd_t< wg_size_x, wg_size_y, 32, - gpu_arch::XeHpc>; + arch_tag>; /// @brief /// @@ -168,7 +169,7 @@ struct layer_norm_bwd_t< reduce_op Op, uint32_t wg_size_x, uint32_t wg_size_y, - gpu_arch arch_ = gpu_arch::XeHpc> + gpu_arch arch_> struct ln_group_all_reduce_t { uint32_t itr_count; uint32_t slm_base_0; @@ -272,7 +273,7 @@ struct layer_norm_bwd_t< reduce_op::sum, wg_size_x, wg_size_y, - gpu_arch::XeHpc>; + arch_tag>; public: __XETLA_API static void call( diff --git a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp index ecd6bc25b..1d52705e1 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp @@ -42,16 +42,15 @@ template < typename dtype_acc_, typename layer_norm_attr_, bool store_for_bwd_, - typename ln_fwd_fused_op_> -struct layer_norm_fwd_t< - dtype_x_, - dtype_y_, - dtype_weight_, - dtype_acc_, - layer_norm_attr_, - store_for_bwd_, - gpu_arch::XeHpc, - ln_fwd_fused_op_> { + gpu_arch arch_tag_, + typename ln_fwd_fused_op_ = group::ln_fwd_fused_op_t< + ln_fwd_fused_kind::none, + dtype_x_, + dtype_y_, + dtype_acc_, + layer_norm_attr_, + arch_tag_>> +struct layer_norm_fwd_t { using dtype_x = dtype_x_; using dtype_y = dtype_y_; using dtype_weight = dtype_weight_; @@ -59,6 +58,7 @@ struct layer_norm_fwd_t< using layer_norm_attr = layer_norm_attr_; using ln_fwd_fused_op = ln_fwd_fused_op_; using ln_fused_op_arguments_t = typename ln_fwd_fused_op::arguments_t; + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr bool store_for_bwd = store_for_bwd_; static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m; @@ -107,26 +107,26 @@ struct layer_norm_fwd_t< mem_desc_x_t, ln_fwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_weight_t = mem_desc_t; using gamma_in_payload_t = subgroup::mem_payload_t< mem_desc_weight_t, ln_fwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using beta_in_payload_t = subgroup::mem_payload_t< mem_desc_weight_t, ln_fwd_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_y_t = mem_desc_t; using y_out_payload_t = subgroup::mem_payload_t< mem_desc_y_t, ln_fwd_tile_desc_t, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; /// @brief /// @@ -208,7 +208,7 @@ struct layer_norm_fwd_t< int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n; int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m; - xetla_nbarrier_t nbarrier; + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier( sg_idy + nbarrier_base, nbarrier_role::producer_consumer); diff --git a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp index 398e6df96..1721a8ee8 100644 --- a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp @@ -36,7 +36,8 @@ template < int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, - int Max_SeqLen = 2048> + int Max_SeqLen = 2048, + gpu_arch arch_tag = gpu_arch::XeHpc> struct xetla_mha_attn_reg_fwd_t { using dtype_bin = dtype_bin_; using dtype_bot = dtype_bot_; @@ -84,7 +85,7 @@ struct xetla_mha_attn_reg_fwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -93,7 +94,7 @@ struct xetla_mha_attn_reg_fwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -103,19 +104,19 @@ struct xetla_mha_attn_reg_fwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x384 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x512 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_32x1024 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_16x2048 = - group::pre_processing_default_t; - using pre_processing_128x64 = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64 = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -233,49 +234,49 @@ struct xetla_mha_attn_reg_fwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_c_t, mat_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_c_t, mat_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_c_t, mat_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_c_t, mat_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_c_t, mat_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_c_t, mat_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_128x128_t = subgroup::tile_t; @@ -294,37 +295,37 @@ struct xetla_mha_attn_reg_fwd_t { mem_desc_dpot_t, mat_128x128_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_128x256_payload_t = subgroup::mem_payload_t< mem_desc_dpot_t, mat_128x256_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_64x384_payload_t = subgroup::mem_payload_t< mem_desc_dpot_t, mat_64x384_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_64x512_payload_t = subgroup::mem_payload_t< mem_desc_dpot_t, mat_64x512_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_dpot_t, mat_32x1024_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_dpot_t, mat_16x2048_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matDpotMk_128x64_payload_t = subgroup::mem_payload_t< mem_desc_dpot_t, mat_128x64_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; /// @brief Arguments for xetla_softmax_fwd_t::run. /// User should prepare matQ_ptr, matK_ptr, matQKT_ptr, ... @@ -452,9 +453,9 @@ struct xetla_mha_attn_reg_fwd_t { int tid_x = tid_linear & ((1 << tid_x_shift) - 1); int tid_y = tid_linear >> tid_x_shift; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> first_nbarr; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> second_nbarr; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> third_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> first_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> second_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> third_nbarr; first_nbarr.init_nbarrier(31, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier(30, nbarrier_role::producer_consumer); third_nbarr.init_nbarrier(29, nbarrier_role::producer_consumer); @@ -859,7 +860,8 @@ struct xetla_mha_attn_reg_fwd_t { 16, cache_hint::none, cache_hint::none, - atomic_op::fmax>( + atomic_op::fmax, + arch_tag>( (uint64_t)args->Max_ptr, address_fmax, matElem_reg_max_local.xetla_select<16, 1>(0), @@ -1003,7 +1005,8 @@ struct xetla_mha_attn_reg_fwd_t { 16, cache_hint::none, cache_hint::none, - atomic_op::fadd>( + atomic_op::fadd, + arch_tag>( (uint64_t)args->Sum_ptr, address_fmax, matElem_reg_Sum_1.xetla_select<16, 1>(0), @@ -1541,7 +1544,8 @@ template < int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, - int Max_SeqLen = 512> + int Max_SeqLen = 512, + gpu_arch arch_tag = gpu_arch::XeHpc> struct xetla_mha_attn_reg_bwd_t { using dtype_bin = dtype_bwd_bin_; using dtype_bot = dtype_bwd_bot_; @@ -1593,7 +1597,7 @@ struct xetla_mha_attn_reg_bwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -1602,7 +1606,7 @@ struct xetla_mha_attn_reg_bwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out_b_trnp_a = mem_desc_t; @@ -1611,7 +1615,7 @@ struct xetla_mha_attn_reg_bwd_t { using compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -1621,25 +1625,25 @@ struct xetla_mha_attn_reg_bwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x384 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_64x512 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_32x1024 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_16x2048 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x64 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_256x64 = - group::pre_processing_default_t; - using pre_processing_128x64_af = group:: - pre_processing_matA_neg_filter_t; - using pre_processing_256x64_af = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64_af = + group::pre_processing_matA_neg_filter_t; + using pre_processing_256x64_af = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -1789,42 +1793,42 @@ struct xetla_mha_attn_reg_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_c_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_c_t, matC_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_c_t, matC_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_c_t, matC_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_c_t, matC_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, @@ -1872,7 +1876,7 @@ struct xetla_mha_attn_reg_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_128x64_trnp_a_tile_desc_t, @@ -1880,12 +1884,12 @@ struct xetla_mha_attn_reg_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_256x64_trnp_a_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_128x64_trnp_af_tile_desc_t, @@ -1893,7 +1897,7 @@ struct xetla_mha_attn_reg_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_256x64_trnp_af_tile_desc_t, @@ -1901,7 +1905,7 @@ struct xetla_mha_attn_reg_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_128x128_t = subgroup::tile_t; using matW_128x256_t = subgroup::tile_t; @@ -1915,32 +1919,32 @@ struct xetla_mha_attn_reg_bwd_t { mem_desc_w_t, matC_128x128_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_128x256_payload_t = subgroup::mem_payload_t< mem_desc_w_t, matC_128x256_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_64x384_payload_t = subgroup::mem_payload_t< mem_desc_w_t, matC_64x384_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_64x512_payload_t = subgroup::mem_payload_t< mem_desc_w_t, matC_64x512_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_w_t, matC_32x1024_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matW_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_w_t, matC_16x2048_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; #if 0 //512 = 16x32 or 8x64 @@ -2090,15 +2094,15 @@ struct xetla_mha_attn_reg_bwd_t { int tid_y = tid_linear >> tid_x_shift; static_assert(ThreadNum == 32, "All Thread Sync"); - xetla_nbarrier_t first_nbarr; - xetla_nbarrier_t second_nbarr; + xetla_nbarrier_t first_nbarr; + xetla_nbarrier_t second_nbarr; int max_2d_nbar_id = ThreadNum >> 1; first_nbarr.init_nbarrier(max_2d_nbar_id, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier( max_2d_nbar_id + 1, nbarrier_role::producer_consumer); - xetla_nbarrier_t all_nbarr; + xetla_nbarrier_t all_nbarr; all_nbarr.init_nbarrier(ThreadNum - 1, nbarrier_role::producer_consumer); for (int transp128_loop = 0; transp128_loop < transp128_loop_num; @@ -2750,7 +2754,8 @@ struct xetla_mha_attn_reg_bwd_t { 16, cache_hint::none, cache_hint::none, - atomic_op::fadd>( + atomic_op::fadd, + arch_tag>( (uint64_t)args->matSum_ptr, address_fsum, matElem_reg_Sum_1.xetla_select<16, 1>(0), diff --git a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp index 41dfa7606..11a926e59 100644 --- a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp @@ -46,7 +46,8 @@ template < int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, - int Max_SeqLen = 512> + int Max_SeqLen = 512, + gpu_arch arch_tag = gpu_arch::XeHpc> struct xetla_mha_core_attn_fwd_t { using dtype_bin = dtype_bin_; using dtype_bot = dtype_bot_; @@ -90,7 +91,7 @@ struct xetla_mha_core_attn_fwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -99,7 +100,7 @@ struct xetla_mha_core_attn_fwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -109,11 +110,11 @@ struct xetla_mha_core_attn_fwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; - using pre_processing_128x64 = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64 = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -167,17 +168,17 @@ struct xetla_mha_core_attn_fwd_t { mem_desc_t, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; // 512 = 16x32 or 8x64 using matElem_tile_desc_t = gpu::xetla::subgroup::tile_desc_t< @@ -194,14 +195,14 @@ struct xetla_mha_core_attn_fwd_t { mem_desc_elem_ld_t, matElem_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matElem_st_t = gpu::xetla::subgroup::tile_t; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< mem_desc_elem_ld_t, matElem_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matElem_reg_t = gpu::xetla::subgroup::tile_t< float, gpu::xetla::subgroup::tile_desc_t<32, 16, 32, 16, reg_layout::tiled>>; @@ -313,8 +314,8 @@ struct xetla_mha_core_attn_fwd_t { blk_128x256_loop_num = 2; } - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> first_nbarr; - xetla_nbarrier_t<32, 32, gpu_arch::XeHpc> second_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> first_nbarr; + xetla_nbarrier_t<32, 32, arch_tag> second_nbarr; first_nbarr.init_nbarrier(0, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier(1, nbarrier_role::producer_consumer); @@ -864,7 +865,8 @@ template < typename dtype_bwd_acc_, int HWThreadNum, bool Dopt_RandGenflag = true, - bool Mkin_flag = false> + bool Mkin_flag = false, + gpu_arch arch_tag = gpu_arch::XeHpc> struct xetla_mha_core_attn_bwd_t { using dtype_bin = dtype_bwd_bin_; using dtype_bot = dtype_bwd_bot_; @@ -911,7 +913,7 @@ struct xetla_mha_core_attn_bwd_t { using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out = mem_desc_t; @@ -920,7 +922,7 @@ struct xetla_mha_core_attn_bwd_t { using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_a_out_b_trnp_a = mem_desc_t; @@ -929,7 +931,7 @@ struct xetla_mha_core_attn_bwd_t { using compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t, bgm_perf_tuning_knob, - gpu_arch::XeHpc>; + arch_tag>; static constexpr uint32_t global_kslicing = 1; static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx); @@ -939,17 +941,17 @@ struct xetla_mha_core_attn_bwd_t { using work_group_t = work_group_t; using pre_processing_128x128 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x256 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_128x64 = - group::pre_processing_default_t; + group::pre_processing_default_t; using pre_processing_256x64 = - group::pre_processing_default_t; - using pre_processing_128x64_af = group:: - pre_processing_matA_neg_filter_t; - using pre_processing_256x64_af = group:: - pre_processing_matA_neg_filter_t; + group::pre_processing_default_t; + using pre_processing_128x64_af = + group::pre_processing_matA_neg_filter_t; + using pre_processing_256x64_af = + group::pre_processing_matA_neg_filter_t; using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, @@ -1074,14 +1076,14 @@ struct xetla_mha_core_attn_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_c_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_bot_t = mem_desc_t; using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, @@ -1089,7 +1091,7 @@ struct xetla_mha_core_attn_bwd_t { (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_128x64_trnp_a_tile_desc_t, @@ -1097,7 +1099,7 @@ struct xetla_mha_core_attn_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_256x64_trnp_a_tile_desc_t, @@ -1105,7 +1107,7 @@ struct xetla_mha_core_attn_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_128x64_trnp_af_tile_desc_t, @@ -1113,7 +1115,7 @@ struct xetla_mha_core_attn_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_bot_t, matC_256x64_trnp_af_tile_desc_t, @@ -1121,7 +1123,7 @@ struct xetla_mha_core_attn_bwd_t { ? msg_type::atomic_add : subgroup:: msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; // 512 = 16x32 or 8x64 using matElem_tile_desc_t = gpu::xetla::subgroup::tile_desc_t< @@ -1141,12 +1143,12 @@ struct xetla_mha_core_attn_bwd_t { mem_desc_elem_ld_t, matElem_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< mem_desc_elem_ld_t, matElem_tile_desc_t, msg_type::block_2d, - gpu_arch::XeHpc>; + arch_tag>; using matElem_reg_t = gpu::xetla::subgroup::tile_t< float, gpu::xetla::subgroup::tile_desc_t<32, 16, 32, 16, reg_layout::tiled>>; @@ -1253,15 +1255,15 @@ struct xetla_mha_core_attn_bwd_t { g_thd32_tid.init(tid_linear); static_assert(ThreadNum == 32, "All Thread Sync"); - xetla_nbarrier_t first_nbarr; - xetla_nbarrier_t second_nbarr; + xetla_nbarrier_t first_nbarr; + xetla_nbarrier_t second_nbarr; int max_2d_nbar_id = ThreadNum >> 1; first_nbarr.init_nbarrier(max_2d_nbar_id, nbarrier_role::producer_consumer); second_nbarr.init_nbarrier( max_2d_nbar_id + 1, nbarrier_role::producer_consumer); - xetla_nbarrier_t all_nbarr; + xetla_nbarrier_t all_nbarr; all_nbarr.init_nbarrier(ThreadNum - 1, nbarrier_role::producer_consumer); for (int transp128_loop = 0; transp128_loop < transp128_loop_num; diff --git a/include/experimental/kernel/reduction/api.hpp b/include/experimental/kernel/reduction/api.hpp index fc7fdfc61..76308e778 100644 --- a/include/experimental/kernel/reduction/api.hpp +++ b/include/experimental/kernel/reduction/api.hpp @@ -20,30 +20,3 @@ #pragma once #include - -namespace gpu::xetla::kernel { - -/// @brief Is the row_reduction functor. -/// -/// @tparam dtype_in_ Is the data type of input. -/// @tparam dtype_out_ Is the data type of output. -/// @tparam dtype_acc_ Is the accumulation data type. -/// @tparam reduction_attr_ Is the tile size for each group to do the reduction. -/// @tparam arch_ Is the HW generation. -/// @tparam fused_op_t_ -template < - typename dtype_in_, - typename dtype_out_, - typename dtype_acc_, - typename reduction_attr_, - gpu_arch arch_, - typename fused_op_t_ = group::row_reduction_fused_op_t< - reduction_fused_kind::none, - dtype_in_, - dtype_out_, - dtype_acc_, - reduction_attr_, - arch_>> -struct xetla_row_reduction_t {}; - -} // namespace gpu::xetla::kernel diff --git a/include/experimental/kernel/reduction/row_reduction_xe.hpp b/include/experimental/kernel/reduction/row_reduction_xe.hpp index c2a4a11c9..8f094f0c8 100644 --- a/include/experimental/kernel/reduction/row_reduction_xe.hpp +++ b/include/experimental/kernel/reduction/row_reduction_xe.hpp @@ -43,14 +43,15 @@ template < typename dtype_out_, typename dtype_acc_, typename reduction_attr_, - typename fused_op_t_> -struct xetla_row_reduction_t< - dtype_in_, - dtype_out_, - dtype_acc_, - reduction_attr_, - gpu_arch::XeHpc, - fused_op_t_> { + gpu_arch arch_tag_, + typename fused_op_t_ = group::row_reduction_fused_op_t< + reduction_fused_kind::none, + dtype_in_, + dtype_out_, + dtype_acc_, + reduction_attr_, + arch_tag_>> +struct xetla_row_reduction_t { using dtype_in = dtype_in_; using dtype_out = dtype_out_; using dtype_acc = dtype_acc_; @@ -58,6 +59,7 @@ struct xetla_row_reduction_t< using fused_op_t = fused_op_t_; using fused_op_arguments_t = typename fused_op_t::arguments_t; + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr uint32_t wg_tile_m = reduction_attr::wg_tile_m; static constexpr uint32_t wg_tile_n = reduction_attr::wg_tile_n; static constexpr uint32_t sg_tile_m = reduction_attr::sg_tile_m; @@ -67,8 +69,7 @@ struct xetla_row_reduction_t< static constexpr uint32_t wg_size_y = (wg_tile_m + sg_tile_m - 1) / sg_tile_m; using work_group_t = work_group_t; static constexpr bool use_dynamic_job = is_dynamic_job && (wg_size_y > 1); - using load_store_attr = typename arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; static constexpr uint32_t max_load_width_in_bytes = @@ -114,7 +115,7 @@ struct xetla_row_reduction_t< mem_desc_in_t, global_ld_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mat_buffer_t = subgroup::tile_t< dtype_acc, subgroup:: @@ -126,7 +127,8 @@ struct xetla_row_reduction_t< sg_tile_n, wg_size_x, wg_size_y, - max_simd_len>; + max_simd_len, + arch_tag>; /// @brief /// @@ -179,7 +181,7 @@ struct xetla_row_reduction_t< int global_start_x_in = item.get_group(2) * wg_tile_n + sg_idx * sg_tile_n; int global_start_y_in = sg_idy * sg_tile_m; - xetla_nbarrier_t nbarrier; + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier( nbarrier_base + sg_idx, nbarrier_role::producer_consumer); if constexpr (use_dynamic_job) { diff --git a/include/group/cooperative_reduction.hpp b/include/group/cooperative_reduction.hpp index 25682c0e2..2ffe0b0f9 100644 --- a/include/group/cooperative_reduction.hpp +++ b/include/group/cooperative_reduction.hpp @@ -54,7 +54,7 @@ class cooperative_reduce_t< matAcc_t, num_cooperative_wg, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape = tile_shape_; @@ -129,7 +129,8 @@ class cooperative_reduce_t< public: using mat_slice_t = subgroup::tile_t; - static constexpr uint32_t barrier_count = work_group_size; + static constexpr uint32_t barrier_count = + arch_has_named_barrier ? work_group_size : 0; static constexpr uint32_t slm_size = wg_tile_size * num_cooperative_wg; uint32_t coop_id; @@ -225,7 +226,7 @@ class cooperative_reduce_t< matAcc_t, 1, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape = tile_shape_; diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index b25397dcd..03daf05fb 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -35,7 +35,7 @@ class epilogue_t< epilogue_policy_default, tile_shape_, mem_desc_c_t_, - std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_default; using tile_shape = tile_shape_; @@ -117,7 +117,7 @@ class epilogue_t< tile_shape_, mem_desc_c_t_, std::enable_if_t<( - (arch_tag_ <= gpu_arch::XeHpc) && (mem_desc_c_t_::dim == 4))>> { + valid_xe_arch_tag && (mem_desc_c_t_::dim == 4))>> { public: using epilogue_policy = epilogue_policy_default; using tile_shape = tile_shape_; diff --git a/include/group/epilogue/impl/quant_tile_op_xe.hpp b/include/group/epilogue/impl/quant_tile_op_xe.hpp index 2648fd77a..4cefae147 100644 --- a/include/group/epilogue/impl/quant_tile_op_xe.hpp +++ b/include/group/epilogue/impl/quant_tile_op_xe.hpp @@ -47,7 +47,7 @@ class epilogue_t< dtype_dequant_>, tile_shape_, mem_desc_c_t_, - std::enable_if_t<(arch_tag_ == gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_quant_op< dequant_op_t_, diff --git a/include/group/epilogue/impl/stream_k_op_xe.hpp b/include/group/epilogue/impl/stream_k_op_xe.hpp index ea60bee2f..b109d854f 100644 --- a/include/group/epilogue/impl/stream_k_op_xe.hpp +++ b/include/group/epilogue/impl/stream_k_op_xe.hpp @@ -34,7 +34,7 @@ template < typename epilogue_t_, typename mem_desc_d_t_, typename mem_desc_atomic_sync_t_, - gpu_arch arch_tag_ = gpu_arch::XeHpc> + gpu_arch arch_tag_> struct epilogue_stream_k_t { static constexpr gpu_arch arch_tag = arch_tag_; using epilogue_t = epilogue_t_; diff --git a/include/group/epilogue/impl/tile_op_xe.hpp b/include/group/epilogue/impl/tile_op_xe.hpp index 3afaec107..ec2a615b9 100644 --- a/include/group/epilogue/impl/tile_op_xe.hpp +++ b/include/group/epilogue/impl/tile_op_xe.hpp @@ -39,7 +39,7 @@ class epilogue_t< epilogue_policy_tile_op, tile_shape_, mem_desc_c_t_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_tile_op; using tile_op_t = typename epilogue_policy::tile_op_t; diff --git a/include/group/epilogue/impl/unaligned_xe.hpp b/include/group/epilogue/impl/unaligned_xe.hpp index b0a0eb8cc..eb0a6c44b 100644 --- a/include/group/epilogue/impl/unaligned_xe.hpp +++ b/include/group/epilogue/impl/unaligned_xe.hpp @@ -35,7 +35,7 @@ class epilogue_t< epilogue_policy_unaligned, tile_shape_, mem_desc_c_t_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_unaligned; using tile_shape = tile_shape_; @@ -45,6 +45,11 @@ class epilogue_t< static constexpr uint32_t slm_size = mem_desc_c_t::is_local ? tile_shape::wg_tile_size_x * tile_shape::wg_tile_size_y : 0; + + using load_store_attr = load_store_attr_t; + static constexpr bool ldc_align = + (mem_desc_c_t::alignment_in_bytes % load_store_attr::alignment_in_bytes == + 0); /// @brief Epilogue arguments. struct arguments_t {}; @@ -71,8 +76,9 @@ class epilogue_t< public: static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::unaligned_2d - : msg_type::scatter); + (mem_space_c == mem_space::global + ? (ldc_align ? msg_type::block_2d : msg_type::unaligned_2d) + : msg_type::scatter); /// @brief Default epilogue. /// 1) Convert dtype_acc to dtype_c 2) Overwrite to memory. diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 0a0cd1c91..2dcf7e756 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -33,7 +33,7 @@ namespace gpu::xetla::group { template < typename compute_attr_, typename perf_tuning_knob_, - gpu_arch arch_tag_ = gpu_arch::XeHpc, + gpu_arch arch_tag_, typename enable = void> struct compute_policy_default_xmx {}; @@ -48,6 +48,7 @@ struct compute_policy_default_xmx< arch_tag_, std::enable_if_t>> { static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr mma_engine mma_engine = mma_engine::xmx; using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; @@ -59,8 +60,9 @@ struct compute_policy_default_xmx< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr uint32_t block_size_y_a = 16; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; + + static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem; static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); @@ -92,7 +94,7 @@ struct compute_policy_unaligned_xmx : public compute_policy_default_xmx< template < typename compute_attr_, typename perf_tuning_knob_, - gpu_arch arch_tag_ = gpu_arch::XeHpc, + gpu_arch arch_tag_, typename enable = void> struct compute_policy_default_fpu {}; @@ -107,6 +109,7 @@ struct compute_policy_default_fpu< arch_tag_, std::enable_if_t>> { static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr mma_engine mma_engine = mma_engine::fpu; using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; @@ -118,18 +121,26 @@ struct compute_policy_default_fpu< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr uint32_t block_size_y_a = - arch_tag_ == gpu_arch::XeLpg ? 8 : 16; - static constexpr uint32_t block_bytes_x_a = 32; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem; + static constexpr uint32_t block_bytes_x_a = mma_attr::blk_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_bytes_x_b = - arch_attr_t::template register_attr<>::reg_in_bytes; + + static constexpr uint32_t block_bytes_x_b = mma_attr::blk_n_in_bytes; static constexpr uint32_t block_size_x_b = block_bytes_x_b / sizeof(dtype_mma_b); static constexpr uint32_t block_size_y_b = block_size_x_a; }; +template < + typename compute_attr_, + typename perf_tuning_knob_, + gpu_arch arch_tag_> +struct compute_policy_unaligned_fpu : public compute_policy_default_fpu< + compute_attr_, + perf_tuning_knob_, + arch_tag_> {}; /// @} xetla_gemm } // namespace gpu::xetla::group diff --git a/include/group/gemm/gemm.hpp b/include/group/gemm/gemm.hpp index ac5d43f16..db2b4b3a4 100644 --- a/include/group/gemm/gemm.hpp +++ b/include/group/gemm/gemm.hpp @@ -30,5 +30,6 @@ #include #include #include +#include #include #include diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index 7345ea322..c2d45f24d 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -55,8 +55,9 @@ class gemm_t< static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; using work_group_t = typename tile_shape::work_group_t; - constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + constexpr static gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; @@ -159,6 +160,7 @@ class gemm_t< subgroup::tile_desc_t, 1, arch_tag>; + matA_prefetch_payload_t matA_prefetch_payload; static constexpr reg_layout reg_layout_b = reg_layout::tiled; using matB_tile_desc_t = subgroup::tile_desc_t< @@ -179,6 +181,7 @@ class gemm_t< subgroup::tile_desc_t, 1, arch_tag>; + matB_prefetch_payload_t matB_prefetch_payload; public: using matAcc_tile_desc_t = subgroup::tile_desc_t< @@ -200,10 +203,19 @@ class gemm_t< static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr bool need_local_fence = + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; public: static constexpr uint32_t barrier_count = - enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + (enable_periodic_sync && arch_has_named_barrier) + ? barrier_count_x + barrier_count_y + : 0; + // current no slm path static constexpr uint32_t slm_size = 0; @@ -283,6 +295,74 @@ class gemm_t< } }; + inline void periodic_sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (enable_periodic_sync) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + } + + inline void periodic_sync_arrive(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive(); + } + } + } + } + + inline void periodic_sync_wait(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.wait(); + } + } + } + } + + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -306,18 +386,14 @@ class gemm_t< "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") - __XETLA_API static void release(uint8_t nbarrier_id = 0) { - static constexpr bool need_local_fence = - (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + __XETLA_API void release(uint8_t nbarrier_id = 0) { if constexpr (need_local_fence) { xetla_fence(); } xetla_fence(); - static constexpr uint32_t wg_size = wg_size_x * wg_size_y; if constexpr (wg_size > 1) { - xetla_nbarrier_t nbarrier; - nbarrier.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); - nbarrier.arrive_wait(); + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); } } @@ -344,75 +420,40 @@ class gemm_t< pre_processing.init(g, args.pre_processing_args); matA_payload_t matA_payload(args.matA_base_desc); matB_payload_t matB_payload(args.matB_base_desc); - matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, 0); - matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, 0); - xetla_nbarrier_t nbarrier_a; - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - xetla_nbarrier_t nbarrier_b; - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); + matA_prefetch_payload.init(args.matA_base_desc, 0); + matB_prefetch_payload.init(args.matB_base_desc, 0); + + periodic_sync_init(sg_idx, sg_idy, nbarrier_base); + #pragma unroll for (uint32_t i = 0; i < stages; i++) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } for (uint32_t i = 0; i < args.inner_loop_count; i++) { - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.arrive(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.arrive(); - } - } - } - SW_BARRIER(); + periodic_sync_arrive(i); subgroup::tile_load( matA, matA_payload); subgroup::tile_load( matB, matB_payload); + SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); + if constexpr (stages != 0) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } + + SW_BARRIER(); matA_acc_t matA_acc; matB_acc_t matB_acc; subgroup::elemwise_cvt(matA_acc, matA); subgroup::elemwise_cvt(matB_acc, matB); pre_processing(matA_acc, matB_acc, matA, matB); - SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.wait(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.wait(); - } - } - } + + periodic_sync_wait(i); } SW_BARRIER(); } diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index 0626f2310..70bdc7610 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -57,7 +57,7 @@ class gemm_t< static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; using work_group_t = typename tile_shape::work_group_t; - constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + constexpr static gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; @@ -65,6 +65,7 @@ class gemm_t< static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; private: + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; /******** set data type **********/ using dtype_a = typename mem_desc_a_t::dtype; using dtype_b = typename mem_desc_b_t::dtype; @@ -126,6 +127,11 @@ class gemm_t< /******** set tile **********/ static constexpr reg_layout reg_layout_a = reg_layout::tiled; + static constexpr reg_layout reg_layout_b = sizeof(dtype_b) < sizeof(uint32_t) + ? reg_layout::vnni_tiled + : reg_layout::tiled; + + public: using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, tile_size_y_a, @@ -144,9 +150,7 @@ class gemm_t< subgroup::tile_desc_t, wg_size_x, arch_tag>; - static constexpr reg_layout reg_layout_b = sizeof(dtype_b) < sizeof(uint32_t) - ? reg_layout::vnni_tiled - : reg_layout::tiled; + matA_prefetch_payload_t matA_prefetch_payload; using matB_tile_desc_t = subgroup::tile_desc_t< tile_size_x_b, tile_size_y_b, @@ -165,8 +169,8 @@ class gemm_t< subgroup::tile_desc_t, wg_size_y, arch_tag>; + matB_prefetch_payload_t matB_prefetch_payload; - public: using matAcc_tile_desc_t = subgroup::tile_desc_t< tile_size_x_c, tile_size_y_c, @@ -186,10 +190,18 @@ class gemm_t< static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr bool need_local_fence = + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; public: static constexpr uint32_t barrier_count = - enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + (enable_periodic_sync && arch_has_named_barrier) + ? barrier_count_x + barrier_count_y + : 0; static constexpr uint32_t slm_size = 0; @@ -269,6 +281,74 @@ class gemm_t< } }; + inline void periodic_sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (enable_periodic_sync) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + } + + inline void periodic_sync_arrive(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive(); + } + } + } + } + + inline void periodic_sync_wait(uint32_t iter_num) { + if constexpr (enable_periodic_sync) { + if ((iter_num % sync_freq) == 0) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.wait(); + } + } + } + } + + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -292,18 +372,14 @@ class gemm_t< "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") - __XETLA_API static void release(uint8_t nbarrier_id = 0) { - static constexpr bool need_local_fence = - (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + __XETLA_API void release(uint8_t nbarrier_id = 0) { if constexpr (need_local_fence) { xetla_fence(); } xetla_fence(); - static constexpr uint32_t wg_size = wg_size_x * wg_size_y; if constexpr (wg_size > 1) { - xetla_nbarrier_t nbarrier; - nbarrier.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); - nbarrier.arrive_wait(); + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); } } @@ -324,10 +400,10 @@ class gemm_t< int32_t sg_idy = g.get_id() / wg_size_x; XETLA_ASSERT( - g.get_id() < (wg_size_x * wg_size_y), + g.get_id() < wg_size, "Thread id(%d) should less than wg_size(%d)", g.get_id(), - wg_size_x * wg_size_y); + wg_size); update_sg_tile_tdesc(args, sg_idx, sg_idy); pre_processing_t pre_processing; @@ -337,79 +413,41 @@ class gemm_t< pre_processing.init(g, args.pre_processing_args); matA_payload_t matA_payload(args.matA_base_desc); matB_payload_t matB_payload(args.matB_base_desc); - matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, sg_idx); - matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, sg_idy); - xetla_nbarrier_t nbarrier_a; - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - xetla_nbarrier_t nbarrier_b; - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); + matA_prefetch_payload.init(args.matA_base_desc, sg_idx); + matB_prefetch_payload.init(args.matB_base_desc, sg_idy); + + periodic_sync_init(sg_idx, sg_idy, nbarrier_base); #pragma unroll for (uint32_t i = 0; i < stages; i++) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } for (uint32_t i = 0; i < args.inner_loop_count; i++) { - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.arrive(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.arrive(); - } - } - } + periodic_sync_arrive(i); + subgroup::tile_load( matB, matB_payload); subgroup::tile_load( matA, matA_payload); - if constexpr (stages != 0) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); + if constexpr (stages != 0) { - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } + SW_BARRIER(); matA_acc_t matA_acc; matB_acc_t matB_acc; subgroup::elemwise_cvt(matA_acc, matA); subgroup::vnni_transform(matB_acc, matB); pre_processing(matA_acc, matB_acc, matA, matB); - SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); - if constexpr (enable_periodic_sync) { - if ((i % sync_freq) == 0) { - if constexpr (wg_size_x > 1) { - nbarrier_a.wait(); - } - if constexpr (arch_tag >= gpu_arch::XeHpc) - if constexpr (wg_size_y > 1) { - nbarrier_b.wait(); - } - } - } + + periodic_sync_wait(i); } SW_BARRIER(); } diff --git a/include/group/gemm/impl/pre_processing_xe.hpp b/include/group/gemm/impl/pre_processing_xe.hpp index 053cf9aeb..588622ef6 100644 --- a/include/group/gemm/impl/pre_processing_xe.hpp +++ b/include/group/gemm/impl/pre_processing_xe.hpp @@ -32,7 +32,7 @@ template class pre_processing_default_t< tile_shape_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using tile_shape = tile_shape_; using work_group_t = typename tile_shape::work_group_t; @@ -67,7 +67,7 @@ template class pre_processing_matA_neg_filter_t< tile_shape_, arch_tag, - std::enable_if_t<(arch_tag == gpu_arch::XeHpc)>> { + std::enable_if_t>> { using tile_shape = tile_shape_; using work_group_t = typename tile_shape::work_group_t; diff --git a/include/group/gemm/impl/selector_xe.hpp b/include/group/gemm/impl/selector_xe.hpp index 6f32fb67b..68560ae7b 100644 --- a/include/group/gemm/impl/selector_xe.hpp +++ b/include/group/gemm/impl/selector_xe.hpp @@ -25,26 +25,33 @@ namespace gpu::xetla::group { namespace detail { +template +class check_block_2d_pitch_alignment { + using load_store_attr = load_store_attr_t; + static constexpr int alignment_in_bytes = load_store_attr::alignment_in_bytes; + static constexpr int alignment_bytes = alignment * sizeof(dtype); + + public: + static constexpr bool value = (alignment_bytes % alignment_in_bytes == 0); +}; + +} // namespace detail + template < typename dtype_a, typename dtype_b, int alignment_a, int alignment_b, gpu_arch arch_tag> -class check_2d_block_pitch_alignment { - using load_store_attr = typename arch_attr_t< - arch_tag>::template load_store_attr; - static constexpr int alignment_bytes = load_store_attr::alignment_in_bytes; - static constexpr int alignment_bytes_a = alignment_a * sizeof(dtype_a); - static constexpr int alignment_bytes_b = alignment_b * sizeof(dtype_b); - +class check_block_2d_pitch_alignment { public: - static constexpr bool value = (alignment_bytes_a % alignment_bytes == 0) && - (alignment_bytes_b % alignment_bytes == 0); + static constexpr int a_align = detail:: + check_block_2d_pitch_alignment::value; + static constexpr int b_align = detail:: + check_block_2d_pitch_alignment::value; + static constexpr bool value = a_align && b_align; }; -} // namespace detail - /// @addtogroup xetla_gemm /// @{ @@ -80,7 +87,7 @@ class gemm_selector_t< arch_tag, stages, sync_freq, - std::enable_if_t; using mem_desc_b = mem_desc_t; + using ld_align_attr = check_block_2d_pitch_alignment< + dtype_a, + dtype_b, + alignment_a, + alignment_b, + arch_tag>; using compute_attr = compute_attr_t; using perf_tuning_knob = perf_tuning_knob_t; using compute_policy = @@ -194,17 +207,12 @@ class gemm_selector_t< arch_tag, stages, sync_freq, - std::enable_if_t::value>> { - static_assert( - std::is_same::value && - std::is_same::value, - "When use gemm_selector, dtype_a and dtype_b in fpu based gemm" - "should be the same as dtype_acc"); using mem_desc_a = mem_desc_t; using mem_desc_b = @@ -224,5 +232,68 @@ class gemm_selector_t< pre_processing>; }; +/// @brief Selects 2d block && fpu based gemm. +template < + typename dtype_a, + typename dtype_b, + mem_layout mem_layout_a, + mem_layout mem_layout_b, + mem_space mem_space_a, + mem_space mem_space_b, + int alignment_a, + int alignment_b, + typename dtype_acc, + typename tile_shape, + int k_stride, + gpu_arch arch_tag, + int stages, + int sync_freq> +class gemm_selector_t< + dtype_a, + dtype_b, + mem_layout_a, + mem_layout_b, + mem_space_a, + mem_space_b, + alignment_a, + alignment_b, + dtype_acc, + tile_shape, + k_stride, + mma_engine::fpu, + arch_tag, + stages, + sync_freq, + std::enable_if_t::value>> { + using mem_desc_a = + mem_desc_t; + using mem_desc_b = + mem_desc_t; + using ld_align_attr = check_block_2d_pitch_alignment< + dtype_a, + dtype_b, + alignment_a, + alignment_b, + arch_tag>; + using compute_attr = compute_attr_t; + using perf_tuning_knob = perf_tuning_knob_t; + using compute_policy = + compute_policy_unaligned_fpu; + using pre_processing = pre_processing_default_t; + + public: + using gemm = gemm_t< + compute_policy, + tile_shape, + mem_desc_a, + mem_desc_b, + pre_processing>; +}; + /// @} xetla_gemm -} // namespace gpu::xetla::group \ No newline at end of file +} // namespace gpu::xetla::group diff --git a/include/group/gemm/impl/unaligned_fpu_xe.hpp b/include/group/gemm/impl/unaligned_fpu_xe.hpp new file mode 100644 index 000000000..4d68fdf28 --- /dev/null +++ b/include/group/gemm/impl/unaligned_fpu_xe.hpp @@ -0,0 +1,666 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include + +namespace gpu::xetla::group { + +/// @addtogroup xetla_gemm +/// @{ + +/// @brief Is the gemm functor for unaligned input, Xe architecture and matrix +/// engine. +template < + typename compute_attr_, + typename perf_tuning_knob_, + typename tile_shape_, + typename mem_desc_a_t_, + typename mem_desc_b_t_, + typename pre_processing_t_, + gpu_arch arch_tag_> +class gemm_t< + compute_policy_unaligned_fpu, + tile_shape_, // tile shape of workgroup-level gemm + mem_desc_a_t_, // memory attribute of matA + mem_desc_b_t_, // memory attribute of matB + pre_processing_t_, // pre_processing functor + std::enable_if_t>> { + public: + using mem_desc_a_t = mem_desc_a_t_; + using mem_desc_b_t = mem_desc_b_t_; + using tile_shape = tile_shape_; + using pre_processing_t = pre_processing_t_; + using compute_policy = + compute_policy_unaligned_fpu; + + static constexpr uint32_t num_cyclic = 2; + + static constexpr uint32_t k_stride = compute_policy::k_stride; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; + using work_group_t = typename tile_shape::work_group_t; + + constexpr static gpu_arch arch_tag = arch_tag_; + + static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; + static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; + static constexpr bool is_col_major_a = + (mem_layout_a == mem_layout::col_major); + static constexpr bool is_col_major_b = + (mem_layout_b == mem_layout::col_major); + + using load_store_attr = load_store_attr_t; + static constexpr bool lda_align = + (mem_desc_a_t::alignment_in_bytes % load_store_attr::alignment_in_bytes == + 0); + static constexpr bool ldb_align = + (mem_desc_b_t::alignment_in_bytes % load_store_attr::alignment_in_bytes == + 0); + + private: + /******** set data type **********/ + using dtype_a = typename mem_desc_a_t::dtype; + using dtype_b = typename mem_desc_b_t::dtype; + using dtype_mma_acc = typename compute_policy::dtype_mma_acc; + using dtype_mma_a = typename compute_policy::dtype_mma_a; + using dtype_mma_b = typename compute_policy::dtype_mma_b; + + using check_dtype = + group::gemm::default_fpu::template check_dtype_default< + dtype_a, + dtype_b, + dtype_mma_a, + dtype_mma_b, + dtype_mma_acc>; + + /******** set memory attribute **********/ + static constexpr mem_space mem_space_a = mem_desc_a_t::space; + static constexpr mem_space mem_space_b = mem_desc_b_t::space; + + static constexpr bool is_local_a = mem_space_a == mem_space::local; + static constexpr bool is_local_b = mem_space_b == mem_space::local; + static constexpr tdesc_update_dir update_dir_a = + is_col_major_a ? tdesc_update_dir::y_dir : tdesc_update_dir::x_dir; + static constexpr tdesc_update_dir update_dir_b = + is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; + + using check_memory = + group::gemm::default_fpu::template check_memory_default< + mem_layout_a, + mem_layout_b, + mem_space_a, + mem_space_b>; + + static constexpr uint32_t stages = compute_policy::stages; + static constexpr uint32_t sync_freq = compute_policy::sync_freq; + + /******** set tile layout && worker scope **********/ + static constexpr uint32_t tile_size_x_a = k_stride; + static constexpr uint32_t tile_size_y_a = sg_tile_m; + static constexpr uint32_t tile_size_x_b = sg_tile_n; + static constexpr uint32_t tile_size_y_b = k_stride; + static constexpr uint32_t tile_size_x_c = sg_tile_n; + static constexpr uint32_t tile_size_y_c = sg_tile_m; + + static constexpr uint32_t block_size_x_a = + (compute_policy::block_size_x_a > tile_size_x_a) + ? tile_size_x_a + : compute_policy::block_size_x_a; + static constexpr uint32_t block_size_y_a = + (compute_policy::block_size_y_a > tile_size_y_a) + ? tile_size_y_a + : compute_policy::block_size_y_a; + static constexpr uint32_t block_size_x_b = + (compute_policy::block_size_x_b > tile_size_x_b) + ? tile_size_x_b + : compute_policy::block_size_x_b; + static constexpr uint32_t block_size_y_b = + (compute_policy::block_size_y_b > tile_size_y_b) + ? tile_size_y_b + : compute_policy::block_size_y_b; + + using check_tile_size = + group::gemm::default_fpu::template check_tile_size_default< + dtype_mma_a, + tile_size_x_a, + tile_size_y_a, + block_size_x_a, + block_size_y_a, + tile_size_x_b, + tile_size_y_b, + block_size_x_b, + block_size_y_b>; + + /******** set tile **********/ + static constexpr reg_layout reg_layout_a = reg_layout::tiled; + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; + + using matA_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_a, + tile_size_y_a, + block_size_x_a, + block_size_y_a, + reg_layout_a>; + + using matA_t = subgroup::tile_t; + + using cooperative_helper_A_t = subgroup::cooperative_load_helper_t< + matA_t, + mem_layout_a, + tile_shape::wg_size_x, + arch_tag>; + using cooperative_tile_desc_A_t = + typename cooperative_helper_A_t::co_tile_desc_t; + using partial_matA_t = subgroup::tile_t; + using matA_payload_t = subgroup::mem_payload_t< + mem_desc_a_t, + cooperative_tile_desc_A_t, + is_local_a ? msg_type::scatter + : lda_align ? msg_type::block_2d + : msg_type::unaligned_2d, + arch_tag>; + + using matA_payload_local_st_t = subgroup::mem_payload_t< + mem_desc_t, + cooperative_tile_desc_A_t, + msg_type::scatter, + arch_tag>; + using matA_payload_local_ld_t = subgroup::mem_payload_t< + mem_desc_t, + matA_tile_desc_t, + msg_type::scatter, + arch_tag>; + + using matA_acc_t = subgroup::tile_t; + using matA_prefetch_payload_t = subgroup::prefetch_payload_t< + mem_desc_a_t, + subgroup::tile_desc_t, + wg_size_x, + arch_tag>; + matA_prefetch_payload_t matA_prefetch_payload; + + static constexpr reg_layout reg_layout_b = reg_layout::tiled; + using matB_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_b, + tile_size_y_b, + block_size_x_b, + block_size_y_b, + reg_layout_b>; + using matB_t = subgroup::tile_t; + + using cooperative_helper_B_t = subgroup::cooperative_load_helper_t< + matB_t, + mem_layout_b, + tile_shape::wg_size_y, + arch_tag>; + using cooperative_tile_desc_B_t = + typename cooperative_helper_B_t::co_tile_desc_t; + + using partial_matB_t = subgroup::tile_t; + + using matB_payload_t = subgroup::mem_payload_t< + mem_desc_b_t, + cooperative_tile_desc_B_t, + is_local_b ? msg_type::scatter + : ldb_align ? msg_type::block_2d + : msg_type::unaligned_2d, + arch_tag>; + + using matB_payload_local_st_t = subgroup::mem_payload_t< + mem_desc_t, + cooperative_tile_desc_B_t, + msg_type::scatter, + arch_tag>; + using matB_payload_local_ld_t = subgroup::mem_payload_t< + mem_desc_t, + matB_tile_desc_t, + msg_type::scatter, + arch_tag>; + + using matB_acc_t = subgroup::tile_t; + using matB_prefetch_payload_t = subgroup::prefetch_payload_t< + mem_desc_b_t, + subgroup::tile_desc_t, + wg_size_y, + arch_tag>; + matB_prefetch_payload_t matB_prefetch_payload; + + public: + using matAcc_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_c, + tile_size_y_c, + block_size_x_b, + block_size_y_a, + reg_layout::tiled>; + using matAcc_t = subgroup::tile_t; + + private: + using tile_mma = subgroup::tile_mma_t< + matAcc_t, + matAcc_t, + matB_acc_t, + matA_acc_t, + mma_engine::fpu, + arch_tag>; + // static constexpr bool enable_periodic_sync = (sync_freq != 0); + static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; + static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr uint32_t tile_size_a = + tile_size_x_a * tile_size_y_a * sizeof(dtype_a); + static constexpr uint32_t tile_size_b = + tile_size_x_b * tile_size_y_b * sizeof(dtype_b); + static constexpr uint32_t slm_size_a = wg_size_y * tile_size_a; + static constexpr uint32_t slm_size_b = wg_size_x * tile_size_b; + + public: + static constexpr uint32_t barrier_count = + arch_has_named_barrier ? barrier_count_x + barrier_count_y : 0; + + static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic; + static_assert(slm_size <= arch_attr_t::local_mem_size); + static constexpr uint32_t slm_base_a = 0; + static constexpr uint32_t slm_base_b = slm_size_a * num_cyclic; + + static constexpr msg_type msg_type_a = matA_payload_t::message_type; + static constexpr msg_type msg_type_b = matB_payload_t::message_type; + + using pre_processing_arg_t = typename pre_processing_t::arguments_t; + + /// @brief Arguments for gemm. + /// User should prepare matA_base_desc, matB_base_desc, inner_loop_count... + struct arguments_t { + /// @brief Is the memory description of matA, including base, shape and + /// coordinate. + mem_desc_a_t matA_base_desc; + /// @brief Is the memory description of matB, including base, shape and + /// coordinate. + mem_desc_b_t matB_base_desc; + /// @brief Is the total inner loop count required to compute the entire + /// K-dim. + uint32_t inner_loop_count; + /// @brief Is the arguments for pre-processing functor. + pre_processing_arg_t pre_processing_args; + + /// @brief Default construct. + inline arguments_t() = default; + // Be aware of the risks: Rule of three (copy constructor, copy assignment, + // destructor) Please check if you need to add self-define destructor + // ~arguments_t(){} + + /// @brief Constructs a new arguments t object. + /// @param matA_desc Is the memory description of matA, including base, + /// shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, + /// shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the + /// entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline arguments_t( + mem_desc_a_t matA_desc, + mem_desc_b_t matB_desc, + uint32_t loop_count, + pre_processing_arg_t args = {}) + : matA_base_desc(matA_desc), + matB_base_desc(matB_desc), + inner_loop_count(loop_count), + pre_processing_args(args) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, + // destructor) Please check if you need to add self-define destructor inline + // ~arguments_t(){} + inline arguments_t(const arguments_t& args) + : matA_base_desc(args.matA_base_desc), + matB_base_desc(args.matB_base_desc), + inner_loop_count(args.inner_loop_count), + pre_processing_args(args.pre_processing_args) {} + inline arguments_t& operator=(const arguments_t& args) { + this->matA_base_desc = args.matA_base_desc; + this->matB_base_desc = args.matB_base_desc; + this->inner_loop_count = args.inner_loop_count; + this->pre_processing_args = args.pre_processing_args; + return *this; + } + + /// @brief Explicit initialization function. + /// @param matA_desc Is the memory description of matA, including base, + /// shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, + /// shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the + /// entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline void init( + mem_desc_a_t matA_desc, + mem_desc_b_t matB_desc, + uint32_t loop_count, + pre_processing_arg_t args = {}) { + matA_base_desc = matA_desc; + matB_base_desc = matB_desc; + inner_loop_count = loop_count; + pre_processing_args = args; + } + }; + + inline void sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + + inline void sync_arrive() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive(); + } + } + + inline void sync_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.wait(); + } + } + + inline void sync_arrive_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive_wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive_wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive_wait(); + } + } + + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + + /// @brief Gets the subgroup-level tile offset x. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset x. + __XETLA_API static int get_matC_offset_x(work_group_t& g) { + int32_t sg_idx = g.get_id() % wg_size_x; + return sg_idx * sg_tile_n; + } + + /// @brief Gets the subgroup-level tile offset y. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset y. + __XETLA_API static int get_matC_offset_y(work_group_t& g) { + int32_t sg_idy = g.get_id() / wg_size_x; + return sg_idy * sg_tile_m; + } + + XETLA_MARKER( + "This release function will wait until all the r/w and nbarrier " + "id used in this gemm have been committed. By default, it will " + "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. " + "If you call this function, please set a free barrier id or make " + "sure barrier_id 0 is not being occupied and you need to allocate " + "one more barrier count in addition to the gemm barrier counts.") + __XETLA_API void release(uint8_t nbarrier_id = 0) { + static constexpr bool need_local_fence = + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + if constexpr (need_local_fence) { + xetla_fence(); + } + xetla_fence(); + if constexpr (wg_size > 1) { + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); + } + } + + /// @brief Main execution function for gemm. + /// The basic process is load data -> matrix multiply. + /// @param g Is the workgroup of the current tile. + /// @param matAcc Is the reference of the accumulation buffer. + /// @param args Is the gemm::arguments_t. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + __XETLA_API KERNEL_FUNC void operator()( + work_group_t& g, + matAcc_t& matAcc, + arguments_t args, + uint32_t slm_base = 0, + uint32_t nbarrier_base = 0) { + int32_t sg_idx = g.get_id() % wg_size_x; + int32_t sg_idy = g.get_id() / wg_size_x; + + XETLA_ASSERT( + g.get_id() < wg_size, + "Thread id(%d) should less than wg_size(%d)", + g.get_id(), + wg_size); + + update_sg_tile_tdesc(args, sg_idx, sg_idy); + pre_processing_t pre_processing; + matA_t matA; + matB_t matB; + matA_acc_t matA_acc; + matB_acc_t matB_acc; + partial_matA_t partial_matA; + partial_matB_t partial_matB; + // >>>>>>>>>>>>>>>>>> pre_processing init + pre_processing.init(g, args.pre_processing_args); + uint32_t base_A = slm_base + slm_base_a + sg_idy * tile_size_a; + uint32_t base_B = slm_base + slm_base_b + sg_idx * tile_size_b; + + uint32_t store_idx = 0; + uint32_t load_idx = 0; + + matA_payload_t matA_payload(args.matA_base_desc); + matA_payload_local_st_t matA_local_st_payload( + base_A, + tile_size_x_a, + tile_size_y_a, + tile_size_x_a, + cooperative_helper_A_t::get_offset_x(sg_idx), + cooperative_helper_A_t::get_offset_y(sg_idx)); + matA_payload_local_ld_t matA_local_ld_payload( + base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, 0, 0); + + matB_payload_t matB_payload(args.matB_base_desc); + matB_payload_local_st_t matB_local_st_payload( + base_B, + tile_size_x_b, + tile_size_y_b, + tile_size_x_b, + cooperative_helper_B_t::get_offset_x(sg_idy), + cooperative_helper_B_t::get_offset_y(sg_idy)); + matB_payload_local_ld_t matB_local_ld_payload( + base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, 0, 0); + + matA_prefetch_payload.init(args.matA_base_desc, sg_idx); + matB_prefetch_payload.init(args.matB_base_desc, sg_idy); + + sync_init(sg_idx, sg_idy, nbarrier_base); + +#pragma unroll + for (uint32_t i = 0; i < stages; i++) { + prefetch_and_update_ab(); + } + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx++; + +#pragma unroll + for (uint32_t i = 1; i < num_cyclic - 1; i++) { + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + + matA_local_st_payload.template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_st_payload.template update_tdesc( + wg_size_x * matB_t::tile_size_y); + + if constexpr (stages != 0) { + prefetch_and_update_ab(); + } + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx++; + } + + xetla_fence(); + sync_arrive_wait(); + + for (uint32_t i = 0; i < args.inner_loop_count - 1; i++) { + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + + load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; + if (load_idx != 0) { + matA_local_ld_payload.template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_ld_payload.template update_tdesc( + wg_size_x * matB_t::tile_size_y); + } else { + matA_local_ld_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_y * matA_t::tile_size_y); + matB_local_ld_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); + } + + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::elemwise_cvt(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + + if (store_idx != 0) { + matA_local_st_payload.template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_st_payload.template update_tdesc( + wg_size_x * matB_t::tile_size_y); + } else { + matA_local_st_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_y * matA_t::tile_size_y); + matB_local_st_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); + } + + if constexpr (stages != 0) { + prefetch_and_update_ab(); + } + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; + + xetla_fence(); + sync_arrive_wait(); + } + + SW_BARRIER(); + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::elemwise_cvt(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + } + + private: + /// @brief Updates tile base descriptor based on the tid. + __XETLA_API static void update_sg_tile_tdesc( + arguments_t& args, + int32_t sg_idx, + int32_t sg_idy) { + int32_t tile_offset_n = sg_idx * sg_tile_n; + int32_t tile_offset_m = sg_idy * sg_tile_m; + + args.matA_base_desc.update_coord_y( + tile_offset_m + cooperative_helper_A_t::get_offset_y(sg_idx)); + args.matA_base_desc.update_coord_x( + cooperative_helper_A_t::get_offset_x(sg_idx)); + args.matB_base_desc.update_coord_x( + tile_offset_n + cooperative_helper_B_t::get_offset_x(sg_idy)); + args.matB_base_desc.update_coord_y( + cooperative_helper_B_t::get_offset_y(sg_idy)); + } +}; + +/// @} xetla_gemm + +} // namespace gpu::xetla::group diff --git a/include/group/gemm/impl/unaligned_xmx_xe.hpp b/include/group/gemm/impl/unaligned_xmx_xe.hpp index 19e4c89b3..38d3ec837 100644 --- a/include/group/gemm/impl/unaligned_xmx_xe.hpp +++ b/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -52,21 +52,32 @@ class gemm_t< using compute_policy = compute_policy_unaligned_xmx; - static constexpr uint32_t num_cyclic = 3; + static constexpr uint32_t num_cyclic = 2; static constexpr uint32_t k_stride = compute_policy::k_stride; static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; using work_group_t = typename tile_shape::work_group_t; constexpr static gpu_arch arch_tag = arch_tag_; static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; - static constexpr bool is_col_major_a = mem_layout_a == mem_layout::col_major; - static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; + static constexpr bool is_col_major_a = + (mem_layout_a == mem_layout::col_major); + static constexpr bool is_col_major_b = + (mem_layout_b == mem_layout::col_major); + + using load_store_attr = load_store_attr_t; + static constexpr bool lda_align = + (mem_desc_a_t::alignment_in_bytes % load_store_attr::alignment_in_bytes == + 0); + static constexpr bool ldb_align = + (mem_desc_b_t::alignment_in_bytes % load_store_attr::alignment_in_bytes == + 0); private: /******** set data type **********/ @@ -76,7 +87,7 @@ class gemm_t< using dtype_mma_a = typename compute_policy::dtype_mma_a; using dtype_mma_b = typename compute_policy::dtype_mma_b; - using check_dtype = group::gemm::default_xmx:: + using check_dtype = group::gemm::default_xmx:: template check_dtype_default; /******** set memory attribute **********/ @@ -91,7 +102,7 @@ class gemm_t< is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; using check_memory = - group::gemm::default_xmx::template check_memory_default< + group::gemm::default_xmx::template check_memory_default< mem_layout_a, mem_layout_b, mem_space_a, @@ -116,7 +127,7 @@ class gemm_t< static constexpr uint32_t block_size_y_b = compute_policy::block_size_y_b; using check_tile_size = - group::gemm::default_xmx::template check_tile_size_default< + group::gemm::default_xmx::template check_tile_size_default< dtype_mma_a, tile_size_x_a, tile_size_y_a, @@ -129,6 +140,11 @@ class gemm_t< /******** set tile **********/ static constexpr reg_layout reg_layout_a = reg_layout::tiled; + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; + using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, tile_size_y_a, @@ -142,14 +158,16 @@ class gemm_t< matA_t, mem_layout_a, tile_shape::wg_size_x, - gpu_arch::XeHpc>; + arch_tag>; using cooperative_tile_desc_A_t = typename cooperative_helper_A_t::co_tile_desc_t; using partial_matA_t = subgroup::tile_t; using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, cooperative_tile_desc_A_t, - is_local_a ? msg_type::scatter : msg_type::unaligned_2d, + is_local_a ? msg_type::scatter + : lda_align ? msg_type::block_2d + : msg_type::unaligned_2d, arch_tag>; using matA_payload_local_st_t = subgroup::mem_payload_t< @@ -169,6 +187,7 @@ class gemm_t< subgroup::tile_desc_t, wg_size_x, arch_tag>; + static constexpr reg_layout reg_layout_b = sizeof(dtype_b) < sizeof(uint32_t) ? reg_layout::vnni_tiled : reg_layout::tiled; @@ -184,7 +203,7 @@ class gemm_t< matB_t, mem_layout_b, tile_shape::wg_size_y, - gpu_arch::XeHpc>; + arch_tag>; using cooperative_tile_desc_B_t = typename cooperative_helper_B_t::co_tile_desc_t; @@ -193,7 +212,9 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, cooperative_tile_desc_B_t, - is_local_b ? msg_type::scatter : msg_type::unaligned_2d, + is_local_b ? msg_type::scatter + : ldb_align ? msg_type::block_2d + : msg_type::unaligned_2d, arch_tag>; using matB_payload_local_st_t = subgroup::mem_payload_t< @@ -242,15 +263,26 @@ class gemm_t< static constexpr uint32_t slm_size_b = wg_size_x * tile_size_b; public: - static constexpr uint32_t barrier_count = barrier_count_x + barrier_count_y; + static constexpr uint32_t barrier_count = + arch_has_named_barrier ? barrier_count_x + barrier_count_y : 0; static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic; + static_assert(slm_size <= arch_attr_t::local_mem_size); static constexpr uint32_t slm_base_a = 0; - static constexpr uint32_t slm_base_b = 0 + slm_size_a * num_cyclic; + static constexpr uint32_t slm_base_b = slm_size_a * num_cyclic; static constexpr msg_type msg_type_a = matA_payload_t::message_type; static constexpr msg_type msg_type_b = matB_payload_t::message_type; + matA_payload_t matA_payload; + matA_payload_local_st_t matA_local_st_payload; + matA_payload_local_ld_t matA_local_ld_payload; + matA_prefetch_payload_t matA_prefetch_payload; + matB_payload_t matB_payload; + matB_payload_local_st_t matB_local_st_payload; + matB_payload_local_ld_t matB_local_ld_payload; + matB_prefetch_payload_t matB_prefetch_payload; + using pre_processing_arg_t = typename pre_processing_t::arguments_t; /// @brief Arguments for gemm. @@ -327,6 +359,77 @@ class gemm_t< } }; + inline void sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + + inline void sync_arrive() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive(); + } + } + + inline void sync_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.wait(); + } + } + + inline void sync_arrive_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive_wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive_wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive_wait(); + } + } + + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -350,18 +453,16 @@ class gemm_t< "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") - __XETLA_API static void release(uint8_t nbarrier_id = 0) { + __XETLA_API void release(uint8_t nbarrier_id = 0) { static constexpr bool need_local_fence = (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); if constexpr (need_local_fence) { xetla_fence(); } xetla_fence(); - static constexpr uint32_t wg_size = wg_size_x * wg_size_y; if constexpr (wg_size > 1) { - xetla_nbarrier_t nbarrier; - nbarrier.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); - nbarrier.arrive_wait(); + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); } } @@ -376,134 +477,113 @@ class gemm_t< work_group_t& g, matAcc_t& matAcc, arguments_t args, - [[maybe_unused]] uint32_t slm_base = 0, + uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { int32_t sg_idx = g.get_id() % wg_size_x; int32_t sg_idy = g.get_id() / wg_size_x; XETLA_ASSERT( - g.get_id() < (wg_size_x * wg_size_y), + g.get_id() < wg_size, "Thread id(%d) should less than wg_size(%d)", g.get_id(), - wg_size_x * wg_size_y); + wg_size); update_sg_tile_tdesc(args, sg_idx, sg_idy); pre_processing_t pre_processing; matA_t matA; matB_t matB; + matA_acc_t matA_acc; + matB_acc_t matB_acc; partial_matA_t partial_matA; partial_matB_t partial_matB; // >>>>>>>>>>>>>>>>>> pre_processing init pre_processing.init(g, args.pre_processing_args); - uint32_t base_A = slm_base_a + sg_idy * tile_size_a; - uint32_t base_B = slm_base_b + sg_idx * tile_size_b; - - uint32_t store_idx = 0; - uint32_t load_idx = 0; - matA_payload_t matA_payload(args.matA_base_desc); - matA_payload_local_st_t matA_local_st_payload( + uint32_t base_A = slm_base + slm_base_a + sg_idy * tile_size_a; + matA_payload.init(args.matA_base_desc); + matA_local_st_payload.init( base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, cooperative_helper_A_t::get_offset_x(sg_idx), cooperative_helper_A_t::get_offset_y(sg_idx)); - matA_payload_local_ld_t matA_local_ld_payload( + matA_local_ld_payload.init( base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, 0, 0); + matA_prefetch_payload.init(args.matA_base_desc, sg_idx); - matB_payload_t matB_payload(args.matB_base_desc); - matB_payload_local_st_t matB_local_st_payload( + uint32_t base_B = slm_base + slm_base_b + sg_idx * tile_size_b; + matB_payload.init(args.matB_base_desc); + matB_local_st_payload.init( base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, cooperative_helper_B_t::get_offset_x(sg_idy), cooperative_helper_B_t::get_offset_y(sg_idy)); - matB_payload_local_ld_t matB_local_ld_payload( + matB_local_ld_payload.init( base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, 0, 0); + matB_prefetch_payload.init(args.matB_base_desc, sg_idy); - matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, sg_idx); - matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, sg_idy); + sync_init(sg_idx, sg_idy, nbarrier_base); - xetla_nbarrier_t nbarrier_a; - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - xetla_nbarrier_t nbarrier_b; - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); + uint32_t store_idx = 0; + uint32_t load_idx = 0; tile_load(partial_matA, matA_payload); tile_load(partial_matB, matB_payload); + SW_BARRIER(); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); tile_store(partial_matA, matA_local_st_payload); tile_store(partial_matB, matB_local_st_payload); - store_idx++; - matA_payload.template update_tdesc(matA_t::tile_size_x); - matB_payload.template update_tdesc(matB_t::tile_size_y); - xetla_fence(); - nbarrier_a.arrive(); - if (arch_tag >= gpu_arch::XeHpc) - nbarrier_b.arrive(); #pragma unroll for (uint32_t i = 1; i < num_cyclic - 1; i++) { tile_load(partial_matA, matA_payload); tile_load(partial_matB, matB_payload); - + SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); matA_local_st_payload.template update_tdesc( wg_size_y * matA_t::tile_size_y); + SW_BARRIER(); + tile_store(partial_matA, matA_local_st_payload); matB_local_st_payload.template update_tdesc( wg_size_x * matB_t::tile_size_y); - - tile_store(partial_matA, matA_local_st_payload); + SW_BARRIER(); tile_store(partial_matB, matB_local_st_payload); store_idx++; } + xetla_fence(); + sync_arrive_wait(); + matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x * (num_cyclic - 1)); + (num_cyclic - 1) * matA_t::tile_size_x); matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y * (num_cyclic - 1)); + (num_cyclic - 1) * matB_t::tile_size_y); + #pragma unroll for (uint32_t i = 0; i < stages; i++) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } - for (uint32_t i = 0; i < args.inner_loop_count; i++) { - tile_load(partial_matA, matA_payload); - tile_load(partial_matB, matB_payload); - - matA_payload.template update_tdesc(matA_t::tile_size_x); - matB_payload.template update_tdesc(matB_t::tile_size_y); - - if constexpr (stages != 0) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - } - - nbarrier_a.wait(); - if (arch_tag >= gpu_arch::XeHpc) - nbarrier_b.wait(); - + for (uint32_t i = 0; i < args.inner_loop_count - 1; i++) { tile_load(matA, matA_local_ld_payload); tile_load(matB, matB_local_ld_payload); - load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::vnni_transform(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; if (load_idx != 0) { matA_local_ld_payload.template update_tdesc( wg_size_y * matA_t::tile_size_y); @@ -515,28 +595,18 @@ class gemm_t< matB_local_ld_payload.template update_tdesc( (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); } - xetla_fence(); + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + SW_BARRIER(); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); if constexpr (stages != 0) { - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } - nbarrier_a.arrive(); - if (arch_tag >= gpu_arch::XeHpc) - nbarrier_b.arrive(); - SW_BARRIER(); - matA_acc_t matA_acc; - matB_acc_t matB_acc; - subgroup::elemwise_cvt(matA_acc, matA); - subgroup::vnni_transform(matB_acc, matB); - pre_processing(matA_acc, matB_acc, matA, matB); - SW_BARRIER(); - tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); - SW_BARRIER(); - + store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; if (store_idx != 0) { matA_local_st_payload.template update_tdesc( wg_size_y * matA_t::tile_size_y); @@ -548,15 +618,23 @@ class gemm_t< matB_local_st_payload.template update_tdesc( (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); } - tile_store(partial_matA, matA_local_st_payload); tile_store(partial_matB, matB_local_st_payload); - store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; + + xetla_fence(); + sync_arrive_wait(); } + + SW_BARRIER(); + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::vnni_transform(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); - nbarrier_a.wait(); - if (arch_tag >= gpu_arch::XeHpc) - nbarrier_b.wait(); } private: diff --git a/include/group/global_reduction.hpp b/include/group/global_reduction.hpp index 0fe83d6a8..8289ab42a 100644 --- a/include/group/global_reduction.hpp +++ b/include/group/global_reduction.hpp @@ -64,7 +64,7 @@ class global_reduce_t< num_group_reduction, counter_size, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape_acc = tile_shape_acc_; @@ -225,7 +225,7 @@ class global_reduce_t< 1, counter_size_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using tile_shape_acc = tile_shape_acc_; diff --git a/include/group/reduction/reduction.hpp b/include/group/reduction/reduction.hpp index 2f8c34586..cde184373 100644 --- a/include/group/reduction/reduction.hpp +++ b/include/group/reduction/reduction.hpp @@ -19,5 +19,4 @@ #pragma once -#include #include diff --git a/include/group/reduction/reduction_api.hpp b/include/group/reduction/reduction_api.hpp index 4b0d1865b..6d76baa21 100644 --- a/include/group/reduction/reduction_api.hpp +++ b/include/group/reduction/reduction_api.hpp @@ -40,8 +40,8 @@ template < uint32_t N, reduce_op Op, uint32_t N_SG, - bool is_all_reduce = true, - gpu_arch arch_ = gpu_arch::XeHpc> + bool is_all_reduce, + gpu_arch arch_> struct group_reduce_t {}; } // namespace gpu::xetla::group diff --git a/include/group/reduction/reduction_xe.hpp b/include/group/reduction/reduction_xe.hpp index d07873551..479d7c50e 100644 --- a/include/group/reduction/reduction_xe.hpp +++ b/include/group/reduction/reduction_xe.hpp @@ -19,8 +19,6 @@ #pragma once -#include - namespace gpu::xetla::group { template < @@ -29,10 +27,12 @@ template < uint32_t N, reduce_op Op, uint32_t N_SG, - bool is_all_reduce> -struct group_reduce_t { - group_reduce_t sg_reduce{}; - xetla_nbarrier_t nbarrier; + bool is_all_reduce, + gpu_arch arch_tag_> +struct group_reduce_t { + static constexpr gpu_arch arch_tag = arch_tag_; + group_reduce_t sg_reduce{}; + xetla_nbarrier_t nbarrier; uint32_t slm_base; uint32_t sg_id; using local_st_tile_desc = @@ -46,14 +46,14 @@ struct group_reduce_t { mem_desc_ld_t, local_ld_tile_desc, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mem_desc_st_t = mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< mem_desc_st_t, local_st_tile_desc, // subgroup::msg_type_v, msg_type::block_1d, - gpu_arch::XeHpc>; + arch_tag>; inline group_reduce_t() = default; inline group_reduce_t( uint32_t sg_id_, @@ -105,8 +105,14 @@ struct group_reduce_t { } }; -template -struct group_reduce_t { +template < + typename T, + uint32_t SZ, + uint32_t N, + reduce_op Op, + bool is_all_reduce, + gpu_arch arch_tag_> +struct group_reduce_t { inline group_reduce_t() = default; inline group_reduce_t( [[maybe_unused]] uint32_t sg_id_, @@ -119,6 +125,8 @@ struct group_reduce_t { inline void set_slm_base([[maybe_unused]] uint32_t slm_base_ = 0) {} inline KERNEL_FUNC xetla_vector operator()( xetla_vector buffer) { + if constexpr (SZ == 1) + return buffer; auto buffer_2d = buffer.xetla_format(); xetla_vector ret; #pragma unroll diff --git a/include/group/softmax/impl/softmax_bwd_xe.hpp b/include/group/softmax/impl/softmax_bwd_xe.hpp index 84d41643e..4f03cf8cc 100644 --- a/include/group/softmax/impl/softmax_bwd_xe.hpp +++ b/include/group/softmax/impl/softmax_bwd_xe.hpp @@ -26,15 +26,19 @@ namespace gpu::xetla::group { -template +template < + typename dtype_in_, + typename dtype_acc_, + typename tile_shape_, + gpu_arch arch_tag_> class softmax_t< - softmax_policy_bwd, + softmax_policy_bwd, tile_shape_> { public: using tile_shape = tile_shape_; using dtype_in = dtype_in_; using dtype_acc = dtype_acc_; - static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + static constexpr gpu_arch arch_tag = arch_tag_; private: using mem_desc_in_t = @@ -55,7 +59,7 @@ class softmax_t< reduce_op::sum, wg_size_x, true, - gpu_arch::XeHpc>; + arch_tag>; public: struct arguments_t { @@ -106,7 +110,7 @@ class softmax_t< mem_desc_in_t, mat_in_tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; int32_t sg_idx = g.get_id() % wg_size_x; int32_t sg_idy = g.get_id() / wg_size_x; diff --git a/include/group/softmax/impl/softmax_fwd_xe.hpp b/include/group/softmax/impl/softmax_fwd_xe.hpp index c5c175855..0857317bd 100644 --- a/include/group/softmax/impl/softmax_fwd_xe.hpp +++ b/include/group/softmax/impl/softmax_fwd_xe.hpp @@ -26,12 +26,12 @@ namespace gpu::xetla::group { -template -class softmax_t, tile_shape_> { +template +class softmax_t, tile_shape_> { public: using tile_shape = tile_shape_; using dtype_acc = dtype_acc_; - static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + static constexpr gpu_arch arch_tag = arch_tag_; struct arguments_t { dtype_acc sqrt_dk_inv; inline arguments_t() = default; @@ -53,7 +53,7 @@ class softmax_t, tile_shape_> { reduce_op::max, wg_size_x, true, - gpu_arch::XeHpc>; + arch_tag>; using wg_reduce_sum_t = group_reduce_t< dtype_acc, 1, @@ -61,7 +61,7 @@ class softmax_t, tile_shape_> { reduce_op::sum, wg_size_x, true, - gpu_arch::XeHpc>; + arch_tag>; public: struct get_barrier_count { diff --git a/include/group/softmax/softmax_policy.hpp b/include/group/softmax/softmax_policy.hpp index db3123523..278e2879d 100644 --- a/include/group/softmax/softmax_policy.hpp +++ b/include/group/softmax/softmax_policy.hpp @@ -23,13 +23,10 @@ namespace gpu::xetla::group { -template +template struct softmax_policy_fwd {}; -template < - typename dtype_in, - typename dtype_acc, - gpu_arch arch_tag_ = gpu_arch::XeHpc> +template struct softmax_policy_bwd {}; } // namespace gpu::xetla::group diff --git a/include/group/tile_shape.hpp b/include/group/tile_shape.hpp index 1adf68da4..79d03b10c 100644 --- a/include/group/tile_shape.hpp +++ b/include/group/tile_shape.hpp @@ -35,15 +35,18 @@ template < uint32_t sg_tile_size_x_, uint32_t sg_tile_size_y_> struct tile_shape_t { - static constexpr uint32_t wg_tile_size_x = wg_tile_size_x_; - static constexpr uint32_t wg_tile_size_y = wg_tile_size_y_; static constexpr uint32_t sg_tile_size_x = sg_tile_size_x_; static constexpr uint32_t sg_tile_size_y = sg_tile_size_y_; - static constexpr uint32_t wg_size_x = - (wg_tile_size_x + sg_tile_size_x - 1) / sg_tile_size_x; + (wg_tile_size_x_ + sg_tile_size_x - 1) / sg_tile_size_x; static constexpr uint32_t wg_size_y = - (wg_tile_size_y + sg_tile_size_y - 1) / sg_tile_size_y; + (wg_tile_size_y_ + sg_tile_size_y - 1) / sg_tile_size_y; + + static constexpr uint32_t wg_tile_size_x = wg_size_x * sg_tile_size_x; + static constexpr uint32_t wg_tile_size_y = wg_size_y * sg_tile_size_y; + + static_assert(wg_tile_size_x % sg_tile_size_x == 0); + static_assert(wg_tile_size_y % sg_tile_size_y == 0); using work_group_t = work_group_t; }; diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 0635aaec6..c5349d38e 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -37,7 +37,7 @@ template < mem_layout mem_layout_c, uint32_t alignment_c, typename dtype_acc, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> struct default_gemm_config_t : param_adaptor< @@ -69,7 +69,7 @@ template < mem_layout mem_layout_c, uint32_t alignment_c, typename dtype_acc, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> using default_gemm_t = typename default_gemm_config_t< dtype_a, @@ -93,11 +93,11 @@ struct param_optimizer { tune_key::param_optimizer_type> != dict_t_::impl::key_not_found) && (dict_t_::template find_elem_v == tune_key_value::param_optimizer_decision_tree); + static_assert( + dict_t_::impl::template find_elem_index != + dict_t_::impl::key_not_found); static constexpr auto arch_tag = - (dict_t_::impl::template find_elem_index != - dict_t_::impl::key_not_found) - ? dict_t_::template find_elem_v - : gpu_arch::XeHpc; + dict_t_::template find_elem_v; static constexpr auto optimizer_level = dict_t_::template find_elem_v; using type = typename std::conditional< @@ -170,7 +170,7 @@ template < typename dtype_acc, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> struct default_gemm_selector_config_t : param_adaptor< @@ -204,7 +204,7 @@ template < typename dtype_acc, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> using default_gemm_selector_t = typename default_gemm_selector_config_t< dtype_a, @@ -228,7 +228,7 @@ template < mem_space mem_space_c, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> struct default_epilogue_selector_config_t : param_adaptor< @@ -252,7 +252,7 @@ template < mem_space mem_space_c, typename wg_shape, uint32_t wg_tile_k, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, typename tune_option = dict_t<>> using default_epilogue_selector_t = typename default_epilogue_selector_config_t< dtype_c, @@ -274,11 +274,11 @@ struct param_optimizer { tune_key_value::param_optimizer_decision_tree); static constexpr auto optimizer_level = dict_t_::template find_elem_v; + static_assert( + dict_t_::impl::template find_elem_index != + dict_t_::impl::key_not_found); static constexpr auto arch_tag = - (dict_t_::impl::template find_elem_index != - dict_t_::impl::key_not_found) - ? dict_t_::template find_elem_v - : gpu_arch::XeHpc; + dict_t_::template find_elem_v; using type = typename std::conditional< use_rule, decision_tree_optimizer< @@ -314,6 +314,12 @@ struct param_adaptor static constexpr auto mem_alignment_b = param::template find_elem_v; + using ld_align_attr = group::check_block_2d_pitch_alignment< + dtype_a, + dtype_b, + mem_alignment_a, + mem_alignment_b, + base_t::gpu_arch_tag>; using compute_attr = group::compute_attr_t; @@ -327,12 +333,7 @@ struct param_adaptor elem_t_t< mma_engine::xmx, typename std::conditional< - (group::detail::check_2d_block_pitch_alignment< - dtype_a, - dtype_b, - mem_alignment_a, - mem_alignment_b, - base_t::gpu_arch_tag>::value), + (ld_align_attr::value), group::compute_policy_default_xmx< compute_attr, perf_tuning_knob, @@ -344,17 +345,16 @@ struct param_adaptor elem_t_t< mma_engine::fpu, typename std::conditional< - (group::detail::check_2d_block_pitch_alignment< - dtype_a, - dtype_b, - mem_alignment_a, - mem_alignment_b, - base_t::gpu_arch_tag>::value), + (ld_align_attr::value), group::compute_policy_default_fpu< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag>, - void>::type>>::template find_elem_t::type; + group::compute_policy_unaligned_fpu< + compute_attr, + perf_tuning_knob, + base_t::gpu_arch_tag>>::type>>:: + template find_elem_t::type; using mem_desc_input_a = mem_desc_t; diff --git a/include/kernel/gemm/dispatch_policy.hpp b/include/kernel/gemm/dispatch_policy.hpp index 28eaf7840..6432082ec 100644 --- a/include/kernel/gemm/dispatch_policy.hpp +++ b/include/kernel/gemm/dispatch_policy.hpp @@ -148,7 +148,7 @@ struct dispatch_policy_kslicing { /// and performs inter-group reduction. Implementation loosely based on this /// paper - https://arxiv.org/pdf/2301.03598.pdf /// @tparam arch_tag_ Is the HW architecture. -template +template struct dispatch_policy_stream_k { static constexpr gpu_arch arch_tag = arch_tag_; diff --git a/include/kernel/gemm/gemm_preset.hpp b/include/kernel/gemm/gemm_preset.hpp index f24af3dba..585f46924 100644 --- a/include/kernel/gemm/gemm_preset.hpp +++ b/include/kernel/gemm/gemm_preset.hpp @@ -46,7 +46,7 @@ using param_performance_default = dict_t< elem_v_t, elem_v_t>; -template +template using param_runtime_default = dict_t< elem_v_t, elem_v_t, @@ -61,7 +61,8 @@ using param_runtime_default = dict_t< tune_key::group_swizzle_policy, kernel::group_swizzle_default>>; } // namespace detail -template + +template using default_param_t = dict_t<>::template update_dict_t< detail::param_dtype_bf16_bf16_bf16>:: template update_dict_t::template update_dict_t< @@ -88,7 +89,7 @@ using default_param_t = dict_t<>::template update_dict_t< param_optimizer_level>>; namespace kernel { -template +template using param_kslicing_g1l1_t = default_param_t::template update_t< elem_v_t, elem_v_t, @@ -99,7 +100,7 @@ using param_kslicing_g1l1_t = default_param_t::template update_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing>>; -template +template using param_kslicing_g2l1_t = default_param_t::template update_t< elem_v_t, elem_v_t, @@ -110,7 +111,7 @@ using param_kslicing_g2l1_t = default_param_t::template update_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing>>; -template +template using param_kslicing_g1l2_t = default_param_t::template update_t< elem_v_t, elem_v_t, @@ -124,7 +125,7 @@ using param_kslicing_g1l2_t = default_param_t::template update_t< } // namespace kernel namespace group { -template +template using param_dict1_wg_t = default_param_t::template update_t< elem_t_t, elem_t_t>, diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index 5d9ed02c7..6bec4a295 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -37,7 +37,7 @@ class gemm_universal_t< dispatch_policy_default, gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index ecd2afb38..e07f0b8ef 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -48,7 +48,7 @@ class gemm_universal_t< num_local_kslicing_>, gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; diff --git a/include/subgroup/cooperative_load_helper.hpp b/include/subgroup/cooperative_load_helper.hpp index ab69e5bc1..79ddc69b9 100644 --- a/include/subgroup/cooperative_load_helper.hpp +++ b/include/subgroup/cooperative_load_helper.hpp @@ -46,7 +46,7 @@ class cooperative_load_helper_t< mem_layout::row_major, num_cooperative_wg, arch_tag_, - std::enable_if_t> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using matAcc_t = matAcc_t_; @@ -112,7 +112,7 @@ class cooperative_load_helper_t< mem_layout::col_major, num_cooperative_wg, arch_tag_, - std::enable_if_t> { + std::enable_if_t>> { public: static constexpr gpu_arch arch_tag = arch_tag_; using matAcc_t = matAcc_t_; diff --git a/include/subgroup/tile/api.hpp b/include/subgroup/tile/api.hpp index 32019bac1..1a428ad3d 100644 --- a/include/subgroup/tile/api.hpp +++ b/include/subgroup/tile/api.hpp @@ -82,8 +82,10 @@ struct tile_desc_t { static constexpr uint32_t tile_size_x = tile_size_x_; static constexpr uint32_t tile_size_y = tile_size_y_; - static constexpr uint32_t block_size_x = block_size_x_; - static constexpr uint32_t block_size_y = block_size_y_; + static constexpr uint32_t block_size_x = + (tile_size_x > block_size_x_) ? block_size_x_ : tile_size_x; + static constexpr uint32_t block_size_y = + (tile_size_y > block_size_y_) ? block_size_y_ : tile_size_y; static constexpr uint32_t remained_size_y = tile_size_y % block_size_y; static constexpr reg_layout register_layout = reg_layout_; diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 0a30a1416..a0ef642b5 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -336,17 +336,20 @@ template < reg_layout reg_layout_ = reg_layout::tiled> struct get_load_block_size_auto {}; -template +template < + typename dtype, + uint32_t tile_size_x, + uint32_t tile_size_y, + gpu_arch arch_tag> struct get_load_block_size_auto< dtype, tile_size_x, tile_size_y, - gpu_arch::XeHpc, + arch_tag, mem_layout::row_major, reg_layout::tiled> { private: - using load_store_attr = arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; static constexpr uint32_t max_load_width_in_bytes = @@ -373,17 +376,20 @@ template < reg_layout reg_layout_ = reg_layout::tiled> struct get_store_block_size_auto {}; -template +template < + typename dtype, + uint32_t tile_size_x, + uint32_t tile_size_y, + gpu_arch arch_tag> struct get_store_block_size_auto< dtype, tile_size_x, tile_size_y, - gpu_arch::XeHpc, + arch_tag, mem_layout::row_major, reg_layout::tiled> { private: - using load_store_attr = arch_attr_t< - gpu_arch::XeHpc>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_height_in_elem = load_store_attr::max_store_height_in_elem; static constexpr uint32_t max_store_width_in_bytes = diff --git a/include/subgroup/tile/impl/blk_mma.hpp b/include/subgroup/tile/impl/blk_mma.hpp new file mode 100644 index 000000000..1bc945509 --- /dev/null +++ b/include/subgroup/tile/impl/blk_mma.hpp @@ -0,0 +1,406 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include + +namespace gpu::xetla::subgroup { + +template < + typename dtype_dst, + typename dtype_src, + typename dtype_b, + typename dtype_a, + int blk_m, + int blk_n, + int blk_k, + int mma_m, + reg_layout a_reg_layout, + reg_layout b_reg_layout, + mma_engine engine_tag> +struct blk_mma_t {}; + +template < + typename dtype_dst, + typename dtype_src, + typename dtype_b, + typename dtype_a, + int blk_m, + int blk_n, + int blk_k, + int mma_m> +struct blk_mma_t< + dtype_dst, + dtype_src, + dtype_b, + dtype_a, + blk_m, + blk_n, + blk_k, + mma_m, + reg_layout::transpose_tiled, + reg_layout::tiled, + mma_engine::fpu> { + static constexpr uint32_t blk_m_iters = blk_m / mma_m; + static constexpr uint32_t tail_m = blk_m % mma_m; + static constexpr uint32_t tail_start_m = blk_m_iters * mma_m; + static constexpr uint32_t a_block_elems = blk_m * blk_k; + + __XETLA_API static void mma_core( + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + xetla_vector new_a_block = + xetla_cvt( + a_block.xetla_select(0)); + + if constexpr (blk_m_iters > 0) { +#pragma unroll + for (uint32_t i = 0; i < blk_m_iters; i++) { + xetla_vector dst_tmp; + auto dst_tmp_2d = dst_tmp.xetla_format(); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + dst_tmp_2d.row(i_acc) = + new_a_block[i_acc + i * mma_m] * b_blk_2d.row(0) + + src_blk_2d.row(i_acc + i * mma_m); + } +#pragma unroll + for (uint32_t k = 1; k < blk_k - 1; k++) { + auto b_blk_k = b_blk_2d.row(k); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = k * blk_m + i_acc + i * mma_m; + dst_tmp_2d.row(i_acc) += new_a_block[a_offset] * b_blk_k; + } + } +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m; + dst_blk_2d.row(i_acc + i * mma_m) = + new_a_block[a_offset] * b_blk_2d.row(blk_k - 1) + + dst_tmp_2d.row(i_acc); + } + SW_BARRIER(); + } + } + + if constexpr (tail_m != 0) { + xetla_vector dst_tmp; + auto dst_tmp_2d = dst_tmp.xetla_format(); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + dst_tmp_2d.row(i_acc) = + new_a_block[i_acc + tail_start_m] * b_blk_2d.row(0) + + src_blk_2d.row(i_acc + tail_start_m); + } +#pragma unroll + for (uint32_t k = 1; k < blk_k - 1; k++) { + auto b_blk_k = b_blk_2d.row(k); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + int a_offset = k * blk_m + i_acc + tail_start_m; + dst_tmp_2d.row(i_acc) += new_a_block[a_offset] * b_blk_k; + } + } +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + int a_offset = (blk_k - 1) * blk_m + i_acc + tail_start_m; + dst_blk_2d.row(i_acc + tail_start_m) = + new_a_block[a_offset] * b_blk_2d.row(blk_k - 1) + + dst_tmp_2d.row(i_acc); + } + } + } +}; + +template < + typename dtype_dst, + typename dtype_src, + typename dtype_b, + typename dtype_a, + int blk_m, + int blk_n, + int blk_k, + int mma_m> +struct blk_mma_t< + dtype_dst, + dtype_src, + dtype_b, + dtype_a, + blk_m, + blk_n, + blk_k, + mma_m, + reg_layout::tiled, + reg_layout::tiled, + mma_engine::fpu> { + static constexpr uint32_t blk_m_iters = blk_m / mma_m; + static constexpr uint32_t tail_m = blk_m % mma_m; + static constexpr uint32_t tail_start_m = blk_m_iters * mma_m; + static constexpr uint32_t a_block_elems = blk_m * blk_k; + + __XETLA_API static void mma_core( + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + xetla_vector new_a_block = + xetla_cvt( + a_block.xetla_select(0)); + + if constexpr (blk_m_iters > 0) { +#pragma unroll + for (uint32_t i = 0; i < blk_m_iters; i++) { + auto b_blk_k0 = b_blk_2d.row(0); + int32_t a_start_off = i * mma_m * blk_k; +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + dst_blk_2d.row(i_acc + i * mma_m) = + src_blk_2d.row(i_acc + i * mma_m) + + new_a_block[a_start_off + i_acc * blk_k] * b_blk_k0; + } + +#pragma unroll + for (uint32_t k = 1; k < blk_k; k++) { + auto b_blk_k = b_blk_2d.row(k); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + dst_blk_2d.row(i_acc + i * mma_m) += + new_a_block[a_start_off + i_acc * blk_k + k] * b_blk_k; + } + } + SW_BARRIER(); + } + } + + if constexpr (tail_m != 0) { + auto b_blk_k0 = b_blk_2d.row(0); + int32_t a_start_off = tail_start_m * blk_k; +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + dst_blk_2d.row(i_acc + tail_start_m) = + src_blk_2d.row(i_acc + tail_start_m) + + new_a_block[a_start_off + i_acc * blk_k] * b_blk_k0; + } +#pragma unroll + for (uint32_t k = 1; k < blk_k; k++) { + auto b_blk_k = b_blk_2d.row(k); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + dst_blk_2d.row(i_acc + tail_start_m) += + new_a_block[a_start_off + i_acc * blk_k + k] * b_blk_k; + } + } + } + } +}; + +template < + typename dtype_dst, + typename dtype_src, + typename dtype_b, + typename dtype_a, + int blk_m, + int blk_n, + int blk_k, + int mma_m> +struct blk_mma_t< + dtype_dst, + dtype_src, + dtype_b, + dtype_a, + blk_m, + blk_n, + blk_k, + mma_m, + reg_layout::tiled, + reg_layout::transpose_tiled, + mma_engine::fpu> { + static constexpr uint32_t blk_m_iters = blk_m / mma_m; + static constexpr uint32_t tail_m = blk_m % mma_m; + static constexpr uint32_t a_block_elems = blk_m * blk_k; + static constexpr uint32_t tail_start_m = blk_m_iters * mma_m; + + __XETLA_API static void mma_core( + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + xetla_vector new_a_block = + xetla_cvt( + a_block.xetla_select(0)); + auto a_blk_2d = new_a_block.xetla_format(); + + if constexpr (blk_m_iters > 0) { +#pragma unroll + for (uint32_t i = 0; i < blk_m_iters; i++) { + auto i_start_m = i * mma_m; +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + auto moffset = i_acc + i_start_m; + auto dst_row = dst_blk_2d.row(moffset); + auto src_row = src_blk_2d.row(moffset); + auto a_blk_m = a_blk_2d.row(moffset); + xetla_vector tmp_k = a_blk_m * b_blk_2d.row(0); + dst_row[0] = src_row[0] + + xetla_reduce(tmp_k); +#pragma unroll + for (uint32_t j = 1; j < blk_n; j++) { + tmp_k = a_blk_m * b_blk_2d.row(j); + dst_row[j] = src_row[j] + + xetla_reduce( + tmp_k); + } + } + } + } + + if constexpr (tail_m != 0) { +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + auto moffset = i_acc + tail_start_m; + auto dst_row = dst_blk_2d.row(moffset); + auto src_row = src_blk_2d.row(moffset); + auto a_blk_m = a_blk_2d.row(moffset); + xetla_vector tmp_k = a_blk_m * b_blk_2d.row(0); + dst_row[0] = src_row[0] + + xetla_reduce(tmp_k); +#pragma unroll + for (uint32_t j = 1; j < blk_n; j++) { + tmp_k = a_blk_m * b_blk_2d.row(j); + dst_row[j] = src_row[j] + + xetla_reduce(tmp_k); + } + } + } + } +}; + +template < + typename dtype_dst, + typename dtype_src, + typename dtype_b, + typename dtype_a, + int blk_m, + int blk_n, + int blk_k, + int mma_m> +struct blk_mma_t< + dtype_dst, + dtype_src, + dtype_b, + dtype_a, + blk_m, + blk_n, + blk_k, + mma_m, + reg_layout::transpose_tiled, + reg_layout::transpose_tiled, + mma_engine::fpu> { + static constexpr uint32_t blk_m_iters = blk_m / mma_m; + static constexpr uint32_t tail_m = blk_m % mma_m; + static constexpr uint32_t a_block_elems = blk_m * blk_k; + static constexpr uint32_t tail_start_m = blk_m_iters * mma_m; + + __XETLA_API static void mma_core( + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + xetla_vector new_a_block = + xetla_cvt( + a_block.xetla_select(0)); + + if constexpr (blk_m_iters > 0) { +#pragma unroll + for (uint32_t i = 0; i < blk_m_iters; i++) { + xetla_vector dst_tmp; + auto dst_tmp_2d = dst_tmp.xetla_format(); + auto b_blk_k0 = b_blk_2d.column(0); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + dst_tmp_2d.row(i_acc) = new_a_block[i_acc + i * mma_m] * b_blk_k0 + + src_blk_2d.row(i_acc + i * mma_m); + } +#pragma unroll + for (uint32_t k = 1; k < blk_k - 1; k++) { + auto b_blk_k = b_blk_2d.column(k); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = k * blk_m + i_acc + i * mma_m; + dst_tmp_2d.row(i_acc) += new_a_block[a_offset] * b_blk_k; + } + } + auto b_blk_k_last = b_blk_2d.column(blk_k - 1); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m; + dst_blk_2d.row(i_acc + i * mma_m) = + new_a_block[a_offset] * b_blk_k_last + dst_tmp_2d.row(i_acc); + } + SW_BARRIER(); + } + } + + if constexpr (tail_m != 0) { + xetla_vector dst_tmp; + auto dst_tmp_2d = dst_tmp.xetla_format(); + auto b_blk_k0 = b_blk_2d.column(0); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + dst_tmp_2d.row(i_acc) = new_a_block[i_acc + tail_start_m] * b_blk_k0 + + src_blk_2d.row(i_acc + tail_start_m); + } +#pragma unroll + for (uint32_t k = 1; k < blk_k - 1; k++) { + auto b_blk_k = b_blk_2d.column(k); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + int a_offset = k * blk_m + i_acc + tail_start_m; + dst_tmp_2d.row(i_acc) += new_a_block[a_offset] * b_blk_k; + } + } + auto b_blk_k_last = b_blk_2d.column(blk_k - 1); +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + int a_offset = (blk_k - 1) * blk_m + i_acc + tail_start_m; + dst_blk_2d.row(i_acc + tail_start_m) = + new_a_block[a_offset] * b_blk_k_last + dst_tmp_2d.row(i_acc); + } + } + } +}; + +} // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index e81d8d7df..ad79dc47c 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -19,7 +19,7 @@ #pragma once -#include +#include "./blk_mma.hpp" namespace gpu::xetla::subgroup { @@ -39,9 +39,6 @@ struct tile_fma_t { using dtype_b = typename matB_t::dtype; using dtype_acc = typename matAcc_t_::dtype; - using register_attr = - typename arch_attr_t::template register_attr<>; - static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; static constexpr uint32_t a_tile_elems = matA_t::tile_elems; @@ -191,16 +188,6 @@ struct tile_mma_t< using dtype_src = typename matSrc_t::dtype; using dtype_dst = typename matDst_t::dtype; - using register_attr = - typename arch_attr_t::template register_attr<>; - - static_assert( - matA_t::reg_transpose, - "For FMAOp GEMM, the register layout of matA should be col-major"); - static_assert( - !matB_t::reg_transpose, - "For FMAOp GEMM, the register layout of matB should be row-major"); - static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; static constexpr uint32_t a_tile_elems = matA_t::tile_elems; @@ -223,6 +210,11 @@ struct tile_mma_t< static constexpr uint32_t block_size_k = a_block_size_x; static constexpr uint32_t block_size_m = matDst_t::block_size_y; static constexpr uint32_t block_elems = block_size_m * block_size_n; + static constexpr uint32_t blk_m_iters = tile_size_m / block_size_m; + static constexpr uint32_t tail_start_m = blk_m_iters * block_size_m; + static constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; + static constexpr uint32_t tail_start_offset = tail_start_m * tile_size_n; + static constexpr uint32_t a_tail_start_offset = tail_start_m * a_tile_size_x; static_assert( tile_size_m == matA_t::tile_size_y, @@ -244,70 +236,30 @@ struct tile_mma_t< "matAcc tile_size_k should be a multiple of block_size_k"); static constexpr int32_t num_block_n = matDst_t::num_block_x; - static constexpr int32_t num_block_m = matDst_t::num_block_y; static constexpr int32_t num_block_k = tile_size_k / block_size_k; - static constexpr int32_t mma_m = - register_attr::acc_reg_in_bytes / (block_size_n * sizeof(dtype_dst)); + static constexpr auto b_reg_sizes = b_block_size_y * b_tile_size_x; + static constexpr auto tile_n_elems = num_block_n * block_elems; + static constexpr auto a_tile_k_elems = num_block_k * a_block_elems; + static constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; + static constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; + static constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; - template - __XETLA_API static void mma_core( - xetla_vector_ref __REF__ dst, - xetla_vector_ref __REF__ src, - xetla_vector_ref __REF__ b_block, - xetla_vector_ref __REF__ a_block) { - auto dst_blk_2d = dst.xetla_format(); - auto src_blk_2d = src.xetla_format(); - auto b_blk_2d = b_block.xetla_format(); -#pragma unroll - for (uint32_t i = 0; i < blk_m / mma_m; i++) { - xetla_vector dst_tmp; - auto dst_tmp_2d = dst_tmp.xetla_format(); -#pragma unroll - for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - dst_tmp_2d.row(i_acc) = a_block[i_acc + i * mma_m] * b_blk_2d.row(0) + - src_blk_2d.row(i_acc + i * mma_m); - } -#pragma unroll - for (uint32_t k = 1; k < blk_k - 1; k++) { - for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - int a_offset = k * blk_m + i_acc + i * mma_m; - dst_tmp_2d.row(i_acc) += a_block[a_offset] * b_blk_2d.row(k); - } - } - for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m; - dst_blk_2d.row(i_acc + i * mma_m) = - a_block[a_offset] * b_blk_2d.row(blk_k - 1) + dst_tmp_2d.row(i_acc); - } - SW_BARRIER(); - } + using mma_attr = mma_attr_t; + static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; - if constexpr ((blk_m % mma_m) != 0) { - constexpr uint32_t tail_start_m = blk_m / mma_m * mma_m; - constexpr uint32_t tail_m = blk_m % mma_m; - xetla_vector dst_tmp; - auto dst_tmp_2d = dst_tmp.xetla_format(); -#pragma unroll - for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { - dst_tmp_2d.row(i_acc) = - a_block[i_acc + tail_start_m] * b_blk_2d.row(0) + - src_blk_2d.row(i_acc + tail_start_m); - } -#pragma unroll - for (uint32_t k = 1; k < blk_k - 1; k++) { - for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { - int a_offset = k * blk_m + i_acc + tail_start_m; - dst_tmp_2d.row(i_acc) += a_block[a_offset] * b_blk_2d.row(k); - } - } - for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { - int a_offset = (blk_k - 1) * blk_m + i_acc + tail_start_m; - dst_blk_2d.row(i_acc + tail_start_m) = - a_block[a_offset] * b_blk_2d.row(blk_k - 1) + dst_tmp_2d.row(i_acc); - } - } - } + using blk_mma = blk_mma_t< + dtype_dst, + dtype_src, + dtype_src, + dtype_a, + block_size_m, + block_size_n, + block_size_k, + mma_m, + matA_t::register_layout, + matB_t::register_layout, + mma_engine::fpu>; __XETLA_API static void mma( matDst_t& dst, @@ -315,84 +267,105 @@ struct tile_mma_t< matB_t& b, matA_t& a) { { // k_blk=0 - auto b_reg = b.reg.xetla_select(0); + auto b_reg = xetla_cvt( + b.reg.xetla_select(0)); + if constexpr (blk_m_iters >= 1) { #pragma unroll - for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { - auto a_block = a.reg.xetla_select( - i * num_block_k * a_block_elems); + for (uint32_t i = 0; i < blk_m_iters; i++) { + auto a_block = + a.reg.xetla_select(i * a_tile_k_elems); + auto i_start_off = i * tile_n_elems; #pragma unroll - for (uint32_t j = 0; j < num_block_n; j++) { - auto b_block = - b_reg.xetla_select(j * b_block_elems); - auto src_block = src.reg.xetla_select( - (i * num_block_n + j) * block_elems); - auto dst_block = dst.reg.xetla_select( - (i * num_block_n + j) * block_elems); - mma_core( - dst_block, src_block, b_block, a_block); + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = + b_reg.xetla_select(j * b_block_elems); + auto src_block = src.reg.xetla_select( + i_start_off + j * block_elems); + auto dst_block = dst.reg.xetla_select( + i_start_off + j * block_elems); + blk_mma::mma_core(dst_block, src_block, b_block, a_block); + } } } // process the tail - if constexpr ((tile_size_m % block_size_m) != 0) { - constexpr uint32_t tail_start_m = - tile_size_m / block_size_m * block_size_m; - constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; - constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; - constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; - constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; - auto a_block = a.reg.xetla_select( - a_tile_size_x * tail_start_m); + if constexpr (tail_size_m != 0) { + using tail_blk_mma = blk_mma_t< + dtype_dst, + dtype_src, + dtype_src, + dtype_a, + tail_size_m, + block_size_n, + block_size_k, + mma_m, + matA_t::register_layout, + matB_t::register_layout, + mma_engine::fpu>; + + auto a_block = + a.reg.xetla_select(a_tail_start_offset); #pragma unroll for (uint32_t j = 0; j < num_block_n; j++) { auto b_block = b_reg.xetla_select(j * b_block_elems); auto src_block = src.reg.xetla_select( - (tail_start_m * tile_size_n) + j * acc_tail_blk_elems); + tail_start_offset + j * acc_tail_blk_elems); auto dst_block = dst.reg.xetla_select( - (tail_start_m * tile_size_n) + j * acc_tail_blk_elems); - mma_core( - dst_block, src_block, b_block, a_block); + tail_start_offset + j * acc_tail_blk_elems); + tail_blk_mma::mma_core(dst_block, src_block, b_block, a_block); } } } // different K block #pragma unroll for (uint32_t k_i = 1; k_i < num_block_k; k_i++) { - auto b_reg = b.reg.xetla_select( - k_i * b_block_size_y * b_tile_size_x); + xetla_vector b_reg = + xetla_cvt( + b.reg.xetla_select( + k_i * b_block_size_y * b_tile_size_x)); + + if constexpr (blk_m_iters >= 1) { #pragma unroll - for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { - auto a_block = a.reg.xetla_select( - (i * num_block_k + k_i) * a_block_elems); + for (uint32_t i = 0; i < blk_m_iters; i++) { + auto a_block = a.reg.xetla_select( + i * a_tile_k_elems + k_i * a_block_elems); + auto i_start_off = i * tile_n_elems; #pragma unroll - for (uint32_t j = 0; j < num_block_n; j++) { - auto b_block = - b_reg.xetla_select(j * b_block_elems); - auto dst_block = dst.reg.xetla_select( - (i * num_block_n + j) * block_elems); - mma_core( - dst_block, dst_block, b_block, a_block); + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = + b_reg.xetla_select(j * b_block_elems); + auto dst_block = dst.reg.xetla_select( + i_start_off + j * block_elems); + blk_mma::mma_core(dst_block, dst_block, b_block, a_block); + } } } + // process the tail - if constexpr ((tile_size_m % block_size_m) != 0) { - constexpr uint32_t tail_start_m = - tile_size_m / block_size_m * block_size_m; - constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; - constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; - constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; - constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; + if constexpr (tail_size_m != 0) { + using tail_blk_mma = blk_mma_t< + dtype_dst, + dtype_src, + dtype_src, + dtype_a, + tail_size_m, + block_size_n, + block_size_k, + mma_m, + matA_t::register_layout, + matB_t::register_layout, + mma_engine::fpu>; + auto a_block = a.reg.xetla_select( - a_tile_size_x * tail_start_m + k_i * a_tail_blk_elems); + a_tail_start_offset + k_i * a_tail_blk_elems); #pragma unroll for (uint32_t j = 0; j < num_block_n; j++) { auto b_block = b_reg.xetla_select(j * b_block_elems); auto dst_block = dst.reg.xetla_select( - (tail_start_m * tile_size_n) + j * acc_tail_blk_elems); - mma_core( - dst_block, dst_block, b_block, a_block); + tail_start_offset + j * acc_tail_blk_elems); + tail_blk_mma::mma_core(dst_block, dst_block, b_block, a_block); } } } diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 6f9a80b4a..fe5cd194d 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -105,13 +105,11 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr bool mem_transform = payload_t::mem_transform; - using load_store_attr = typename arch_attr_t< - arch_tag>::template load_store_attr; + using load_store_attr = load_store_attr_t; static constexpr uint32_t elems_per_CL = load_store_attr::cache_line_size_in_bytes / sizeof(dtype); static constexpr uint32_t elems_per_reg = - arch_attr_t::template register_attr<>::reg_in_bytes / - sizeof(dtype); + register_bytes_t::reg_in_bytes / sizeof(dtype); static constexpr int32_t max_load_block_height = load_store_attr::max_load_height_in_elem; static constexpr int32_t max_block_width = @@ -396,11 +394,12 @@ tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; static constexpr uint32_t load_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; + static constexpr uint32_t power2_block_elems = + detail::getNextPowerOf2(); using load_store_attr = load_store_attr_t; - static constexpr uint32_t max_load_vec_len = std::min( - uint32_t(tile_t::block_elems * sizeof(dtype)), - load_store_attr::max_load_vec_len); + static constexpr uint32_t max_load_vec_len = + std::min(power2_block_elems, load_store_attr::max_aligned_load_vec_len); static constexpr uint32_t max_load_vec_elems = max_load_vec_len / sizeof(dtype); @@ -459,6 +458,7 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t load_elems = num_channel * payload_t::simd_exec_size; constexpr uint32_t pack_factor = payload_t::pack_factor; + auto channel_offset = payload.channel_offset + payload.base_offset; #pragma unroll for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { uint32_t offset_y = i * tile_desc::block_size_y; @@ -503,9 +503,7 @@ tile_load(tile_t& tile, payload_t& payload) { L1, L2, num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - pred); + payload.base_ptr, channel_offset + address_offset, pred); if constexpr ( payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { @@ -646,26 +644,32 @@ tile_load( constexpr bool oob_check = std::is_same::value; using dtype = typename payload_t::dtype; - using tile_desc = typename payload_t::tile_desc; using load_dtype = typename payload_t::mem_dtype; constexpr uint32_t num_channel_y = payload_t::num_channel_y; constexpr uint32_t load_elems = num_channel_y * payload_t::num_channel_x; constexpr uint32_t scale_factor = payload_t::scale_factor; + using tile_desc = typename tile_t::tile_desc; + static constexpr uint32_t block_elems = tile_desc::block_elems; + static constexpr uint32_t block_size_x = tile_desc::block_size_x; + static constexpr uint32_t num_block_x = tile_desc::num_block_x; + static constexpr uint32_t block_size_y = tile_desc::block_size_y; + static constexpr uint32_t num_block_y = tile_desc::num_block_y; + + auto channel_offset = payload.channel_offset + payload.base_offset; #pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; - i++) { - uint32_t offset_y = i * tile_desc::block_size_y; + for (uint32_t i = 0; i < num_block_y; i++) { + uint32_t offset_y = i * block_size_y; #pragma unroll - for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { - uint32_t offset_x = j * tile_desc::block_size_x; - auto reg_sub = tile.reg.xetla_select( - (i * tile_desc::num_block_x + j) * tile_desc::block_elems); + for (uint32_t j = 0; j < num_block_x; j++) { + uint32_t offset_x = j * block_size_x; + auto reg_sub = tile.reg.xetla_select( + (i * num_block_x + j) * block_elems); xetla_mask pred_x = oob_check ? payload.step_x + payload.base_x + offset_x < payload.width_in_elems : 1; #pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; + for (uint32_t sub_block_y = 0; sub_block_y < block_size_y; sub_block_y += num_channel_y) { xetla_vector reg_tmp; xetla_mask pred_y = oob_check @@ -673,7 +677,7 @@ tile_load( payload.height_in_elems : 1; - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + sub_block_y) * sizeof(dtype) : offset_x * sizeof(dtype) + @@ -687,27 +691,27 @@ tile_load( L2, load_elems>( payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, + channel_offset + address_offset, pred_x && pred_y); reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y); reg_sub .xetla_select( - sub_block_y * tile_desc::block_size_x) + sub_block_y * block_size_x) .xetla_format() = reg_tmp; } } } // process the tail - if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) { + if constexpr (tile_desc::remained_size_y != 0) { constexpr uint32_t remained_size_y = tile_desc::remained_size_y; constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y; constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x; constexpr uint32_t remain_block_elems = remained_size_y * tile_desc::block_size_x; #pragma unroll - for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { - uint32_t offset_x = j * tile_desc::block_size_x; + for (uint32_t j = 0; j < num_block_x; j++) { + uint32_t offset_x = j * block_size_x; auto reg_sub = tile.reg.xetla_select( processed_elems + j * remain_block_elems); xetla_mask pred_x = oob_check @@ -722,7 +726,7 @@ tile_load( payload.height_in_elems : 1; - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + sub_block_y) * sizeof(dtype) : offset_x * sizeof(dtype) + @@ -736,7 +740,7 @@ tile_load( L2, load_elems>( payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, + channel_offset + address_offset, pred_x && pred_y); reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y); @@ -749,6 +753,11 @@ tile_load( } } + if constexpr (payload_t::reg_transpose) { + SW_BARRIER(); + tile_transpose(tile); + } + if constexpr (payload_t::mem_transform) { SW_BARRIER(); vnni_convert(tile); @@ -868,7 +877,7 @@ tile_load(tile_t& tile, payload_t& payload) { using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_vec_len = - load_store_attr::max_load_vec_len; + load_store_attr::max_aligned_load_vec_len; static constexpr uint32_t max_load_vec_elems = max_load_vec_len / sizeof(dtype); diff --git a/include/subgroup/tile/impl/mma_xe.hpp b/include/subgroup/tile/impl/mma_xe.hpp index 371040397..37eb392a1 100644 --- a/include/subgroup/tile/impl/mma_xe.hpp +++ b/include/subgroup/tile/impl/mma_xe.hpp @@ -105,7 +105,7 @@ struct tile_mma_t< static constexpr int32_t num_block_mma_b = b_block_size_y / block_size_k; static constexpr uint32_t b_block_mma_elems = b_block_elems / num_block_mma_b; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; static constexpr int32_t mma_k = mma_attr::mma_k_in_bytes / sizeof(uint32_t); @@ -120,74 +120,78 @@ struct tile_mma_t< matA_t& a) { constexpr int32_t a_mma_elems = mma_m * a_block_size_x; constexpr int32_t c_mma_elems = mma_m * block_size_n; + constexpr uint32_t blk_m_iters = tile_size_m / block_size_m; + constexpr uint32_t tail_block_size_m = tile_size_m % block_size_m; #pragma unroll for (uint32_t j = 0; j < num_block_n; j++) { + if constexpr (blk_m_iters > 0) { #pragma unroll - for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { - auto src_block = src.reg.xetla_select( - (i * num_block_n + j) * block_elems); - auto dst_block = dst.reg.xetla_select( - (i * num_block_n + j) * block_elems); + for (uint32_t i = 0; i < blk_m_iters; i++) { + auto src_block = src.reg.xetla_select( + (i * num_block_n + j) * block_elems); + auto dst_block = dst.reg.xetla_select( + (i * num_block_n + j) * block_elems); #pragma unroll - for (uint32_t mma_i = 0; mma_i < block_size_m / mma_m; mma_i++) { - auto src_sub_blk = - src_block.xetla_select(mma_i * c_mma_elems); - auto dst_sub_blk = - dst_block.xetla_select(mma_i * c_mma_elems); - { // k=0 - auto a_block = a.reg.xetla_select( - (i * num_block_k) * a_block_elems); - auto a_sub_blk = - a_block.xetla_select(mma_i * a_mma_elems); - auto b_blk = - b.reg.xetla_select(j * b_block_elems); - auto b_sub_blk = b_blk.xetla_select(0); - dst_sub_blk = xetla_mma< - gpu::xetla::detail::mma_argument_type(), - gpu::xetla::detail::mma_argument_type(), - mma_k, - mma_m, - dtype_src, - dtype_b, - dtype_a, - c_mma_elems, - b_block_mma_elems, - a_mma_elems>(src_sub_blk, b_sub_blk, a_sub_blk); - } + for (uint32_t mma_i = 0; mma_i < block_size_m / mma_m; mma_i++) { + auto src_sub_blk = + src_block.xetla_select(mma_i * c_mma_elems); + auto dst_sub_blk = + dst_block.xetla_select(mma_i * c_mma_elems); + + { // k=0 + auto a_block = a.reg.xetla_select( + (i * num_block_k) * a_block_elems); + auto a_sub_blk = + a_block.xetla_select(mma_i * a_mma_elems); + auto b_blk = + b.reg.xetla_select(j * b_block_elems); + auto b_sub_blk = b_blk.xetla_select(0); + dst_sub_blk = xetla_mma< + gpu::xetla::detail::mma_argument_type(), + gpu::xetla::detail::mma_argument_type(), + mma_k, + mma_m, + dtype_src, + dtype_b, + dtype_a, + c_mma_elems, + b_block_mma_elems, + a_mma_elems>(src_sub_blk, b_sub_blk, a_sub_blk); + } #pragma unroll - for (uint32_t k = 1; k < num_block_k; k++) { - auto a_block = a.reg.xetla_select( - (i * num_block_k + k) * a_block_elems); - auto a_sub_blk = - a_block.xetla_select(mma_i * a_mma_elems); - int inter_k_b = k / num_block_mma_b; - int inner_k_b = k % num_block_mma_b; - auto b_blk = b.reg.xetla_select( - (j + inter_k_b * num_block_n) * b_block_elems); - auto b_sub_blk = b_blk.xetla_select( - inner_k_b * b_block_mma_elems); - dst_sub_blk = xetla_mma< - gpu::xetla::detail::mma_argument_type(), - gpu::xetla::detail::mma_argument_type(), - mma_k, - mma_m, - dtype_src, - dtype_b, - dtype_a, - c_mma_elems, - b_block_mma_elems, - a_mma_elems>(dst_sub_blk, b_sub_blk, a_sub_blk); + for (uint32_t k = 1; k < num_block_k; k++) { + auto a_block = a.reg.xetla_select( + (i * num_block_k + k) * a_block_elems); + auto a_sub_blk = + a_block.xetla_select(mma_i * a_mma_elems); + int inter_k_b = k / num_block_mma_b; + int inner_k_b = k % num_block_mma_b; + auto b_blk = b.reg.xetla_select( + (j + inter_k_b * num_block_n) * b_block_elems); + auto b_sub_blk = b_blk.xetla_select( + inner_k_b * b_block_mma_elems); + dst_sub_blk = xetla_mma< + gpu::xetla::detail::mma_argument_type(), + gpu::xetla::detail::mma_argument_type(), + mma_k, + mma_m, + dtype_src, + dtype_b, + dtype_a, + c_mma_elems, + b_block_mma_elems, + a_mma_elems>(dst_sub_blk, b_sub_blk, a_sub_blk); + } } } } - if constexpr ((tile_size_m % block_size_m) != 0) { - constexpr uint32_t tail_block_size_m = tile_size_m % block_size_m; + + if constexpr (tail_block_size_m != 0) { constexpr uint32_t tail_block_elems = block_size_n * tail_block_size_m; constexpr uint32_t a_tail_block_elems = tail_block_size_m * a_block_size_x; - constexpr uint32_t tail_m_start = - tile_size_m / block_size_m * block_size_m; + constexpr uint32_t tail_m_start = blk_m_iters * block_size_m; constexpr uint32_t tail_elems_start = tail_m_start * tile_size_n; constexpr uint32_t a_tail_elems_start = tail_m_start * a_tile_size_x; auto src_block = src.reg.xetla_select( diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 8c43b5a32..44d2f6569 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -197,6 +197,7 @@ __XETLA_API native_type_t, remain_move_rows, remain_move_cols>(); +#pragma unroll for (int vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { reg_dst_2d.xetla_select( 0, vnni_i) = diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index e6843561f..2e234a4aa 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -44,7 +44,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_2d, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::XeHpc)>> { + std::enable_if_t>> { using tile_desc = tile_desc_; using mem_desc_t = mem_desc_t; @@ -406,7 +406,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_1d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = native_type_t; @@ -560,7 +560,7 @@ struct mem_payload_t< tile_desc_, msg_type::atomic_add, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t= sizeof(uint16_t)>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -742,7 +742,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_1d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -873,7 +873,7 @@ struct mem_payload_t< tile_desc_, msg_type::unaligned_2d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -960,7 +960,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -986,7 +986,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -1006,7 +1006,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -1032,7 +1032,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -1100,7 +1100,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_2d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpg)>> { + std::enable_if_t>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -1158,17 +1158,20 @@ struct mem_payload_t< // for pvc, we can use simd16 or simd32 using load_store_attr = load_store_attr_t; - static constexpr uint32_t max_bytes = - std::min(load_store_attr::max_load_vec_len, block_bytes); + static constexpr uint32_t max_load_vec_len = (alignment_in_bytes == 8) + ? load_store_attr::max_aligned_load_vec_len + : load_store_attr::max_load_vec_len; + static constexpr uint32_t max_bytes = std::min(max_load_vec_len, block_bytes); static constexpr uint32_t max_channel = max_bytes / (simd_exec_size * sizeof(mem_dtype)); static constexpr uint32_t select_channel(const uint32_t channel) { - return (channel >= 32 && arch_tag == gpu_arch::XeHpc) ? 32 - : channel >= 16 ? 16 - : channel >= 8 ? 8 - : 1; + return (channel >= load_store_attr::max_channel_num) + ? load_store_attr::max_channel_num + : channel >= 16 ? 16 + : channel >= 8 ? 8 + : 1; } static constexpr uint32_t num_channel = select_channel( @@ -1325,7 +1328,7 @@ struct mem_payload_t< tile_desc_, msg_type::scatter, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -1504,7 +1507,7 @@ struct mem_payload_t< reg_layout::vnni_tiled_col_major>, msg_type::scatter, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_t = mem_desc_t; using dtype = dtype_; @@ -1656,12 +1659,11 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t< - arch_tag_ <= gpu_arch::XeHpg && - (((block_size_y_ != 1 || tile_size_y_ != 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ != 1 || tile_size_x_ != 1) && - mem_layout_ == mem_layout::col_major))>> { + std::enable_if_t<(!arch_has_2d_load_store)&&( + ((block_size_y_ != 1 || tile_size_y_ != 1) && + mem_layout_ == mem_layout::row_major) || + ((block_size_x_ != 1 || tile_size_x_ != 1) && + mem_layout_ == mem_layout::col_major))>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -1841,6 +1843,26 @@ struct prefetch_payload_t< xetla_vector_gen(0, 1); channel_offset = channel_index * pitch_in_bytes; } + + inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { + uint32_t coop_id_x = coop_id % num_coop_sg_w; + uint32_t coop_id_y = coop_id / num_coop_sg_w; + + pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype); + base_x = mem_desc.coord.x + coop_id_x * tile_size_w; + base_y = mem_desc.coord.y + coop_id_y * tile_size_h; + width_in_elems = mem_desc.shape.x; + height_in_elems = mem_desc.shape.y; + base_offset = mem_transpose + ? base_x * pitch_in_bytes + base_y * sizeof(dtype) + : base_y * pitch_in_bytes + base_x * sizeof(dtype); + base_ptr = reinterpret_cast(mem_desc.base.base); + + xetla_vector channel_index = + xetla_vector_gen(0, 1); + channel_offset = channel_index * pitch_in_bytes; + } + // Be aware of the risks: Rule of three (copy constructor, copy // assignment, destructor) Please check if you need to add self-define // destructor ~prefetch_payload_t(){} @@ -1884,10 +1906,9 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t< - (arch_tag_ == gpu_arch::XeHpc) && - (((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || - ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { + std::enable_if_t<(arch_has_2d_load_store)&&( + ((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1990,6 +2011,17 @@ struct prefetch_payload_t< prepare_tdesc(base_tdesc); } + inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { + xetla_tdescriptor base_tdesc = mem_desc.get_tdesc(); + uint32_t coop_id_x = coop_id % num_coop_sg_w; + uint32_t coop_id_y = coop_id / num_coop_sg_w; + xetla_update_tdesc_offsetx( + base_tdesc.xetla_format(), coop_id_x * tile_size_w); + xetla_update_tdesc_offsety( + base_tdesc.xetla_format(), coop_id_y * tile_size_h); + prepare_tdesc(base_tdesc); + } + inline void init(xetla_tdescriptor base_tdesc, uint32_t coop_id = 0) { uint32_t coop_id_x = coop_id % num_coop_sg_w; uint32_t coop_id_y = coop_id / num_coop_sg_w; @@ -2258,6 +2290,16 @@ struct prefetch_payload_t< base_ptr = (prefetch_dtype*)p + (coop_id % num_coop_sg) * mem_tile_size_x; } + inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { + pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype); + uint32_t offset_x = mem_desc.coord.x; + uint32_t offset_y = mem_desc.coord.y; + base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + uint64_t ptr_temp = (uint64_t)mem_desc.base.base; + base_ptr = + (prefetch_dtype*)ptr_temp + (coop_id % num_coop_sg) * mem_tile_size_x; + } + template __XETLA_API void update_tdesc(int offset) { if constexpr (update_dir == tdesc_update_dir::x_dir) { @@ -2286,7 +2328,7 @@ struct prefetch_payload_t< tile_desc_, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2295,6 +2337,8 @@ struct prefetch_payload_t< static constexpr mem_layout memory_layout = mem_layout_; static constexpr gpu_arch arch_tag = arch_tag_; + inline prefetch_payload_t() = default; + inline prefetch_payload_t( [[maybe_unused]] mem_desc_t& mem_desc, [[maybe_unused]] uint32_t coop_id = 0) {} @@ -2308,6 +2352,10 @@ struct prefetch_payload_t< [[maybe_unused]] int surface_offset_y, [[maybe_unused]] uint32_t coop_id = 0) {} + inline void init( + [[maybe_unused]] mem_desc_t& mem_desc, + [[maybe_unused]] uint32_t coop_id = 0) {} + template __XETLA_API void update_tdesc([[maybe_unused]] int offset) {} }; diff --git a/include/subgroup/tile/impl/quant_op_functor.hpp b/include/subgroup/tile/impl/quant_op_functor.hpp index ffb3dab54..ce10dbad7 100644 --- a/include/subgroup/tile/impl/quant_op_functor.hpp +++ b/include/subgroup/tile/impl/quant_op_functor.hpp @@ -40,7 +40,7 @@ template struct dequant_op_t< tile_op_t_, arch_tag, - std::enable_if_t<(arch_tag == gpu_arch::XeHpc)>> { + std::enable_if_t>> { // may need to add some limitations to tile_op used in dequant_op using tile_op_t = tile_op_t_; struct arguments_t { @@ -72,7 +72,7 @@ template struct quant_op_t< tile_op_t_, arch_tag, - std::enable_if_t<(arch_tag == gpu_arch::XeHpc)>> { + std::enable_if_t>> { // may need to add some limitations to tile_op used in dequant_op using tile_op_t = tile_op_t_; diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index d5345b06b..d5ba9e34a 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -352,9 +352,10 @@ tile_store( constexpr uint32_t num_channel_y = payload_t::num_channel_y; constexpr uint32_t store_elems = num_channel_y * payload_t::num_channel_x; constexpr uint32_t scale_factor = payload_t::scale_factor; + + auto channel_offset = payload.channel_offset + payload.base_offset; #pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; - i++) { + for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { uint32_t offset_y = i * tile_desc::block_size_y; #pragma unroll for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { @@ -382,7 +383,7 @@ tile_store( L3, store_elems>( payload.base_ptr, - (payload.base_offset + address_offset + payload.channel_offset), + (address_offset + channel_offset), reg_sub .xetla_select( sub_block_y * tile_desc::block_size_x) @@ -392,7 +393,7 @@ tile_store( } } // process the tail - if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) { + if constexpr (tile_desc::remained_size_y != 0) { constexpr uint32_t remained_size_y = tile_desc::remained_size_y; constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y; constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x; @@ -424,7 +425,7 @@ tile_store( L3, store_elems>( payload.base_ptr, - (payload.base_offset + address_offset + payload.channel_offset), + (address_offset + channel_offset), reg_sub .xetla_select( sub_block_y * tile_desc::block_size_x) @@ -466,6 +467,7 @@ tile_store(tile_t& tile, payload_t& payload) { constexpr uint32_t store_elems = num_channel * payload_t::simd_exec_size; constexpr uint32_t pack_factor = payload_t::pack_factor; + auto channel_offset = payload.channel_offset + payload.base_offset; #pragma unroll for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; i++) { @@ -521,7 +523,7 @@ tile_store(tile_t& tile, payload_t& payload) { L3, num_channel>( payload.base_ptr, - (payload.base_offset + address_offset + payload.channel_offset), + (address_offset + channel_offset), reg_tmp, pred_y); } else if ( @@ -533,9 +535,7 @@ tile_store(tile_t& tile, payload_t& payload) { L1, L3, num_channel>( - payload.base_ptr, - (payload.base_offset + address_offset + payload.channel_offset), - reg_tmp); + payload.base_ptr, (address_offset + channel_offset), reg_tmp); } else { break; } @@ -719,19 +719,34 @@ tile_store( uint64_t address_offset = offset_x * sizeof(dtype) + (sub_block_y + offset_y) * payload.pitch_in_bytes; - xetla_tatomic_store_global< - dtype, - payload_t::num_channel, - L1, - L2, - op_kind, - payload_t::arch_tag, - typename payload_t::Toffset>( - (uint64_t)payload.base_pointer + address_offset, - payload.channel_offset, - reg_sub.xetla_select( - sub_block_y * block_size_x), - pred_x & pred_y); + if constexpr (arch_has_2d_load_store) { + xetla_tatomic_store_global< + dtype, + payload_t::num_channel, + L1, + L2, + op_kind, + payload_t::arch_tag, + typename payload_t::Toffset>( + (uint64_t)payload.base_pointer + address_offset, + payload.channel_offset, + reg_sub.xetla_select( + sub_block_y * block_size_x), + pred_x & pred_y); + } else { + xetla_atomic_global< + op_kind, + dtype, + payload_t::num_channel, + data_size::default_size, + L1, + L2>( + reinterpret_cast(payload.base_pointer + address_offset), + payload.channel_offset, + reg_sub.xetla_select( + sub_block_y * block_size_x), + pred_x & pred_y); + } } } } diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index f5a41e931..6f2ffdc37 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -331,7 +331,7 @@ template struct gelu_fwd_w_op_t< dtype_out_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_out = dtype_out_; using mem_desc_w_t = mem_desc_t; @@ -463,7 +463,7 @@ template struct gelu_bwd_op_t< dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_x_t = mem_desc_t; @@ -570,7 +570,7 @@ template struct bias_add_op_t< mem_desc_bias_t_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using mem_desc_bias_t = mem_desc_bias_t_; using dtype_bias = typename mem_desc_bias_t::dtype; using shape_t = typename mem_desc_bias_t::shape_t; @@ -674,7 +674,7 @@ struct scale_v_offset_v_op_t< scale_dtype_, offset_dtype_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using scale_dtype = scale_dtype_; using offset_dtype = offset_dtype_; @@ -809,7 +809,7 @@ template struct scale_v_op_t< scale_dtype_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using scale_dtype = scale_dtype_; using scale_mem_desc_t = @@ -916,7 +916,7 @@ struct elemwise_reduce_op_t< reduce_kind_, dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -1034,7 +1034,7 @@ struct elemwise_reduce_op_t< template < reduce_op reduce_kind, typename dtype_in, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, class enable = void> struct elemwise_reduce_op_stream_k_t {}; /// @brief Is the element-wise reduce op functor, specialized for Xe @@ -1044,7 +1044,7 @@ struct elemwise_reduce_op_stream_k_t< reduce_kind_, dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -1156,7 +1156,7 @@ template struct dropout_op_t< dtype_mask_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1243,7 +1243,7 @@ template struct rng_dropout_op_t< dtype_mask_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1357,7 +1357,7 @@ template struct scalar_mul_op_t< dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -1396,7 +1396,7 @@ template struct linear_op_t< dtype_in_, arch_tag, - std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { + std::enable_if_t>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; diff --git a/tests/integration/gemm/bf16/common.hpp b/tests/integration/gemm/bf16/common.hpp index 8110d6d16..6ddf014ff 100644 --- a/tests/integration/gemm/bf16/common.hpp +++ b/tests/integration/gemm/bf16/common.hpp @@ -45,11 +45,33 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } - static constexpr mma_engine engine = mma_engine::xmx; + static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; + //static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t wg_num_n = 64; +}; + +class TestBaseBF16x : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::xmx; }; -class Test0 : public TestBase { +class TestBaseBF16f : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class Test0x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -59,17 +81,11 @@ class Test0 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test1 : public TestBase { +class Test1x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -79,17 +95,11 @@ class Test1 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test2 : public TestBase { +class Test2x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -99,17 +109,11 @@ class Test2 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test3 : public TestBase { +class Test3x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -119,17 +123,11 @@ class Test3 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test4 : public TestBase { +class Test4x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -139,8 +137,6 @@ class Test4 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -148,18 +144,22 @@ class Test4 : public TestBase { using data_type_c = float; using data_type_acc = float; }; -class Test5 : public TestBase { + +class Test5x : public TestBaseBF16x { public: static constexpr size_t mat_m = 192; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 48; - static constexpr size_t wg_n = 80; - static constexpr size_t sg_m = 24; + // static constexpr size_t wg_m = 48; + // If want ot allow any kind of wg_m and wg_n instead of the power of 2 + // DG2 still need check workgroup oob on both direction by using block_1d load + static constexpr size_t wg_m = 64; + // static constexpr size_t wg_n = 80; + static constexpr size_t wg_n = 128; + // static constexpr size_t sg_m = 24; + static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -168,18 +168,18 @@ class Test5 : public TestBase { using data_type_acc = float; }; -class Test6 : public TestBase { +class Test6x : public TestBaseBF16x { public: static constexpr size_t mat_m = 96; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 40; + // static constexpr size_t wg_m = 40; + static constexpr size_t wg_m = 64; static constexpr size_t wg_n = 256; - static constexpr size_t sg_m = 24; + // static constexpr size_t sg_m = 24; + static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -187,7 +187,8 @@ class Test6 : public TestBase { using data_type_c = float; using data_type_acc = float; }; -class Test7 : public TestBase { + +class Test7x : public TestBaseBF16x { public: static constexpr size_t mat_m = 80; static constexpr size_t mat_n = 256; @@ -197,17 +198,11 @@ class Test7 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = float; - using data_type_acc = float; }; -class Test8 : public TestBase { +class Test8x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -217,7 +212,6 @@ class Test8 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; static constexpr uint32_t global_kslicing = 2; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; @@ -227,7 +221,7 @@ class Test8 : public TestBase { using data_type_acc = float; }; -class Test9 : public TestBase { +class Test9x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -237,7 +231,10 @@ class Test9 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 2; + // static constexpr uint32_t local_kslicing = 2; + // Look like local_kslicing will fail on DG2 + static constexpr uint32_t local_kslicing = 1; + // global_kslicing work for aligned case on DG2 static constexpr uint32_t global_kslicing = 4; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; @@ -247,7 +244,7 @@ class Test9 : public TestBase { using data_type_acc = float; }; -class Test10 : public TestBase { +class Test10x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -257,8 +254,8 @@ class Test10 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 4; - static constexpr uint32_t global_kslicing = 1; + // static constexpr uint32_t local_kslicing = 4; + static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -267,7 +264,7 @@ class Test10 : public TestBase { using data_type_acc = float; }; -class Test11 : public TestBase { +class Test11x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -278,13 +275,108 @@ class Test11 : public TestBase { static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; static constexpr uint32_t local_kslicing = 16; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; +}; + +class Test12x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 13824; + static constexpr size_t mat_k = 5120; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test13x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 8192; + static constexpr size_t mat_n = 8192; + static constexpr size_t mat_k = 8192; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test14x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test15x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 28672; + static constexpr size_t mat_k = 8192; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test16x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 12288; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 8; // wg_m = 4 will fail on DG2 + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test17x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 3072; + static constexpr size_t mat_n = 3072; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test18x : public TestBaseBF16x { // Get better perf on DG2, ~15.48 TFlops + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 512; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; }; template @@ -325,4 +417,6 @@ using bf16_gemm_func = bf16_gemm_test_func< Test::layout_b, Test::global_kslicing, Test::local_kslicing, - Test::engine>; + Test::wg_num_n, + Test::engine, + Test::gpu_arch>; diff --git a/tests/integration/gemm/bf16/kernel_func.hpp b/tests/integration/gemm/bf16/kernel_func.hpp index 432172f9a..6345047df 100644 --- a/tests/integration/gemm/bf16/kernel_func.hpp +++ b/tests/integration/gemm/bf16/kernel_func.hpp @@ -37,11 +37,15 @@ template < mem_layout layout_b, uint32_t global_kslicing, uint32_t local_kslicing, - mma_engine engine> + uint32_t wg_num_n, + mma_engine engine, + gpu_arch arch_tag> struct bf16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 8; - static constexpr uint32_t prefetch_distance = 3; + static constexpr uint32_t periodic_sync_interval = + (arch_tag == gpu_arch::XeHpc ? 8 : 0); + static constexpr uint32_t prefetch_distance = + (arch_tag == gpu_arch::XeHpc ? 3 : 0); using gemm_t = typename gemm_selector_t< dtype_a, dtype_b, @@ -55,17 +59,17 @@ struct bf16_gemm_test_func { tile_shape, sg_k, engine, - gpu_arch::XeHpc, + arch_tag, prefetch_distance, periodic_sync_interval>::gemm; using epilogue_t = epilogue_t< - epilogue_policy_default, + epilogue_policy_default, tile_shape, mem_desc_t>; - using group_swizzle = - gpu::xetla::kernel::group_swizzle_default; + using group_swizzle = gpu::xetla::kernel::group_swizzle_default; + // using group_swizzle = kernel::group_swizzle_snake; using dispatch_policy = dispatch_policy_kslicing; diff --git a/tests/integration/gemm/bf16/main.cpp b/tests/integration/gemm/bf16/main.cpp index e383ab205..45ffec38e 100644 --- a/tests/integration/gemm/bf16/main.cpp +++ b/tests/integration/gemm/bf16/main.cpp @@ -33,16 +33,24 @@ TYPED_TEST_P(bf16_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(bf16_gemm_test, esimd); using tests = ::testing::Types< - Test0, - Test1, - Test2, - Test3, - Test4, - Test5, - Test6, - Test7, - Test8, - Test9, - Test10, - Test11>; -INSTANTIATE_TYPED_TEST_SUITE_P(bf16_gemm_test_suite, bf16_gemm_test, tests); \ No newline at end of file + Test0x, + Test1x, + Test2x, + Test3x, + Test4x, + Test5x, + Test6x, + Test7x, + Test8x, + Test9x, + Test10x, + Test11x, + Test12x, + Test13x, + Test14x, + Test15x, + Test16x, + Test17x, + Test18x>; + +INSTANTIATE_TYPED_TEST_SUITE_P(bf16_gemm_test_suite, bf16_gemm_test, tests); diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index a675ab0b9..0432fa349 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -45,12 +45,34 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } - static constexpr mma_engine engine = mma_engine::xmx; + static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; + //static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t wg_num_n = 64; +}; + +class TestBaseFP16f : public TestBase { + public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; }; -class Test0 : public TestBase { +class TestBaseFP16x : public TestBase { public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::xmx; +}; + +class Test0 : public TestBaseFP16f { + public: static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 64; static constexpr size_t mat_k = 8192; @@ -63,112 +85,155 @@ class Test0 : public TestBase { static constexpr uint32_t local_kslicing = 8; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; }; -class Test1 : public TestBase { +class Test0x : public TestBaseFP16x { public: static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; + static constexpr size_t wg_n = 16; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; - static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; + static constexpr mem_layout layout_b = mem_layout::col_major; }; -class Test2 : public TestBase { + +class Test0f : public TestBaseFP16f { public: - static constexpr size_t mat_m = 256; + static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; + static constexpr size_t wg_n = 16; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr size_t sg_k = 32; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; + static constexpr mem_layout layout_b = mem_layout::col_major; }; -class Test3 : public TestBase { + +class Test1f : public TestBaseFP16f { public: - static constexpr size_t mat_m = 256; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 250880; + static constexpr size_t mat_k = 1792; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; + static constexpr size_t wg_n = 2048; static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test2f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096 * 3; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 128; + static constexpr size_t sg_m = 4; static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test2fx1 : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 32; + static constexpr size_t mat_n = 4096 * 3; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 128; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; + // static constexpr uint32_t global_kslicing = 2; //here global_kslicing will + // fail on DG2 + static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test3f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 16384; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 64; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test4f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; }; -class Test4 : public TestBase { +class Test4x : public TestBaseFP16x { public: static constexpr size_t mat_m = 1024; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 16 * 1; - static constexpr size_t wg_n = 32 * 32; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test4x1 : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 16 * 2; + static constexpr size_t wg_n = 32 * 16; static constexpr size_t sg_m = 16; static constexpr size_t sg_n = 32; - static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr size_t sg_k = 16; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::xmx; }; -class Test5 : public TestBase { +class Test5f : public TestBaseFP16f { public: static constexpr size_t mat_m = 1024; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; static constexpr size_t wg_m = 32; - static constexpr size_t wg_n = 32 * 4; - static constexpr size_t sg_m = 1; + static constexpr size_t wg_n = 32 * 8; + static constexpr size_t sg_m = 16; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test6 : public TestBase { + +class Test6f : public TestBaseFP16f { public: static constexpr size_t mat_m = 96; static constexpr size_t mat_n = 256; @@ -178,16 +243,11 @@ class Test6 : public TestBase { static constexpr size_t sg_m = 24; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test7 : public TestBase { + +class Test7f : public TestBaseFP16f { public: static constexpr size_t mat_m = 80; static constexpr size_t mat_n = 256; @@ -197,17 +257,11 @@ class Test7 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test8 : public TestBase { +class Test8f : public TestBaseFP16f { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -218,16 +272,11 @@ class Test8 : public TestBase { static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 2; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test9 : public TestBase { +class Test9f : public TestBaseFP16f { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -238,16 +287,11 @@ class Test9 : public TestBase { static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 4; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test10 : public TestBase { +class Test10f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 4096; @@ -257,17 +301,11 @@ class Test10 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test11 : public TestBase { +class Test11f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 4096; @@ -277,17 +315,11 @@ class Test11 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test12 : public TestBase { +class Test12f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 16384; @@ -297,17 +329,11 @@ class Test12 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test13 : public TestBase { +class Test13f : public TestBaseFP16f { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 4096; @@ -317,17 +343,11 @@ class Test13 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test14 : public TestBase { +class Test14f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 50400; @@ -337,17 +357,11 @@ class Test14 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test15 : public TestBase { +class Test15f : public TestBaseFP16f { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 4096; @@ -357,17 +371,11 @@ class Test15 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test16 : public TestBase { +class Test16x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 50400; @@ -377,17 +385,11 @@ class Test16 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test17 : public TestBase { +class Test17x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 256; @@ -397,18 +399,11 @@ class Test17 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test18 : public TestBase { +class Test18x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 256; @@ -418,18 +413,11 @@ class Test18 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test19 : public TestBase { +class Test19x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 256; @@ -439,15 +427,8 @@ class Test19 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; template @@ -488,5 +469,6 @@ using fp16_gemm_func = fp16_gemm_test_func< Test::layout_b, Test::global_kslicing, Test::local_kslicing, + Test::wg_num_n, Test::engine, Test::gpu_arch>; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index 191213794..30877c12b 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -37,12 +37,15 @@ template < mem_layout layout_b, uint32_t global_kslicing, uint32_t local_kslicing, + uint32_t wg_num_n, mma_engine engine, gpu_arch gpu_arch> struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 1 ; //8; - static constexpr uint32_t prefetch_distance = 3 ;//256 / (sg_k * sizeof(dtype_a)); + static constexpr uint32_t periodic_sync_interval = + (gpu_arch == gpu_arch::XeHpc ? 8 : 0); + static constexpr uint32_t prefetch_distance = + (gpu_arch == gpu_arch::XeHpc ? 3 : 0); using compute_attr = typename std::conditional< (engine == mma_engine::fpu), @@ -70,6 +73,8 @@ struct fp16_gemm_test_func { mem_desc_output_c>; using group_swizzle = gpu::xetla::kernel::group_swizzle_default; + // using group_swizzle = gpu::xetla::kernel::group_swizzle_snake; using dispatch_policy = dispatch_policy_kslicing; diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index eabc8915d..32ef12461 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -32,25 +32,7 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); -using tests = ::testing::Types< - Test0>; - // Test1, - // Test2, - // Test3>; - // Test4, - // Test5, - // Test6, - // Test7, - // Test8, - // Test9, - // Test10, - // Test11, - // Test12, - // Test13, - // Test14, - // Test15, - // Test16, - // Test17, - // Test18, - // Test19>; +using tests = + ::testing::Types; + INSTANTIATE_TYPED_TEST_SUITE_P(fp16_gemm_test_suite, fp16_gemm_test, tests); diff --git a/tests/integration/gemm/fp32/common.hpp b/tests/integration/gemm/fp32/common.hpp index 645253415..7ce5098fc 100644 --- a/tests/integration/gemm/fp32/common.hpp +++ b/tests/integration/gemm/fp32/common.hpp @@ -205,7 +205,7 @@ class Test8 : public TestBase { static constexpr uint32_t global_kslicing = 2; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -248,7 +248,7 @@ class Test10 : public TestBase { static constexpr uint32_t global_kslicing = 2; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -258,9 +258,9 @@ class Test10 : public TestBase { class Test11 : public TestBase { public: static constexpr size_t batch_size = 35; - static constexpr size_t mat_m = 4193; - static constexpr size_t mat_k = 1134; - static constexpr size_t mat_n = 686; + static constexpr size_t mat_m = 4192; + static constexpr size_t mat_k = 1136; + static constexpr size_t mat_n = 688; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; @@ -270,7 +270,6 @@ class Test11 : public TestBase { static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - static constexpr mma_engine engine = mma_engine::xmx; using data_type_a = float; using data_type_b = float; using data_type_c = float; diff --git a/tests/integration/gemm/fp32/main.cpp b/tests/integration/gemm/fp32/main.cpp index a04d82e01..3a9badc9f 100644 --- a/tests/integration/gemm/fp32/main.cpp +++ b/tests/integration/gemm/fp32/main.cpp @@ -36,7 +36,7 @@ REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd); using tests = ::testing::Types< Test1, Test2, - // Test3, // TODO(Yi): Fix this case + Test3, Test4, Test5, Test6, diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index 0597af758..218a8248d 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -259,16 +259,16 @@ class test9_xehpg { using data_type_c = fp16; }; -class test1_xelpg { +class test1_xelpg_1x12288x4096 { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096 * 3; static constexpr size_t mat_k = 4096 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; static constexpr size_t dequant_s = 128; @@ -282,6 +282,79 @@ class test1_xelpg { using data_type_b = int4x2; using data_type_c = fp16; }; + +class test1_xelpg_1x4096x11008 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 11008; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 8; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + +class test1_xelpg_1x4096x4096 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 8; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + +class test1_xelpg_4x4096x4096 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096 * 1; + static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 64; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 8; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + class test2_xelpg { public: // Extract the parameters required by different test cases @@ -1019,7 +1092,11 @@ TYPED_TEST_P(dequantize_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types< + test1_xelpg_1x12288x4096, + test1_xelpg_1x4096x11008, + test1_xelpg_1x4096x4096, + test1_xelpg_4x4096x4096>; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_test_suite, diff --git a/tests/integration/gemm/int8/kernel_func.hpp b/tests/integration/gemm/int8/kernel_func.hpp index 934e82a50..3a9e595ca 100644 --- a/tests/integration/gemm/int8/kernel_func.hpp +++ b/tests/integration/gemm/int8/kernel_func.hpp @@ -49,8 +49,8 @@ struct int8gemm_test_func { layout_b, mem_space::global, mem_space::global, - 8, - 8, + 16, + 16, dtype_acc, tile_shape, sg_k, diff --git a/tests/integration/gemm/int8_quantization/kernel_func.hpp b/tests/integration/gemm/int8_quantization/kernel_func.hpp index ef2ccd219..e708ead73 100644 --- a/tests/integration/gemm/int8_quantization/kernel_func.hpp +++ b/tests/integration/gemm/int8_quantization/kernel_func.hpp @@ -46,8 +46,8 @@ struct igemm_quantize_func { mem_layout_b, mem_space::global, mem_space::global, - 8, - 8, + 16, + 16, dtype_acc, tile_shape, sg_k, diff --git a/tests/integration/gemm/unaligned_bf16/common.hpp b/tests/integration/gemm/unaligned_bf16/common.hpp index d1b45c169..d141f50ad 100644 --- a/tests/integration/gemm/unaligned_bf16/common.hpp +++ b/tests/integration/gemm/unaligned_bf16/common.hpp @@ -45,11 +45,53 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } + + static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + //static constexpr gpu_arch arch_tag = gpu_arch::XeHpg; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t lda_alignment = 1; + static constexpr uint32_t ldb_alignment = 1; + static constexpr uint32_t ldc_alignment = 1; +}; + +class TestBaseBF16x : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::xmx; +}; + +class TestBaseBF16f : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class TestBaseFP16x : public TestBase { + public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; static constexpr mma_engine engine = mma_engine::xmx; - static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; }; -class Test0 : public TestBase { +class TestBaseFP16f : public TestBase { + public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class Test0x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; static constexpr size_t mat_n = 257; @@ -59,96 +101,68 @@ class Test0 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test1 : public TestBase { +class Test1x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; - static constexpr size_t mat_n = 255; - static constexpr size_t mat_k = 255; + static constexpr size_t mat_n = 1023; + static constexpr size_t mat_k = 767; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test2 : public TestBase { +class Test2x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; - static constexpr size_t mat_n = 251; - static constexpr size_t mat_k = 253; + static constexpr size_t mat_n = 1011; + static constexpr size_t mat_k = 511; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test3 : public TestBase { +class Test3x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; - static constexpr size_t mat_n = 251; - static constexpr size_t mat_k = 253; + static constexpr size_t mat_n = 767; + static constexpr size_t mat_k = 1023; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test4 : public TestBase { +class Test4x : public TestBaseBF16x { public: static constexpr size_t mat_m = 257; static constexpr size_t mat_n = 257; - static constexpr size_t mat_k = 259; + static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 16; static constexpr size_t wg_n = 32; static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = float; - using data_type_acc = float; + static constexpr uint32_t lda_alignment = 8; }; -class Test5 : public TestBase { + +class Test5x : public TestBaseBF16x { public: static constexpr size_t mat_m = 191; static constexpr size_t mat_n = 251; @@ -158,17 +172,12 @@ class Test5 : public TestBase { static constexpr size_t sg_m = 24; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; using data_type_c = float; - using data_type_acc = float; }; -class Test6 : public TestBase { +class Test6x : public TestBaseBF16x { public: static constexpr size_t mat_m = 93; static constexpr size_t mat_n = 253; @@ -178,16 +187,12 @@ class Test6 : public TestBase { static constexpr size_t sg_m = 24; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; + static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_c = float; - using data_type_acc = float; }; -class Test7 : public TestBase { + +class Test7x : public TestBaseBF16x { public: static constexpr size_t mat_m = 80; static constexpr size_t mat_n = 251; @@ -197,17 +202,12 @@ class Test7 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; using data_type_c = float; - using data_type_acc = float; }; -class Test8 : public TestBase { +class Test8x : public TestBaseBF16x { public: static constexpr size_t mat_m = 257; static constexpr size_t mat_n = 255; @@ -217,17 +217,14 @@ class Test8 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 2; - static constexpr mem_layout layout_a = mem_layout::row_major; + // static constexpr uint32_t global_kslicing = 2; //will compile fail on DG2 + static constexpr uint32_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; using data_type_c = float; - using data_type_acc = float; }; -class Test9 : public TestBase { +class Test9x : public TestBaseBF16x { public: static constexpr size_t mat_m = 251; static constexpr size_t mat_n = 253; @@ -237,14 +234,332 @@ class Test9 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 2; - static constexpr uint32_t global_kslicing = 4; + // static constexpr uint32_t global_kslicing = 4; //will compile fail on DG2 + static constexpr uint32_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::col_major; +}; + +class Test10x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 253; + static constexpr size_t mat_k = 259; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; +}; + +class Test11x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 1025; + static constexpr size_t mat_k = 256; + static constexpr size_t wg_m = 8; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; +}; + +class Test12x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4095; + static constexpr size_t mat_n = 4097; + static constexpr size_t mat_k = 4091; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test13x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4095; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test14x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4097; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; +}; + +class Test15x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test16x : public TestBaseBF16x { // Get better perf on DG2 + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 512; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test17x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test18x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 2560; + static constexpr size_t mat_k = 5120; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test19x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 12288; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 8; //DG@ will fail on wg_m = 4 + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test19f : public TestBaseBF16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 12288; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test20f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test20x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test21x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test21f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test22f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test22x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test23f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test23x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; }; template @@ -283,6 +598,10 @@ using unaligned_gemm_func = unaligned_gemm_test_func< Test::sg_k, Test::layout_a, Test::layout_b, + Test::lda_alignment, + Test::ldb_alignment, + Test::ldc_alignment, Test::global_kslicing, Test::local_kslicing, - Test::engine>; + Test::engine, + Test::arch_tag>; diff --git a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp index 780c092e9..d45ddc0b7 100644 --- a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp +++ b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp @@ -35,13 +35,19 @@ template < uint32_t sg_k, mem_layout layout_a, mem_layout layout_b, + uint32_t lda_alignment, + uint32_t ldb_alignment, + uint32_t ldc_alignment, uint32_t global_kslicing, uint32_t local_kslicing, - mma_engine engine> + mma_engine engine, + gpu_arch arch_tag> struct unaligned_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 8; - static constexpr uint32_t prefetch_distance = 3; + static constexpr uint32_t periodic_sync_interval = + (arch_tag == gpu_arch::XeHpc ? 8 : 0); + static constexpr uint32_t prefetch_distance = + (arch_tag == gpu_arch::XeHpc ? 3 : 0); using gemm_t = typename gemm_selector_t< dtype_a, dtype_b, @@ -49,23 +55,22 @@ struct unaligned_gemm_test_func { layout_b, mem_space::global, mem_space::global, - 1, - 1, + lda_alignment, + ldb_alignment, dtype_acc, tile_shape, sg_k, engine, - gpu_arch::XeHpc, + arch_tag, prefetch_distance, periodic_sync_interval>::gemm; using epilogue_t = epilogue_t< - epilogue_policy_unaligned, + epilogue_policy_unaligned, tile_shape, - mem_desc_t>; + mem_desc_t>; - using group_swizzle = - gpu::xetla::kernel::group_swizzle_default; + using group_swizzle = gpu::xetla::kernel::group_swizzle_default; using dispatch_policy = dispatch_policy_kslicing; using gemm_op_t = gemm_universal_t; diff --git a/tests/integration/gemm/unaligned_bf16/main.cpp b/tests/integration/gemm/unaligned_bf16/main.cpp index 3d8778a3c..9ceaf695e 100644 --- a/tests/integration/gemm/unaligned_bf16/main.cpp +++ b/tests/integration/gemm/unaligned_bf16/main.cpp @@ -38,17 +38,37 @@ TYPED_TEST_P(unaligned_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(unaligned_gemm_test, esimd); using tests = ::testing::Types< - Test0, - Test1, - Test2, - Test3, - // Test4, - Test5, - Test6, - Test7, - Test8>; + Test0x, + Test1x, + Test2x, + Test3x, + Test4x, + Test5x, + Test6x, + Test7x, + Test8x, + Test9x, + Test10x, + Test11x, + Test12x, + Test13x, + Test14x, + Test15x, + Test16x, + Test17x, + Test18x, + Test19f, + Test19x, + Test20x, + Test20f, + Test21x, + Test21f, + Test22x, + Test22f, + Test23x, + Test23f>; INSTANTIATE_TYPED_TEST_SUITE_P( unaligned_gemm_test_suite, unaligned_gemm_test, - tests); \ No newline at end of file + tests); diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 5edf0a0d6..d43f90815 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -40,7 +40,6 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 512 / sg_m; static constexpr size_t dequant_s = 128; - // static constexpr quant_mode quant_mode = quant_mode::I4_ASYM; static constexpr quant_mode quant_mode = quant_mode::I4_SYM; static constexpr size_t local_kslicing = 1; @@ -53,6 +52,7 @@ class test_col_major_1 { using data_type_b = int4x8; using data_type_c = scalar_t; }; + class test_col_major_2 { public: // Extract the parameters required by different test cases @@ -63,8 +63,9 @@ class test_col_major_2 { static constexpr size_t wg_n = 1; static constexpr size_t sg_m = 4; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 1024; - static constexpr size_t dequant_s = 4096; + static constexpr size_t sg_k = 512 / sg_m; + static constexpr size_t dequant_s = 128; + static constexpr quant_mode quant_mode = quant_mode::I4_ASYM; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; diff --git a/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp b/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp index ff5ef1074..bc7f63af9 100644 --- a/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp +++ b/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp @@ -586,13 +586,13 @@ class int4_mlp_gate_mul_up_fwd_t { gemm_args_t up_proj_args, gate_proj_args; if constexpr ( - gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) { ASSIGN_SYM_GEMM_ARG( up_proj_args, mem_desc_up_proj, mem_desc_up_proj_scale) ASSIGN_SYM_GEMM_ARG( gate_proj_args, mem_desc_gate_proj, mem_desc_gate_proj_scale) } else if constexpr ( - gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) { + gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) { DEF_ZP_MEM_DESC(mem_desc_up_zero_pt, up_proj_zero_pt_base) DEF_ZP_MEM_DESC(mem_desc_gate_zero_pt, gate_proj_zero_pt_base) ASSIGN_ASYM_GEMM_ARG( diff --git a/tests/integration/mlp/int4/mlp.cpp b/tests/integration/mlp/int4/mlp.cpp index 1cd765af8..15e4422ce 100644 --- a/tests/integration/mlp/int4/mlp.cpp +++ b/tests/integration/mlp/int4/mlp.cpp @@ -41,8 +41,7 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 512; static constexpr size_t dequant_s = 128; - static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; - // static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; + static constexpr quant_mode quant_mode = quant_mode::I4_ASYM; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -67,8 +66,7 @@ class test_col_major_2 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024 / 4; static constexpr size_t dequant_s = 128; - // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; - static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; + static constexpr quant_mode quant_mode = quant_mode::I4_SYM; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -144,7 +142,7 @@ int int4_mlp_result_validate( } template < - quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::I4_SYM, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -158,7 +156,7 @@ std::vector convert_int4( int8_t zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { int8_t dequant_8bit = data_b & 0xf; - if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (quant_mode == quant_mode::I4_SYM) { dequant_fp16[i] = scale * (dequant_8bit - 8); } else { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); @@ -171,7 +169,7 @@ std::vector convert_int4( template < size_t dequant_s, mem_layout layout_b = mem_layout::col_major, - quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::I4_SYM, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, diff --git a/tests/integration/softmax/softmax_bwd_kernel.hpp b/tests/integration/softmax/softmax_bwd_kernel.hpp index e556c67c4..6341f48ac 100644 --- a/tests/integration/softmax/softmax_bwd_kernel.hpp +++ b/tests/integration/softmax/softmax_bwd_kernel.hpp @@ -34,12 +34,13 @@ struct softmax_bwd_test_func { using work_group_t = typename tile_shape::work_group_t; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpg; using in_block_size = subgroup::get_load_block_size_auto< dtype_in, sg_n, sg_m, - gpu_arch::XeHpc, + arch_tag, mem_layout::row_major, reg_layout::tiled>; static constexpr uint32_t tile_size_x = sg_n; @@ -62,7 +63,7 @@ struct softmax_bwd_test_func { mem_desc_in_t, tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mat_out_t = subgroup::tile_t; using mem_desc_out_t = @@ -71,10 +72,10 @@ struct softmax_bwd_test_func { mem_desc_out_t, tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using softmax_bwd_t = group::softmax_t< - group::softmax_policy_bwd, + group::softmax_policy_bwd, tile_shape>; static constexpr uint32_t barrier_count = softmax_bwd_t::get_barrier_count::count; diff --git a/tests/integration/softmax/softmax_fwd_kernel.hpp b/tests/integration/softmax/softmax_fwd_kernel.hpp index abbd14696..144a9e90b 100644 --- a/tests/integration/softmax/softmax_fwd_kernel.hpp +++ b/tests/integration/softmax/softmax_fwd_kernel.hpp @@ -38,12 +38,13 @@ struct softmax_fwd_test_func { using work_group_t = typename tile_shape::work_group_t; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr gpu_arch arch_tag = gpu_arch::XeHpg; using in_block_size = subgroup::get_load_block_size_auto< dtype_in, sg_n, sg_m, - gpu_arch::XeHpc, + arch_tag, mem_layout::row_major, reg_layout::tiled>; static constexpr uint32_t tile_size_x = sg_n; @@ -60,22 +61,21 @@ struct softmax_fwd_test_func { reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using mat_in_t = subgroup::tile_t; - + using mat_in_payload_t = subgroup::mem_payload_t< mem_desc_in_t, tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; using mat_out_t = subgroup::tile_t; using mat_out_payload_t = subgroup::mem_payload_t< mem_desc_in_t, tile_desc_t, subgroup::msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; - using softmax_fwd_t = group::softmax_t< - group::softmax_policy_fwd, - tile_shape>; + using softmax_fwd_t = group:: + softmax_t, tile_shape>; static constexpr uint32_t barrier_count = softmax_fwd_t::get_barrier_count::count; static constexpr uint32_t slm_size = softmax_fwd_t::get_slm_size::size; diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp index 46040dddd..85d722ce9 100644 --- a/tests/utils/execution.hpp +++ b/tests/utils/execution.hpp @@ -315,7 +315,6 @@ class dispatch_arch { switch (deviceArch) { case ENS::architecture::intel_gpu_pvc: return F::exec(std::forward(args)...); - return; case ENS::architecture::intel_gpu_dg2_g10: case ENS::architecture::intel_gpu_dg2_g11: case ENS::architecture::intel_gpu_dg2_g12: