Skip to content

Commit

Permalink
Extended the select lifting pass to handle select nodes with no cases.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699255258
  • Loading branch information
scampanoni authored and copybara-github committed Nov 22, 2024
1 parent a9e6172 commit e7373a8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
22 changes: 18 additions & 4 deletions xls/passes/select_lifting_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,16 @@ absl::StatusOr<std::optional<LiftableSelectOperandInfo>> CanLiftSelect(
// Only "select" nodes with specific properties can be optimized by this
// transformation.
//
// Shared property that must hold for all cases:
// Only "select" nodes with the same node type for all its inputs can
// be optimized.
// Shared properties that must hold for all cases:
//
// Property A:
// Only "select" nodes with at least one input case can be optimized.
//
// Property B:
// Only "select" nodes with the same node type for all its inputs can
// be optimized.
//
//
//
// There are more properties that must hold for the transformation to be
// applicable. Such properties are specific to the node type of the inputs of
Expand All @@ -229,7 +236,14 @@ absl::StatusOr<std::optional<LiftableSelectOperandInfo>> CanLiftSelect(
absl::Span<Node *const> select_cases = GetCases(select_to_optimize);
std::optional<Node *> default_value = GetDefaultValue(select_to_optimize);

// Check the shared property
// Check the shared property A
if (select_cases.empty()) {
VLOG(3) << " The transformation is not applicable: the select does not "
"have input cases";
return std::nullopt;
}

// Check the shared property B
std::optional<Op> shared_input_op =
SharedOperation(select_cases, default_value);
if (!shared_input_op) {
Expand Down
28 changes: 28 additions & 0 deletions xls/passes/select_lifting_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "xls/passes/select_lifting_pass.h"

#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -214,6 +216,32 @@ TEST_F(SelectLiftingPassTest, LiftSingleSelectWithIndicesOfDifferentBitwidth) {
EXPECT_EQ(f->node_count(), 9);
}

TEST_F(SelectLiftingPassTest, LiftSingleSelectWithNoCases) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

// Fetch the types
Type* u32_type = p->GetBitsType(32);

// Create the parameters of the IR function
BValue a = fb.Param("array", p->GetArrayType(16, u32_type));
BValue c = fb.Param("condition", u32_type);
BValue i = fb.Param("first_index", u32_type);

// Create the body of the IR function
BValue condition_constant = fb.Literal(UBits(10, 32));
BValue selector = fb.AddCompareOp(Op::kUGt, c, condition_constant);
BValue array_index_i = fb.ArrayIndex(a, {i});
std::vector<BValue> cases;
BValue select_node = fb.Select(selector, cases, array_index_i);

// Build the function
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(select_node));

// Set the expected outputs
EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false));
}

} // namespace

} // namespace xls

0 comments on commit e7373a8

Please sign in to comment.