Skip to content

Commit

Permalink
[CPU] Register Enter/Exit/NextIteration on CPU Device (#2398)
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng authored Sep 18, 2023
1 parent 6a86c74 commit c4eb5fb
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 2 deletions.
8 changes: 8 additions & 0 deletions itex/core/kernels/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ filegroup(
visibility = ["//visibility:public"],
)

filegroup(
name = "control_flow_hdrs",
srcs = [
"control_flow_ops.h",
],
visibility = ["//visibility:public"],
)

filegroup(
name = "conv_hdrs",
srcs = [
Expand Down
65 changes: 65 additions & 0 deletions itex/core/kernels/common/control_flow_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ITEX_CORE_KERNELS_COMMON_CONTROL_FLOW_OPS_H_
#define ITEX_CORE_KERNELS_COMMON_CONTROL_FLOW_OPS_H_

#include "itex/core/utils/op_kernel.h"
#include "itex/core/utils/op_requires.h"
#include "itex/core/utils/register_types.h"

namespace itex {

// An enter op has one input and one output. It creates or finds
// the child frame that is uniquely identified by the frame_name,
// and makes its input available to the child frame.
class EnterOp : public OpKernel {
public:
explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override;
bool IsExpensive() { return false; }
~EnterOp() override {}

TF_DISALLOW_COPY_AND_ASSIGN(EnterOp);
};

// An exit op has one input and one output. It exits the current
// frame to its parent frame, and makes its input available to the
// parent frame.
class ExitOp : public OpKernel {
public:
explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override;
bool IsExpensive() { return false; }
~ExitOp() override {}

TF_DISALLOW_COPY_AND_ASSIGN(ExitOp);
};

// A next_iteration op has one input and one output. It makes its input
// available to the next iteration.
class NextIterationOp : public OpKernel {
public:
explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override;
bool IsExpensive() { return false; }
~NextIterationOp() override {}

TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp);
};

} // namespace itex

#endif // ITEX_CORE_KERNELS_COMMON_CONTROL_FLOW_OPS_H_
19 changes: 19 additions & 0 deletions itex/core/kernels/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ package(
licenses = ["notice"], # Apache 2.0
)

itex_xpu_library(
name = "control_flow_ops",
srcs = [
"control_flow_ops.cc",
],
hdrs = [
"//itex/core/kernels/common:control_flow_hdrs",
],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//itex:core",
"//itex/core/devices:xpu_device_util",
],
alwayslink = True,
)

itex_xpu_library(
name = "conv_ops",
srcs = [
Expand Down Expand Up @@ -403,6 +421,7 @@ CPU_KERNELS = [
":aggregate_ops",
":binary_op",
":batch_matmul_op",
":control_flow_ops",
":conv_ops",
":dequantize_op",
":einsum_op",
Expand Down
60 changes: 60 additions & 0 deletions itex/core/kernels/cpu/control_flow_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "itex/core/kernels/common/control_flow_ops.h"

namespace itex {

void EnterOp::Compute(OpKernelContext* context) {
context->set_output(0, context->input(0));
}

#define REGISTER_CPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Enter").Device(DEVICE_CPU).TypeConstraint<type>("T"), EnterOp)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);

#undef REGISTER_CPU_KERNEL

void ExitOp::Compute(OpKernelContext* context) {
context->set_output(0, context->input(0));
}

#define REGISTER_CPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Exit").Device(DEVICE_CPU).TypeConstraint<type>("T"), EnterOp)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);

#undef REGISTER_CPU_KERNEL

void NextIterationOp::Compute(OpKernelContext* context) {
context->set_output(0, context->input(0));
}

#define REGISTER_CPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("NextIteration").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
EnterOp)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);

#undef REGISTER_CPU_KERNEL

} // namespace itex
4 changes: 2 additions & 2 deletions itex/core/ops/op_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ void Register_ITEXTensorArrayConcat();
void Register_ITEXTensorArraySplit();
void Register_ITEXTensorArraySize();
void Register_ITEXTensorArrayClose();
void Register_LayerNormOp();
void Register_ITEXGroupNormOp();
void Register_ITEXRMSNormOp();
void Register_LayerNormGradOp();
void Register_ITEXRnnOp();
void Register_ITEXRnnGradOp();
void Register_LayerNormOp();
void Register_LayerNormGradOp();
void Register_OneDnnGraphOp();

// Native kernels
Expand Down

0 comments on commit c4eb5fb

Please sign in to comment.