Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Enhance unaligment handling
Browse files Browse the repository at this point in the history
  • Loading branch information
JianpingChen066 committed Jun 24, 2024
1 parent c5314b7 commit b399daf
Show file tree
Hide file tree
Showing 30 changed files with 2,134 additions and 665 deletions.
6 changes: 4 additions & 2 deletions examples/02_basic_gemm/basic_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
// wrap the nd_range to XeTLA range

// Performance tuning setting based on different shapes
static constexpr uint32_t periodic_sync_interval = 0;
static constexpr uint32_t prefetch_distance = 1;
static constexpr uint32_t periodic_sync_interval =
(arch_tag == gpu_arch::XeHpc) ? 8 : 0;
static constexpr uint32_t prefetch_distance =
(arch_tag == gpu_arch::XeHpc) ? 3 : 1;
// should larger than 8
static constexpr uint32_t k_stride = 32;

Expand Down
7 changes: 4 additions & 3 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,14 @@ struct mma_attr_t<
m,
std::enable_if_t<arch_has_xmx<arch_tag>>> {
using dpas_attr = dpas_attr_t<arch_tag>;
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
static constexpr uint32_t mma_m_in_elem =
(m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
static constexpr uint32_t blk_m_in_elem = 16;

static constexpr uint32_t mma_n_in_elem = dpas_attr::n_in_elem;
[[maybe_unused]] static constexpr uint32_t blk_n_in_bytes = 64;
[[maybe_unused]] static constexpr uint32_t blk_n_in_bytes =
load_store_attr::max_trans_load_width_in_bytes;

static constexpr uint32_t mma_k_in_bytes = dpas_attr::k_in_bytes;
static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes;
Expand All @@ -224,8 +226,7 @@ struct mma_attr_t<
load_store_attr::max_trans_load_width_in_bytes;

[[maybe_unused]] static constexpr uint32_t mma_n_in_elem = 16;
static constexpr uint32_t blk_n_in_bytes =
register_bytes_t<arch_tag>::reg_in_bytes;
static constexpr uint32_t blk_n_in_bytes = blk_k_in_bytes;
};

template <gpu_arch arch_tag>
Expand Down
2 changes: 2 additions & 0 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class gemm_t<

// note: plane format, row-major
// note: 4bit x 2, row-major
static_assert(tile_size_x_b % pack_ratio == 0);
static_assert(block_size_x_b % pack_ratio == 0);
using matB_tile_desc_t = subgroup::tile_desc_t<
tile_size_x_b / pack_ratio,
tile_size_y_b,
Expand Down
4 changes: 3 additions & 1 deletion include/group/epilogue/impl/unaligned_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class epilogue_t<
static constexpr uint32_t slm_size = mem_desc_c_t::is_local
? tile_shape::wg_tile_size_x * tile_shape::wg_tile_size_y
: 0;
static constexpr bool ldc_align16 =
(mem_desc_c_t::alignment_in_bytes % 16 == 0);
/// @brief Epilogue arguments.
struct arguments_t {};

Expand All @@ -71,7 +73,7 @@ class epilogue_t<

public:
static constexpr msg_type msg_type_c =
(mem_space_c == mem_space::global ? msg_type::unaligned_2d
(mem_space_c == mem_space::global ? (ldc_align16 ? msg_type::block_2d : msg_type::unaligned_2d)
: msg_type::scatter);

/// @brief Default epilogue.
Expand Down
16 changes: 14 additions & 2 deletions include/group/gemm/compute_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct compute_policy_default_xmx<
arch_tag_,
std::enable_if_t<arch_has_xmx<arch_tag_>>> {
static constexpr gpu_arch arch_tag = arch_tag_;
static constexpr mma_engine mma_engine = mma_engine::xmx;

using compute_attr = compute_attr_;
using dtype_mma_acc = typename compute_attr::dtype_acc;
Expand All @@ -59,7 +60,8 @@ struct compute_policy_default_xmx<
static constexpr int sync_freq = perf_tuning_knob::sync_freq;
static constexpr int k_stride = perf_tuning_knob::k_stride;

using mma_attr = mma_attr_t<arch_tag, mma_engine::xmx, 16>;
using mma_attr = mma_attr_t<arch_tag, mma_engine, 16>;

static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem;
static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes;
static constexpr uint32_t block_size_x_a =
Expand Down Expand Up @@ -107,6 +109,7 @@ struct compute_policy_default_fpu<
arch_tag_,
std::enable_if_t<arch_has_fpu<arch_tag_>>> {
static constexpr gpu_arch arch_tag = arch_tag_;
static constexpr mma_engine mma_engine = mma_engine::fpu;

using compute_attr = compute_attr_;
using dtype_mma_acc = typename compute_attr::dtype_acc;
Expand All @@ -118,17 +121,26 @@ struct compute_policy_default_fpu<
static constexpr int sync_freq = perf_tuning_knob::sync_freq;
static constexpr int k_stride = perf_tuning_knob::k_stride;

using mma_attr = mma_attr_t<arch_tag, mma_engine::fpu, 16>;
using mma_attr = mma_attr_t<arch_tag, mma_engine, 16>;
static constexpr uint32_t block_size_y_a = mma_attr::blk_m_in_elem;
static constexpr uint32_t block_bytes_x_a = mma_attr::blk_k_in_bytes;
static constexpr uint32_t block_size_x_a =
block_bytes_x_a / sizeof(dtype_mma_a);

static constexpr uint32_t block_bytes_x_b = mma_attr::blk_n_in_bytes;
static constexpr uint32_t block_size_x_b =
block_bytes_x_b / sizeof(dtype_mma_b);
static constexpr uint32_t block_size_y_b = block_size_x_a;
};

template <
typename compute_attr_,
typename perf_tuning_knob_,
gpu_arch arch_tag_>
struct compute_policy_unaligned_fpu : public compute_policy_default_fpu<
compute_attr_,
perf_tuning_knob_,
arch_tag_> {};
/// @} xetla_gemm

} // namespace gpu::xetla::group
1 change: 1 addition & 0 deletions include/group/gemm/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@
#include <group/gemm/impl/default_xmx_xe.hpp>
#include <group/gemm/impl/pre_processing_xe.hpp>
#include <group/gemm/impl/selector_xe.hpp>
#include <group/gemm/impl/unaligned_fpu_xe.hpp>
#include <group/gemm/impl/unaligned_xmx_xe.hpp>
#include <group/tile_shape.hpp>
80 changes: 44 additions & 36 deletions include/group/gemm/impl/default_fpu_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class gemm_t<
subgroup::tile_desc_t<tile_size_x_a, tile_size_y_a, 1, 1>,
1,
arch_tag>;
matA_prefetch_payload_t matA_prefetch_payload;

static constexpr reg_layout reg_layout_b = reg_layout::tiled;
using matB_tile_desc_t = subgroup::tile_desc_t<
Expand All @@ -180,6 +181,7 @@ class gemm_t<
subgroup::tile_desc_t<tile_size_x_b, tile_size_y_b, 1, 1>,
1,
arch_tag>;
matB_prefetch_payload_t matB_prefetch_payload;

public:
using matAcc_tile_desc_t = subgroup::tile_desc_t<
Expand All @@ -202,16 +204,17 @@ class gemm_t<
static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0;
static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0;
static constexpr bool need_local_fence =
(mem_space_a == mem_space::local) || (mem_space_b == mem_space::local);
(mem_space_a == mem_space::local) || (mem_space_b == mem_space::local);

[[maybe_unused]] xetla_nbarrier_t<wg_size, wg_size, arch_tag> barrier_all;
[[maybe_unused]] xetla_nbarrier_t<wg_size_x, wg_size_x, arch_tag> nbarrier_a;
[[maybe_unused]] xetla_nbarrier_t<wg_size_y, wg_size_y, arch_tag> nbarrier_b;

public:
static constexpr uint32_t barrier_count =
(enable_periodic_sync && arch_has_named_barrier<arch_tag>) ?
barrier_count_x + barrier_count_y : 0;
(enable_periodic_sync && arch_has_named_barrier<arch_tag>)
? barrier_count_x + barrier_count_y
: 0;

// current no slm path
static constexpr uint32_t slm_size = 0;
Expand Down Expand Up @@ -293,18 +296,23 @@ class gemm_t<
};

inline void periodic_sync_init(
[[maybe_unused]] int32_t sg_idx,
[[maybe_unused]] int32_t sg_idy,
uint32_t nbarrier_base) {
[[maybe_unused]] int32_t sg_idx,
[[maybe_unused]] int32_t sg_idy,
uint32_t nbarrier_base) {
if constexpr (enable_periodic_sync) {
if constexpr (arch_has_named_barrier<arch_tag>) {
nbarrier_a.init_nbarrier(
sg_idy + nbarrier_base, nbarrier_role::producer_consumer);
nbarrier_b.init_nbarrier(
sg_idx + barrier_count_y + nbarrier_base,
nbarrier_role::producer_consumer);
} else {
barrier_all.init_nbarrier(nbarrier_base, nbarrier_role::producer_consumer);
if constexpr (wg_size_x > 1) {
nbarrier_a.init_nbarrier(
sg_idy + nbarrier_base, nbarrier_role::producer_consumer);
}
if constexpr (wg_size_y > 1) {
nbarrier_b.init_nbarrier(
sg_idx + barrier_count_y + nbarrier_base,
nbarrier_role::producer_consumer);
}
} else if constexpr (wg_size > 1) {
barrier_all.init_nbarrier(
nbarrier_base, nbarrier_role::producer_consumer);
}
}
}
Expand All @@ -319,7 +327,7 @@ class gemm_t<
if constexpr (wg_size_y > 1) {
nbarrier_b.arrive();
}
} else {
} else if constexpr (wg_size > 1) {
barrier_all.arrive();
}
}
Expand All @@ -336,13 +344,25 @@ class gemm_t<
if constexpr (wg_size_y > 1) {
nbarrier_b.wait();
}
} else {
} else if constexpr (wg_size > 1) {
barrier_all.wait();
}
}
}
}

