Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
bugfix for update offset_x/offset_y
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Aug 29, 2024
1 parent 4f481df commit 90662a8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ tile_load(tile_t& tile, payload_t& payload) {
static constexpr gpu_arch arch_tag = payload_t::arch_tag;

static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
// In the case of pack, tranpose is in vnni format
static constexpr bool is_vnni_reverse =
payload_t::mem_transpose_dtype_less4bytes &&
((reg_layout_ == reg_layout::tiled) ||
Expand Down Expand Up @@ -188,14 +189,13 @@ tile_load(tile_t& tile, payload_t& payload) {
((block_size_y * sizeof(dtype)) % sizeof(load_dtype) == 0),
"check vnni limitation for DW transpose");

// auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
#pragma unroll
for (uint32_t i = 0; i < num_block_y; ++i) {
constexpr uint32_t load_block_elems = block_elems * arr_len;
int offset_y = i * block_size_y;
#pragma unroll
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
int32_t offset_x = j * block_size_x;
constexpr uint32_t load_block_elems = block_elems * arr_len;
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
(i * num_block_x + j) * block_elems);
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
Expand Down
2 changes: 2 additions & 0 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,14 @@ struct mem_payload_t<
__XETLA_API void update_tdesc(int offset) {
auto payloads_2d = payloads.xetla_format<uint32_t, num_block, 16>();
if constexpr (update_dir == tdesc_update_dir::x_dir) {
offset_x += offset / scale_factor;
#pragma unroll
for (uint32_t i = 0; i < num_block; i++) {
xetla_update_tdesc_offsetx(
payloads_2d.row(i), offset / int32_t(scale_factor));
}
} else {
offset_y += offset;
#pragma unroll
for (uint32_t i = 0; i < num_block; i++) {
xetla_update_tdesc_offsety(payloads_2d.row(i), offset);
Expand Down

0 comments on commit 90662a8

Please sign in to comment.