From 5baaa03c6b2588a762b7e29fadd6c2330ac4f34c Mon Sep 17 00:00:00 2001 From: jianyizh Date: Mon, 16 Oct 2023 13:47:50 +0800 Subject: [PATCH] [Graph] Fuse conv bn add activation (#2420) Co-authored-by: Li, Yifan --- itex/core/graph/remapper/remapper.cc | 137 ++++++++++++++++++ itex/core/graph/utils/layout_utils.cc | 4 + itex/core/kernels/common/conv_ops.h | 17 ++- .../python/grappler/remapper_test.py | 23 ++- 4 files changed, 174 insertions(+), 7 deletions(-) diff --git a/itex/core/graph/remapper/remapper.cc b/itex/core/graph/remapper/remapper.cc index c1f833dd9..f02000f63 100644 --- a/itex/core/graph/remapper/remapper.cc +++ b/itex/core/graph/remapper/remapper.cc @@ -330,6 +330,27 @@ struct ContractionWithBatchNormAndActivation { float epsilon = 0.0; }; +struct ContractionWithBatchNormAndAddV2AndActivation { + ContractionWithBatchNormAndAddV2AndActivation() = default; + ContractionWithBatchNormAndAddV2AndActivation(int contraction, + int fused_batch_norm, int add, + int activation, int add_port, + float epsilon = 0.0) + : contraction(contraction), + fused_batch_norm(fused_batch_norm), + add(add), + activation(activation), + add_port(add_port), + epsilon(epsilon) {} + + int contraction = kMissingIndex; + int fused_batch_norm = kMissingIndex; + int add = kMissingIndex; + int activation = kMissingIndex; + int add_port = 0; + float epsilon = 0.0; +}; + struct ContractionWithBiasAndActivationAdd { ContractionWithBiasAndActivationAdd() = default; ContractionWithBiasAndActivationAdd(int contraction, int bias_add, @@ -1747,6 +1768,60 @@ bool FindConv2DWithBatchNormAndActivation( return true; } +bool FindConv2DWithBatchNormAndAddV2AndActivation( + const RemapperContext& ctx, int node_index, + ContractionWithBatchNormAndAddV2AndActivation* matched) { + const auto* node_view = ctx.graph_view.GetNode(node_index); + if (HasControlFaninOrFanout(*node_view)) return false; + + // Root of the pattern must be an activation node. + const auto* node_def = node_view->node(); + if (!IsSupportedActivation(*node_def)) return false; + + // OneDnn activation op only supports float, float16 and bfloat16 data types + // on GPU. + if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16) && + !HasDataType(node_def, DT_HALF)) + return false; + + // And input to activation must match ContractionWithBiasAddAndAdd pattern. + if (node_view->NumRegularFanins() < 1) return false; + const auto* add_node_view = node_view->GetRegularFanin(0).node_view(); + const auto* add_node_def = add_node_view->node(); + + if (!IsAddV2(*add_node_def)) return false; + auto* batch_norm_node_view = add_node_view->GetRegularFanin(0).node_view(); + + ContractionWithBatchNorm base; + if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), + &base)) { + batch_norm_node_view = add_node_view->GetRegularFanin(1).node_view(); + if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), + &base)) + return false; + else + matched->add_port = 1; + } else { + matched->add_port = 0; + } + + const auto* fused_batch_norm_node_view = + ctx.graph_view.GetNode(base.fused_batch_norm); + const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node(); + if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) || + !HaveSameDataType(node_def, fused_batch_norm_node_def) || + IsInPreserveSet(ctx, fused_batch_norm_node_def)) + return false; + + // We successfully found a Conv2D+FusedBatchNorm+AddV2+Activation pattern. + matched->contraction = base.contraction; + matched->fused_batch_norm = base.fused_batch_norm; + matched->activation = node_index; + matched->add = add_node_view->node_index(); + matched->epsilon = base.epsilon; + return true; +} + bool FindContractionWithBiasAndActivationInPort( const RemapperContext& ctx, const utils::MutableNodeView& add_node_view, const NodeDef& add_node_def, int port_id) { @@ -4764,6 +4839,56 @@ Status AddFusedConv2DNode(RemapperContext* ctx, return Status::OK(); } +Status AddFusedConv2DNode( + RemapperContext* ctx, + const ContractionWithBatchNormAndAddV2AndActivation& matched, + std::vector* invalidated_nodes, std::vector* nodes_to_delete) { + const GraphDef* graph = ctx->graph_view.graph(); + const NodeDef& contraction = graph->node(matched.contraction); + + ITEX_DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now"; + + const NodeDef& activation = graph->node(matched.activation); + const NodeDef& add = graph->node(matched.add); + const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm); + ITEX_VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op() + << ": activation=" << activation.name() << " add=" << add.name() + << " batch_norm=" << fused_batch_norm.name() + << " conv2d=" << contraction.name(); + + NodeDef fused_conv2d; + fused_conv2d.set_name(activation.name()); + fused_conv2d.set_op(kFusedConv2D); + fused_conv2d.set_device(contraction.device()); + fused_conv2d.add_input(contraction.input(0)); // 0: input + fused_conv2d.add_input(contraction.input(1)); // 1: filter + fused_conv2d.add_input(add.input(1 - matched.add_port)); // 1: AddV2 + fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale + fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset + fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean + fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance + + CopyAllAttrs(contraction, &fused_conv2d); + SetFusedOpAttributesWithActivation(&fused_conv2d, &activation, + {"FusedBatchNorm", "Add"}, 1); + auto* attr = fused_conv2d.mutable_attr(); + SetAttrValue(matched.epsilon, &(*attr)["epsilon"]); + SetAttrValue(4, &(*attr)["num_bn_args"]); + + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + mutation->AddNode(std::move(fused_conv2d), &status); + TF_ABORT_IF_ERROR(status); + TF_ABORT_IF_ERROR(mutation->Apply()); + + (*invalidated_nodes)[matched.activation] = true; + (*nodes_to_delete)[matched.contraction] = true; + (*nodes_to_delete)[matched.fused_batch_norm] = true; + (*nodes_to_delete)[matched.add] = true; + + return Status::OK(); +} + // Contraction + Mul(scale). // TODO(itex): Try to combine this function with Conv + BiasAdd Status AddFusedContractionNode(RemapperContext* ctx, @@ -6968,6 +7093,18 @@ Status RunRemapper(OptimizerContext* opt_ctx, const GrapplerItem& item, // it for MatMul as well, but in practice this pattern does not appear in // real Tensorflow graphs. + // Remap Conv2D+FusedBatchNorm+AddV2+Activation into the _FusedConv2D; + ContractionWithBatchNormAndAddV2AndActivation + contract_with_batch_norm_and_addv2_and_activation; + if (!is_layout_opt && + FindConv2DWithBatchNormAndAddV2AndActivation( + ctx, i, &contract_with_batch_norm_and_addv2_and_activation)) { + TF_RETURN_IF_ERROR(AddFusedConv2DNode( + &ctx, contract_with_batch_norm_and_addv2_and_activation, + &invalidated_nodes, &nodes_to_delete)); + continue; + } + // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D; ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation; diff --git a/itex/core/graph/utils/layout_utils.cc b/itex/core/graph/utils/layout_utils.cc index e62c1f7ea..00b47ae23 100644 --- a/itex/core/graph/utils/layout_utils.cc +++ b/itex/core/graph/utils/layout_utils.cc @@ -464,6 +464,10 @@ void CopyAttrsForTensorArray(const utils::MutableNodeView* orig_node_view, bool IsUnchangingVariable(const utils::MutableNodeView* node_view) { const NodeDef* node_def = node_view->node(); + if (IsCast(*node_def) && + IsReadVariableOp(*(node_view->GetRegularFanin(0).node_view()->node())) && + GetOptimizerConfigFlags().enable_optimize_aggressive) + return true; if (!GetOptimizerConfigFlags().enable_optimize_aggressive || !IsReadVariableOp(*node_def)) diff --git a/itex/core/kernels/common/conv_ops.h b/itex/core/kernels/common/conv_ops.h index 73691aa2e..79cd455f1 100644 --- a/itex/core/kernels/common/conv_ops.h +++ b/itex/core/kernels/common/conv_ops.h @@ -1149,10 +1149,11 @@ class ConvOpBase : public OpKernel { protected: std::vector explicit_paddings_; bool is_conv2d_; - const int kSrcIndex_ = 0, kFilterIndex_ = 1, kBiasIndex_ = 2, kAddIndex_ = 3; + const int kSrcIndex_ = 0, kFilterIndex_ = 1, kBiasIndex_ = 2; + int kAddIndex_ = 3; // Input indices for FusedBatchNorm - const int kInputIndex_BN_Scale_ = 2, kInputIndex_BN_Offset_ = 3; - const int kInputIndex_BN_Mean_ = 4, kInputIndex_BN_Variance_ = 5; + int kInputIndex_BN_Scale_ = 2, kInputIndex_BN_Offset_ = 3; + int kInputIndex_BN_Mean_ = 4, kInputIndex_BN_Variance_ = 5; const int kDstIndex_ = 0; PostOpUtil post_op_util_; @@ -1310,6 +1311,10 @@ class FusedConvOp : public ConvOpBasepost_op_util_.AddOps(fused_ops), errors::InvalidArgument("Found unsupported fusion in Fused Conv2D.")); + if (!this->post_op_util_.HasBias()) + this->kAddIndex_ = 2; + else + this->kAddIndex_ = 3; if (this->post_op_util_.HasBN()) { float epsilon; int num_bn_args; @@ -1321,6 +1326,12 @@ class FusedConvOp : public ConvOpBasepost_op_util_.set_epsilon(epsilon); this->bn_epsilon_ = epsilon; + if (this->post_op_util_.HasAdd()) { + this->kInputIndex_BN_Scale_ = 3; + this->kInputIndex_BN_Offset_ = 4; + this->kInputIndex_BN_Mean_ = 5; + this->kInputIndex_BN_Variance_ = 6; + } } // Set alpha if get `LeakyRelu` after adding ops. if (this->post_op_util_.HasLeakyRelu()) { diff --git a/test/tensorflow/python/grappler/remapper_test.py b/test/tensorflow/python/grappler/remapper_test.py index 2cd53fdd8..3cd0aeac1 100644 --- a/test/tensorflow/python/grappler/remapper_test.py +++ b/test/tensorflow/python/grappler/remapper_test.py @@ -618,7 +618,8 @@ def _build_fused_conv2d_batchnorm_activation(self, strides = [1, 1, 1, 1], dilations = [1, 1, 1, 1], data_format='NHWC', - activation=None): + activation=None, + has_add=False): os.environ['ITEX_LAYOUT_OPT'] = '0' is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU() @@ -651,6 +652,8 @@ def _build_fused_conv2d_batchnorm_activation(self, padding=padding, strides=strides, dilations=dilations) + + add = tf.ones(conv_out.shape,dtype=conv_out.dtype) bn_sizes = weight_sizes[3] bn_scale = [0.2] * bn_sizes @@ -661,6 +664,8 @@ def _build_fused_conv2d_batchnorm_activation(self, out, _, _ = _batch_norm(conv_out, mean = bn_mean, var = bn_var, offset=bn_offset, scale=bn_scale, data_format=data_format) + if has_add: + out = out + add if activation == 'GeluExact': out = Activation_op_dict[activation](out, approximate=False) @@ -669,10 +674,11 @@ def _build_fused_conv2d_batchnorm_activation(self, out = array_ops.identity(out) tol = 1e-5 if precision == 'float32' else 1e-2 + expect_fused_ops = ['FusedBatchNorm'] + if has_add: + expect_fused_ops.append('Add') if activation: - expect_fused_ops = ['FusedBatchNorm', activation] - else: - expect_fused_ops = ['FusedBatchNorm'] + expect_fused_ops.append(activation) self._verify_value(out, 'FusedConv2D', expect_fused_ops, tol, tol) @test_util.run_deprecated_v1 @@ -771,6 +777,15 @@ def test_conv2d_batchnorm_geluexact_fusion(self): input_sizes=[1, 3, 6, 1], weight_sizes=[2, 2, 1, 1], activation='GeluExact') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_add_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 6, 1], + weight_sizes=[2, 2, 1, 1], + activation='Relu', + has_add=True) if __name__ == '__main__': test.main()