diff --git a/examples/02_basic_gemm/basic_gemm.cpp b/examples/02_basic_gemm/basic_gemm.cpp index 8be0f6d7e..838bab930 100644 --- a/examples/02_basic_gemm/basic_gemm.cpp +++ b/examples/02_basic_gemm/basic_gemm.cpp @@ -114,8 +114,10 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { // wrap the nd_range to XeTLA range // Performance tuning setting based on different shapes - static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 1; + static constexpr uint32_t periodic_sync_interval = + (arch_tag == gpu_arch::XeHpc) ? 8 : 0; + static constexpr uint32_t prefetch_distance = + (arch_tag == gpu_arch::XeHpc) ? 3 : 1; // should larger than 8 static constexpr uint32_t k_stride = 32; diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index a4d487ff1..66bba66c6 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -198,12 +198,14 @@ struct mma_attr_t< m, std::enable_if_t>> { using dpas_attr = dpas_attr_t; + using load_store_attr = load_store_attr_t; static constexpr uint32_t mma_m_in_elem = (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m; static constexpr uint32_t blk_m_in_elem = 16; static constexpr uint32_t mma_n_in_elem = dpas_attr::n_in_elem; - [[maybe_unused]] static constexpr uint32_t blk_n_in_bytes = 64; + [[maybe_unused]] static constexpr uint32_t blk_n_in_bytes = + load_store_attr::max_trans_load_width_in_bytes; static constexpr uint32_t mma_k_in_bytes = dpas_attr::k_in_bytes; static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes; @@ -224,8 +226,7 @@ struct mma_attr_t< load_store_attr::max_trans_load_width_in_bytes; [[maybe_unused]] static constexpr uint32_t mma_n_in_elem = 16; - static constexpr uint32_t blk_n_in_bytes = - register_bytes_t::reg_in_bytes; + static constexpr uint32_t blk_n_in_bytes = blk_k_in_bytes; }; template diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index cf3ab0173..3f52569ab 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -176,6 +176,8 @@ class gemm_t< // note: plane format, row-major // note: 4bit x 2, row-major + static_assert(tile_size_x_b % pack_ratio == 0); + static_assert(block_size_x_b % pack_ratio == 0); using matB_tile_desc_t = subgroup::tile_desc_t< tile_size_x_b / pack_ratio, tile_size_y_b, diff --git a/include/group/epilogue/impl/unaligned_xe.hpp b/include/group/epilogue/impl/unaligned_xe.hpp index c8ca5385a..de350b804 100644 --- a/include/group/epilogue/impl/unaligned_xe.hpp +++ b/include/group/epilogue/impl/unaligned_xe.hpp @@ -45,6 +45,8 @@ class epilogue_t< static constexpr uint32_t slm_size = mem_desc_c_t::is_local ? tile_shape::wg_tile_size_x * tile_shape::wg_tile_size_y : 0; + static constexpr bool ldc_align16 = + (mem_desc_c_t::alignment_in_bytes % 16 == 0); /// @brief Epilogue arguments. struct arguments_t {}; @@ -71,7 +73,7 @@ class epilogue_t< public: static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::unaligned_2d + (mem_space_c == mem_space::global ? (ldc_align16 ? msg_type::block_2d : msg_type::unaligned_2d) : msg_type::scatter); /// @brief Default epilogue. diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index f089f84d6..2dcf7e756 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -48,6 +48,7 @@ struct compute_policy_default_xmx< arch_tag_, std::enable_if_t>> { static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr mma_engine mma_engine = mma_engine::xmx; using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; @@ -59,7 +60,8 @@ struct compute_policy_default_xmx< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem; static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = @@ -107,6 +109,7 @@ struct compute_policy_default_fpu< arch_tag_, std::enable_if_t>> { static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr mma_engine mma_engine = mma_engine::fpu; using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; @@ -118,17 +121,26 @@ struct compute_policy_default_fpu< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - using mma_attr = mma_attr_t; + using mma_attr = mma_attr_t; static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem; static constexpr uint32_t block_bytes_x_a = mma_attr::blk_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); + static constexpr uint32_t block_bytes_x_b = mma_attr::blk_n_in_bytes; static constexpr uint32_t block_size_x_b = block_bytes_x_b / sizeof(dtype_mma_b); static constexpr uint32_t block_size_y_b = block_size_x_a; }; +template < + typename compute_attr_, + typename perf_tuning_knob_, + gpu_arch arch_tag_> +struct compute_policy_unaligned_fpu : public compute_policy_default_fpu< + compute_attr_, + perf_tuning_knob_, + arch_tag_> {}; /// @} xetla_gemm } // namespace gpu::xetla::group diff --git a/include/group/gemm/gemm.hpp b/include/group/gemm/gemm.hpp index ac5d43f16..db2b4b3a4 100644 --- a/include/group/gemm/gemm.hpp +++ b/include/group/gemm/gemm.hpp @@ -30,5 +30,6 @@ #include #include #include +#include #include #include diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index 9af4e366f..a18b2434d 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -160,6 +160,7 @@ class gemm_t< subgroup::tile_desc_t, 1, arch_tag>; + matA_prefetch_payload_t matA_prefetch_payload; static constexpr reg_layout reg_layout_b = reg_layout::tiled; using matB_tile_desc_t = subgroup::tile_desc_t< @@ -180,6 +181,7 @@ class gemm_t< subgroup::tile_desc_t, 1, arch_tag>; + matB_prefetch_payload_t matB_prefetch_payload; public: using matAcc_tile_desc_t = subgroup::tile_desc_t< @@ -202,7 +204,7 @@ class gemm_t< static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; static constexpr bool need_local_fence = - (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); [[maybe_unused]] xetla_nbarrier_t barrier_all; [[maybe_unused]] xetla_nbarrier_t nbarrier_a; @@ -210,8 +212,9 @@ class gemm_t< public: static constexpr uint32_t barrier_count = - (enable_periodic_sync && arch_has_named_barrier) ? - barrier_count_x + barrier_count_y : 0; + (enable_periodic_sync && arch_has_named_barrier) + ? barrier_count_x + barrier_count_y + : 0; // current no slm path static constexpr uint32_t slm_size = 0; @@ -293,18 +296,23 @@ class gemm_t< }; inline void periodic_sync_init( - [[maybe_unused]] int32_t sg_idx, - [[maybe_unused]] int32_t sg_idy, - uint32_t nbarrier_base) { + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { if constexpr (enable_periodic_sync) { if constexpr (arch_has_named_barrier) { - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); - } else { - barrier_all.init_nbarrier(nbarrier_base, nbarrier_role::producer_consumer); + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); } } } @@ -319,7 +327,7 @@ class gemm_t< if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } - } else { + } else if constexpr (wg_size > 1) { barrier_all.arrive(); } } @@ -336,13 +344,25 @@ class gemm_t< if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } - } else { + } else if constexpr (wg_size > 1) { barrier_all.wait(); } } } } + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -400,51 +420,39 @@ class gemm_t< pre_processing.init(g, args.pre_processing_args); matA_payload_t matA_payload(args.matA_base_desc); matB_payload_t matB_payload(args.matB_base_desc); - matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, 0); - matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, 0); + matA_prefetch_payload.init(args.matA_base_desc, 0); + matB_prefetch_payload.init(args.matB_base_desc, 0); periodic_sync_init(sg_idx, sg_idy, nbarrier_base); #pragma unroll for (uint32_t i = 0; i < stages; i++) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } for (uint32_t i = 0; i < args.inner_loop_count; i++) { periodic_sync_arrive(i); - SW_BARRIER(); - subgroup::tile_load( matA, matA_payload); subgroup::tile_load( matB, matB_payload); + SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); + if constexpr (stages != 0) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } + + SW_BARRIER(); matA_acc_t matA_acc; matB_acc_t matB_acc; subgroup::elemwise_cvt(matA_acc, matA); subgroup::elemwise_cvt(matB_acc, matB); pre_processing(matA_acc, matB_acc, matA, matB); - SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); + periodic_sync_wait(i); } SW_BARRIER(); diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index c7e7856e7..c05aad31a 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -145,6 +145,7 @@ class gemm_t< subgroup::tile_desc_t, wg_size_x, arch_tag>; + matA_prefetch_payload_t matA_prefetch_payload; static constexpr reg_layout reg_layout_b = sizeof(dtype_b) < sizeof(uint32_t) ? reg_layout::vnni_tiled : reg_layout::tiled; @@ -166,6 +167,7 @@ class gemm_t< subgroup::tile_desc_t, wg_size_y, arch_tag>; + matB_prefetch_payload_t matB_prefetch_payload; public: using matAcc_tile_desc_t = subgroup::tile_desc_t< @@ -284,12 +286,16 @@ class gemm_t< uint32_t nbarrier_base) { if constexpr (enable_periodic_sync) { if constexpr (arch_has_named_barrier) { - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); - } else { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { barrier_all.init_nbarrier( nbarrier_base, nbarrier_role::producer_consumer); } @@ -306,7 +312,7 @@ class gemm_t< if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } - } else { + } else if constexpr (wg_size > 1) { barrier_all.arrive(); } } @@ -323,13 +329,25 @@ class gemm_t< if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } - } else { + } else if constexpr (wg_size > 1) { barrier_all.wait(); } } } } + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -394,21 +412,14 @@ class gemm_t< pre_processing.init(g, args.pre_processing_args); matA_payload_t matA_payload(args.matA_base_desc); matB_payload_t matB_payload(args.matB_base_desc); - matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, sg_idx); - matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, sg_idy); + matA_prefetch_payload.init(args.matA_base_desc, sg_idx); + matB_prefetch_payload.init(args.matB_base_desc, sg_idy); periodic_sync_init(sg_idx, sg_idy, nbarrier_base); #pragma unroll for (uint32_t i = 0; i < stages; i++) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } for (uint32_t i = 0; i < args.inner_loop_count; i++) { @@ -418,31 +429,23 @@ class gemm_t< matB, matB_payload); subgroup::tile_load( matA, matA_payload); - if constexpr (stages != 0) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); + if constexpr (stages != 0) { - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } + SW_BARRIER(); matA_acc_t matA_acc; matB_acc_t matB_acc; subgroup::elemwise_cvt(matA_acc, matA); subgroup::vnni_transform(matB_acc, matB); pre_processing(matA_acc, matB_acc, matA, matB); - SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); - SW_BARRIER(); + periodic_sync_wait(i); } SW_BARRIER(); diff --git a/include/group/gemm/impl/selector_xe.hpp b/include/group/gemm/impl/selector_xe.hpp index 6f32fb67b..68560ae7b 100644 --- a/include/group/gemm/impl/selector_xe.hpp +++ b/include/group/gemm/impl/selector_xe.hpp @@ -25,26 +25,33 @@ namespace gpu::xetla::group { namespace detail { +template +class check_block_2d_pitch_alignment { + using load_store_attr = load_store_attr_t; + static constexpr int alignment_in_bytes = load_store_attr::alignment_in_bytes; + static constexpr int alignment_bytes = alignment * sizeof(dtype); + + public: + static constexpr bool value = (alignment_bytes % alignment_in_bytes == 0); +}; + +} // namespace detail + template < typename dtype_a, typename dtype_b, int alignment_a, int alignment_b, gpu_arch arch_tag> -class check_2d_block_pitch_alignment { - using load_store_attr = typename arch_attr_t< - arch_tag>::template load_store_attr; - static constexpr int alignment_bytes = load_store_attr::alignment_in_bytes; - static constexpr int alignment_bytes_a = alignment_a * sizeof(dtype_a); - static constexpr int alignment_bytes_b = alignment_b * sizeof(dtype_b); - +class check_block_2d_pitch_alignment { public: - static constexpr bool value = (alignment_bytes_a % alignment_bytes == 0) && - (alignment_bytes_b % alignment_bytes == 0); + static constexpr int a_align = detail:: + check_block_2d_pitch_alignment::value; + static constexpr int b_align = detail:: + check_block_2d_pitch_alignment::value; + static constexpr bool value = a_align && b_align; }; -} // namespace detail - /// @addtogroup xetla_gemm /// @{ @@ -80,7 +87,7 @@ class gemm_selector_t< arch_tag, stages, sync_freq, - std::enable_if_t; using mem_desc_b = mem_desc_t; + using ld_align_attr = check_block_2d_pitch_alignment< + dtype_a, + dtype_b, + alignment_a, + alignment_b, + arch_tag>; using compute_attr = compute_attr_t; using perf_tuning_knob = perf_tuning_knob_t; using compute_policy = @@ -194,17 +207,12 @@ class gemm_selector_t< arch_tag, stages, sync_freq, - std::enable_if_t::value>> { - static_assert( - std::is_same::value && - std::is_same::value, - "When use gemm_selector, dtype_a and dtype_b in fpu based gemm" - "should be the same as dtype_acc"); using mem_desc_a = mem_desc_t; using mem_desc_b = @@ -224,5 +232,68 @@ class gemm_selector_t< pre_processing>; }; +/// @brief Selects 2d block && fpu based gemm. +template < + typename dtype_a, + typename dtype_b, + mem_layout mem_layout_a, + mem_layout mem_layout_b, + mem_space mem_space_a, + mem_space mem_space_b, + int alignment_a, + int alignment_b, + typename dtype_acc, + typename tile_shape, + int k_stride, + gpu_arch arch_tag, + int stages, + int sync_freq> +class gemm_selector_t< + dtype_a, + dtype_b, + mem_layout_a, + mem_layout_b, + mem_space_a, + mem_space_b, + alignment_a, + alignment_b, + dtype_acc, + tile_shape, + k_stride, + mma_engine::fpu, + arch_tag, + stages, + sync_freq, + std::enable_if_t::value>> { + using mem_desc_a = + mem_desc_t; + using mem_desc_b = + mem_desc_t; + using ld_align_attr = check_block_2d_pitch_alignment< + dtype_a, + dtype_b, + alignment_a, + alignment_b, + arch_tag>; + using compute_attr = compute_attr_t; + using perf_tuning_knob = perf_tuning_knob_t; + using compute_policy = + compute_policy_unaligned_fpu; + using pre_processing = pre_processing_default_t; + + public: + using gemm = gemm_t< + compute_policy, + tile_shape, + mem_desc_a, + mem_desc_b, + pre_processing>; +}; + /// @} xetla_gemm -} // namespace gpu::xetla::group \ No newline at end of file +} // namespace gpu::xetla::group diff --git a/include/group/gemm/impl/unaligned_fpu_xe.hpp b/include/group/gemm/impl/unaligned_fpu_xe.hpp new file mode 100644 index 000000000..b53aa02ad --- /dev/null +++ b/include/group/gemm/impl/unaligned_fpu_xe.hpp @@ -0,0 +1,662 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include + +namespace gpu::xetla::group { + +/// @addtogroup xetla_gemm +/// @{ + +/// @brief Is the gemm functor for unaligned input, Xe architecture and matrix +/// engine. +template < + typename compute_attr_, + typename perf_tuning_knob_, + typename tile_shape_, + typename mem_desc_a_t_, + typename mem_desc_b_t_, + typename pre_processing_t_, + gpu_arch arch_tag_> +class gemm_t< + compute_policy_unaligned_fpu, + tile_shape_, // tile shape of workgroup-level gemm + mem_desc_a_t_, // memory attribute of matA + mem_desc_b_t_, // memory attribute of matB + pre_processing_t_, // pre_processing functor + std::enable_if_t>> { + public: + using mem_desc_a_t = mem_desc_a_t_; + using mem_desc_b_t = mem_desc_b_t_; + using tile_shape = tile_shape_; + using pre_processing_t = pre_processing_t_; + using compute_policy = + compute_policy_unaligned_fpu; + + static constexpr uint32_t num_cyclic = 2; + + static constexpr uint32_t k_stride = compute_policy::k_stride; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; + using work_group_t = typename tile_shape::work_group_t; + + constexpr static gpu_arch arch_tag = arch_tag_; + + static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; + static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; + static constexpr bool is_col_major_a = + (mem_layout_a == mem_layout::col_major); + static constexpr bool is_col_major_b = + (mem_layout_b == mem_layout::col_major); + static constexpr bool lda_align16 = + (mem_desc_a_t::alignment_in_bytes % 16 == 0); + static constexpr bool ldb_align16 = + (mem_desc_b_t::alignment_in_bytes % 16 == 0); + + private: + /******** set data type **********/ + using dtype_a = typename mem_desc_a_t::dtype; + using dtype_b = typename mem_desc_b_t::dtype; + using dtype_mma_acc = typename compute_policy::dtype_mma_acc; + using dtype_mma_a = typename compute_policy::dtype_mma_a; + using dtype_mma_b = typename compute_policy::dtype_mma_b; + + using check_dtype = + group::gemm::default_fpu::template check_dtype_default< + dtype_a, + dtype_b, + dtype_mma_a, + dtype_mma_b, + dtype_mma_acc>; + + /******** set memory attribute **********/ + static constexpr mem_space mem_space_a = mem_desc_a_t::space; + static constexpr mem_space mem_space_b = mem_desc_b_t::space; + + static constexpr bool is_local_a = mem_space_a == mem_space::local; + static constexpr bool is_local_b = mem_space_b == mem_space::local; + static constexpr tdesc_update_dir update_dir_a = + is_col_major_a ? tdesc_update_dir::y_dir : tdesc_update_dir::x_dir; + static constexpr tdesc_update_dir update_dir_b = + is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; + + using check_memory = + group::gemm::default_fpu::template check_memory_default< + mem_layout_a, + mem_layout_b, + mem_space_a, + mem_space_b>; + + static constexpr uint32_t stages = compute_policy::stages; + static constexpr uint32_t sync_freq = compute_policy::sync_freq; + + /******** set tile layout && worker scope **********/ + static constexpr uint32_t tile_size_x_a = k_stride; + static constexpr uint32_t tile_size_y_a = sg_tile_m; + static constexpr uint32_t tile_size_x_b = sg_tile_n; + static constexpr uint32_t tile_size_y_b = k_stride; + static constexpr uint32_t tile_size_x_c = sg_tile_n; + static constexpr uint32_t tile_size_y_c = sg_tile_m; + + static constexpr uint32_t block_size_x_a = + (compute_policy::block_size_x_a > tile_size_x_a) + ? tile_size_x_a + : compute_policy::block_size_x_a; + static constexpr uint32_t block_size_y_a = + (compute_policy::block_size_y_a > tile_size_y_a) + ? tile_size_y_a + : compute_policy::block_size_y_a; + static constexpr uint32_t block_size_x_b = + (compute_policy::block_size_x_b > tile_size_x_b) + ? tile_size_x_b + : compute_policy::block_size_x_b; + static constexpr uint32_t block_size_y_b = + (compute_policy::block_size_y_b > tile_size_y_b) + ? tile_size_y_b + : compute_policy::block_size_y_b; + + using check_tile_size = + group::gemm::default_fpu::template check_tile_size_default< + dtype_mma_a, + tile_size_x_a, + tile_size_y_a, + block_size_x_a, + block_size_y_a, + tile_size_x_b, + tile_size_y_b, + block_size_x_b, + block_size_y_b>; + + /******** set tile **********/ + static constexpr reg_layout reg_layout_a = reg_layout::tiled; + + [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; + + using matA_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_a, + tile_size_y_a, + block_size_x_a, + block_size_y_a, + reg_layout_a>; + + using matA_t = subgroup::tile_t; + + using cooperative_helper_A_t = subgroup::cooperative_load_helper_t< + matA_t, + mem_layout_a, + tile_shape::wg_size_x, + arch_tag>; + using cooperative_tile_desc_A_t = + typename cooperative_helper_A_t::co_tile_desc_t; + using partial_matA_t = subgroup::tile_t; + using matA_payload_t = subgroup::mem_payload_t< + mem_desc_a_t, + cooperative_tile_desc_A_t, + is_local_a ? msg_type::scatter + : lda_align16 ? msg_type::block_2d + : msg_type::unaligned_2d, + arch_tag>; + + using matA_payload_local_st_t = subgroup::mem_payload_t< + mem_desc_t, + cooperative_tile_desc_A_t, + msg_type::scatter, + arch_tag>; + using matA_payload_local_ld_t = subgroup::mem_payload_t< + mem_desc_t, + matA_tile_desc_t, + msg_type::scatter, + arch_tag>; + + using matA_acc_t = subgroup::tile_t; + using matA_prefetch_payload_t = subgroup::prefetch_payload_t< + mem_desc_a_t, + subgroup::tile_desc_t, + wg_size_x, + arch_tag>; + matA_prefetch_payload_t matA_prefetch_payload; + + static constexpr reg_layout reg_layout_b = reg_layout::tiled; + using matB_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_b, + tile_size_y_b, + block_size_x_b, + block_size_y_b, + reg_layout_b>; + using matB_t = subgroup::tile_t; + + using cooperative_helper_B_t = subgroup::cooperative_load_helper_t< + matB_t, + mem_layout_b, + tile_shape::wg_size_y, + arch_tag>; + using cooperative_tile_desc_B_t = + typename cooperative_helper_B_t::co_tile_desc_t; + + using partial_matB_t = subgroup::tile_t; + + using matB_payload_t = subgroup::mem_payload_t< + mem_desc_b_t, + cooperative_tile_desc_B_t, + is_local_b ? msg_type::scatter + : ldb_align16 ? msg_type::block_2d + : msg_type::unaligned_2d, + arch_tag>; + + using matB_payload_local_st_t = subgroup::mem_payload_t< + mem_desc_t, + cooperative_tile_desc_B_t, + msg_type::scatter, + arch_tag>; + using matB_payload_local_ld_t = subgroup::mem_payload_t< + mem_desc_t, + matB_tile_desc_t, + msg_type::scatter, + arch_tag>; + + using matB_acc_t = subgroup::tile_t; + using matB_prefetch_payload_t = subgroup::prefetch_payload_t< + mem_desc_b_t, + subgroup::tile_desc_t, + wg_size_y, + arch_tag>; + matB_prefetch_payload_t matB_prefetch_payload; + + public: + using matAcc_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_c, + tile_size_y_c, + block_size_x_b, + block_size_y_a, + reg_layout::tiled>; + using matAcc_t = subgroup::tile_t; + + private: + using tile_mma = subgroup::tile_mma_t< + matAcc_t, + matAcc_t, + matB_acc_t, + matA_acc_t, + mma_engine::fpu, + arch_tag>; + // static constexpr bool enable_periodic_sync = (sync_freq != 0); + static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; + static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr uint32_t tile_size_a = + tile_size_x_a * tile_size_y_a * sizeof(dtype_a); + static constexpr uint32_t tile_size_b = + tile_size_x_b * tile_size_y_b * sizeof(dtype_b); + static constexpr uint32_t slm_size_a = wg_size_y * tile_size_a; + static constexpr uint32_t slm_size_b = wg_size_x * tile_size_b; + + public: + static constexpr uint32_t barrier_count = + arch_has_named_barrier ? barrier_count_x + barrier_count_y : 0; + + static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic; + static_assert(slm_size <= arch_attr_t::local_mem_size); + static constexpr uint32_t slm_base_a = 0; + static constexpr uint32_t slm_base_b = slm_size_a * num_cyclic; + + static constexpr msg_type msg_type_a = matA_payload_t::message_type; + static constexpr msg_type msg_type_b = matB_payload_t::message_type; + + using pre_processing_arg_t = typename pre_processing_t::arguments_t; + + /// @brief Arguments for gemm. + /// User should prepare matA_base_desc, matB_base_desc, inner_loop_count... + struct arguments_t { + /// @brief Is the memory description of matA, including base, shape and + /// coordinate. + mem_desc_a_t matA_base_desc; + /// @brief Is the memory description of matB, including base, shape and + /// coordinate. + mem_desc_b_t matB_base_desc; + /// @brief Is the total inner loop count required to compute the entire + /// K-dim. + uint32_t inner_loop_count; + /// @brief Is the arguments for pre-processing functor. + pre_processing_arg_t pre_processing_args; + + /// @brief Default construct. + inline arguments_t() = default; + // Be aware of the risks: Rule of three (copy constructor, copy assignment, + // destructor) Please check if you need to add self-define destructor + // ~arguments_t(){} + + /// @brief Constructs a new arguments t object. + /// @param matA_desc Is the memory description of matA, including base, + /// shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, + /// shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the + /// entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline arguments_t( + mem_desc_a_t matA_desc, + mem_desc_b_t matB_desc, + uint32_t loop_count, + pre_processing_arg_t args = {}) + : matA_base_desc(matA_desc), + matB_base_desc(matB_desc), + inner_loop_count(loop_count), + pre_processing_args(args) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, + // destructor) Please check if you need to add self-define destructor inline + // ~arguments_t(){} + inline arguments_t(const arguments_t& args) + : matA_base_desc(args.matA_base_desc), + matB_base_desc(args.matB_base_desc), + inner_loop_count(args.inner_loop_count), + pre_processing_args(args.pre_processing_args) {} + inline arguments_t& operator=(const arguments_t& args) { + this->matA_base_desc = args.matA_base_desc; + this->matB_base_desc = args.matB_base_desc; + this->inner_loop_count = args.inner_loop_count; + this->pre_processing_args = args.pre_processing_args; + return *this; + } + + /// @brief Explicit initialization function. + /// @param matA_desc Is the memory description of matA, including base, + /// shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, + /// shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the + /// entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline void init( + mem_desc_a_t matA_desc, + mem_desc_b_t matB_desc, + uint32_t loop_count, + pre_processing_arg_t args = {}) { + matA_base_desc = matA_desc; + matB_base_desc = matB_desc; + inner_loop_count = loop_count; + pre_processing_args = args; + } + }; + + inline void sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + + inline void sync_arrive() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive(); + } + } + + inline void sync_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.wait(); + } + } + + inline void sync_arrive_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive_wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive_wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive_wait(); + } + } + + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + + /// @brief Gets the subgroup-level tile offset x. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset x. + __XETLA_API static int get_matC_offset_x(work_group_t& g) { + int32_t sg_idx = g.get_id() % wg_size_x; + return sg_idx * sg_tile_n; + } + + /// @brief Gets the subgroup-level tile offset y. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset y. + __XETLA_API static int get_matC_offset_y(work_group_t& g) { + int32_t sg_idy = g.get_id() / wg_size_x; + return sg_idy * sg_tile_m; + } + + XETLA_MARKER( + "This release function will wait until all the r/w and nbarrier " + "id used in this gemm have been committed. By default, it will " + "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. " + "If you call this function, please set a free barrier id or make " + "sure barrier_id 0 is not being occupied and you need to allocate " + "one more barrier count in addition to the gemm barrier counts.") + __XETLA_API void release(uint8_t nbarrier_id = 0) { + static constexpr bool need_local_fence = + (mem_space_a == mem_space::local) || (mem_space_b == mem_space::local); + if constexpr (need_local_fence) { + xetla_fence(); + } + xetla_fence(); + if constexpr (wg_size > 1) { + barrier_all.init_nbarrier(nbarrier_id, nbarrier_role::producer_consumer); + barrier_all.arrive_wait(); + } + } + + /// @brief Main execution function for gemm. + /// The basic process is load data -> matrix multiply. + /// @param g Is the workgroup of the current tile. + /// @param matAcc Is the reference of the accumulation buffer. + /// @param args Is the gemm::arguments_t. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + __XETLA_API KERNEL_FUNC void operator()( + work_group_t& g, + matAcc_t& matAcc, + arguments_t args, + uint32_t slm_base = 0, + uint32_t nbarrier_base = 0) { + int32_t sg_idx = g.get_id() % wg_size_x; + int32_t sg_idy = g.get_id() / wg_size_x; + + XETLA_ASSERT( + g.get_id() < wg_size, + "Thread id(%d) should less than wg_size(%d)", + g.get_id(), + wg_size); + + update_sg_tile_tdesc(args, sg_idx, sg_idy); + pre_processing_t pre_processing; + matA_t matA; + matB_t matB; + matA_acc_t matA_acc; + matB_acc_t matB_acc; + partial_matA_t partial_matA; + partial_matB_t partial_matB; + // >>>>>>>>>>>>>>>>>> pre_processing init + pre_processing.init(g, args.pre_processing_args); + uint32_t base_A = slm_base + slm_base_a + sg_idy * tile_size_a; + uint32_t base_B = slm_base + slm_base_b + sg_idx * tile_size_b; + + uint32_t store_idx = 0; + uint32_t load_idx = 0; + + matA_payload_t matA_payload(args.matA_base_desc); + matA_payload_local_st_t matA_local_st_payload( + base_A, + tile_size_x_a, + tile_size_y_a, + tile_size_x_a, + cooperative_helper_A_t::get_offset_x(sg_idx), + cooperative_helper_A_t::get_offset_y(sg_idx)); + matA_payload_local_ld_t matA_local_ld_payload( + base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, 0, 0); + + matB_payload_t matB_payload(args.matB_base_desc); + matB_payload_local_st_t matB_local_st_payload( + base_B, + tile_size_x_b, + tile_size_y_b, + tile_size_x_b, + cooperative_helper_B_t::get_offset_x(sg_idy), + cooperative_helper_B_t::get_offset_y(sg_idy)); + matB_payload_local_ld_t matB_local_ld_payload( + base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, 0, 0); + + matA_prefetch_payload.init(args.matA_base_desc, sg_idx); + matB_prefetch_payload.init(args.matB_base_desc, sg_idy); + + sync_init(sg_idx, sg_idy, nbarrier_base); + +#pragma unroll + for (uint32_t i = 0; i < stages; i++) { + prefetch_and_update_ab(); + } + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx++; + +#pragma unroll + for (uint32_t i = 1; i < num_cyclic - 1; i++) { + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + + matA_local_st_payload.template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_st_payload.template update_tdesc( + wg_size_x * matB_t::tile_size_y); + + if constexpr (stages != 0) { + prefetch_and_update_ab(); + } + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx++; + } + + xetla_fence(); + sync_arrive_wait(); + + for (uint32_t i = 0; i < args.inner_loop_count - 1; i++) { + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + + load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; + if (load_idx != 0) { + matA_local_ld_payload.template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_ld_payload.template update_tdesc( + wg_size_x * matB_t::tile_size_y); + } else { + matA_local_ld_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_y * matA_t::tile_size_y); + matB_local_ld_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); + } + + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::elemwise_cvt(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + + if (store_idx != 0) { + matA_local_st_payload.template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_st_payload.template update_tdesc( + wg_size_x * matB_t::tile_size_y); + } else { + matA_local_st_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_y * matA_t::tile_size_y); + matB_local_st_payload.template update_tdesc( + (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); + } + + if constexpr (stages != 0) { + prefetch_and_update_ab(); + } + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; + + xetla_fence(); + sync_arrive_wait(); + } + + SW_BARRIER(); + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::elemwise_cvt(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + } + + private: + /// @brief Updates tile base descriptor based on the tid. + __XETLA_API static void update_sg_tile_tdesc( + arguments_t& args, + int32_t sg_idx, + int32_t sg_idy) { + int32_t tile_offset_n = sg_idx * sg_tile_n; + int32_t tile_offset_m = sg_idy * sg_tile_m; + + args.matA_base_desc.update_coord_y( + tile_offset_m + cooperative_helper_A_t::get_offset_y(sg_idx)); + args.matA_base_desc.update_coord_x( + cooperative_helper_A_t::get_offset_x(sg_idx)); + args.matB_base_desc.update_coord_x( + tile_offset_n + cooperative_helper_B_t::get_offset_x(sg_idy)); + args.matB_base_desc.update_coord_y( + cooperative_helper_B_t::get_offset_y(sg_idy)); + } +}; + +/// @} xetla_gemm + +} // namespace gpu::xetla::group diff --git a/include/group/gemm/impl/unaligned_xmx_xe.hpp b/include/group/gemm/impl/unaligned_xmx_xe.hpp index 1e2f134a6..0246b2fb3 100644 --- a/include/group/gemm/impl/unaligned_xmx_xe.hpp +++ b/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -52,7 +52,7 @@ class gemm_t< using compute_policy = compute_policy_unaligned_xmx; - static constexpr uint32_t num_cyclic = 3; + static constexpr uint32_t num_cyclic = 2; static constexpr uint32_t k_stride = compute_policy::k_stride; static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; @@ -66,8 +66,14 @@ class gemm_t< static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; - static constexpr bool is_col_major_a = mem_layout_a == mem_layout::col_major; - static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; + static constexpr bool is_col_major_a = + (mem_layout_a == mem_layout::col_major); + static constexpr bool is_col_major_b = + (mem_layout_b == mem_layout::col_major); + static constexpr bool lda_align16 = + (mem_desc_a_t::alignment_in_bytes % 16 == 0); + static constexpr bool ldb_align16 = + (mem_desc_b_t::alignment_in_bytes % 16 == 0); private: /******** set data type **********/ @@ -132,6 +138,8 @@ class gemm_t< static constexpr reg_layout reg_layout_a = reg_layout::tiled; [[maybe_unused]] xetla_nbarrier_t barrier_all; + [[maybe_unused]] xetla_nbarrier_t nbarrier_a; + [[maybe_unused]] xetla_nbarrier_t nbarrier_b; using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, @@ -153,7 +161,9 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, cooperative_tile_desc_A_t, - is_local_a ? msg_type::scatter : msg_type::unaligned_2d, + is_local_a ? msg_type::scatter + : lda_align16 ? msg_type::block_2d + : msg_type::unaligned_2d, arch_tag>; using matA_payload_local_st_t = subgroup::mem_payload_t< @@ -173,6 +183,7 @@ class gemm_t< subgroup::tile_desc_t, wg_size_x, arch_tag>; + static constexpr reg_layout reg_layout_b = sizeof(dtype_b) < sizeof(uint32_t) ? reg_layout::vnni_tiled : reg_layout::tiled; @@ -197,7 +208,9 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, cooperative_tile_desc_B_t, - is_local_b ? msg_type::scatter : msg_type::unaligned_2d, + is_local_b ? msg_type::scatter + : ldb_align16 ? msg_type::block_2d + : msg_type::unaligned_2d, arch_tag>; using matB_payload_local_st_t = subgroup::mem_payload_t< @@ -252,11 +265,20 @@ class gemm_t< static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic; static_assert(slm_size <= arch_attr_t::local_mem_size); static constexpr uint32_t slm_base_a = 0; - static constexpr uint32_t slm_base_b = 0 + slm_size_a * num_cyclic; + static constexpr uint32_t slm_base_b = slm_size_a * num_cyclic; static constexpr msg_type msg_type_a = matA_payload_t::message_type; static constexpr msg_type msg_type_b = matB_payload_t::message_type; + matA_payload_t matA_payload; + matA_payload_local_st_t matA_local_st_payload; + matA_payload_local_ld_t matA_local_ld_payload; + matA_prefetch_payload_t matA_prefetch_payload; + matB_payload_t matB_payload; + matB_payload_local_st_t matB_local_st_payload; + matB_payload_local_ld_t matB_local_ld_payload; + matB_prefetch_payload_t matB_prefetch_payload; + using pre_processing_arg_t = typename pre_processing_t::arguments_t; /// @brief Arguments for gemm. @@ -333,6 +355,77 @@ class gemm_t< } }; + inline void sync_init( + [[maybe_unused]] int32_t sg_idx, + [[maybe_unused]] int32_t sg_idy, + uint32_t nbarrier_base) { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.init_nbarrier( + sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + } + } else if constexpr (wg_size > 1) { + barrier_all.init_nbarrier( + nbarrier_base, nbarrier_role::producer_consumer); + } + } + + inline void sync_arrive() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive(); + } + } + + inline void sync_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.wait(); + } + } + + inline void sync_arrive_wait() { + if constexpr (arch_has_named_barrier) { + if constexpr (wg_size_x > 1) { + nbarrier_a.arrive_wait(); + } + if constexpr (wg_size_y > 1) { + nbarrier_b.arrive_wait(); + } + } else if constexpr (wg_size > 1) { + barrier_all.arrive_wait(); + } + } + + inline void prefetch_and_update_ab() { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + SW_BARRIER(); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + /// @brief Gets the subgroup-level tile offset x. /// @param g Is the workgroup of the current tile. /// @return Subgroup-level tile offset x. @@ -380,138 +473,113 @@ class gemm_t< work_group_t& g, matAcc_t& matAcc, arguments_t args, - [[maybe_unused]] uint32_t slm_base = 0, + uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { int32_t sg_idx = g.get_id() % wg_size_x; int32_t sg_idy = g.get_id() / wg_size_x; XETLA_ASSERT( - g.get_id() < (wg_size_x * wg_size_y), + g.get_id() < wg_size, "Thread id(%d) should less than wg_size(%d)", g.get_id(), - wg_size_x * wg_size_y); + wg_size); update_sg_tile_tdesc(args, sg_idx, sg_idy); pre_processing_t pre_processing; matA_t matA; matB_t matB; + matA_acc_t matA_acc; + matB_acc_t matB_acc; partial_matA_t partial_matA; partial_matB_t partial_matB; // >>>>>>>>>>>>>>>>>> pre_processing init pre_processing.init(g, args.pre_processing_args); - uint32_t base_A = slm_base_a + sg_idy * tile_size_a; - uint32_t base_B = slm_base_b + sg_idx * tile_size_b; - - uint32_t store_idx = 0; - uint32_t load_idx = 0; - matA_payload_t matA_payload(args.matA_base_desc); - matA_payload_local_st_t matA_local_st_payload( + uint32_t base_A = slm_base + slm_base_a + sg_idy * tile_size_a; + matA_payload.init(args.matA_base_desc); + matA_local_st_payload.init( base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, cooperative_helper_A_t::get_offset_x(sg_idx), cooperative_helper_A_t::get_offset_y(sg_idx)); - matA_payload_local_ld_t matA_local_ld_payload( + matA_local_ld_payload.init( base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, 0, 0); + matA_prefetch_payload.init(args.matA_base_desc, sg_idx); - matB_payload_t matB_payload(args.matB_base_desc); - matB_payload_local_st_t matB_local_st_payload( + uint32_t base_B = slm_base + slm_base_b + sg_idx * tile_size_b; + matB_payload.init(args.matB_base_desc); + matB_local_st_payload.init( base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, cooperative_helper_B_t::get_offset_x(sg_idy), cooperative_helper_B_t::get_offset_y(sg_idy)); - matB_payload_local_ld_t matB_local_ld_payload( + matB_local_ld_payload.init( base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, 0, 0); + matB_prefetch_payload.init(args.matB_base_desc, sg_idy); - matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, sg_idx); - matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, sg_idy); - - xetla_nbarrier_t nbarrier_a; - nbarrier_a.init_nbarrier( - sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + sync_init(sg_idx, sg_idy, nbarrier_base); - xetla_nbarrier_t nbarrier_b; - nbarrier_b.init_nbarrier( - sg_idx + barrier_count_y + nbarrier_base, - nbarrier_role::producer_consumer); + uint32_t store_idx = 0; + uint32_t load_idx = 0; tile_load(partial_matA, matA_payload); tile_load(partial_matB, matB_payload); - - tile_store(partial_matA, matA_local_st_payload); - tile_store(partial_matB, matB_local_st_payload); - store_idx++; - + SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); - xetla_fence(); - nbarrier_a.arrive(); - if constexpr (arch_has_named_barrier) - nbarrier_b.arrive(); + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); #pragma unroll for (uint32_t i = 1; i < num_cyclic - 1; i++) { tile_load(partial_matA, matA_payload); tile_load(partial_matB, matB_payload); - + SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); matA_local_st_payload.template update_tdesc( wg_size_y * matA_t::tile_size_y); + SW_BARRIER(); + tile_store(partial_matA, matA_local_st_payload); matB_local_st_payload.template update_tdesc( wg_size_x * matB_t::tile_size_y); - - tile_store(partial_matA, matA_local_st_payload); + SW_BARRIER(); tile_store(partial_matB, matB_local_st_payload); store_idx++; } + xetla_fence(); + sync_arrive_wait(); + matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x * (num_cyclic - 1)); + (num_cyclic - 1) * matA_t::tile_size_x); matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y * (num_cyclic - 1)); + (num_cyclic - 1) * matB_t::tile_size_y); #pragma unroll for (uint32_t i = 0; i < stages; i++) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } - for (uint32_t i = 0; i < args.inner_loop_count; i++) { - tile_load(partial_matA, matA_payload); - tile_load(partial_matB, matB_payload); - - matA_payload.template update_tdesc(matA_t::tile_size_x); - matB_payload.template update_tdesc(matB_t::tile_size_y); - - if constexpr (stages != 0) { - subgroup::tile_prefetch( - matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - } - - nbarrier_a.wait(); - if constexpr (arch_has_named_barrier) - nbarrier_b.wait(); - + for (uint32_t i = 0; i < args.inner_loop_count - 1; i++) { tile_load(matA, matA_local_ld_payload); tile_load(matB, matB_local_ld_payload); - load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::vnni_transform(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; if (load_idx != 0) { matA_local_ld_payload.template update_tdesc( wg_size_y * matA_t::tile_size_y); @@ -523,28 +591,18 @@ class gemm_t< matB_local_ld_payload.template update_tdesc( (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); } - xetla_fence(); + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + SW_BARRIER(); + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); if constexpr (stages != 0) { - matA_prefetch_payload.template update_tdesc( - matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); + prefetch_and_update_ab(); } - nbarrier_a.arrive(); - if constexpr (arch_has_named_barrier) - nbarrier_b.arrive(); - SW_BARRIER(); - matA_acc_t matA_acc; - matB_acc_t matB_acc; - subgroup::elemwise_cvt(matA_acc, matA); - subgroup::vnni_transform(matB_acc, matB); - pre_processing(matA_acc, matB_acc, matA, matB); - SW_BARRIER(); - tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); - SW_BARRIER(); - + store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; if (store_idx != 0) { matA_local_st_payload.template update_tdesc( wg_size_y * matA_t::tile_size_y); @@ -556,15 +614,23 @@ class gemm_t< matB_local_st_payload.template update_tdesc( (1 - num_cyclic) * wg_size_x * matB_t::tile_size_y); } - tile_store(partial_matA, matA_local_st_payload); tile_store(partial_matB, matB_local_st_payload); - store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; + + xetla_fence(); + sync_arrive_wait(); } + + SW_BARRIER(); + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + SW_BARRIER(); + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::vnni_transform(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); - nbarrier_a.wait(); - if constexpr (arch_has_named_barrier) - nbarrier_b.wait(); } private: diff --git a/include/group/tile_shape.hpp b/include/group/tile_shape.hpp index 1adf68da4..79d03b10c 100644 --- a/include/group/tile_shape.hpp +++ b/include/group/tile_shape.hpp @@ -35,15 +35,18 @@ template < uint32_t sg_tile_size_x_, uint32_t sg_tile_size_y_> struct tile_shape_t { - static constexpr uint32_t wg_tile_size_x = wg_tile_size_x_; - static constexpr uint32_t wg_tile_size_y = wg_tile_size_y_; static constexpr uint32_t sg_tile_size_x = sg_tile_size_x_; static constexpr uint32_t sg_tile_size_y = sg_tile_size_y_; - static constexpr uint32_t wg_size_x = - (wg_tile_size_x + sg_tile_size_x - 1) / sg_tile_size_x; + (wg_tile_size_x_ + sg_tile_size_x - 1) / sg_tile_size_x; static constexpr uint32_t wg_size_y = - (wg_tile_size_y + sg_tile_size_y - 1) / sg_tile_size_y; + (wg_tile_size_y_ + sg_tile_size_y - 1) / sg_tile_size_y; + + static constexpr uint32_t wg_tile_size_x = wg_size_x * sg_tile_size_x; + static constexpr uint32_t wg_tile_size_y = wg_size_y * sg_tile_size_y; + + static_assert(wg_tile_size_x % sg_tile_size_x == 0); + static_assert(wg_tile_size_y % sg_tile_size_y == 0); using work_group_t = work_group_t; }; diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 189514808..44b9c3a18 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -314,6 +314,12 @@ struct param_adaptor static constexpr auto mem_alignment_b = param::template find_elem_v; + using ld_align_attr = group::check_block_2d_pitch_alignment< + dtype_a, + dtype_b, + mem_alignment_a, + mem_alignment_b, + base_t::gpu_arch_tag>; using compute_attr = group::compute_attr_t; @@ -327,12 +333,7 @@ struct param_adaptor elem_t_t< mma_engine::xmx, typename std::conditional< - (group::detail::check_2d_block_pitch_alignment< - dtype_a, - dtype_b, - mem_alignment_a, - mem_alignment_b, - base_t::gpu_arch_tag>::value), + (ld_align_attr::value), group::compute_policy_default_xmx< compute_attr, perf_tuning_knob, @@ -344,17 +345,16 @@ struct param_adaptor elem_t_t< mma_engine::fpu, typename std::conditional< - (group::detail::check_2d_block_pitch_alignment< - dtype_a, - dtype_b, - mem_alignment_a, - mem_alignment_b, - base_t::gpu_arch_tag>::value), + (ld_align_attr::value), group::compute_policy_default_fpu< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag>, - void>::type>>::template find_elem_t::type; + group::compute_policy_unaligned_fpu< + compute_attr, + perf_tuning_knob, + base_t::gpu_arch_tag>>::type>>:: + template find_elem_t::type; using mem_desc_input_a = mem_desc_t; diff --git a/include/subgroup/tile/api.hpp b/include/subgroup/tile/api.hpp index 1e654f7f0..a57c045b0 100644 --- a/include/subgroup/tile/api.hpp +++ b/include/subgroup/tile/api.hpp @@ -82,8 +82,8 @@ struct tile_desc_t { static constexpr uint32_t tile_size_x = tile_size_x_; static constexpr uint32_t tile_size_y = tile_size_y_; - static constexpr uint32_t block_size_x = block_size_x_; - static constexpr uint32_t block_size_y = block_size_y_; + static constexpr uint32_t block_size_x = (tile_size_x > block_size_x_) ? block_size_x_ : tile_size_x; + static constexpr uint32_t block_size_y = (tile_size_y > block_size_y_) ? block_size_y_ : tile_size_y; static constexpr uint32_t remained_size_y = tile_size_y % block_size_y; static constexpr reg_layout register_layout = reg_layout_; diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index 9cf308bb0..c0a557f83 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -19,6 +19,7 @@ #pragma once +#include #include namespace gpu::xetla::subgroup { @@ -37,7 +38,10 @@ struct tile_mma_t< matA_t_, mma_engine::fpu, arch_tag_, - std::enable_if_t>> { + std::enable_if_t< + arch_has_fpu && + matA_t_::register_layout == reg_layout::transpose_tiled && + matB_t_::register_layout == reg_layout::tiled>> { using matA_t = matA_t_; using matB_t = matB_t_; using matSrc_t = matAcc_src_t_; @@ -47,16 +51,6 @@ struct tile_mma_t< using dtype_src = typename matSrc_t::dtype; using dtype_dst = typename matDst_t::dtype; - using register_attr = - typename arch_attr_t::template register_attr<>; - - static_assert( - matA_t::reg_transpose, - "For FMAOp GEMM, the register layout of matA should be col-major"); - static_assert( - !matB_t::reg_transpose, - "For FMAOp GEMM, the register layout of matB should be row-major"); - static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; static constexpr uint32_t a_tile_elems = matA_t::tile_elems; @@ -110,13 +104,17 @@ struct tile_mma_t< __XETLA_API static void mma_core( xetla_vector_ref __REF__ dst, xetla_vector_ref __REF__ src, - xetla_vector_ref __REF__ b_block, - xetla_vector_ref __REF__ a_block) { + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { constexpr uint32_t blk_m_iters = blk_m / mma_m; constexpr uint32_t tail_m = blk_m % mma_m; auto dst_blk_2d = dst.xetla_format(); auto src_blk_2d = src.xetla_format(); - auto b_blk_2d = b_block.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + xetla_vector new_a_block = + xetla_cvt( + a_block.xetla_select(0)); + if constexpr (blk_m_iters > 0) { #pragma unroll for (uint32_t i = 0; i < blk_m_iters; i++) { @@ -124,20 +122,22 @@ struct tile_mma_t< auto dst_tmp_2d = dst_tmp.xetla_format(); #pragma unroll for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { - dst_tmp_2d.row(i_acc) = a_block[i_acc + i * mma_m] * b_blk_2d.row(0) + - src_blk_2d.row(i_acc + i * mma_m); + dst_tmp_2d.row(i_acc) = + new_a_block[i_acc + i * mma_m] * b_blk_2d.row(0) + + src_blk_2d.row(i_acc + i * mma_m); } #pragma unroll for (uint32_t k = 1; k < blk_k - 1; k++) { for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { int a_offset = k * blk_m + i_acc + i * mma_m; - dst_tmp_2d.row(i_acc) += a_block[a_offset] * b_blk_2d.row(k); + dst_tmp_2d.row(i_acc) += new_a_block[a_offset] * b_blk_2d.row(k); } } for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m; dst_blk_2d.row(i_acc + i * mma_m) = - a_block[a_offset] * b_blk_2d.row(blk_k - 1) + dst_tmp_2d.row(i_acc); + new_a_block[a_offset] * b_blk_2d.row(blk_k - 1) + + dst_tmp_2d.row(i_acc); } SW_BARRIER(); } @@ -150,20 +150,21 @@ struct tile_mma_t< #pragma unroll for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { dst_tmp_2d.row(i_acc) = - a_block[i_acc + tail_start_m] * b_blk_2d.row(0) + + new_a_block[i_acc + tail_start_m] * b_blk_2d.row(0) + src_blk_2d.row(i_acc + tail_start_m); } #pragma unroll for (uint32_t k = 1; k < blk_k - 1; k++) { for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { int a_offset = k * blk_m + i_acc + tail_start_m; - dst_tmp_2d.row(i_acc) += a_block[a_offset] * b_blk_2d.row(k); + dst_tmp_2d.row(i_acc) += new_a_block[a_offset] * b_blk_2d.row(k); } } for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { int a_offset = (blk_k - 1) * blk_m + i_acc + tail_start_m; dst_blk_2d.row(i_acc + tail_start_m) = - a_block[a_offset] * b_blk_2d.row(blk_k - 1) + dst_tmp_2d.row(i_acc); + new_a_block[a_offset] * b_blk_2d.row(blk_k - 1) + + dst_tmp_2d.row(i_acc); } } } @@ -173,23 +174,258 @@ struct tile_mma_t< matSrc_t& src, matB_t& b, matA_t& a) { + constexpr auto b_reg_sizes = b_block_size_y * b_tile_size_x; + { // k_blk=0 - auto b_reg = b.reg.xetla_select(0); + xetla_vector b_reg = + xetla_cvt( + b.reg.xetla_select(0)); + if constexpr (tile_size_m >= block_size_m) { #pragma unroll for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { auto a_block = a.reg.xetla_select( - i * num_block_k * a_block_elems); + i * num_block_k * a_block_elems); #pragma unroll for (uint32_t j = 0; j < num_block_n; j++) { auto b_block = - b_reg.xetla_select(j * b_block_elems); + b_reg.xetla_select(j * b_block_elems); auto src_block = src.reg.xetla_select( - (i * num_block_n + j) * block_elems); + (i * num_block_n + j) * block_elems); auto dst_block = dst.reg.xetla_select( - (i * num_block_n + j) * block_elems); + (i * num_block_n + j) * block_elems); mma_core( + dst_block, src_block, b_block, a_block); + } + } + } + + // process the tail + if constexpr ((tile_size_m % block_size_m) != 0) { + constexpr uint32_t tail_start_m = + tile_size_m / block_size_m * block_size_m; + constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; + constexpr uint32_t a_tail_blk_elems = a_block_size_h * a_tail_blk_w; + constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; + constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; + auto a_block = a.reg.xetla_select( + a_tile_size_x * tail_start_m); +#pragma unroll + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = + b_reg.xetla_select(j * b_block_elems); + auto src_block = src.reg.xetla_select( + (tail_start_m * tile_size_n) + j * acc_tail_blk_elems); + auto dst_block = dst.reg.xetla_select( + (tail_start_m * tile_size_n) + j * acc_tail_blk_elems); + mma_core( dst_block, src_block, b_block, a_block); + } + } + } + // different K block +#pragma unroll + for (uint32_t k_i = 1; k_i < num_block_k; k_i++) { + xetla_vector b_reg = + xetla_cvt( + b.reg.xetla_select( + k_i * b_block_size_y * b_tile_size_x)); +#pragma unroll + for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { + auto a_block = a.reg.xetla_select( + (i * num_block_k + k_i) * a_block_elems); +#pragma unroll + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = + b_reg.xetla_select(j * b_block_elems); + auto dst_block = dst.reg.xetla_select( + (i * num_block_n + j) * block_elems); + mma_core( + dst_block, dst_block, b_block, a_block); + } + } + // process the tail + if constexpr ((tile_size_m % block_size_m) != 0) { + constexpr uint32_t tail_start_m = + tile_size_m / block_size_m * block_size_m; + constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; + constexpr uint32_t a_tail_blk_elems = a_block_size_h * a_tail_blk_w; + constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; + constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; + auto a_block = a.reg.xetla_select( + a_tile_size_x * tail_start_m + k_i * a_tail_blk_elems); +#pragma unroll + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = + b_reg.xetla_select(j * b_block_elems); + auto dst_block = dst.reg.xetla_select( + (tail_start_m * tile_size_n) + j * acc_tail_blk_elems); + mma_core( + dst_block, dst_block, b_block, a_block); + } + } + } + } +}; + +template < + typename matAcc_dst_t_, + typename matAcc_src_t_, + typename matB_t_, + typename matA_t_, + gpu_arch arch_tag_> +struct tile_mma_t< + matAcc_dst_t_, + matAcc_src_t_, + matB_t_, + matA_t_, + mma_engine::fpu, + arch_tag_, + std::enable_if_t< + arch_has_fpu && + matA_t_::register_layout == reg_layout::tiled && + matB_t_::register_layout == reg_layout::tiled>> { + using matA_t = matA_t_; + using matB_t = matB_t_; + using matSrc_t = matAcc_src_t_; + using matDst_t = matAcc_dst_t_; + using dtype_a = typename matA_t::dtype; + using dtype_b = typename matB_t::dtype; + using dtype_src = typename matSrc_t::dtype; + using dtype_dst = typename matDst_t::dtype; + + static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; + static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; + static constexpr uint32_t a_tile_elems = matA_t::tile_elems; + static constexpr uint32_t a_block_size_w = matA_t::block_size_y; + static constexpr uint32_t a_block_size_h = matA_t::block_size_x; + static constexpr uint32_t a_block_elems = matA_t::block_elems; + + static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x; + static constexpr uint32_t b_tile_size_y = matB_t::tile_size_y; + static constexpr uint32_t b_tile_elems = matB_t::tile_elems; + static constexpr uint32_t b_block_size_x = matB_t::block_size_x; + static constexpr uint32_t b_block_size_y = matB_t::block_size_y; + static constexpr uint32_t b_block_elems = matB_t::block_elems; + + static constexpr uint32_t tile_size_m = matDst_t::tile_size_y; + static constexpr uint32_t tile_size_k = a_tile_size_x; + static constexpr uint32_t tile_size_n = matDst_t::tile_size_x; + static constexpr uint32_t tile_elems = tile_size_m * tile_size_n; + static constexpr uint32_t block_size_n = matDst_t::block_size_x; + static constexpr uint32_t block_size_k = a_block_size_h; + static constexpr uint32_t block_size_m = matDst_t::block_size_y; + static constexpr uint32_t block_elems = block_size_m * block_size_n; + + static_assert( + tile_size_m == matA_t::tile_size_y, + "matAcc tile m should match with matA tile m"); + static_assert( + a_tile_size_x == b_tile_size_y, + "matA tile k should match with matB tile k"); + static_assert( + tile_size_n == matB_t::tile_size_x, + "matAcc tile n should match with matB tile n"); + static_assert( + block_size_m == a_block_size_w, + "matAcc block m should match with matA block m"); + static_assert( + block_size_n == b_block_size_x, + "matAcc block n should match with matB block n"); + static_assert( + (tile_size_k % block_size_k) == 0, + "matAcc tile_size_k should be a multiple of block_size_k"); + + static constexpr int32_t num_block_n = matDst_t::num_block_x; + static constexpr int32_t num_block_m = matDst_t::num_block_y; + static constexpr int32_t num_block_k = tile_size_k / block_size_k; + + using mma_attr = mma_attr_t; + static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; + + template + __XETLA_API static void mma_core( + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + constexpr uint32_t blk_m_iters = blk_m / mma_m; + constexpr uint32_t tail_m = blk_m % mma_m; + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + xetla_vector new_a_block = + xetla_cvt( + a_block.xetla_select(0)); + + if constexpr (blk_m_iters > 0) { +#pragma unroll + for (uint32_t i = 0; i < blk_m_iters; i++) { +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + dst_blk_2d.row(i_acc + i * mma_m) = src_blk_2d.row(i_acc + i * mma_m); + } + int32_t a_start_off = i * mma_m * blk_k; +#pragma unroll + for (uint32_t k = 0; k < blk_k; k++) { +#pragma unroll + for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) { + int a_offset = a_start_off + i_acc * blk_k + k; + dst_blk_2d.row(i_acc + i * mma_m) += + new_a_block[a_offset] * b_blk_2d.row(k); + } + } + SW_BARRIER(); + } + } + + if constexpr (tail_m != 0) { + constexpr uint32_t tail_start_m = blk_m_iters * mma_m; +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + dst_blk_2d.row(i_acc + tail_start_m) = + src_blk_2d.row(i_acc + tail_start_m); + } + int32_t a_start_off = tail_start_m * blk_k; +#pragma unroll + for (uint32_t k = 0; k < blk_k; k++) { +#pragma unroll + for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) { + int a_offset = a_start_off + i_acc * blk_k + k; + dst_blk_2d.row(i_acc + tail_start_m) += + new_a_block[a_offset] * b_blk_2d.row(k); + } + } + } + } + + __XETLA_API static void mma( + matDst_t& dst, + matSrc_t& src, + matB_t& b, + matA_t& a) { + constexpr auto b_reg_sizes = b_block_size_y * b_tile_size_x; + + { // k_blk=0 + xetla_vector b_reg = + xetla_cvt( + b.reg.xetla_select(0)); + + if constexpr (tile_size_m >= block_size_m) { +#pragma unroll + for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { + auto a_block = a.reg.xetla_select( + i * num_block_k * a_block_elems); +#pragma unroll + for (uint32_t j = 0; j < num_block_n; j++) { + auto b_block = + b_reg.xetla_select(j * b_block_elems); + auto src_block = src.reg.xetla_select( + (i * num_block_n + j) * block_elems); + auto dst_block = dst.reg.xetla_select( + (i * num_block_n + j) * block_elems); + mma_core( + dst_block, src_block, b_block, a_block); } } } @@ -220,8 +456,10 @@ struct tile_mma_t< // different K block #pragma unroll for (uint32_t k_i = 1; k_i < num_block_k; k_i++) { - auto b_reg = b.reg.xetla_select( - k_i * b_block_size_y * b_tile_size_x); + xetla_vector b_reg = + xetla_cvt( + b.reg.xetla_select( + k_i * b_block_size_y * b_tile_size_x)); #pragma unroll for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) { auto a_block = a.reg.xetla_select( diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 8e059d3b3..27f45a92a 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -110,8 +110,7 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr uint32_t elems_per_CL = load_store_attr::cache_line_size_in_bytes / sizeof(dtype); static constexpr uint32_t elems_per_reg = - register_bytes_t::reg_in_bytes / - sizeof(dtype); + register_bytes_t::reg_in_bytes / sizeof(dtype); static constexpr int32_t max_load_block_height = load_store_attr::max_load_height_in_elem; static constexpr int32_t max_block_width = @@ -630,26 +629,32 @@ tile_load( constexpr bool oob_check = std::is_same::value; using dtype = typename payload_t::dtype; - using tile_desc = typename payload_t::tile_desc; using load_dtype = typename payload_t::mem_dtype; constexpr uint32_t num_channel_y = payload_t::num_channel_y; constexpr uint32_t load_elems = num_channel_y * payload_t::num_channel_x; constexpr uint32_t scale_factor = payload_t::scale_factor; + using tile_desc = typename tile_t::tile_desc; + static constexpr uint32_t block_elems = tile_desc::block_elems; + static constexpr uint32_t block_size_x = tile_desc::block_size_x; + static constexpr uint32_t num_block_x = tile_desc::num_block_x; + static constexpr uint32_t block_size_y = tile_desc::block_size_y; + static constexpr uint32_t num_block_y = tile_desc::num_block_y; + static constexpr bool reg_transpose = tile_desc::reg_transpose; + #pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; - i++) { - uint32_t offset_y = i * tile_desc::block_size_y; + for (uint32_t i = 0; i < num_block_y; i++) { + uint32_t offset_y = i * block_size_y; #pragma unroll - for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { - uint32_t offset_x = j * tile_desc::block_size_x; - auto reg_sub = tile.reg.xetla_select( - (i * tile_desc::num_block_x + j) * tile_desc::block_elems); + for (uint32_t j = 0; j < num_block_x; j++) { + uint32_t offset_x = j * block_size_x; + auto reg_sub = tile.reg.xetla_select( + (i * num_block_x + j) * block_elems); xetla_mask pred_x = oob_check ? payload.step_x + payload.base_x + offset_x < payload.width_in_elems : 1; #pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; + for (uint32_t sub_block_y = 0; sub_block_y < block_size_y; sub_block_y += num_channel_y) { xetla_vector reg_tmp; xetla_mask pred_y = oob_check @@ -657,7 +662,7 @@ tile_load( payload.height_in_elems : 1; - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + sub_block_y) * sizeof(dtype) : offset_x * sizeof(dtype) + @@ -677,21 +682,30 @@ tile_load( reg_sub .xetla_select( - sub_block_y * tile_desc::block_size_x) + sub_block_y * block_size_x) .xetla_format() = reg_tmp; } + if constexpr (reg_transpose) { + xetla_vector trans_blk; +#pragma unroll + for (uint32_t y = 0; y < block_size_y; y++) { + trans_blk.xetla_select(y) = + reg_sub.xetla_select(y * block_size_x); + } + reg_sub = trans_blk; + } } } // process the tail - if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) { + if constexpr (tile_desc::remained_size_y != 0) { constexpr uint32_t remained_size_y = tile_desc::remained_size_y; constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y; constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x; constexpr uint32_t remain_block_elems = remained_size_y * tile_desc::block_size_x; #pragma unroll - for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { - uint32_t offset_x = j * tile_desc::block_size_x; + for (uint32_t j = 0; j < num_block_x; j++) { + uint32_t offset_x = j * block_size_x; auto reg_sub = tile.reg.xetla_select( processed_elems + j * remain_block_elems); xetla_mask pred_x = oob_check @@ -706,7 +720,7 @@ tile_load( payload.height_in_elems : 1; - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + sub_block_y) * sizeof(dtype) : offset_x * sizeof(dtype) + @@ -730,6 +744,15 @@ tile_load( sub_block_y * tile_desc::block_size_x) .xetla_format() = reg_tmp; } + if constexpr (reg_transpose) { + xetla_vector trans_blk; +#pragma unroll + for (uint32_t y = 0; y < remained_size_y; y++) { + trans_blk.xetla_select(y) = + reg_sub.xetla_select(y * block_size_x); + } + reg_sub = trans_blk; + } } } diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 85e83b45b..1e848a6dd 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -197,6 +197,7 @@ __XETLA_API native_type_t, remain_move_rows, remain_move_cols>(); +#pragma unroll for (int vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { reg_dst_2d.xetla_select( 0, vnni_i) = diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index e990bfaf6..040d055c9 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -896,7 +896,9 @@ struct mem_payload_t< static constexpr uint32_t num_channel_x = block_size_x * sizeof(dtype) / sizeof(mem_dtype); + static_assert(num_channel_x <= num_channel); static constexpr uint32_t num_channel_y = num_channel / num_channel_x; + static_assert(num_channel_x * num_channel_y == num_channel); xetla_vector channel_offset; xetla_vector step_x; @@ -926,7 +928,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -952,7 +954,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -972,7 +974,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -998,7 +1000,7 @@ struct mem_payload_t< xetla_vector_gen(0, 1); step_x = channel_index % num_channel_x; step_y = channel_index / num_channel_x; - channel_offset = trans + channel_offset = mem_transpose ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes; } @@ -1792,6 +1794,25 @@ struct prefetch_payload_t< xetla_vector_gen(0, 1); channel_offset = channel_index * pitch_in_bytes; } + + inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { + uint32_t coop_id_x = coop_id % num_coop_sg_w; + uint32_t coop_id_y = coop_id / num_coop_sg_w; + + pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype); + base_x = mem_desc.coord.x + coop_id_x * tile_size_w; + base_y = mem_desc.coord.y + coop_id_y * tile_size_h; + width_in_elems = mem_desc.shape.x; + height_in_elems = mem_desc.shape.y; + base_offset = mem_transpose + ? base_x * pitch_in_bytes + base_y * sizeof(dtype) + : base_y * pitch_in_bytes + base_x * sizeof(dtype); + base_ptr = reinterpret_cast(mem_desc.base.base); + + xetla_vector channel_index = + xetla_vector_gen(0, 1); + channel_offset = channel_index * pitch_in_bytes; + } // Be aware of the risks: Rule of three (copy constructor, copy // assignment, destructor) Please check if you need to add self-define // destructor ~prefetch_payload_t(){} @@ -1940,6 +1961,17 @@ struct prefetch_payload_t< prepare_tdesc(base_tdesc); } + inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { + xetla_tdescriptor base_tdesc = mem_desc.get_tdesc(); + uint32_t coop_id_x = coop_id % num_coop_sg_w; + uint32_t coop_id_y = coop_id / num_coop_sg_w; + xetla_update_tdesc_offsetx( + base_tdesc.xetla_format(), coop_id_x * tile_size_w); + xetla_update_tdesc_offsety( + base_tdesc.xetla_format(), coop_id_y * tile_size_h); + prepare_tdesc(base_tdesc); + } + inline void init(xetla_tdescriptor base_tdesc, uint32_t coop_id = 0) { uint32_t coop_id_x = coop_id % num_coop_sg_w; uint32_t coop_id_y = coop_id / num_coop_sg_w; @@ -2198,6 +2230,16 @@ struct prefetch_payload_t< base_ptr = (prefetch_dtype*)p + (coop_id % num_coop_sg) * mem_tile_size_x; } + inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { + pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype); + uint32_t offset_x = mem_desc.coord.x; + uint32_t offset_y = mem_desc.coord.y; + base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + uint64_t ptr_temp = (uint64_t)mem_desc.base.base; + base_ptr = + (prefetch_dtype*)ptr_temp + (coop_id % num_coop_sg) * mem_tile_size_x; + } + template __XETLA_API void update_tdesc(int offset) { if constexpr (update_dir == tdesc_update_dir::x_dir) { @@ -2235,6 +2277,8 @@ struct prefetch_payload_t< static constexpr mem_layout memory_layout = mem_layout_; static constexpr gpu_arch arch_tag = arch_tag_; + inline prefetch_payload_t() = default; + inline prefetch_payload_t( [[maybe_unused]] mem_desc_t& mem_desc, [[maybe_unused]] uint32_t coop_id = 0) {} @@ -2248,6 +2292,10 @@ struct prefetch_payload_t< [[maybe_unused]] int surface_offset_y, [[maybe_unused]] uint32_t coop_id = 0) {} + inline void init( + [[maybe_unused]] mem_desc_t& mem_desc, + [[maybe_unused]] uint32_t coop_id = 0) {} + template __XETLA_API void update_tdesc([[maybe_unused]] int offset) {} }; diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 324fd57de..8103db73f 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -355,8 +355,7 @@ tile_store( constexpr uint32_t scale_factor = payload_t::scale_factor; #pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; - i++) { + for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { uint32_t offset_y = i * tile_desc::block_size_y; #pragma unroll for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { @@ -394,7 +393,7 @@ tile_store( } } // process the tail - if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) { + if constexpr (tile_desc::remained_size_y != 0) { constexpr uint32_t remained_size_y = tile_desc::remained_size_y; constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y; constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x; @@ -722,19 +721,34 @@ tile_store( uint64_t address_offset = offset_x * sizeof(dtype) + (sub_block_y + offset_y) * payload.pitch_in_bytes; - xetla_tatomic_store_global< - dtype, - payload_t::num_channel, - L1, - L2, - op_kind, - payload_t::arch_tag, - typename payload_t::Toffset>( - (uint64_t)payload.base_pointer + address_offset, - payload.channel_offset, - reg_sub.xetla_select( - sub_block_y * block_size_x), - pred_x & pred_y); + if constexpr (arch_has_2d_load_store) { + xetla_tatomic_store_global< + dtype, + payload_t::num_channel, + L1, + L2, + op_kind, + payload_t::arch_tag, + typename payload_t::Toffset>( + (uint64_t)payload.base_pointer + address_offset, + payload.channel_offset, + reg_sub.xetla_select( + sub_block_y * block_size_x), + pred_x & pred_y); + } else { + xetla_atomic_global< + op_kind, + dtype, + payload_t::num_channel, + data_size::default_size, + L1, + L2>( + reinterpret_cast(payload.base_pointer + address_offset), + payload.channel_offset, + reg_sub.xetla_select( + sub_block_y * block_size_x), + pred_x & pred_y); + } } } } @@ -1020,7 +1034,8 @@ tile_store(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t j = 0; j < store_iter_steps; j++) { uint32_t offset_x = j * max_store_vec_len * scale_factor; - auto reg_sub = tile.reg.xetla_select(offset_x); + auto reg_sub = + tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); xetla_store_local( payload.address + address_offset, diff --git a/tests/integration/gemm/bf16/common.hpp b/tests/integration/gemm/bf16/common.hpp index 8110d6d16..73f117bde 100644 --- a/tests/integration/gemm/bf16/common.hpp +++ b/tests/integration/gemm/bf16/common.hpp @@ -45,11 +45,32 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } + static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; + // static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t wg_num_n = 64; +}; + +class TestBaseBF16x : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; static constexpr mma_engine engine = mma_engine::xmx; - static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; }; -class Test0 : public TestBase { +class TestBaseBF16f : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class Test0x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -59,17 +80,11 @@ class Test0 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test1 : public TestBase { +class Test1x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -79,17 +94,11 @@ class Test1 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test2 : public TestBase { +class Test2x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -99,17 +108,11 @@ class Test2 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test3 : public TestBase { +class Test3x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -119,17 +122,11 @@ class Test3 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test4 : public TestBase { +class Test4x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -139,8 +136,6 @@ class Test4 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -148,18 +143,22 @@ class Test4 : public TestBase { using data_type_c = float; using data_type_acc = float; }; -class Test5 : public TestBase { + +class Test5x : public TestBaseBF16x { public: static constexpr size_t mat_m = 192; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 48; - static constexpr size_t wg_n = 80; - static constexpr size_t sg_m = 24; + // static constexpr size_t wg_m = 48; + // If want ot allow any kind of wg_m and wg_n instead of the power of 2 + // DG2 still need check workgroup oob on both direction by using block_1d load + static constexpr size_t wg_m = 64; + // static constexpr size_t wg_n = 80; + static constexpr size_t wg_n = 128; + // static constexpr size_t sg_m = 24; + static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -168,18 +167,18 @@ class Test5 : public TestBase { using data_type_acc = float; }; -class Test6 : public TestBase { +class Test6x : public TestBaseBF16x { public: static constexpr size_t mat_m = 96; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 40; + // static constexpr size_t wg_m = 40; + static constexpr size_t wg_m = 64; static constexpr size_t wg_n = 256; - static constexpr size_t sg_m = 24; + // static constexpr size_t sg_m = 24; + static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -187,7 +186,8 @@ class Test6 : public TestBase { using data_type_c = float; using data_type_acc = float; }; -class Test7 : public TestBase { + +class Test7x : public TestBaseBF16x { public: static constexpr size_t mat_m = 80; static constexpr size_t mat_n = 256; @@ -197,17 +197,11 @@ class Test7 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = float; - using data_type_acc = float; }; -class Test8 : public TestBase { +class Test8x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -217,7 +211,6 @@ class Test8 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; static constexpr uint32_t global_kslicing = 2; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; @@ -227,7 +220,7 @@ class Test8 : public TestBase { using data_type_acc = float; }; -class Test9 : public TestBase { +class Test9x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -237,7 +230,10 @@ class Test9 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 2; + // static constexpr uint32_t local_kslicing = 2; + // Look like local_kslicing will fail on DG2 + static constexpr uint32_t local_kslicing = 1; + // global_kslicing work for aligned case on DG2 static constexpr uint32_t global_kslicing = 4; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; @@ -247,7 +243,7 @@ class Test9 : public TestBase { using data_type_acc = float; }; -class Test10 : public TestBase { +class Test10x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -257,8 +253,8 @@ class Test10 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 4; - static constexpr uint32_t global_kslicing = 1; + // static constexpr uint32_t local_kslicing = 4; + static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = bf16; @@ -267,7 +263,7 @@ class Test10 : public TestBase { using data_type_acc = float; }; -class Test11 : public TestBase { +class Test11x : public TestBaseBF16x { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -278,13 +274,108 @@ class Test11 : public TestBase { static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; static constexpr uint32_t local_kslicing = 16; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; +}; + +class Test12x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 13824; + static constexpr size_t mat_k = 5120; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test13x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 8192; + static constexpr size_t mat_n = 8192; + static constexpr size_t mat_k = 8192; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test14x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test15x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 28672; + static constexpr size_t mat_k = 8192; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test16x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 12288; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 8; // wg_m = 4 will fail on DG2 + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test17x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 3072; + static constexpr size_t mat_n = 3072; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test18x : public TestBaseBF16x { // Get better perf on DG2, ~15.48 TFlops + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 512; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; }; template @@ -325,4 +416,6 @@ using bf16_gemm_func = bf16_gemm_test_func< Test::layout_b, Test::global_kslicing, Test::local_kslicing, - Test::engine>; + Test::wg_num_n, + Test::engine, + Test::gpu_arch>; diff --git a/tests/integration/gemm/bf16/kernel_func.hpp b/tests/integration/gemm/bf16/kernel_func.hpp index 432172f9a..6345047df 100644 --- a/tests/integration/gemm/bf16/kernel_func.hpp +++ b/tests/integration/gemm/bf16/kernel_func.hpp @@ -37,11 +37,15 @@ template < mem_layout layout_b, uint32_t global_kslicing, uint32_t local_kslicing, - mma_engine engine> + uint32_t wg_num_n, + mma_engine engine, + gpu_arch arch_tag> struct bf16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 8; - static constexpr uint32_t prefetch_distance = 3; + static constexpr uint32_t periodic_sync_interval = + (arch_tag == gpu_arch::XeHpc ? 8 : 0); + static constexpr uint32_t prefetch_distance = + (arch_tag == gpu_arch::XeHpc ? 3 : 0); using gemm_t = typename gemm_selector_t< dtype_a, dtype_b, @@ -55,17 +59,17 @@ struct bf16_gemm_test_func { tile_shape, sg_k, engine, - gpu_arch::XeHpc, + arch_tag, prefetch_distance, periodic_sync_interval>::gemm; using epilogue_t = epilogue_t< - epilogue_policy_default, + epilogue_policy_default, tile_shape, mem_desc_t>; - using group_swizzle = - gpu::xetla::kernel::group_swizzle_default; + using group_swizzle = gpu::xetla::kernel::group_swizzle_default; + // using group_swizzle = kernel::group_swizzle_snake; using dispatch_policy = dispatch_policy_kslicing; diff --git a/tests/integration/gemm/bf16/main.cpp b/tests/integration/gemm/bf16/main.cpp index e383ab205..45ffec38e 100644 --- a/tests/integration/gemm/bf16/main.cpp +++ b/tests/integration/gemm/bf16/main.cpp @@ -33,16 +33,24 @@ TYPED_TEST_P(bf16_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(bf16_gemm_test, esimd); using tests = ::testing::Types< - Test0, - Test1, - Test2, - Test3, - Test4, - Test5, - Test6, - Test7, - Test8, - Test9, - Test10, - Test11>; -INSTANTIATE_TYPED_TEST_SUITE_P(bf16_gemm_test_suite, bf16_gemm_test, tests); \ No newline at end of file + Test0x, + Test1x, + Test2x, + Test3x, + Test4x, + Test5x, + Test6x, + Test7x, + Test8x, + Test9x, + Test10x, + Test11x, + Test12x, + Test13x, + Test14x, + Test15x, + Test16x, + Test17x, + Test18x>; + +INSTANTIATE_TYPED_TEST_SUITE_P(bf16_gemm_test_suite, bf16_gemm_test, tests); diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 3674ee922..cc7318f7f 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -45,24 +45,24 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } - static constexpr mma_engine engine = mma_engine::fpu; static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; + // static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t wg_num_n = 64; }; -class Test0x : public TestBase { +class TestBaseFP16f : public TestBase { + public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class TestBaseFP16x : public TestBase { public: - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 16; - static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_a = fp16; using data_type_b = fp16; using data_type_c = fp16; @@ -70,7 +70,7 @@ class Test0x : public TestBase { static constexpr mma_engine engine = mma_engine::xmx; }; -class Test0f : public TestBase { +class Test0x : public TestBaseFP16x { public: static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; @@ -80,18 +80,11 @@ class Test0f : public TestBase { static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test1 : public TestBase { +class Test0f : public TestBaseFP16f { public: static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; @@ -100,56 +93,72 @@ class Test1 : public TestBase { static constexpr size_t wg_n = 16; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr size_t sg_k = 32; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; + static constexpr mem_layout layout_b = mem_layout::col_major; }; -class Test2 : public TestBase { + +class Test1f : public TestBaseFP16f { public: - static constexpr size_t mat_m = 256; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 250880; + static constexpr size_t mat_k = 1792; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; + static constexpr size_t wg_n = 2048; static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test2f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096 * 3; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 128; + static constexpr size_t sg_m = 4; static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; + static constexpr size_t sg_k = 32; + static constexpr uint32_t local_kslicing = 4; static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; }; -class Test3 : public TestBase { + +class Test2fx1 : public TestBaseFP16f { public: - static constexpr size_t mat_m = 256; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 1; + static constexpr size_t mat_m = 48; // mat_m = 32 will fail + static constexpr size_t mat_n = 4096 * 3; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 64; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr size_t sg_k = 32; + // static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test3f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 16384; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 64; + static constexpr uint32_t local_kslicing = 8; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; + static constexpr mem_layout layout_b = mem_layout::row_major; }; -class Test4f : public TestBase { +class Test4f : public TestBaseFP16f { public: static constexpr size_t mat_m = 1024; static constexpr size_t mat_n = 4096; @@ -159,18 +168,11 @@ class Test4f : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test4x : public TestBase { +class Test4x : public TestBaseFP16x { public: static constexpr size_t mat_m = 1024; static constexpr size_t mat_n = 4096; @@ -180,38 +182,39 @@ class Test4x : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::xmx; }; -class Test5 : public TestBase { +class Test4x1 : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 16 * 2; + static constexpr size_t wg_n = 32 * 16; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test5f : public TestBaseFP16f { public: static constexpr size_t mat_m = 1024; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; static constexpr size_t wg_m = 32; - static constexpr size_t wg_n = 32 * 4; - static constexpr size_t sg_m = 1; + static constexpr size_t wg_n = 32 * 8; + static constexpr size_t sg_m = 16; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test6 : public TestBase { + +class Test6f : public TestBaseFP16f { public: static constexpr size_t mat_m = 96; static constexpr size_t mat_n = 256; @@ -221,16 +224,11 @@ class Test6 : public TestBase { static constexpr size_t sg_m = 24; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test7 : public TestBase { + +class Test7f : public TestBaseFP16f { public: static constexpr size_t mat_m = 80; static constexpr size_t mat_n = 256; @@ -240,17 +238,11 @@ class Test7 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test8 : public TestBase { +class Test8f : public TestBaseFP16f { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -261,16 +253,11 @@ class Test8 : public TestBase { static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 2; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test9 : public TestBase { +class Test9f : public TestBaseFP16f { public: static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; @@ -281,16 +268,11 @@ class Test9 : public TestBase { static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 4; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test10 : public TestBase { +class Test10f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 4096; @@ -300,17 +282,11 @@ class Test10 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test11 : public TestBase { +class Test11f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 4096; @@ -320,17 +296,11 @@ class Test11 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test12 : public TestBase { +class Test12f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 16384; @@ -340,17 +310,11 @@ class Test12 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test13 : public TestBase { +class Test13f : public TestBaseFP16f { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 4096; @@ -360,17 +324,11 @@ class Test13 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test14 : public TestBase { +class Test14f : public TestBaseFP16f { public: static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 50400; @@ -380,17 +338,11 @@ class Test14 : public TestBase { static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test15 : public TestBase { +class Test15f : public TestBaseFP16f { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 4096; @@ -400,17 +352,11 @@ class Test15 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test16 : public TestBase { +class Test16x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 50400; @@ -420,17 +366,11 @@ class Test16 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = float; - using data_type_acc = float; }; -class Test17 : public TestBase { +class Test17x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 256; @@ -440,18 +380,11 @@ class Test17 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test18 : public TestBase { +class Test18x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 256; @@ -461,18 +394,11 @@ class Test18 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; -class Test19 : public TestBase { +class Test19x : public TestBaseFP16x { public: static constexpr size_t mat_m = 128; static constexpr size_t mat_n = 256; @@ -482,15 +408,8 @@ class Test19 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; template @@ -531,5 +450,6 @@ using fp16_gemm_func = fp16_gemm_test_func< Test::layout_b, Test::global_kslicing, Test::local_kslicing, + Test::wg_num_n, Test::engine, Test::gpu_arch>; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index f13956cdc..30877c12b 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -37,13 +37,15 @@ template < mem_layout layout_b, uint32_t global_kslicing, uint32_t local_kslicing, + uint32_t wg_num_n, mma_engine engine, gpu_arch gpu_arch> struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0; // 8; + static constexpr uint32_t periodic_sync_interval = + (gpu_arch == gpu_arch::XeHpc ? 8 : 0); static constexpr uint32_t prefetch_distance = - 1; // 256 / (sg_k * sizeof(dtype_a)); + (gpu_arch == gpu_arch::XeHpc ? 3 : 0); using compute_attr = typename std::conditional< (engine == mma_engine::fpu), @@ -71,6 +73,8 @@ struct fp16_gemm_test_func { mem_desc_output_c>; using group_swizzle = gpu::xetla::kernel::group_swizzle_default; + // using group_swizzle = gpu::xetla::kernel::group_swizzle_snake; using dispatch_policy = dispatch_policy_kslicing; diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index 6f9d0d490..f7e0e0c6f 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -32,24 +32,7 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); -using tests = ::testing::Types; -// Test1, -// Test2, -// Test3>; -// Test4, -// Test5, -// Test6, -// Test7, -// Test8, -// Test9, -// Test10, -// Test11, -// Test12, -// Test13, -// Test14, -// Test15, -// Test16, -// Test17, -// Test18, -// Test19>; +using tests = + ::testing::Types; + INSTANTIATE_TYPED_TEST_SUITE_P(fp16_gemm_test_suite, fp16_gemm_test, tests); diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index a8e4da602..c1775d5ac 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -259,16 +259,16 @@ class test9_xehpg { using data_type_c = fp16; }; -class test1_xelpg { +class test1_xelpg_1x12288x4096 { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096 * 3; static constexpr size_t mat_k = 4096 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; static constexpr size_t dequant_s = 128; @@ -282,6 +282,55 @@ class test1_xelpg { using data_type_b = int4x2; using data_type_c = fp16; }; + +class test1_xelpg_1x4096x11008 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 11008; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 8; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + +class test1_xelpg_4x4096x4096 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096 * 1; + static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 64; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 8; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + class test2_xelpg { public: // Extract the parameters required by different test cases @@ -1017,7 +1066,10 @@ TYPED_TEST_P(dequantize_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types< + test1_xelpg_1x12288x4096, + test1_xelpg_1x4096x11008, + test1_xelpg_4x4096x4096>; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_test_suite, diff --git a/tests/integration/gemm/unaligned_bf16/common.hpp b/tests/integration/gemm/unaligned_bf16/common.hpp index d1b45c169..213200b3a 100644 --- a/tests/integration/gemm/unaligned_bf16/common.hpp +++ b/tests/integration/gemm/unaligned_bf16/common.hpp @@ -45,11 +45,52 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } + static constexpr gpu_arch arch_tag = gpu_arch::XeHpg; + // static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t lda_alignment = 1; + static constexpr uint32_t ldb_alignment = 1; + static constexpr uint32_t ldc_alignment = 1; +}; + +class TestBaseBF16x : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::xmx; +}; + +class TestBaseBF16f : public TestBase { + public: + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class TestBaseFP16x : public TestBase { + public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; static constexpr mma_engine engine = mma_engine::xmx; - static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; }; -class Test0 : public TestBase { +class TestBaseFP16f : public TestBase { + public: + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; +}; + +class Test0x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; static constexpr size_t mat_n = 257; @@ -59,96 +100,68 @@ class Test0 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test1 : public TestBase { +class Test1x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; - static constexpr size_t mat_n = 255; - static constexpr size_t mat_k = 255; + static constexpr size_t mat_n = 1023; + static constexpr size_t mat_k = 767; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test2 : public TestBase { +class Test2x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; - static constexpr size_t mat_n = 251; - static constexpr size_t mat_k = 253; + static constexpr size_t mat_n = 1011; + static constexpr size_t mat_k = 511; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test3 : public TestBase { +class Test3x : public TestBaseBF16x { public: static constexpr size_t mat_m = 253; - static constexpr size_t mat_n = 251; - static constexpr size_t mat_k = 253; + static constexpr size_t mat_n = 767; + static constexpr size_t mat_k = 1023; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::col_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; }; -class Test4 : public TestBase { +class Test4x : public TestBaseBF16x { public: static constexpr size_t mat_m = 257; static constexpr size_t mat_n = 257; - static constexpr size_t mat_k = 259; + static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 16; static constexpr size_t wg_n = 32; static constexpr size_t sg_m = 8; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = float; - using data_type_acc = float; + static constexpr uint32_t lda_alignment = 8; }; -class Test5 : public TestBase { + +class Test5x : public TestBaseBF16x { public: static constexpr size_t mat_m = 191; static constexpr size_t mat_n = 251; @@ -158,17 +171,12 @@ class Test5 : public TestBase { static constexpr size_t sg_m = 24; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; using data_type_c = float; - using data_type_acc = float; }; -class Test6 : public TestBase { +class Test6x : public TestBaseBF16x { public: static constexpr size_t mat_m = 93; static constexpr size_t mat_n = 253; @@ -178,16 +186,12 @@ class Test6 : public TestBase { static constexpr size_t sg_m = 24; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; + static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_c = float; - using data_type_acc = float; }; -class Test7 : public TestBase { + +class Test7x : public TestBaseBF16x { public: static constexpr size_t mat_m = 80; static constexpr size_t mat_n = 251; @@ -197,17 +201,12 @@ class Test7 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; using data_type_c = float; - using data_type_acc = float; }; -class Test8 : public TestBase { +class Test8x : public TestBaseBF16x { public: static constexpr size_t mat_m = 257; static constexpr size_t mat_n = 255; @@ -217,17 +216,14 @@ class Test8 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 1; - static constexpr uint32_t global_kslicing = 2; - static constexpr mem_layout layout_a = mem_layout::row_major; + // static constexpr uint32_t global_kslicing = 2; //will compile fail on DG2 + static constexpr uint32_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::col_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; using data_type_c = float; - using data_type_acc = float; }; -class Test9 : public TestBase { +class Test9x : public TestBaseBF16x { public: static constexpr size_t mat_m = 251; static constexpr size_t mat_n = 253; @@ -237,14 +233,242 @@ class Test9 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; - static constexpr uint32_t local_kslicing = 2; - static constexpr uint32_t global_kslicing = 4; + // static constexpr uint32_t global_kslicing = 4; //will compile fail on DG2 + static constexpr uint32_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::col_major; +}; + +class Test10x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 253; + static constexpr size_t mat_k = 259; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; - using data_type_acc = float; +}; + +class Test11x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 1025; + static constexpr size_t mat_k = 256; + static constexpr size_t wg_m = 8; + static constexpr size_t wg_n = 64; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; +}; + +class Test12x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4095; + static constexpr size_t mat_n = 4097; + static constexpr size_t mat_k = 4091; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; +}; + +class Test13x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4095; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test14x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4097; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 32; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; +}; + +class Test15x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test16x : public TestBaseBF16x { // Get better perf on DG2 + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 512; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test17x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 4096; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test18x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 2560; + static constexpr size_t mat_k = 5120; + static constexpr size_t wg_m = 256; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 32; + static constexpr size_t sg_n = 64; + static constexpr size_t sg_k = 16; + static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test19x : public TestBaseBF16x { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 12288; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 8; //DG@ will fail on wg_m = 4 + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test19f : public TestBaseBF16f { + public: + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 12288; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test20f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 256; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 4; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test21x : public TestBaseFP16x { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 8; + static constexpr size_t wg_n = 128; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; +}; + +class Test21f : public TestBaseFP16f { + public: + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 128; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 16; + static constexpr uint32_t local_kslicing = 8; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr uint32_t lda_alignment = 8; + static constexpr uint32_t ldb_alignment = 8; + static constexpr uint32_t ldc_alignment = 8; }; template @@ -283,6 +507,10 @@ using unaligned_gemm_func = unaligned_gemm_test_func< Test::sg_k, Test::layout_a, Test::layout_b, + Test::lda_alignment, + Test::ldb_alignment, + Test::ldc_alignment, Test::global_kslicing, Test::local_kslicing, - Test::engine>; + Test::engine, + Test::arch_tag>; diff --git a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp index 780c092e9..d45ddc0b7 100644 --- a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp +++ b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp @@ -35,13 +35,19 @@ template < uint32_t sg_k, mem_layout layout_a, mem_layout layout_b, + uint32_t lda_alignment, + uint32_t ldb_alignment, + uint32_t ldc_alignment, uint32_t global_kslicing, uint32_t local_kslicing, - mma_engine engine> + mma_engine engine, + gpu_arch arch_tag> struct unaligned_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 8; - static constexpr uint32_t prefetch_distance = 3; + static constexpr uint32_t periodic_sync_interval = + (arch_tag == gpu_arch::XeHpc ? 8 : 0); + static constexpr uint32_t prefetch_distance = + (arch_tag == gpu_arch::XeHpc ? 3 : 0); using gemm_t = typename gemm_selector_t< dtype_a, dtype_b, @@ -49,23 +55,22 @@ struct unaligned_gemm_test_func { layout_b, mem_space::global, mem_space::global, - 1, - 1, + lda_alignment, + ldb_alignment, dtype_acc, tile_shape, sg_k, engine, - gpu_arch::XeHpc, + arch_tag, prefetch_distance, periodic_sync_interval>::gemm; using epilogue_t = epilogue_t< - epilogue_policy_unaligned, + epilogue_policy_unaligned, tile_shape, - mem_desc_t>; + mem_desc_t>; - using group_swizzle = - gpu::xetla::kernel::group_swizzle_default; + using group_swizzle = gpu::xetla::kernel::group_swizzle_default; using dispatch_policy = dispatch_policy_kslicing; using gemm_op_t = gemm_universal_t; diff --git a/tests/integration/gemm/unaligned_bf16/main.cpp b/tests/integration/gemm/unaligned_bf16/main.cpp index 3d8778a3c..a91f35c6c 100644 --- a/tests/integration/gemm/unaligned_bf16/main.cpp +++ b/tests/integration/gemm/unaligned_bf16/main.cpp @@ -38,17 +38,32 @@ TYPED_TEST_P(unaligned_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(unaligned_gemm_test, esimd); using tests = ::testing::Types< - Test0, - Test1, - Test2, - Test3, - // Test4, - Test5, - Test6, - Test7, - Test8>; + Test0x, + Test1x, + Test2x, + Test3x, + Test4x, + Test5x, + Test6x, + Test7x, + Test8x, + Test9x, + Test10x, + Test11x, + Test12x, + Test13x, + Test14x, + Test15x, + Test16x, + Test17x, + Test18x, + Test19f, + Test19x, + Test20f, + Test21x, + Test21f>; INSTANTIATE_TYPED_TEST_SUITE_P( unaligned_gemm_test_suite, unaligned_gemm_test, - tests); \ No newline at end of file + tests); diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp index 46040dddd..c89fdf532 100644 --- a/tests/utils/execution.hpp +++ b/tests/utils/execution.hpp @@ -132,7 +132,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(arg); int constexpr warm_up = 10; - int constexpr iters = 100; + int constexpr iters = 10; for (size_t i = 0; i < batch; i++) { auto A_ptr = A + i * size_a; auto B_ptr = B + i * size_b; @@ -315,7 +315,6 @@ class dispatch_arch { switch (deviceArch) { case ENS::architecture::intel_gpu_pvc: return F::exec(std::forward(args)...); - return; case ENS::architecture::intel_gpu_dg2_g10: case ENS::architecture::intel_gpu_dg2_g11: case ENS::architecture::intel_gpu_dg2_g12: