Skip to content

Commit

Permalink
[Fix][GPU] fix xetla kernel (#2366)
Browse files Browse the repository at this point in the history
  • Loading branch information
yitingw1 authored Sep 6, 2023
1 parent 4673c67 commit 17a3754
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 31 deletions.
8 changes: 8 additions & 0 deletions itex/core/graph/config_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,12 @@ void itex_set_config(const ConfigProto& config) { Configs() = config; }

ConfigProto itex_get_config() { return Configs(); }

bool isxehpc_value;
ConfigProto itex_get_isxehpc() {
ConfigProto isxehpc_proto;
GraphOptions* isxehpc_graph = isxehpc_proto.mutable_graph_options();
isxehpc_graph->set_device_isxehpc(isxehpc_value);
return isxehpc_proto;
}

} // namespace itex
2 changes: 2 additions & 0 deletions itex/core/graph/config_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.
namespace itex {
void itex_set_config(const ConfigProto& config);
ConfigProto itex_get_config();
extern bool isxehpc_value;
ConfigProto itex_get_isxehpc();
} // namespace itex

#endif // ITEX_CORE_GRAPH_CONFIG_UTIL_H_
6 changes: 6 additions & 0 deletions itex/core/graph/optimizer_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ void HelperSetEnvOptimzerConfig(std::string new_name, std::string old_name,

} // namespace

extern bool itex::isxehpc_value;

