Skip to content

Commit

Permalink
[CPU] Enable CPU sdp forward kernel (#2412)
Browse files Browse the repository at this point in the history
  • Loading branch information
yisonzhu authored Oct 16, 2023
1 parent 5baaa03 commit a993d4c
Show file tree
Hide file tree
Showing 14 changed files with 657 additions and 19 deletions.
33 changes: 33 additions & 0 deletions itex/core/kernels/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ itex_xpu_library(
alwayslink = True,
)

itex_xpu_library(
name = "mha_op",
srcs = ["mha_op.cc"],
hdrs = [
"mha_op.h",
],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":cpu_blas",
"//itex:core",
],
alwayslink = True,
)

itex_xpu_library(
name = "matmul_op",
srcs = ["matmul_op.cc"],
Expand Down Expand Up @@ -417,6 +433,21 @@ itex_xpu_library(
alwayslink = True,
)

itex_xpu_library(
name = "cpu_blas",
srcs = ["cpu_blas.cc"],
hdrs = [
"cpu_blas.h",
],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//itex:core",
],
alwayslink = True,
)

CPU_KERNELS = [
":aggregate_ops",
":binary_op",
Expand All @@ -427,6 +458,7 @@ CPU_KERNELS = [
":einsum_op",
":fused_batch_norm_op",
":fused_binary_op",
":mha_op",
":fused_random_op",
":gru_ops",
":instance_norm_ops",
Expand All @@ -444,6 +476,7 @@ CPU_KERNELS = [
":slice_op",
":softmax_op",
":transpose_op",
":cpu_blas",
]

itex_xpu_library(
Expand Down
49 changes: 49 additions & 0 deletions itex/core/kernels/cpu/cpu_blas.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright (c) 2023 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "itex/core/kernels/cpu/cpu_blas.h"

#include "itex/core/utils/onednn/onednn_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace itex {
namespace cpublas {

extern "C" {
dnnl_status_t dnnl_gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha,
const dnnl_port::bfloat16_t* A,
dnnl_dim_t lda,
const dnnl_port::bfloat16_t* B,
dnnl_dim_t ldb, float beta, float* C,
dnnl_dim_t ldc);
}

void gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, float* a, int64_t lda, float* b, int64_t ldb, float beta,
float* c, int64_t ldc) {
dnnl_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

void gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, Eigen::bfloat16* a, int64_t lda, Eigen::bfloat16* b,
int64_t ldb, float beta, float* c, int64_t ldc) {
dnnl_port::bfloat16_t* dnnl_a = reinterpret_cast<dnnl_port::bfloat16_t*>(a);
dnnl_port::bfloat16_t* dnnl_b = reinterpret_cast<dnnl_port::bfloat16_t*>(b);
dnnl_gemm_bf16bf16f32(transa, transb, m, n, k, alpha, dnnl_a, lda, dnnl_b,
ldb, beta, c, ldc);
}
} // namespace cpublas
} // namespace itex
34 changes: 34 additions & 0 deletions itex/core/kernels/cpu/cpu_blas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright (c) 2023 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ITEX_CORE_KERNELS_CPU_CPU_BLAS_H_
#define ITEX_CORE_KERNELS_CPU_CPU_BLAS_H_

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace itex {
namespace cpublas {

void gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, float* a, int64_t lda, float* b, int64_t ldb, float beta,
float* c, int64_t ldc);

void gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, Eigen::bfloat16* a, int64_t lda, Eigen::bfloat16* b,
int64_t ldb, float beta, float* c, int64_t ldc);
} // namespace cpublas
} // namespace itex

#endif // ITEX_CORE_KERNELS_CPU_CPU_BLAS_H_
101 changes: 101 additions & 0 deletions itex/core/kernels/cpu/mha_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/* Copyright (c) 2023 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "itex/core/kernels/cpu/mha_op.h"

#include <string>
#include <unordered_map>
#include <vector>

#include "itex/core/utils/env_var.h"
#include "itex/core/utils/errors.h"
#include "itex/core/utils/macros.h"
#include "itex/core/utils/op_kernel.h"
#include "itex/core/utils/op_requires.h"
#include "itex/core/utils/register_types.h"
#include "itex/core/utils/tensor_shape.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace itex {

template <typename T>
class MHAOp : public OpKernel {
public:
explicit MHAOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("is_inference", &is_inference));
if (!is_inference) {
OP_REQUIRES_OK(context, context->GetAttr("use_dropout", &use_dropout));
OP_REQUIRES_OK(context, context->GetAttr("dropout_prob", &dropout_prob));
} else {
OP_REQUIRES_OK(context, context->GetAttr("use_causal", &use_causal));
}
OP_REQUIRES_OK(context, context->GetAttr("use_mask", &use_mask));
}

void Compute(OpKernelContext* context) override {
const Tensor& query = context->input(0);
const Tensor& key = context->input(1);
const Tensor& value = context->input(2);
Tensor atten_mask;
if (use_mask) atten_mask = context->input(3);
Tensor dropout_mask;
if (!is_inference) dropout_mask = context->input(4);

int64_t batch_size = query.dim_size(0);
int64_t num_heads = query.dim_size(1);
int64_t q_seq_len = query.dim_size(2);
int64_t head_size = query.dim_size(3);
int64_t k_seq_len = key.dim_size(2);

Tensor* output = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(
0, {batch_size, q_seq_len, num_heads, head_size}, &output));

#define CALL_FMHA_FUNC(T, qSplitSize, kvSplitSize) \
FmhaFunctor<T, qSplitSize, kvSplitSize>()( \
query, key, value, batch_size, q_seq_len, num_heads, head_size, \
k_seq_len, use_mask, use_causal, use_dropout, atten_mask, dropout_mask, \
dropout_prob, output)

if (q_seq_len >= 768) {
CALL_FMHA_FUNC(T, 256, 512);
} else if (q_seq_len >= 192) {
CALL_FMHA_FUNC(T, 64, 512);
} else {
CALL_FMHA_FUNC(T, 32, 512);
}
}

private:
float dropout_prob = 0;
bool use_mask = false;
bool use_causal = false;
bool use_dropout = false;
bool is_inference = false;
};

#define REGISTER_MHA_INF_CPU(type) \
REGISTER_KERNEL_BUILDER(Name("ScaledDotProductAttentionInference") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
MHAOp<type>);

REGISTER_MHA_INF_CPU(Eigen::bfloat16);
REGISTER_MHA_INF_CPU(float);
#undef REGISTER_MHA_INF_GPU

} // namespace itex
Loading

0 comments on commit a993d4c

Please sign in to comment.