From 70db70220aa8e75056c69b0a4421742fd846f969 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 29 Aug 2024 02:41:27 +0000 Subject: [PATCH] save --- CMakeLists.txt | 2 +- include/subgroup/tile/impl/load_xe.hpp | 65 +++++++++++--------------- tests/integration/gemm/fp32/main.cpp | 18 +++---- 3 files changed, 38 insertions(+), 47 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8717a0212..e3f7da9ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,7 @@ else() endif() add_compile_options(-fsycl -fsycl-device-code-split=per_kernel) -add_compile_options(-Wall -Wextra -Werror) +add_compile_options(-Wall -Wextra ) include(ProcessorCount) ProcessorCount(nproc) diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 8d958ea9d..c35478e80 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -119,8 +119,9 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr uint32_t max_load_width_in_elem = load_store_attr::max_load_width_in_bytes / sizeof(dtype); - // static constexpr uint32_t max_trans_load_height_in_elem = - // load_store_attr::max_trans_load_height_in_elem; + static constexpr uint32_t max_trans_load_height_in_elem = + load_store_attr::max_trans_load_height_in_elem; + static constexpr uint32_t max_load_height_in_elem = load_store_attr::max_load_height_in_elem; @@ -130,11 +131,25 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr uint32_t elems_per_reg = register_bytes_t::reg_in_bytes / sizeof(dtype); + static constexpr uint32_t max_ld_blk_width_in_elem = + trans ? max_trans_load_width_in_elem : max_load_width_in_elem; + + static constexpr uint32_t max_ld_blk_height_in_elem = + trans ? max_trans_load_height_in_elem : max_load_height_in_elem; + + static constexpr uint32_t ld_blk_width = std::min( + mem_transpose ? block_size_y : block_size_x, max_ld_blk_width_in_elem); + static constexpr uint32_t ld_blk_height = std::min( + mem_transpose ? block_size_x : block_size_y, max_ld_blk_height_in_elem); + + static constexpr uint32_t ld_blk_size_y = + mem_transpose ? ld_blk_width : ld_blk_height; + static constexpr uint32_t ld_blk_size_y_limit = mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem; - static constexpr uint32_t ld_blk_size_y = reg_transpose - ? block_size_y - : std::min(ld_blk_size_y_limit, block_size_y); + // static constexpr uint32_t ld_blk_size_y = reg_transpose + // ? block_size_y + // : std::min(ld_blk_size_y_limit, block_size_y); // array len is used to make sure memory load is cache line aligned // disabled while register or memory transpose @@ -198,10 +213,10 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t load_block_elems = block_elems * arr_len; auto reg_blk = tile.reg.xetla_select( (i * num_block_x + j) * block_elems); - constexpr uint32_t ld_blk_height = (reg_transpose && trans) + constexpr uint32_t ld_blk_size_y_pad = (reg_transpose && trans) ? detail::getNextPowerOf2() : ld_blk_size_y; - constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len; + constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len; xetla_vector reg_tmp; #pragma unroll for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) { @@ -213,10 +228,8 @@ tile_load(tile_t& tile, payload_t& payload) { mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y); reg_tmp.xetla_format>() = xetla_load_global< native_type_t, - (trans ? ld_blk_size_y : block_size_x) / scale_factor, - (trans ? block_size_x : ld_blk_size_y), - // block_size_x / scale_factor, - // ld_blk_size_y, + ld_blk_width / scale_factor, + ld_blk_height, arr_len, trans, mem_transform, @@ -261,11 +274,6 @@ tile_load(tile_t& tile, payload_t& payload) { (mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor; constexpr uint8_t block_height = mem_transpose ? block_size_x : remained_blk_size_y; - // constexpr uint32_t block_widthx_widthy_arrlen = - // (block_width - 1) | ((block_height - 1) << 8); - // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( - // tdesc.xetla_format(), block_widthx_widthy_arrlen); - reg_blk.xetla_select(remained_start) .xetla_format>() = xetla_load_global< native_type_t, @@ -283,15 +291,6 @@ tile_load(tile_t& tile, payload_t& payload) { payload.surface_pitch, payload.offset_x + offset_x / scale_factor, payload.offset_y + offset_y + remained_start_y); - - // xetla_tload_global< - // load_dtype, - // (load_elems / scale_factor), - // L1, - // L2, - // trans, - // mem_transform, - // arch_tag>(tdesc); } } } @@ -304,24 +303,16 @@ tile_load(tile_t& tile, payload_t& payload) { (!reg_transpose && (remained_size_y > ld_blk_size_y_limit)) ? ld_blk_size_y_limit : remained_size_y; - // auto payload_row = payload_2d.xetla_select( - // num_block_y * num_block_x, 0); - // detail::reset_tile_desc_core< - // num_block_x, - // block_size_x, - // remained_ld_blk_size_y, - // scale_factor, - // arr_len, - // mem_transpose>(payload_row); + #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { int32_t offset_x = j * block_size_x; // xetla_tdescriptor tdesc = payload_row.row(j); auto reg_blk = tile.reg.xetla_select( processed_elems + j * remained_block_elems); - constexpr uint32_t ld_blk_height = (reg_transpose && trans) - ? detail::getNextPowerOf2() - : remained_ld_blk_size_y; + // constexpr uint32_t ld_blk_height = (reg_transpose && trans) + // ? detail::getNextPowerOf2() + // : remained_ld_blk_size_y; constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len; xetla_vector reg_tmp; #pragma unroll diff --git a/tests/integration/gemm/fp32/main.cpp b/tests/integration/gemm/fp32/main.cpp index a05393399..3a9badc9f 100644 --- a/tests/integration/gemm/fp32/main.cpp +++ b/tests/integration/gemm/fp32/main.cpp @@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) { REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd); using tests = ::testing::Types< - // Test1, - // Test2, - // Test3, - // Test4, - // Test5, - // Test6, - // Test7, - // Test8, - // Test9, + Test1, + Test2, + Test3, + Test4, + Test5, + Test6, + Test7, + Test8, + Test9, Test10, Test11>; INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);