Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
sync fmha
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Aug 6, 2024
1 parent f3b453d commit d7bbf6b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 48 deletions.
69 changes: 31 additions & 38 deletions tests/integration/fmha/fmha_forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ class fmha_forward_t {
using comp_attr = group::compute_attr_t<scalar_t, scalar_t, accum_t>;
using knobs = group::perf_tuning_knob_t<accum_step, stages, sync_freq>;
using compute_policy_BrBc = std::conditional_t<
(arch_tag >= gpu_arch::XeHpg),
(arch_has_xmx<arch_tag>),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// TODO: add k slicing
using compute_policy_BrBm = std::conditional_t<
(arch_tag >= gpu_arch::XeHpg),
(arch_has_xmx<arch_tag>),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// ---------------- // Tile shape and Threads // ---------------- //
Expand Down Expand Up @@ -688,7 +688,7 @@ class fmha_forward_t {
uint8_t,
mem_desc_Dp_Mask_t::layout,
mem_desc_Dp_Mask_t::space>>,
gpu_arch::XeHpc>;
arch_tag>;
load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij);
subgroup::tile_load(mask_in, load_payload_mask);
matAccSij.reg = matAccSij.reg * mask_in.reg * args.dp_scale;
Expand Down Expand Up @@ -771,7 +771,7 @@ class fmha_forward_t {
uint32_t height = args.uB * args.uN * args.uF;
uint32_t offset_height = b * args.uN * args.uF + f * args.uN + n;

if constexpr (arch_tag != gpu_arch::XeHpc) {
if constexpr (!arch_has_2d_load_store<arch_tag>) {
// offset for curr work item
const uint32_t O_offset = offset_height * args.uH + h;
const auto ld_c = args.uN * args.uH;
Expand All @@ -798,30 +798,30 @@ class fmha_forward_t {
matOi_store_t matOi_store(mem_desc_Oi);
subgroup::tile_store<cache_hint::write_back, cache_hint::write_back>(
matOi, matOi_store);
return;
}

xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
transpose_tdecs.xetla_format<uint32_t>(),
args.O_ptr,
args.uH,
height,
args.uH,
h,
offset_height);

for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
// load data from matAccOi
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);

xetla_tstore_global<
scalar_t,
kSgHm,
cache_hint::write_back,
cache_hint::write_back>(transpose_tdecs, v_out);
xetla_update_tdesc_offsety(
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
} else {
xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
transpose_tdecs.xetla_format<uint32_t>(),
args.O_ptr,
args.uH,
height,
args.uH,
h,
offset_height);

for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
// load data from matAccOi
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);

xetla_tstore_global<
scalar_t,
kSgHm,
cache_hint::write_back,
cache_hint::write_back,
arch_tag>(transpose_tdecs, v_out);
xetla_update_tdesc_offsety(
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
}
}
}
// ====================== // preload_Qi // ====================== //
Expand Down Expand Up @@ -888,16 +888,9 @@ class fmha_forward_t {
/// @return The size of local memory required.
inline static constexpr uint32_t get_slm_size() {
constexpr uint32_t size = slm_size_Qi + slm_size_Pij + slm_size_softmax;
if constexpr (arch_tag == gpu_arch::XeHpc) {
static_assert(
size <= (128 * 1024),
"The local memory size should be less than 128KB!");

} else {
static_assert(
size <= (64 * 1024),
"The local memory size should be less than 64KB!");
}
static_assert(
size <= (arch_attr_t<arch_tag>::local_mem_size),
"The local memory size should be less than arch total local memory size");
return size;
};

Expand Down
23 changes: 13 additions & 10 deletions tests/integration/fmha/fmha_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ template <
typename mat_t,
uint32_t kNumSg,
reduce_op reduce_kind,
gpu_arch arch_tag = gpu_arch::XeHpc>
gpu_arch arch_tag>
struct group_row_reduce_t {
using T = typename mat_t::dtype;
static constexpr uint32_t kNum = mat_t::tile_desc::tile_size_y;
Expand Down Expand Up @@ -215,7 +215,7 @@ enum class add_type : uint8_t {
/// @tparam arch_tag Is the hardware architecture tag.
template <
typename dtype_bias_,
gpu_arch arch_tag = gpu_arch::XeHpc,
gpu_arch arch_tag,
add_type add_tag = add_type::single_line>
struct bias_add_op_t {};

Expand Down Expand Up @@ -324,8 +324,8 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
using base_t = typename mem_desc_bias_t::base_t;

struct arguments_t {
shape_t shape;
base_t base;
shape_t shape;
inline arguments_t() = default;
inline arguments_t(base_t base_, shape_t shape_)
: base(base_), shape(shape_) {}
Expand All @@ -351,11 +351,10 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
uint32_t offset = (pos_y + pos_x * args.shape.stride) * sizeof(dtype_bias);
auto bias_data_vector = xetla_load_global<
dtype_bias,
16,
1,
data_size::default_size,
cache_hint::cached,
cache_hint::cached,
16>(ptr, offset);
cache_hint::cached>(ptr, offset);
dtype_acc bias_data =
xetla_cvt<dtype_acc, dtype_bias, 16>(bias_data_vector)[0];

Expand Down Expand Up @@ -418,15 +417,19 @@ template <
typename mem_desc_c_t_>
class epilogue_transp_t {};

template <typename tile_op_t_, typename tile_shape_, typename mem_desc_c_t_>
template <
typename tile_op_t_,
typename tile_shape_,
typename mem_desc_c_t_,
gpu_arch arch_tag_>
class epilogue_transp_t<
epilogue_policy_tile_op<tile_op_t_, gpu_arch::XeHpc>,
epilogue_policy_tile_op<tile_op_t_, arch_tag_>,
tile_shape_,
mem_desc_c_t_> {
public:
using tile_shape = tile_shape_;
using mem_desc_c_t = mem_desc_c_t_;
static constexpr gpu_arch arch_tag = gpu_arch::XeHpc;
static constexpr gpu_arch arch_tag = arch_tag_;
static constexpr uint32_t barrier_count = 0;
static constexpr uint32_t slm_size = 0;

Expand Down Expand Up @@ -505,7 +508,7 @@ class epilogue_write_back_t<
epilogue_policy_default<arch_tag_>,
tile_shape_,
mem_desc_c_t_,
std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> {
std::enable_if_t<valid_xe_arch_tag<arch_tag_>>> {
public:
using epilogue_policy = epilogue_policy_default<arch_tag_>;
using tile_shape = tile_shape_;
Expand Down

0 comments on commit d7bbf6b

Please sign in to comment.