Skip to content

Commit

Permalink
[graph] conditionally add is_filter_const for readvariable (#2359)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyizh authored Aug 14, 2023
1 parent af30556 commit a873845
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions itex/core/graph/utils/layout_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,50 @@ void CopyAttrsAllCheckConstFilter(const utils::MutableNodeView* orig_node_view,

bool is_filter_const = true;
for (int index = 0; index < checklist.size(); index++) {
const NodeDef* filter_node =
orig_node_view->GetRegularFanin(checklist[index]).node_view()->node();
const auto* filter_node_view =
orig_node_view->GetRegularFanin(checklist[index]).node_view();
const NodeDef* filter_node = filter_node_view->node();
if (!IsConstant(*filter_node)) {
is_filter_const = false;
break;
if (GetOptimizerConfigFlags().enable_optimize_aggressive &&
IsReadVariableOp(*filter_node)) {
bool freeze = false;
auto* arg_node_view = filter_node_view->GetRegularFanin(0).node_view();
auto* arg_node_def = arg_node_view->node();
if (IsEnter(*arg_node_def)) {
arg_node_view = arg_node_view->GetRegularFanin(0).node_view();
arg_node_def = arg_node_view->node();
}
if (IsArg(*arg_node_def)) {
if (arg_node_view->NumRegularFanouts() == 1) {
freeze = true;
} else {
// read variable inside while loop
// _Arg
// / \
// Enter ReadVariable
// |
// ReadVariable
bool is_legal_arg = true;
for (const auto& fanout_i : arg_node_view->GetRegularFanouts()) {
for (const auto fanout : fanout_i) {
if (!IsEnter(*(fanout.node_view()->node())) &&
!IsReadVariableOp(*(fanout.node_view()->node()))) {
is_legal_arg = false;
}
}
}
freeze = is_legal_arg;
}
}

if (!freeze) {
is_filter_const = false;
break;
}
} else {
is_filter_const = false;
break;
}
}
}

Expand Down

0 comments on commit a873845

Please sign in to comment.