From 17a37549ef60dc5f1adfce25a1dced1845c4b053 Mon Sep 17 00:00:00 2001 From: yitingw1 <106734399+yitingw1@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:16:09 +0800 Subject: [PATCH] [Fix][GPU] fix xetla kernel (#2366) --- itex/core/graph/config_util.cc | 8 +++++ itex/core/graph/config_util.h | 2 ++ itex/core/graph/optimizer_config.cc | 6 ++++ itex/core/kernels/gpu/xetla/fmha_forward.h | 5 +-- itex/core/kernels/gpu/xetla/fmha_utils.h | 9 ++--- .../gpu/xetla/non_flash_sdp/mha_forward.h | 3 +- .../gpu/xetla/non_flash_sdp/mha_util.h | 9 ++--- itex/core/utils/protobuf/config.proto | 2 ++ itex/python/device.py | 6 +++- itex/python/itex_wrapper.cc | 18 +++------- itex/python/ops/multi_head_attention.py | 35 +++++++++++++++---- 11 files changed, 72 insertions(+), 31 deletions(-) diff --git a/itex/core/graph/config_util.cc b/itex/core/graph/config_util.cc index e4da99f26..cc589b529 100644 --- a/itex/core/graph/config_util.cc +++ b/itex/core/graph/config_util.cc @@ -29,4 +29,12 @@ void itex_set_config(const ConfigProto& config) { Configs() = config; } ConfigProto itex_get_config() { return Configs(); } +bool isxehpc_value; +ConfigProto itex_get_isxehpc() { + ConfigProto isxehpc_proto; + GraphOptions* isxehpc_graph = isxehpc_proto.mutable_graph_options(); + isxehpc_graph->set_device_isxehpc(isxehpc_value); + return isxehpc_proto; +} + } // namespace itex diff --git a/itex/core/graph/config_util.h b/itex/core/graph/config_util.h index c3cc319fb..93ab78889 100644 --- a/itex/core/graph/config_util.h +++ b/itex/core/graph/config_util.h @@ -23,6 +23,8 @@ limitations under the License. namespace itex { void itex_set_config(const ConfigProto& config); ConfigProto itex_get_config(); +extern bool isxehpc_value; +ConfigProto itex_get_isxehpc(); } // namespace itex #endif // ITEX_CORE_GRAPH_CONFIG_UTIL_H_ diff --git a/itex/core/graph/optimizer_config.cc b/itex/core/graph/optimizer_config.cc index e83577d15..9dccb57ce 100644 --- a/itex/core/graph/optimizer_config.cc +++ b/itex/core/graph/optimizer_config.cc @@ -41,6 +41,8 @@ void HelperSetEnvOptimzerConfig(std::string new_name, std::string old_name, } // namespace +extern bool itex::isxehpc_value; + void SetOptimizerConfigFlags(OptimizerConfigFlags* opt_config_flags) { bool sharding_flag; bool onednn_graph_flag; @@ -149,6 +151,10 @@ void SetOptimizerConfigFlags(OptimizerConfigFlags* opt_config_flags) { ITEX_CHECK_OK(itex::ReadBoolFromEnvVar( "_ITEX_TEST_MODE", enable_itex_test_mode, &test_mode_flag)); +#ifndef INTEL_CPU_ONLY + itex::isxehpc_value = IsXeHPC(); +#endif + #undef USER_IS_ON #undef USER_IS_OFF #undef USER_IS_SET diff --git a/itex/core/kernels/gpu/xetla/fmha_forward.h b/itex/core/kernels/gpu/xetla/fmha_forward.h index 34f238e71..e185ebeb1 100644 --- a/itex/core/kernels/gpu/xetla/fmha_forward.h +++ b/itex/core/kernels/gpu/xetla/fmha_forward.h @@ -348,7 +348,8 @@ class fmha_forward_t { using tile_mask = tile_mask_t; uint32_t sg_startT = startT + ctx.sg_idx * kSgBc; - uint32_t remainT = (args.uT < sg_startT) ? 0 : (args.uT - sg_startT); + uint32_t remainT = + std::max(static_cast(args.uT) - static_cast(sg_startT), 0); if (remainT < kSgBc) { tile_mask::padding_mask(matAccSij, remainT); } @@ -582,7 +583,7 @@ class fmha_forward_t { matAccOi_t matAccOi(0); uint32_t startF = ei.get_group(1) * kBr; - uint32_t endF = (startF + kBr) > args.uF ? args.uF : (startF + kBr); + uint32_t endF = std::min(startF + kBr, args.uF); // iterate through the keys for (uint32_t startT = 0; startT < args.uT; startT += kBc) { diff --git a/itex/core/kernels/gpu/xetla/fmha_utils.h b/itex/core/kernels/gpu/xetla/fmha_utils.h index d34268c6d..11df0150f 100644 --- a/itex/core/kernels/gpu/xetla/fmha_utils.h +++ b/itex/core/kernels/gpu/xetla/fmha_utils.h @@ -17,6 +17,7 @@ #ifndef ITEX_CORE_KERNELS_GPU_XETLA_FMHA_UTILS_H_ #define ITEX_CORE_KERNELS_GPU_XETLA_FMHA_UTILS_H_ +#include #include namespace gpu::xetla { @@ -96,8 +97,8 @@ struct tile_mask_t { for (int i = 0; i < tile_size_y / block_size_y; i++) { #pragma unroll for (int j = 0; j < num_block_x; j++) { - uint32_t start_x = j * block_size_x; - uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x); + int start_x = j * block_size_x; + int num_keep_blk = std::max(static_cast(num_keep) - start_x, 0); if (num_keep_blk < block_size_x) { xetla_mask mask = @@ -122,8 +123,8 @@ struct tile_mask_t { constexpr uint32_t tail_block_elems = tail_size_y * block_size_x; #pragma unroll for (int j = 0; j < num_block_x; j++) { - uint32_t start_x = j * block_size_x; - uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x); + int start_x = j * block_size_x; + int num_keep_blk = std::max(static_cast(num_keep) - start_x, 0); if (num_keep_blk < block_size_x) { xetla_mask mask = diff --git a/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_forward.h b/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_forward.h index 1b267b051..b95798171 100644 --- a/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_forward.h +++ b/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_forward.h @@ -269,7 +269,8 @@ class mha_forward_t { using tile_mask = tile_mask_t; uint32_t sg_startT = ctx.sg_idx * kSgTm; - uint32_t remainT = (args.uT < sg_startT) ? 0 : (args.uT - sg_startT); + uint32_t remainT = + std::max(static_cast(args.uT) - static_cast(sg_startT), 0); if (remainT < kSgTm) { tile_mask::padding_mask(matAcc_S, remainT); } diff --git a/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_util.h b/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_util.h index 4a1d4b4a3..0d3360d52 100644 --- a/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_util.h +++ b/itex/core/kernels/gpu/xetla/non_flash_sdp/mha_util.h @@ -17,6 +17,7 @@ #ifndef ITEX_CORE_KERNELS_GPU_XETLA_NON_FLASH_SDP_MHA_UTIL_H_ #define ITEX_CORE_KERNELS_GPU_XETLA_NON_FLASH_SDP_MHA_UTIL_H_ +#include #include namespace gpu::xetla { @@ -96,8 +97,8 @@ struct tile_mask_t { for (int i = 0; i < tile_size_y / block_size_y; i++) { #pragma unroll for (int j = 0; j < num_block_x; j++) { - uint32_t start_x = j * block_size_x; - uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x); + int start_x = j * block_size_x; + int num_keep_blk = std::max(static_cast(num_keep) - start_x, 0); if (num_keep_blk < block_size_x) { xetla_mask mask = @@ -122,8 +123,8 @@ struct tile_mask_t { constexpr uint32_t tail_block_elems = tail_size_y * block_size_x; #pragma unroll for (int j = 0; j < num_block_x; j++) { - uint32_t start_x = j * block_size_x; - uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x); + int start_x = j * block_size_x; + int num_keep_blk = std::max(static_cast(num_keep) - start_x, 0); if (num_keep_blk < block_size_x) { xetla_mask mask = diff --git a/itex/core/utils/protobuf/config.proto b/itex/core/utils/protobuf/config.proto index 1d94e22ef..c6326804b 100644 --- a/itex/core/utils/protobuf/config.proto +++ b/itex/core/utils/protobuf/config.proto @@ -49,6 +49,8 @@ message GraphOptions { // Shard single device graph on multiple devices. Toggle sharding = 6; ShardingConfig sharding_config = 7; + // Get device isXeHPC() + bool device_isxehpc = 8; } message ConfigProto { diff --git a/itex/python/device.py b/itex/python/device.py index 4a69ee896..17241b012 100644 --- a/itex/python/device.py +++ b/itex/python/device.py @@ -19,9 +19,13 @@ from __future__ import print_function from intel_extension_for_tensorflow.python._pywrap_itex import * +from intel_extension_for_tensorflow.core.utils.protobuf import config_pb2 def get_backend(): return ITEX_GetBackend() def is_xehpc(): - return ITEX_IsXeHPC() \ No newline at end of file + isxehpc = ITEX_IsXeHPC() + isxehpc_proto = config_pb2.ConfigProto() + isxehpc_proto.ParseFromString(isxehpc) + return isxehpc_proto.graph_options.device_isxehpc \ No newline at end of file diff --git a/itex/python/itex_wrapper.cc b/itex/python/itex_wrapper.cc index 8b02f7875..6a9b7c1b6 100644 --- a/itex/python/itex_wrapper.cc +++ b/itex/python/itex_wrapper.cc @@ -40,19 +40,11 @@ static py::bytes ITEX_GetConfig() { return py::bytes(config_str); } -static bool ITEX_IsXeHPC() { - // TODO(itex): __LIBSYCL_MINOR_VERSION == 1 is to limit compiler version as - // there is bug for __LIBSYCL_MINOR_VERSION == 2 remove this once the bug is - // fixed -#ifndef INTEL_CPU_ONLY -#if __LIBSYCL_MINOR_VERSION == 1 - return IsXeHPC(nullptr); -#else - return false; -#endif -#else - return false; -#endif +static py::bytes ITEX_IsXeHPC() { + std::string config_str; + ConfigProto config_proto = itex_get_isxehpc(); + config_proto.SerializeToString(&config_str); + return py::bytes(config_str); } PYBIND11_MODULE(_pywrap_itex, m) { diff --git a/itex/python/ops/multi_head_attention.py b/itex/python/ops/multi_head_attention.py index 71209fb12..8070bf843 100644 --- a/itex/python/ops/multi_head_attention.py +++ b/itex/python/ops/multi_head_attention.py @@ -46,8 +46,25 @@ def _stateless_dropout(input_tensor, dropout_prob, seed): output = tf.nn.experimental.stateless_dropout(input_tensor, rate=dropout_prob, seed=seed) return output +def _dropout(input_tensor, dropout_prob, seed): + """Perform dropout. + + Args: + input_tensor: float Tensor. + dropout_prob: Python float. The probability of dropping out a value (NOT of + *keeping* a dimension as in `tf.nn.dropout`). + + Returns: + A version of `input_tensor` with dropout applied. + """ + if dropout_prob is None or dropout_prob == 0.0: + return input_tensor -def scaled_dot_product_attention(query, key, value, atten_mask=None, dropout_p=0.0, seed=(2,3), is_causal=False, use_fast_attention=True): + output = tf.nn.dropout(input_tensor, rate=dropout_prob, seed=seed) + return output + + +def scaled_dot_product_attention(query, key, value, atten_mask=None, dropout_p=0.0, seed=(2,3), is_causal=False, use_fast_attention=True, use_stateless_randomuniform=True): """Applies Dot-product attention with query, key, value tensors. Args: @@ -88,7 +105,11 @@ def sdp(): atten_probs = tf.nn.softmax(atten_scores, -1) if dropout_p != 0.0: - atten_probs = _stateless_dropout(atten_probs, dropout_p, seed) + if use_stateless_randomuniform: + atten_probs = _stateless_dropout(atten_probs, dropout_p, seed) + else: + atten_probs = _dropout(atten_probs, dropout_p, seed[0]) + # `atten_output` =[B, N, F, H] atten_output = tf.matmul(atten_probs, value) @@ -98,7 +119,7 @@ def sdp(): return output def fast_sdp(): - batch_size = query.shape[0] + batch_size = tf.shape(query)[0] num_heads = query.shape[1] from_seq_len = query.shape[2] head_size = query.shape[3] @@ -108,9 +129,11 @@ def fast_sdp(): use_dropout = (dropout_p != 0.0) use_mask = (atten_mask is not None) if use_dropout: - uniform_sampler = functools.partial(stateless_random_ops.stateless_random_uniform, seed=seed) - uniform_sampler_input = tf.compat.v1.placeholder(i_dtype, shape=[batch_size, num_heads, from_seq_len, to_seq_len]) - random_tensor = uniform_sampler(shape=tf.shape(uniform_sampler_input), dtype=i_dtype) + if use_stateless_randomuniform: + uniform_sampler = functools.partial(stateless_random_ops.stateless_random_uniform, seed=seed) + else: + uniform_sampler = functools.partial(random_ops.random_uniform, seed=seed[0]) + random_tensor = uniform_sampler(shape=[batch_size, num_heads, from_seq_len, to_seq_len], dtype=i_dtype) dropout_mask = math_ops.greater_equal(random_tensor, dropout_p) else: dropout_mask = 0