Skip to content

Commit

Permalink
[Graph] Fuse conv bn add activation (#2420)
Browse files Browse the repository at this point in the history
Co-authored-by: Li, Yifan <yifan4.li@intel.com>
  • Loading branch information
jianyizh and LIONEFAN authored Oct 16, 2023
1 parent 87e2b66 commit 5baaa03
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 7 deletions.
137 changes: 137 additions & 0 deletions itex/core/graph/remapper/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ struct ContractionWithBatchNormAndActivation {
float epsilon = 0.0;
};

struct ContractionWithBatchNormAndAddV2AndActivation {
ContractionWithBatchNormAndAddV2AndActivation() = default;
ContractionWithBatchNormAndAddV2AndActivation(int contraction,
int fused_batch_norm, int add,
int activation, int add_port,
float epsilon = 0.0)
: contraction(contraction),
fused_batch_norm(fused_batch_norm),
add(add),
activation(activation),
add_port(add_port),
epsilon(epsilon) {}

int contraction = kMissingIndex;
int fused_batch_norm = kMissingIndex;
int add = kMissingIndex;
int activation = kMissingIndex;
int add_port = 0;
float epsilon = 0.0;
};

struct ContractionWithBiasAndActivationAdd {
ContractionWithBiasAndActivationAdd() = default;
ContractionWithBiasAndActivationAdd(int contraction, int bias_add,
Expand Down Expand Up @@ -1747,6 +1768,60 @@ bool FindConv2DWithBatchNormAndActivation(
return true;
}

bool FindConv2DWithBatchNormAndAddV2AndActivation(
const RemapperContext& ctx, int node_index,
ContractionWithBatchNormAndAddV2AndActivation* matched) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
if (HasControlFaninOrFanout(*node_view)) return false;

// Root of the pattern must be an activation node.
const auto* node_def = node_view->node();
if (!IsSupportedActivation(*node_def)) return false;

// OneDnn activation op only supports float, float16 and bfloat16 data types
// on GPU.
if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16) &&
!HasDataType(node_def, DT_HALF))
return false;

// And input to activation must match ContractionWithBiasAddAndAdd pattern.
if (node_view->NumRegularFanins() < 1) return false;
const auto* add_node_view = node_view->GetRegularFanin(0).node_view();
const auto* add_node_def = add_node_view->node();

if (!IsAddV2(*add_node_def)) return false;
auto* batch_norm_node_view = add_node_view->GetRegularFanin(0).node_view();

ContractionWithBatchNorm base;
if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(),
&base)) {
batch_norm_node_view = add_node_view->GetRegularFanin(1).node_view();
if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(),
&base))
return false;
else
matched->add_port = 1;
} else {
matched->add_port = 0;
}

const auto* fused_batch_norm_node_view =
ctx.graph_view.GetNode(base.fused_batch_norm);
const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node();
if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) ||
!HaveSameDataType(node_def, fused_batch_norm_node_def) ||
IsInPreserveSet(ctx, fused_batch_norm_node_def))
return false;

// We successfully found a Conv2D+FusedBatchNorm+AddV2+Activation pattern.
matched->contraction = base.contraction;
matched->fused_batch_norm = base.fused_batch_norm;
matched->activation = node_index;
matched->add = add_node_view->node_index();
matched->epsilon = base.epsilon;
return true;
}

