From 09bd4976c2e8b5c83b22e349021b0d52c039d565 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Mon, 26 Aug 2024 04:19:52 +0000 Subject: [PATCH] modify rvalue reference for store_2d --- include/common/core/memory.hpp | 8 ++++---- include/subgroup/tile/impl/load_xe.hpp | 4 ++-- include/subgroup/tile/impl/store_xe.hpp | 16 ++++++++-------- .../default_config/group_gemm/kernel_func.hpp | 3 +++ .../default_config/kernel_gemm/kernel_func.hpp | 3 +++ tests/integration/gemm/bf16/kernel_func.hpp | 3 +++ tests/integration/gemm/fp32/kernel_func.hpp | 3 +++ tests/integration/gemm/int8/kernel_func.hpp | 3 +++ tests/integration/gemm/tf32/kernel_func.hpp | 3 +++ .../gemm/unaligned_bf16/kernel_func.hpp | 9 ++++++++- tests/integration/gemm/unaligned_bf16/main.cpp | 5 +---- 11 files changed, 41 insertions(+), 19 deletions(-) diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index ed75cf935..c152f187c 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -769,9 +769,8 @@ __XETLA_API void xetla_store_global( unsigned SurfacePitch, int X, int Y, - xetla_vector Vals) { + auto&& Vals) { if constexpr (std::is_same_v) { - xetla_vector Vals_fp16 = Vals.xetla_format(); xetla_store_global( reinterpret_cast(Ptr), SurfaceWidth, @@ -779,14 +778,15 @@ __XETLA_API void xetla_store_global( SurfacePitch, X, Y, - Vals_fp16); + Vals.xetla_format()); } else { __ESIMD_ENS::lsc_store_2d< T, BlockWidth, BlockHeight, gpu::xetla::detail::get_cache_hint(L1H), - gpu::xetla::detail::get_cache_hint(L2H)>( + gpu::xetla::detail::get_cache_hint(L2H), + N>( Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y, Vals); } } diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 844c94a27..999233108 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -89,7 +89,7 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr uint32_t num_block_x = tile_desc::num_block_x; static constexpr uint32_t num_block_y = tile_desc::num_block_y; -// static constexpr uint32_t num_block = tile_desc::num_block; + // static constexpr uint32_t num_block = tile_desc::num_block; static constexpr gpu_arch arch_tag = payload_t::arch_tag; @@ -329,7 +329,7 @@ tile_load(tile_t& tile, payload_t& payload) { reg_tmp.xetla_format>() = xetla_load_global< native_type_t, block_size_x / scale_factor, - block_size_y, + ld_blk_height, arr_len, trans, mem_transform, diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 30671491b..ae0cce2b9 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -98,7 +98,7 @@ tile_store(tile_t& tile, payload_t& payload) { static constexpr uint32_t num_block_x = tile_desc::num_block_x; static constexpr uint32_t num_block_y = tile_desc::num_block_y; -// static constexpr uint32_t num_block = tile_desc::num_block; + // static constexpr uint32_t num_block = tile_desc::num_block; using load_store_attr = typename arch_attr_t< payload_t::arch_tag>::template load_store_attr; @@ -145,7 +145,7 @@ tile_store(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { int32_t offset_x = j * block_size_x; - // xetla_tdescriptor tdesc = payload_row.row(j); + // xetla_tdescriptor tdesc = payload_row.row(j); auto reg_blk = tile.reg.xetla_select( (i * num_block_x + j) * block_elems); xetla_vector combine_blk; @@ -163,7 +163,7 @@ tile_store(tile_t& tile, payload_t& payload) { for (uint32_t ii = 0; ii < block_size_y / st_block_size_y; ++ii) { constexpr uint32_t store_elems = st_block_size_y * block_size_x * arr_len; - xetla_vector st_blk = + auto st_blk = combine_blk.xetla_select(ii * store_elems); // xetla_tstore_global( // tdesc, st_blk); @@ -173,7 +173,7 @@ tile_store(tile_t& tile, payload_t& payload) { st_block_size_y, L1, L2>( - payload.base_ptr, + reinterpret_cast(payload.base_ptr), payload.surface_width, payload.surface_height, payload.surface_pitch, @@ -210,7 +210,7 @@ tile_store(tile_t& tile, payload_t& payload) { blk_remained_y, L1, L2>( - payload.base_ptr, + reinterpret_cast(payload.base_ptr), payload.surface_width, payload.surface_height, payload.surface_pitch, @@ -240,7 +240,7 @@ tile_store(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { int offset_x = j * block_size_x; - // xetla_tdescriptor tdesc = payload_row.row(j); + // xetla_tdescriptor tdesc = payload_row.row(j); auto reg_blk = tile.reg.xetla_select( processed_elems + j * remained_block_elems); // Do combination @@ -271,7 +271,7 @@ tile_store(tile_t& tile, payload_t& payload) { remained_st_blk_size_y, L1, L2>( - payload.base_ptr, + reinterpret_cast(payload.base_ptr), payload.surface_width, payload.surface_height, payload.surface_pitch, @@ -308,7 +308,7 @@ tile_store(tile_t& tile, payload_t& payload) { final_st_blk_size_y, L1, L2>( - payload.base_ptr, + reinterpret_cast(payload.base_ptr), payload.surface_width, payload.surface_height, payload.surface_pitch, diff --git a/tests/integration/default_config/group_gemm/kernel_func.hpp b/tests/integration/default_config/group_gemm/kernel_func.hpp index 26b2fb181..7b7b31404 100644 --- a/tests/integration/default_config/group_gemm/kernel_func.hpp +++ b/tests/integration/default_config/group_gemm/kernel_func.hpp @@ -108,6 +108,9 @@ struct default_config_group_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "default_config_group_gemm_test_func"; } diff --git a/tests/integration/default_config/kernel_gemm/kernel_func.hpp b/tests/integration/default_config/kernel_gemm/kernel_func.hpp index 3745343d0..24b82bb94 100644 --- a/tests/integration/default_config/kernel_gemm/kernel_func.hpp +++ b/tests/integration/default_config/kernel_gemm/kernel_func.hpp @@ -65,6 +65,9 @@ struct default_config_kernel_gemm_test_func { gpu_arch::XeHpc, // GPU arch tune_option>; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "default_config_kernel_gemm_test_func"; } diff --git a/tests/integration/gemm/bf16/kernel_func.hpp b/tests/integration/gemm/bf16/kernel_func.hpp index 6345047df..286e53dbb 100644 --- a/tests/integration/gemm/bf16/kernel_func.hpp +++ b/tests/integration/gemm/bf16/kernel_func.hpp @@ -76,6 +76,9 @@ struct bf16_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "bf16_gemm_test_func"; } diff --git a/tests/integration/gemm/fp32/kernel_func.hpp b/tests/integration/gemm/fp32/kernel_func.hpp index 962acfe96..a83518076 100644 --- a/tests/integration/gemm/fp32/kernel_func.hpp +++ b/tests/integration/gemm/fp32/kernel_func.hpp @@ -77,6 +77,9 @@ struct fp32_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "fp32_gemm_test_func"; } diff --git a/tests/integration/gemm/int8/kernel_func.hpp b/tests/integration/gemm/int8/kernel_func.hpp index 3a9e595ca..da2eab012 100644 --- a/tests/integration/gemm/int8/kernel_func.hpp +++ b/tests/integration/gemm/int8/kernel_func.hpp @@ -72,6 +72,9 @@ struct int8gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "int8gemm_test_func"; } diff --git a/tests/integration/gemm/tf32/kernel_func.hpp b/tests/integration/gemm/tf32/kernel_func.hpp index 8c2850b9b..42d69eb51 100644 --- a/tests/integration/gemm/tf32/kernel_func.hpp +++ b/tests/integration/gemm/tf32/kernel_func.hpp @@ -71,6 +71,9 @@ struct tf32_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "tf32_gemm_test_func"; } diff --git a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp index d45ddc0b7..d5534ebdf 100644 --- a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp +++ b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp @@ -68,13 +68,20 @@ struct unaligned_gemm_test_func { using epilogue_t = epilogue_t< epilogue_policy_unaligned, tile_shape, - mem_desc_t>; + mem_desc_t< + dtype_c, + mem_layout::row_major, + mem_space::global, + ldc_alignment>>; using group_swizzle = gpu::xetla::kernel::group_swizzle_default; using dispatch_policy = dispatch_policy_kslicing; using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "unaligned_gemm_test_func"; } diff --git a/tests/integration/gemm/unaligned_bf16/main.cpp b/tests/integration/gemm/unaligned_bf16/main.cpp index 9ceaf695e..3d23c78ce 100644 --- a/tests/integration/gemm/unaligned_bf16/main.cpp +++ b/tests/integration/gemm/unaligned_bf16/main.cpp @@ -31,10 +31,7 @@ TYPED_TEST_P(unaligned_gemm_test, esimd) { gemm_exec< TypeParam, result_validate, - unaligned_gemm_func, - unaligned_gemm_func::gemm_op_t::get_slm_size(), - unaligned_gemm_func::gemm_op_t::get_barrier_count()>( - esimd_compile_string); + unaligned_gemm_func>(esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(unaligned_gemm_test, esimd); using tests = ::testing::Types<