Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
modify rvalue reference for store_2d
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Aug 26, 2024
1 parent 4e45885 commit 09bd497
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 19 deletions.
8 changes: 4 additions & 4 deletions include/common/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,24 +769,24 @@ __XETLA_API void xetla_store_global(
unsigned SurfacePitch,
int X,
int Y,
xetla_vector<T, N> Vals) {
auto&& Vals) {
if constexpr (std::is_same_v<T, bf16>) {
xetla_vector<fp16, N> Vals_fp16 = Vals.xetla_format<fp16>();
xetla_store_global<fp16, BlockWidth, BlockHeight, L1H, L2H>(
reinterpret_cast<fp16*>(Ptr),
SurfaceWidth,
SurfaceHeight,
SurfacePitch,
X,
Y,
Vals_fp16);
Vals.xetla_format<fp16>());
} else {
__ESIMD_ENS::lsc_store_2d<
T,
BlockWidth,
BlockHeight,
gpu::xetla::detail::get_cache_hint(L1H),
gpu::xetla::detail::get_cache_hint(L2H)>(
gpu::xetla::detail::get_cache_hint(L2H),
N>(
Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y, Vals);
}
}
Expand Down
4 changes: 2 additions & 2 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ tile_load(tile_t& tile, payload_t& payload) {

static constexpr uint32_t num_block_x = tile_desc::num_block_x;
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
// static constexpr uint32_t num_block = tile_desc::num_block;
// static constexpr uint32_t num_block = tile_desc::num_block;

static constexpr gpu_arch arch_tag = payload_t::arch_tag;

Expand Down Expand Up @@ -329,7 +329,7 @@ tile_load(tile_t& tile, payload_t& payload) {
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
native_type_t<load_dtype>,
block_size_x / scale_factor,
block_size_y,
ld_blk_height,
arr_len,
trans,
mem_transform,
Expand Down
16 changes: 8 additions & 8 deletions include/subgroup/tile/impl/store_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tile_store(tile_t& tile, payload_t& payload) {

static constexpr uint32_t num_block_x = tile_desc::num_block_x;
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
// static constexpr uint32_t num_block = tile_desc::num_block;
// static constexpr uint32_t num_block = tile_desc::num_block;

using load_store_attr = typename arch_attr_t<
payload_t::arch_tag>::template load_store_attr<msg_type::block_2d>;
Expand Down Expand Up @@ -145,7 +145,7 @@ tile_store(tile_t& tile, payload_t& payload) {
#pragma unroll
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
int32_t offset_x = j * block_size_x;
// xetla_tdescriptor tdesc = payload_row.row(j);
// xetla_tdescriptor tdesc = payload_row.row(j);
auto reg_blk = tile.reg.xetla_select<store_block_elems, 1>(
(i * num_block_x + j) * block_elems);
xetla_vector<dtype, store_block_elems> combine_blk;
Expand All @@ -163,7 +163,7 @@ tile_store(tile_t& tile, payload_t& payload) {
for (uint32_t ii = 0; ii < block_size_y / st_block_size_y; ++ii) {
constexpr uint32_t store_elems =
st_block_size_y * block_size_x * arr_len;
xetla_vector<dtype, store_elems> st_blk =
auto st_blk =
combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
// xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
// tdesc, st_blk);
Expand All @@ -173,7 +173,7 @@ tile_store(tile_t& tile, payload_t& payload) {
st_block_size_y,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<dtype*>(payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down Expand Up @@ -210,7 +210,7 @@ tile_store(tile_t& tile, payload_t& payload) {
blk_remained_y,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<dtype*>(payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down Expand Up @@ -240,7 +240,7 @@ tile_store(tile_t& tile, payload_t& payload) {
#pragma unroll
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
int offset_x = j * block_size_x;
// xetla_tdescriptor tdesc = payload_row.row(j);
// xetla_tdescriptor tdesc = payload_row.row(j);
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
processed_elems + j * remained_block_elems);
// Do combination
Expand Down Expand Up @@ -271,7 +271,7 @@ tile_store(tile_t& tile, payload_t& payload) {
remained_st_blk_size_y,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<dtype*>(payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down Expand Up @@ -308,7 +308,7 @@ tile_store(tile_t& tile, payload_t& payload) {
final_st_blk_size_y,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<dtype*>(payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/default_config/group_gemm/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ struct default_config_group_gemm_test_func {

using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "default_config_group_gemm_test_func";
}
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/default_config/kernel_gemm/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct default_config_kernel_gemm_test_func {
gpu_arch::XeHpc, // GPU arch
tune_option>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "default_config_kernel_gemm_test_func";
}
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/gemm/bf16/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct bf16_gemm_test_func {

using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "bf16_gemm_test_func";
}
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/gemm/fp32/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ struct fp32_gemm_test_func {

using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "fp32_gemm_test_func";
}
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/gemm/int8/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct int8gemm_test_func {

using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "int8gemm_test_func";
}
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/gemm/tf32/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ struct tf32_gemm_test_func {

using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "tf32_gemm_test_func";
}
Expand Down
9 changes: 8 additions & 1 deletion tests/integration/gemm/unaligned_bf16/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,20 @@ struct unaligned_gemm_test_func {
using epilogue_t = epilogue_t<
epilogue_policy_unaligned<arch_tag>,
tile_shape,
mem_desc_t<dtype_c, mem_layout::row_major, mem_space::global, ldc_alignment>>;
mem_desc_t<
dtype_c,
mem_layout::row_major,
mem_space::global,
ldc_alignment>>;

using group_swizzle = gpu::xetla::kernel::group_swizzle_default<arch_tag>;
using dispatch_policy =
dispatch_policy_kslicing<group_swizzle, global_kslicing, local_kslicing>;
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

static const char* func_name() {
return "unaligned_gemm_test_func";
}
Expand Down
5 changes: 1 addition & 4 deletions tests/integration/gemm/unaligned_bf16/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ TYPED_TEST_P(unaligned_gemm_test, esimd) {
gemm_exec<
TypeParam,
result_validate<TypeParam>,
unaligned_gemm_func<TypeParam>,
unaligned_gemm_func<TypeParam>::gemm_op_t::get_slm_size(),
unaligned_gemm_func<TypeParam>::gemm_op_t::get_barrier_count()>(
esimd_compile_string);
unaligned_gemm_func<TypeParam>>(esimd_compile_string);
}
REGISTER_TYPED_TEST_SUITE_P(unaligned_gemm_test, esimd);
using tests = ::testing::Types<
Expand Down

0 comments on commit 09bd497

Please sign in to comment.