bool FindContractionWithBiasAndActivationInPort(
const RemapperContext& ctx, const utils::MutableNodeView& add_node_view,
const NodeDef& add_node_def, int port_id) {
Expand Down Expand Up @@ -4764,6 +4839,56 @@ Status AddFusedConv2DNode(RemapperContext* ctx,
return Status::OK();
}

Status AddFusedConv2DNode(
RemapperContext* ctx,
const ContractionWithBatchNormAndAddV2AndActivation& matched,
std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& contraction = graph->node(matched.contraction);

ITEX_DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";

const NodeDef& activation = graph->node(matched.activation);
const NodeDef& add = graph->node(matched.add);
const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
ITEX_VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op()
<< ": activation=" << activation.name() << " add=" << add.name()
<< " batch_norm=" << fused_batch_norm.name()
<< " conv2d=" << contraction.name();

NodeDef fused_conv2d;
fused_conv2d.set_name(activation.name());
fused_conv2d.set_op(kFusedConv2D);
fused_conv2d.set_device(contraction.device());
fused_conv2d.add_input(contraction.input(0)); // 0: input
fused_conv2d.add_input(contraction.input(1)); // 1: filter
fused_conv2d.add_input(add.input(1 - matched.add_port)); // 1: AddV2
fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance

CopyAllAttrs(contraction, &fused_conv2d);
SetFusedOpAttributesWithActivation(&fused_conv2d, &activation,
{"FusedBatchNorm", "Add"}, 1);
auto* attr = fused_conv2d.mutable_attr();
SetAttrValue(matched.epsilon, &(*attr)["epsilon"]);
SetAttrValue(4, &(*attr)["num_bn_args"]);

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_conv2d), &status);
TF_ABORT_IF_ERROR(status);
TF_ABORT_IF_ERROR(mutation->Apply());

(*invalidated_nodes)[matched.activation] = true;
(*nodes_to_delete)[matched.contraction] = true;
(*nodes_to_delete)[matched.fused_batch_norm] = true;
(*nodes_to_delete)[matched.add] = true;

return Status::OK();
}

