From a873845f2e8ce83b77182433dbcb1aa5c95f9a45 Mon Sep 17 00:00:00 2001 From: jianyizh Date: Mon, 14 Aug 2023 09:41:55 +0800 Subject: [PATCH] [graph] conditionally add is_filter_const for readvariable (#2359) --- itex/core/graph/utils/layout_utils.cc | 47 ++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/itex/core/graph/utils/layout_utils.cc b/itex/core/graph/utils/layout_utils.cc index 8b49ee6f3..3330cf973 100644 --- a/itex/core/graph/utils/layout_utils.cc +++ b/itex/core/graph/utils/layout_utils.cc @@ -466,11 +466,50 @@ void CopyAttrsAllCheckConstFilter(const utils::MutableNodeView* orig_node_view, bool is_filter_const = true; for (int index = 0; index < checklist.size(); index++) { - const NodeDef* filter_node = - orig_node_view->GetRegularFanin(checklist[index]).node_view()->node(); + const auto* filter_node_view = + orig_node_view->GetRegularFanin(checklist[index]).node_view(); + const NodeDef* filter_node = filter_node_view->node(); if (!IsConstant(*filter_node)) { - is_filter_const = false; - break; + if (GetOptimizerConfigFlags().enable_optimize_aggressive && + IsReadVariableOp(*filter_node)) { + bool freeze = false; + auto* arg_node_view = filter_node_view->GetRegularFanin(0).node_view(); + auto* arg_node_def = arg_node_view->node(); + if (IsEnter(*arg_node_def)) { + arg_node_view = arg_node_view->GetRegularFanin(0).node_view(); + arg_node_def = arg_node_view->node(); + } + if (IsArg(*arg_node_def)) { + if (arg_node_view->NumRegularFanouts() == 1) { + freeze = true; + } else { + // read variable inside while loop + // _Arg + // / \ + // Enter ReadVariable + // | + // ReadVariable + bool is_legal_arg = true; + for (const auto& fanout_i : arg_node_view->GetRegularFanouts()) { + for (const auto fanout : fanout_i) { + if (!IsEnter(*(fanout.node_view()->node())) && + !IsReadVariableOp(*(fanout.node_view()->node()))) { + is_legal_arg = false; + } + } + } + freeze = is_legal_arg; + } + } + + if (!freeze) { + is_filter_const = false; + break; + } + } else { + is_filter_const = false; + break; + } } }