From 65f012b271d67b80fd36d1056c4a8119159f0577 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Mon, 26 Aug 2024 16:12:16 +0000 Subject: [PATCH] Sync ipex --- include/common/core/math_general.hpp | 8 ++++++-- include/experimental/kernel/col_major_shuf/api.hpp | 1 + .../kernel/col_major_shuf/col_major_shuf.hpp | 1 + .../kernel/col_major_shuf/col_major_shuf_xe.hpp | 1 + include/experimental/kernel/col_major_shuf/config.hpp | 6 ++---- include/experimental/kernel/kernel.hpp | 2 +- include/subgroup/tile/impl/op_function.hpp | 10 ++++------ 7 files changed, 16 insertions(+), 13 deletions(-) diff --git a/include/common/core/math_general.hpp b/include/common/core/math_general.hpp index 54f4e1a2f..013ac017f 100644 --- a/include/common/core/math_general.hpp +++ b/include/common/core/math_general.hpp @@ -460,7 +460,9 @@ __XETLA_API T xetla_rsqrt(T src, Sat sat = {}) { template __XETLA_API xetla_vector xetla_tanh(xetla_vector src) { static_assert( - std::is_same, float>::value, "Only support fp32! "); + (std::is_same, float>::value) || + (std::is_same, fp16>::value), + "Only support fp32 and fp16"); constexpr uint32_t flag_elems = 8 * 16; xetla_vector ret; if constexpr (SZ / flag_elems > 0) { @@ -502,7 +504,9 @@ __XETLA_API xetla_vector xetla_tanh(xetla_vector src) { template __XETLA_API T xetla_tanh(T src) { static_assert( - std::is_same, float>::value, "Only support fp32! "); + (std::is_same, float>::value) || + (std::is_same, fp16>::value), + "Only support fp32 and fp16"); T exp2x = xetla_exp(src * 2.f); T ret = (exp2x - 1.f) / (exp2x + 1.f); return (src >= 10) ? 1 : ret; diff --git a/include/experimental/kernel/col_major_shuf/api.hpp b/include/experimental/kernel/col_major_shuf/api.hpp index f7e2036c8..5eb3db9f7 100644 --- a/include/experimental/kernel/col_major_shuf/api.hpp +++ b/include/experimental/kernel/col_major_shuf/api.hpp @@ -19,6 +19,7 @@ #pragma once +#include #include namespace gpu::xetla::kernel { diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp index a77ba06de..1cd3b7153 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp @@ -21,4 +21,5 @@ #include #include +#include #include diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp index 6eb671d3e..3c6cb8dda 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp @@ -20,6 +20,7 @@ #pragma once #include +#include #include namespace gpu::xetla::kernel { diff --git a/include/experimental/kernel/col_major_shuf/config.hpp b/include/experimental/kernel/col_major_shuf/config.hpp index 2c6a7c502..d5fbe5543 100644 --- a/include/experimental/kernel/col_major_shuf/config.hpp +++ b/include/experimental/kernel/col_major_shuf/config.hpp @@ -19,13 +19,11 @@ #pragma once -#include -#include -#include +#include namespace gpu::xetla::kernel { -/// @brief Sets up attribute of the col-major-shuf. +/// @brief Sets up attribute of the layer norm. /// /// @tparam wg_tile_x_ Is the num of cols processed by one workgroup. /// @tparam wg_tile_y_ Is the num of rows processed by one workgroup. diff --git a/include/experimental/kernel/kernel.hpp b/include/experimental/kernel/kernel.hpp index 0354a8108..fedc6b271 100644 --- a/include/experimental/kernel/kernel.hpp +++ b/include/experimental/kernel/kernel.hpp @@ -20,9 +20,9 @@ #pragma once #include -#include #include #include +#include #include #include #include diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 144149d18..44d2f6569 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -713,12 +713,10 @@ void dump_mat( for (size_t row = 0; row < tile_y; row++) { #pragma unroll for (size_t col = 0; col < tile_x; col++) { - sycl::ext::oneapi::experimental::printf( - "%d ", (int)(sycl::half)mat.reg[row * tile_x + col]); - // sycl::ext::oneapi::experimental::printf( - // "%x(%d) ", - // int(native_type_t(mat.reg[row * tile_x + col])), - // int(native_type_t(mat.reg[row * tile_x + col]))); + sycl::ext::oneapi::experimental::printf( + "%x(%d) ", + int(native_type_t(mat.reg[row * tile_x + col])), + int(native_type_t(mat.reg[row * tile_x + col]))); } sycl::ext::oneapi::experimental::printf("\n"); }