void SetOptimizerConfigFlags(OptimizerConfigFlags* opt_config_flags) {
bool sharding_flag;
bool onednn_graph_flag;
Expand Down Expand Up @@ -149,6 +151,10 @@ void SetOptimizerConfigFlags(OptimizerConfigFlags* opt_config_flags) {
ITEX_CHECK_OK(itex::ReadBoolFromEnvVar(
"_ITEX_TEST_MODE", enable_itex_test_mode, &test_mode_flag));

#ifndef INTEL_CPU_ONLY
itex::isxehpc_value = IsXeHPC();
#endif

#undef USER_IS_ON
#undef USER_IS_OFF
#undef USER_IS_SET
Expand Down
5 changes: 3 additions & 2 deletions itex/core/kernels/gpu/xetla/fmha_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ class fmha_forward_t {
using tile_mask = tile_mask_t<matAccSij_t>;

uint32_t sg_startT = startT + ctx.sg_idx * kSgBc;
uint32_t remainT = (args.uT < sg_startT) ? 0 : (args.uT - sg_startT);
uint32_t remainT =
std::max(static_cast<int>(args.uT) - static_cast<int>(sg_startT), 0);
if (remainT < kSgBc) {
tile_mask::padding_mask(matAccSij, remainT);
}
Expand Down Expand Up @@ -582,7 +583,7 @@ class fmha_forward_t {
matAccOi_t matAccOi(0);

uint32_t startF = ei.get_group(1) * kBr;
uint32_t endF = (startF + kBr) > args.uF ? args.uF : (startF + kBr);
uint32_t endF = std::min(startF + kBr, args.uF);

// iterate through the keys
for (uint32_t startT = 0; startT < args.uT; startT += kBc) {
Expand Down
9 changes: 5 additions & 4 deletions itex/core/kernels/gpu/xetla/fmha_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef ITEX_CORE_KERNELS_GPU_XETLA_FMHA_UTILS_H_
#define ITEX_CORE_KERNELS_GPU_XETLA_FMHA_UTILS_H_

#include <algorithm>
#include <xetla.hpp>

namespace gpu::xetla {
Expand Down Expand Up @@ -96,8 +97,8 @@ struct tile_mask_t {
for (int i = 0; i < tile_size_y / block_size_y; i++) {
#pragma unroll
for (int j = 0; j < num_block_x; j++) {
uint32_t start_x = j * block_size_x;
uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x);
int start_x = j * block_size_x;
int num_keep_blk = std::max(static_cast<int>(num_keep) - start_x, 0);

if (num_keep_blk < block_size_x) {
xetla_mask<block_size_x> mask =
Expand All @@ -122,8 +123,8 @@ struct tile_mask_t {
constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
#pragma unroll
for (int j = 0; j < num_block_x; j++) {
uint32_t start_x = j * block_size_x;
uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x);
int start_x = j * block_size_x;
int num_keep_blk = std::max(static_cast<int>(num_keep) - start_x, 0);

if (num_keep_blk < block_size_x) {
xetla_mask<block_size_x> mask =
Expand Down
3 changes: 2 additions & 1 deletion itex/core/kernels/gpu/xetla/non_flash_sdp/mha_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ class mha_forward_t {
using tile_mask = tile_mask_t<matAcc_S_t>;

uint32_t sg_startT = ctx.sg_idx * kSgTm;
uint32_t remainT = (args.uT < sg_startT) ? 0 : (args.uT - sg_startT);
uint32_t remainT =
std::max(static_cast<int>(args.uT) - static_cast<int>(sg_startT), 0);
if (remainT < kSgTm) {
tile_mask::padding_mask(matAcc_S, remainT);
}
Expand Down
9 changes: 5 additions & 4 deletions itex/core/kernels/gpu/xetla/non_flash_sdp/mha_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef ITEX_CORE_KERNELS_GPU_XETLA_NON_FLASH_SDP_MHA_UTIL_H_
#define ITEX_CORE_KERNELS_GPU_XETLA_NON_FLASH_SDP_MHA_UTIL_H_

#include <algorithm>
#include <xetla.hpp>

namespace gpu::xetla {
Expand Down Expand Up @@ -96,8 +97,8 @@ struct tile_mask_t {
for (int i = 0; i < tile_size_y / block_size_y; i++) {
#pragma unroll
for (int j = 0; j < num_block_x; j++) {
uint32_t start_x = j * block_size_x;
uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x);
int start_x = j * block_size_x;
int num_keep_blk = std::max(static_cast<int>(num_keep) - start_x, 0);

if (num_keep_blk < block_size_x) {
xetla_mask<block_size_x> mask =
Expand All @@ -122,8 +123,8 @@ struct tile_mask_t {
constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
#pragma unroll
for (int j = 0; j < num_block_x; j++) {
uint32_t start_x = j * block_size_x;
uint32_t num_keep_blk = (start_x > num_keep) ? 0 : (num_keep - start_x);
int start_x = j * block_size_x;
int num_keep_blk = std::max(static_cast<int>(num_keep) - start_x, 0);

if (num_keep_blk < block_size_x) {
xetla_mask<block_size_x> mask =
Expand Down
2 changes: 2 additions & 0 deletions itex/core/utils/protobuf/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ message GraphOptions {
// Shard single device graph on multiple devices.
Toggle sharding = 6;
ShardingConfig sharding_config = 7;
// Get device isXeHPC()
bool device_isxehpc = 8;
}

message ConfigProto {
Expand Down
6 changes: 5 additions & 1 deletion itex/python/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
from __future__ import print_function

from intel_extension_for_tensorflow.python._pywrap_itex import *
from intel_extension_for_tensorflow.core.utils.protobuf import config_pb2

def get_backend():
return ITEX_GetBackend()

def is_xehpc():
return ITEX_IsXeHPC()
isxehpc = ITEX_IsXeHPC()
isxehpc_proto = config_pb2.ConfigProto()
isxehpc_proto.ParseFromString(isxehpc)
return isxehpc_proto.graph_options.device_isxehpc
18 changes: 5 additions & 13 deletions itex/python/itex_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,11 @@ static py::bytes ITEX_GetConfig() {
return py::bytes(config_str);
}

static bool ITEX_IsXeHPC() {
// TODO(itex): __LIBSYCL_MINOR_VERSION == 1 is to limit compiler version as
// there is bug for __LIBSYCL_MINOR_VERSION == 2 remove this once the bug is
// fixed
#ifndef INTEL_CPU_ONLY
#if __LIBSYCL_MINOR_VERSION == 1
return IsXeHPC(nullptr);
#else
return false;
#endif
#else
return false;
#endif
static py::bytes ITEX_IsXeHPC() {
std::string config_str;
ConfigProto config_proto = itex_get_isxehpc();
config_proto.SerializeToString(&config_str);
return py::bytes(config_str);
}

PYBIND11_MODULE(_pywrap_itex, m) {
Expand Down
35 changes: 29 additions & 6 deletions itex/python/ops/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,25 @@ def _stateless_dropout(input_tensor, dropout_prob, seed):
output = tf.nn.experimental.stateless_dropout(input_tensor, rate=dropout_prob, seed=seed)
return output

def _dropout(input_tensor, dropout_prob, seed):
"""Perform dropout.
Args:
input_tensor: float Tensor.
dropout_prob: Python float. The probability of dropping out a value (NOT of
*keeping* a dimension as in `tf.nn.dropout`).
Returns:
A version of `input_tensor` with dropout applied.
"""
if dropout_prob is None or dropout_prob == 0.0:
return input_tensor

def scaled_dot_product_attention(query, key, value, atten_mask=None, dropout_p=0.0, seed=(2,3), is_causal=False, use_fast_attention=True):
output = tf.nn.dropout(input_tensor, rate=dropout_prob, seed=seed)
return output


def scaled_dot_product_attention(query, key, value, atten_mask=None, dropout_p=0.0, seed=(2,3), is_causal=False, use_fast_attention=True, use_stateless_randomuniform=True):
"""Applies Dot-product attention with query, key, value tensors.
Args:
Expand Down Expand Up @@ -88,7 +105,11 @@ def sdp():
atten_probs = tf.nn.softmax(atten_scores, -1)

if dropout_p != 0.0:
atten_probs = _stateless_dropout(atten_probs, dropout_p, seed)
if use_stateless_randomuniform:
atten_probs = _stateless_dropout(atten_probs, dropout_p, seed)
else:
atten_probs = _dropout(atten_probs, dropout_p, seed[0])


# `atten_output` =[B, N, F, H]
atten_output = tf.matmul(atten_probs, value)
Expand All @@ -98,7 +119,7 @@ def sdp():
return output

def fast_sdp():
batch_size = query.shape[0]
batch_size = tf.shape(query)[0]
num_heads = query.shape[1]
from_seq_len = query.shape[2]
head_size = query.shape[3]
Expand All @@ -108,9 +129,11 @@ def fast_sdp():
use_dropout = (dropout_p != 0.0)
use_mask = (atten_mask is not None)
if use_dropout:
uniform_sampler = functools.partial(stateless_random_ops.stateless_random_uniform, seed=seed)
uniform_sampler_input = tf.compat.v1.placeholder(i_dtype, shape=[batch_size, num_heads, from_seq_len, to_seq_len])
random_tensor = uniform_sampler(shape=tf.shape(uniform_sampler_input), dtype=i_dtype)
if use_stateless_randomuniform:
uniform_sampler = functools.partial(stateless_random_ops.stateless_random_uniform, seed=seed)
else:
uniform_sampler = functools.partial(random_ops.random_uniform, seed=seed[0])
random_tensor = uniform_sampler(shape=[batch_size, num_heads, from_seq_len, to_seq_len], dtype=i_dtype)
dropout_mask = math_ops.greater_equal(random_tensor, dropout_p)
else:
dropout_mask = 0
Expand Down

0 comments on commit 17a3754

Please sign in to comment.