Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Remove hardcoded XeHpc and enhance multi target support on barrier
Browse files Browse the repository at this point in the history
  • Loading branch information
JianpingChen066 committed Jul 30, 2024
1 parent c3d779a commit fd8f8d1
Show file tree
Hide file tree
Showing 90 changed files with 3,273 additions and 1,587 deletions.
6 changes: 4 additions & 2 deletions examples/02_basic_gemm/basic_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
// wrap the nd_range to XeTLA range

// Performance tuning setting based on different shapes
static constexpr uint32_t periodic_sync_interval = 0;
static constexpr uint32_t prefetch_distance = 1;
static constexpr uint32_t periodic_sync_interval =
(arch_tag == gpu_arch::XeHpc) ? 8 : 0;
static constexpr uint32_t prefetch_distance =
(arch_tag == gpu_arch::XeHpc) ? 3 : 1;
// should larger than 8
static constexpr uint32_t k_stride = 32;

Expand Down
8 changes: 5 additions & 3 deletions examples/06_gemm_softmax/gemm_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ void gemm_softmax_run(uint32_t iter) {
static constexpr uint32_t prefetch_distance = 3;
// should larger than 8
static constexpr uint32_t k_iter_num = 16;
static constexpr gpu_arch arch_tag = gpu_arch::XeHpc;
//static constexpr gpu_arch arch_tag = gpu_arch::XeHpg;

// Step 1: define Micro-kernel's configuration
using wg_shape = shape<wg_tile_n, wg_tile_m>;
Expand Down Expand Up @@ -227,7 +229,7 @@ void gemm_softmax_run(uint32_t iter) {
data_type_sfx, // accumulator data type for intermediate results
wg_shape, // computation tile shape
k_iter_num, // elements in each iteration
gpu_arch::XeHpc, // GPU arch
arch_tag, // GPU arch
tune_option>;

using gemm_args_t = gemm_op_t::arguments_t;
Expand All @@ -239,14 +241,14 @@ void gemm_softmax_run(uint32_t iter) {
mem_space::global, // memory writing to global mem for C
wg_shape, // computation tile shape
k_iter_num, // elements in each iteration
gpu_arch::XeHpc, // GPU arch
arch_tag, // GPU arch
tune_option>;

// using experimental::group::softmax
// define softmax forward op
using tile_shape = typename gemm_op_t::tile_shape;
using softmax_fwd_t = softmax_t<
softmax_policy_fwd<data_type_sfx, gpu_arch::XeHpc>,
softmax_policy_fwd<data_type_sfx, arch_tag>,
tile_shape>;
using softmax_fwd_args_t = typename softmax_fwd_t::arguments_t;

Expand Down
3 changes: 2 additions & 1 deletion examples/08_scaled_dot_product_attention/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ struct xetla_softmax_fwd_t {
using softmax_tile_desc_t = subgroup::
tile_desc_t<SIMD, block_height, SIMD, block_height, reg_layout::tiled>;
using softmax_load_t = subgroup::tile_t<dtype_in, softmax_tile_desc_t>;
using mem_desc_in_t = mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>;
using mem_desc_in_t =
mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>;
using softmax_load_payload_t = subgroup::mem_payload_t<
mem_desc_in_t,
softmax_tile_desc_t,
Expand Down
80 changes: 61 additions & 19 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
static constexpr uint32_t special_prefetch_width_in_bytes = 64;

static constexpr uint32_t cache_line_size_in_bytes = 64;
static constexpr uint32_t alignment_in_bytes = 8;
static constexpr uint32_t alignment_in_bytes = 16;
};

template <msg_type message_type, gpu_arch arg_tag>
Expand All @@ -72,7 +72,7 @@ struct client_load_store_attr_base_t {
static constexpr uint32_t special_prefetch_width_in_bytes = 64;

static constexpr uint32_t cache_line_size_in_bytes = 64;
static constexpr uint32_t alignment_in_bytes = 8;
static constexpr uint32_t alignment_in_bytes = 4;
};

template <>
Expand All @@ -94,15 +94,21 @@ inline constexpr bool arch_has_2d_load_store =
template <gpu_arch arch_tag>
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
static constexpr uint32_t max_load_vec_len = 256;
static constexpr uint32_t max_aligned_load_vec_len = 256;
static constexpr uint32_t max_store_vec_len = 256;
static constexpr uint32_t max_aligned_store_vec_len = 256;
static constexpr uint32_t max_prefetch_vec_len = 32;
static constexpr uint32_t max_channel_num = 16;
};

template <>
struct load_store_attr_t<msg_type::block_1d, gpu_arch::XeHpc> {
static constexpr uint32_t max_load_vec_len = 512;
static constexpr uint32_t max_store_vec_len = 512;
static constexpr uint32_t max_load_vec_len = 256;
static constexpr uint32_t max_aligned_load_vec_len = 512;
static constexpr uint32_t max_store_vec_len = 256;
static constexpr uint32_t max_aligned_store_vec_len = 512;
static constexpr uint32_t max_prefetch_vec_len = 64;
static constexpr uint32_t max_channel_num = 32;
};

struct dpas_attr_base_t {
Expand All @@ -112,6 +118,7 @@ struct dpas_attr_base_t {
static constexpr uint32_t op_per_channel_bits = 32;
static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3);
static constexpr uint32_t op_per_channel_max = 8;
static constexpr uint32_t k_in_bytes = systolic_depth * op_per_channel_bytes;
};

template <gpu_arch arch_tag>
Expand All @@ -121,12 +128,12 @@ struct dpas_attr_t {

template <>
struct dpas_attr_t<gpu_arch::XeHpc> : public dpas_attr_base_t {
static constexpr uint32_t n_fixed_limit = 16;
static constexpr uint32_t n_in_elem = 16;
};

template <>
struct dpas_attr_t<gpu_arch::XeHpg> : public dpas_attr_base_t {
static constexpr uint32_t n_fixed_limit = 8;
static constexpr uint32_t n_in_elem = 8;
};

template <gpu_arch arch_tag>
Expand All @@ -140,9 +147,10 @@ struct fpu_attr_t {
template <gpu_arch arch_tag>
inline constexpr bool arch_has_fpu = fpu_attr_t<arch_tag>::has_fpu;

#define GRF grf_mode::double_grf
#ifdef NORMAL_GRF
#define GRF grf_mode::normal_grf
#else
#define GRF grf_mode::double_grf
#endif

template <grf_mode grf_num_mode>
Expand All @@ -155,6 +163,7 @@ struct register_nums_t {

template <gpu_arch arch_tag>
struct register_bytes_t;

template <>
struct register_bytes_t<gpu_arch::XeHpc> {
static constexpr uint32_t reg_in_bytes = 64;
Expand All @@ -180,24 +189,49 @@ struct register_attr_t {
static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes;
};

template <gpu_arch arch_tag, uint32_t m, class enable = void>
template <
gpu_arch arch_tag,
mma_engine engine_type,
uint32_t m,
class enable = void>
struct mma_attr_t {};

template <gpu_arch arch_tag, uint32_t m>
struct mma_attr_t<arch_tag, m, std::enable_if_t<arch_has_xmx<arch_tag>>> {
struct mma_attr_t<
arch_tag,
mma_engine::xmx,
m,
std::enable_if_t<arch_has_xmx<arch_tag>>> {
using dpas_attr = dpas_attr_t<arch_tag>;
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
static constexpr uint32_t mma_m_in_elem =
(m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit;
static constexpr uint32_t mma_k_in_bytes =
dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes;
static constexpr uint32_t blk_m_in_elem = 16;

static constexpr uint32_t mma_n_in_elem = dpas_attr::n_in_elem;
[[maybe_unused]] static constexpr uint32_t blk_n_in_bytes =
load_store_attr::max_trans_load_width_in_bytes;

static constexpr uint32_t mma_k_in_bytes = dpas_attr::k_in_bytes;
static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes;
};

template <gpu_arch arch_tag, uint32_t m>
struct mma_attr_t<arch_tag, m, std::enable_if_t<!arch_has_xmx<arch_tag>>> {
struct mma_attr_t<
arch_tag,
mma_engine::fpu,
m,
std::enable_if_t<arch_has_fpu<arch_tag>>> {
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
static constexpr uint32_t mma_m_in_elem = (m > 8) ? 8 : m;
static constexpr uint32_t mma_n_in_elem = 16;
static constexpr uint32_t blk_m_in_elem = 16;

static constexpr uint32_t mma_k_in_bytes = 32;
static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes;

[[maybe_unused]] static constexpr uint32_t mma_n_in_elem = 16;
static constexpr uint32_t blk_n_in_bytes =
register_bytes_t<arch_tag>::reg_in_bytes;
};

template <gpu_arch arch_tag>
Expand All @@ -208,43 +242,51 @@ struct arch_attr_t<gpu_arch::XeHpc> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpc>;

template <grf_mode grf_num_mode = grf_mode::double_grf>
template <grf_mode grf_num_mode = GRF>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;

using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t local_mem_size = 128 * 1024;
static constexpr bool has_named_barrier = true;
};

template <>
struct arch_attr_t<gpu_arch::XeHpg> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpg>;

template <grf_mode grf_num_mode = grf_mode::double_grf>
template <grf_mode grf_num_mode = GRF>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpg>;

using dpas_attr = dpas_attr_t<gpu_arch::XeHpg>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t max_wg_num = 32;
static constexpr uint32_t local_mem_size = 64 * 1024;

static constexpr bool has_named_barrier = false;
};

template <>
struct arch_attr_t<gpu_arch::XeLpg> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeLpg>;

template <grf_mode grf_num_mode = grf_mode::double_grf>
template <grf_mode grf_num_mode = GRF>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeLpg>;

using dpas_attr = dpas_attr_t<gpu_arch::XeLpg>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t max_wg_num = 32;
static constexpr uint32_t local_mem_size = 64 * 1024;
static constexpr bool has_named_barrier = false;
};

template <gpu_arch arch_tag>
inline constexpr bool arch_has_named_barrier =
arch_attr_t<arch_tag>::has_named_barrier;

/// @} xetla_core_arch_config

} // namespace gpu::xetla
7 changes: 6 additions & 1 deletion include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
#include <cstdint>

namespace gpu::xetla {
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2, XeLast };

template <gpu_arch arch_tag>
inline constexpr bool valid_xe_arch_tag = (arch_tag < gpu_arch::XeLast);

enum class mma_engine : uint8_t { xmx = 0, fpu = 1 };

enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

Expand Down
8 changes: 1 addition & 7 deletions include/common/core/math_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ constexpr gpu::xetla::argument_type mma_argument_type<fp16>() {
template <gpu::xetla::argument_type arg_type>
constexpr __ESIMD_NS::xmx::dpas_argument_type get_argument_type() {
static_assert(
arg_type == gpu::xetla::argument_type::U1 ||
arg_type == gpu::xetla::argument_type::S1 ||
arg_type == gpu::xetla::argument_type::U2 ||
arg_type == gpu::xetla::argument_type::U2 ||
arg_type == gpu::xetla::argument_type::S2 ||
arg_type == gpu::xetla::argument_type::U4 ||
arg_type == gpu::xetla::argument_type::S4 ||
Expand All @@ -85,10 +83,6 @@ constexpr __ESIMD_NS::xmx::dpas_argument_type get_argument_type() {
arg_type == gpu::xetla::argument_type::TF32,
"Unsupported argument type");
switch (arg_type) {
case gpu::xetla::argument_type::U1:
return __ESIMD_NS::xmx::dpas_argument_type::u1;
case gpu::xetla::argument_type::S1:
return __ESIMD_NS::xmx::dpas_argument_type::s1;
case gpu::xetla::argument_type::U2:
return __ESIMD_NS::xmx::dpas_argument_type::u2;
case gpu::xetla::argument_type::S2:
Expand Down
17 changes: 0 additions & 17 deletions include/common/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,6 @@ __XETLA_API xetla_vector<Ty, N * NElts> xetla_load_global(
xetla_vector<Toffset, N> offsets,
xetla_mask<N> pred = 1) {
using T = native_type_t<Ty>;
DEBUG_INVOKE(
dbg_level::core,
core::general_1d<gpu_arch::XeHpc, Ty>::
template check_restriction<NElts, N>(offsets, (uint64_t)p));

return __ESIMD_ENS::lsc_gather<
T,
Expand Down Expand Up @@ -666,10 +662,6 @@ __XETLA_API xetla_vector<Ty, N * NElts> xetla_load_local(
xetla_vector<uint32_t, N> offsets,
xetla_mask<N> pred = 1) {
using T = native_type_t<Ty>;
DEBUG_INVOKE(
dbg_level::core,
core::general_1d<gpu_arch::XeHpc, Ty>::
template check_restriction<NElts, N>(offsets));

return __ESIMD_ENS::
lsc_slm_gather<T, NElts, gpu::xetla::detail::get_data_size(DS), N>(
Expand All @@ -694,11 +686,6 @@ __XETLA_API xetla_vector<Ty, N * NElts> xetla_load_local(
template <typename Ty, int NElts = 1, data_size DS = data_size::default_size>
__XETLA_API xetla_vector<Ty, NElts> xetla_load_local(uint32_t offset) {
using T = native_type_t<Ty>;
// DEBUG_INVOKE(
// dbg_level::core,
// core::general_1d<gpu_arch::XeHpc, Ty>::template
// check_restriction<NElts>(
// (uint64_t)offset));

return __ESIMD_NS::slm_block_load<T, NElts>(offset);
}
Expand Down Expand Up @@ -729,10 +716,6 @@ __XETLA_API void xetla_store_local(
xetla_vector<Ty, N * NElts> vals,
xetla_mask<N> pred = 1) {
using T = native_type_t<Ty>;
DEBUG_INVOKE(
dbg_level::core,
core::general_1d<gpu_arch::XeHpc, Ty>::
template check_restriction<NElts, N, uint32_t>(offsets));

__ESIMD_ENS::
lsc_slm_scatter<T, NElts, gpu::xetla::detail::get_data_size(DS), N>(
Expand Down
11 changes: 5 additions & 6 deletions include/common/utils/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ constexpr uint32_t get_element_size_code() {
enum class lsc_action : uint8_t { prefetch, load, store, atomic };

template <lsc_action Action, cache_hint L1H, cache_hint L2H, gpu_arch arch_tag>
constexpr std::enable_if_t<arch_tag <= gpu_arch::XeHpc, void>
constexpr std::enable_if_t<valid_xe_arch_tag<arch_tag>, void>
check_lsc_cache_hint() {
if constexpr (Action == lsc_action::prefetch) {
// https://gfxspecs.intel.com/Predator/Home/Index/53560
Expand Down Expand Up @@ -94,7 +94,7 @@ check_lsc_cache_hint() {
}

template <cache_hint L1H, cache_hint L2H, gpu_arch arch_tag>
constexpr std::enable_if_t<arch_tag == gpu_arch::XeHpc, uint32_t>
constexpr std::enable_if_t<arch_has_2d_load_store<arch_tag>, uint32_t>
get_load_cache_hint_code() {
check_lsc_cache_hint<lsc_action::load, L1H, L2H, arch_tag>();
if (L1H == cache_hint::none && L2H == cache_hint::none) {
Expand Down Expand Up @@ -126,7 +126,7 @@ get_load_cache_hint_code() {
}

template <cache_hint L1H, cache_hint L2H, gpu_arch arch_tag>
constexpr std::enable_if_t<arch_tag == gpu_arch::XeHpc, uint32_t>
constexpr std::enable_if_t<arch_has_2d_load_store<arch_tag>, uint32_t>
get_prefetch_cache_hint_code() {
check_lsc_cache_hint<lsc_action::prefetch, L1H, L2H, arch_tag>();
if (L2H == cache_hint::uncached) {
Expand All @@ -153,7 +153,7 @@ get_prefetch_cache_hint_code() {
}

template <cache_hint L1H, cache_hint L2H, gpu_arch arch_tag>
constexpr std::enable_if_t<arch_tag <= gpu_arch::XeHpc, uint32_t>
constexpr std::enable_if_t<arch_has_2d_load_store<arch_tag>, uint32_t>
get_store_cache_hint_code() {
check_lsc_cache_hint<lsc_action::store, L1H, L2H, arch_tag>();
if (L1H == cache_hint::none && L2H == cache_hint::none) {
Expand Down Expand Up @@ -185,7 +185,7 @@ get_store_cache_hint_code() {
}

template <cache_hint L1H, cache_hint L2H, gpu_arch arch_tag>
constexpr std::enable_if_t<arch_tag == gpu_arch::XeHpc, uint32_t>
constexpr std::enable_if_t<arch_has_2d_load_store<arch_tag>, uint32_t>
get_atomic_cache_hint_code() {
check_lsc_cache_hint<lsc_action::atomic, L1H, L2H, arch_tag>();
if (L1H == cache_hint::none && L2H == cache_hint::none) {
Expand Down Expand Up @@ -286,7 +286,6 @@ enum class store_op : uint8_t {
scattered_transpose = 3,
block_1d = 4
};
enum class mma_engine : uint8_t { xmx = 0, fpu = 1 };
// enum class trans_mode : uint8_t { none = 0, transpose = 1 };
enum class memory_op : uint8_t { load = 0, store = 1 };
enum class tdesc_update_dir : uint8_t { x_dir = 0, y_dir = 1 };
Expand Down
Loading

0 comments on commit fd8f8d1

Please sign in to comment.