// Contraction + Mul(scale).
// TODO(itex): Try to combine this function with Conv + BiasAdd
Status AddFusedContractionNode(RemapperContext* ctx,
Expand Down Expand Up @@ -6968,6 +7093,18 @@ Status RunRemapper(OptimizerContext* opt_ctx, const GrapplerItem& item,
// it for MatMul as well, but in practice this pattern does not appear in
// real Tensorflow graphs.

// Remap Conv2D+FusedBatchNorm+AddV2+Activation into the _FusedConv2D;
ContractionWithBatchNormAndAddV2AndActivation
contract_with_batch_norm_and_addv2_and_activation;
if (!is_layout_opt &&
FindConv2DWithBatchNormAndAddV2AndActivation(
ctx, i, &contract_with_batch_norm_and_addv2_and_activation)) {
TF_RETURN_IF_ERROR(AddFusedConv2DNode(
&ctx, contract_with_batch_norm_and_addv2_and_activation,
&invalidated_nodes, &nodes_to_delete));
continue;
}

// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
ContractionWithBatchNormAndActivation
contract_with_batch_norm_and_activation;
Expand Down
4 changes: 4 additions & 0 deletions itex/core/graph/utils/layout_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ void CopyAttrsForTensorArray(const utils::MutableNodeView* orig_node_view,

bool IsUnchangingVariable(const utils::MutableNodeView* node_view) {
const NodeDef* node_def = node_view->node();
if (IsCast(*node_def) &&
IsReadVariableOp(*(node_view->GetRegularFanin(0).node_view()->node())) &&
GetOptimizerConfigFlags().enable_optimize_aggressive)
return true;

if (!GetOptimizerConfigFlags().enable_optimize_aggressive ||
!IsReadVariableOp(*node_def))
Expand Down
17 changes: 14 additions & 3 deletions itex/core/kernels/common/conv_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1149,10 +1149,11 @@ class ConvOpBase : public OpKernel {
protected:
std::vector<int64_t> explicit_paddings_;
bool is_conv2d_;
const int kSrcIndex_ = 0, kFilterIndex_ = 1, kBiasIndex_ = 2, kAddIndex_ = 3;
const int kSrcIndex_ = 0, kFilterIndex_ = 1, kBiasIndex_ = 2;
int kAddIndex_ = 3;
// Input indices for FusedBatchNorm
const int kInputIndex_BN_Scale_ = 2, kInputIndex_BN_Offset_ = 3;
const int kInputIndex_BN_Mean_ = 4, kInputIndex_BN_Variance_ = 5;
int kInputIndex_BN_Scale_ = 2, kInputIndex_BN_Offset_ = 3;
int kInputIndex_BN_Mean_ = 4, kInputIndex_BN_Variance_ = 5;

const int kDstIndex_ = 0;
PostOpUtil post_op_util_;
Expand Down Expand Up @@ -1310,6 +1311,10 @@ class FusedConvOp : public ConvOpBase<Device, Tinput, Tfilter, Tbias, Toutput,
OP_REQUIRES(
context, this->post_op_util_.AddOps(fused_ops),
errors::InvalidArgument("Found unsupported fusion in Fused Conv2D."));
if (!this->post_op_util_.HasBias())
this->kAddIndex_ = 2;
else
this->kAddIndex_ = 3;
if (this->post_op_util_.HasBN()) {
float epsilon;
int num_bn_args;
Expand All @@ -1321,6 +1326,12 @@ class FusedConvOp : public ConvOpBase<Device, Tinput, Tfilter, Tbias, Toutput,
"Fused Conv2D with batchnorm must have 4 extra argument"));
this->post_op_util_.set_epsilon(epsilon);
this->bn_epsilon_ = epsilon;
if (this->post_op_util_.HasAdd()) {
this->kInputIndex_BN_Scale_ = 3;
this->kInputIndex_BN_Offset_ = 4;
this->kInputIndex_BN_Mean_ = 5;
this->kInputIndex_BN_Variance_ = 6;
}
}
// Set alpha if get `LeakyRelu` after adding ops.
if (this->post_op_util_.HasLeakyRelu()) {
Expand Down
23 changes: 19 additions & 4 deletions test/tensorflow/python/grappler/remapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,8 @@ def _build_fused_conv2d_batchnorm_activation(self,
strides = [1, 1, 1, 1],
dilations = [1, 1, 1, 1],
data_format='NHWC',
activation=None):
activation=None,
has_add=False):
os.environ['ITEX_LAYOUT_OPT'] = '0'
is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()

Expand Down Expand Up @@ -651,6 +652,8 @@ def _build_fused_conv2d_batchnorm_activation(self,
padding=padding,
strides=strides,
dilations=dilations)

add = tf.ones(conv_out.shape,dtype=conv_out.dtype)

bn_sizes = weight_sizes[3]
bn_scale = [0.2] * bn_sizes
Expand All @@ -661,6 +664,8 @@ def _build_fused_conv2d_batchnorm_activation(self,
out, _, _ = _batch_norm(conv_out, mean = bn_mean,
var = bn_var, offset=bn_offset,
scale=bn_scale, data_format=data_format)
if has_add:
out = out + add

if activation == 'GeluExact':
out = Activation_op_dict[activation](out, approximate=False)
Expand All @@ -669,10 +674,11 @@ def _build_fused_conv2d_batchnorm_activation(self,
out = array_ops.identity(out)

tol = 1e-5 if precision == 'float32' else 1e-2
expect_fused_ops = ['FusedBatchNorm']
if has_add:
expect_fused_ops.append('Add')
if activation:
expect_fused_ops = ['FusedBatchNorm', activation]
else:
expect_fused_ops = ['FusedBatchNorm']
expect_fused_ops.append(activation)
self._verify_value(out, 'FusedConv2D', expect_fused_ops, tol, tol)

@test_util.run_deprecated_v1
Expand Down Expand Up @@ -771,6 +777,15 @@ def test_conv2d_batchnorm_geluexact_fusion(self):
input_sizes=[1, 3, 6, 1],
weight_sizes=[2, 2, 1, 1],
activation='GeluExact')

@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_conv2d_batchnorm_add_fusion(self):
self._build_fused_conv2d_batchnorm_activation(
input_sizes=[1, 3, 6, 1],
weight_sizes=[2, 2, 1, 1],
activation='Relu',
has_add=True)

if __name__ == '__main__':
test.main()

0 comments on commit 5baaa03

Please sign in to comment.