inline void prefetch_and_update_ab() {
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matA_prefetch_payload);
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matB_prefetch_payload);
SW_BARRIER();
matA_prefetch_payload.template update_tdesc<update_dir_a>(
matA_t::tile_size_x);
matB_prefetch_payload.template update_tdesc<update_dir_b>(
matB_t::tile_size_y);
}

/// @brief Gets the subgroup-level tile offset x.
/// @param g Is the workgroup of the current tile.
/// @return Subgroup-level tile offset x.
Expand Down Expand Up @@ -400,51 +420,39 @@ class gemm_t<
pre_processing.init(g, args.pre_processing_args);
matA_payload_t matA_payload(args.matA_base_desc);
matB_payload_t matB_payload(args.matB_base_desc);
matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, 0);
matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, 0);
matA_prefetch_payload.init(args.matA_base_desc, 0);
matB_prefetch_payload.init(args.matB_base_desc, 0);

periodic_sync_init(sg_idx, sg_idy, nbarrier_base);

#pragma unroll
for (uint32_t i = 0; i < stages; i++) {
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matA_prefetch_payload);
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matB_prefetch_payload);
matA_prefetch_payload.template update_tdesc<update_dir_a>(
matA_t::tile_size_x);
matB_prefetch_payload.template update_tdesc<update_dir_b>(
matB_t::tile_size_y);
prefetch_and_update_ab();
}

for (uint32_t i = 0; i < args.inner_loop_count; i++) {
periodic_sync_arrive(i);
SW_BARRIER();

subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
matA, matA_payload);
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
matB, matB_payload);
SW_BARRIER();
matA_payload.template update_tdesc<update_dir_a>(matA_t::tile_size_x);
matB_payload.template update_tdesc<update_dir_b>(matB_t::tile_size_y);

if constexpr (stages != 0) {
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matA_prefetch_payload);
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matB_prefetch_payload);
matA_prefetch_payload.template update_tdesc<update_dir_a>(
matA_t::tile_size_x);
matB_prefetch_payload.template update_tdesc<update_dir_b>(
matB_t::tile_size_y);
prefetch_and_update_ab();
}

SW_BARRIER();
matA_acc_t matA_acc;
matB_acc_t matB_acc;
subgroup::elemwise_cvt(matA_acc, matA);
subgroup::elemwise_cvt(matB_acc, matB);
pre_processing(matA_acc, matB_acc, matA, matB);
SW_BARRIER();
tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc);
SW_BARRIER();

periodic_sync_wait(i);
}
SW_BARRIER();
Expand Down
Loading

0 comments on commit b399daf

Please sign in to comment.