Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Sync ipex
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Aug 26, 2024
1 parent 2ad313a commit 65f012b
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 13 deletions.
8 changes: 6 additions & 2 deletions include/common/core/math_general.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ __XETLA_API T xetla_rsqrt(T src, Sat sat = {}) {
template <typename T, int SZ>
__XETLA_API xetla_vector<T, SZ> xetla_tanh(xetla_vector<T, SZ> src) {
static_assert(
std::is_same<remove_const_t<T>, float>::value, "Only support fp32! ");
(std::is_same<remove_const_t<T>, float>::value) ||
(std::is_same<remove_const_t<T>, fp16>::value),
"Only support fp32 and fp16");
constexpr uint32_t flag_elems = 8 * 16;
xetla_vector<T, SZ> ret;
if constexpr (SZ / flag_elems > 0) {
Expand Down Expand Up @@ -502,7 +504,9 @@ __XETLA_API xetla_vector<T, SZ> xetla_tanh(xetla_vector<T, SZ> src) {
template <typename T>
__XETLA_API T xetla_tanh(T src) {
static_assert(
std::is_same<remove_const_t<T>, float>::value, "Only support fp32! ");
(std::is_same<remove_const_t<T>, float>::value) ||
(std::is_same<remove_const_t<T>, fp16>::value),
"Only support fp32 and fp16");
T exp2x = xetla_exp<T>(src * 2.f);
T ret = (exp2x - 1.f) / (exp2x + 1.f);
return (src >= 10) ? 1 : ret;
Expand Down
1 change: 1 addition & 0 deletions include/experimental/kernel/col_major_shuf/api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#pragma once

#include <experimental/kernel/col_major_shuf/common.hpp>
#include <experimental/kernel/col_major_shuf/config.hpp>

namespace gpu::xetla::kernel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@

#include <experimental/kernel/col_major_shuf/api.hpp>
#include <experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp>
#include <experimental/kernel/col_major_shuf/common.hpp>
#include <experimental/kernel/col_major_shuf/config.hpp>
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#pragma once

#include <experimental/kernel/col_major_shuf/api.hpp>
#include <experimental/kernel/col_major_shuf/common.hpp>
#include <experimental/kernel/col_major_shuf/config.hpp>

namespace gpu::xetla::kernel {
Expand Down
6 changes: 2 additions & 4 deletions include/experimental/kernel/col_major_shuf/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@

#pragma once

#include <common/common.hpp>
#include <group/group.hpp>
#include <subgroup/subgroup.hpp>
#include <experimental/kernel/layer_norm/common.hpp>

namespace gpu::xetla::kernel {

/// @brief Sets up attribute of the col-major-shuf.
/// @brief Sets up attribute of the layer norm.
///
/// @tparam wg_tile_x_ Is the num of cols processed by one workgroup.
/// @tparam wg_tile_y_ Is the num of rows processed by one workgroup.
Expand Down
2 changes: 1 addition & 1 deletion include/experimental/kernel/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
#pragma once

#include <experimental/kernel/col_major_shuf/col_major_shuf.hpp>
#include <experimental/kernel/int4_dequantize/int4_dequantize.hpp>
#include <experimental/kernel/data_transformer/data_transformer.hpp>
#include <experimental/kernel/gemm/gemm.hpp>
#include <experimental/kernel/int4_dequantize/int4_dequantize.hpp>
#include <experimental/kernel/layer_norm/layer_norm.hpp>
#include <experimental/kernel/mha_core_attention/mha_attn_reg.hpp>
#include <experimental/kernel/mha_core_attention/mha_core_attn.hpp>
Expand Down
10 changes: 4 additions & 6 deletions include/subgroup/tile/impl/op_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,12 +713,10 @@ void dump_mat(
for (size_t row = 0; row < tile_y; row++) {
#pragma unroll
for (size_t col = 0; col < tile_x; col++) {
sycl::ext::oneapi::experimental::printf(
"%d ", (int)(sycl::half)mat.reg[row * tile_x + col]);
// sycl::ext::oneapi::experimental::printf(
// "%x(%d) ",
// int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])),
// int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])));
sycl::ext::oneapi::experimental::printf(
"%x(%d) ",
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])),
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])));
}
sycl::ext::oneapi::experimental::printf("\n");
}
Expand Down

0 comments on commit 65f012b

Please sign in to comment.