Skip to content

Commit

Permalink
Add IR matcher for blocks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605101409
  • Loading branch information
grebe authored and copybara-github committed Feb 7, 2024
1 parent f81ce0a commit adb92be
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 19 deletions.
1 change: 0 additions & 1 deletion xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,6 @@ cc_library(
":type",
":value",
"//xls/common/logging",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
Expand Down
22 changes: 22 additions & 0 deletions xls/ir/ir_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,28 @@ void FunctionMatcher::DescribeNegationTo(std::ostream* os) const {
*os << absl::StreamFormat("FunctionBase was not a function%s.", name_str);
}

void BlockMatcher::DescribeTo(::std::ostream* os) const {
std::stringstream ss;
std::optional<std::string> name_str;
if (name_.has_value()) {
name_->DescribeTo(&ss);
name_str = ss.str();
}
*os << absl::StreamFormat("block %s { ... }",
name_str.value_or("<unspecified>"));
}

void BlockMatcher::DescribeNegationTo(std::ostream* os) const {
std::string name_str;
if (name_.has_value()) {
std::stringstream ss;
ss << " named ";
name_->DescribeTo(&ss);
name_str = ss.str();
}
*os << absl::StreamFormat("FunctionBase was not a block%s.", name_str);
}

bool MinDelayMatcher::MatchAndExplain(
const Node* node, ::testing::MatchResultListener* listener) const {
if (!NodeMatcher::MatchAndExplain(node, listener)) {
Expand Down
91 changes: 78 additions & 13 deletions xls/ir/ir_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/meta/type_traits.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -1253,10 +1252,11 @@ class FunctionMatcher {
std::optional<::testing::Matcher<const std::string>> name)
: name_(std::move(name)) {}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
template <typename T>
bool MatchAndExplain(const T* fb,
::testing::MatchResultListener* listener) const {
::testing::MatchResultListener* listener) const
requires(std::is_convertible_v<T*, ::xls::FunctionBase*>)
{
if (fb == nullptr) {
return false;
}
Expand All @@ -1273,10 +1273,11 @@ class FunctionMatcher {
return true;
}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
template <typename T>
bool MatchAndExplain(const std::unique_ptr<T>& fb,
::testing::MatchResultListener* listener) const {
::testing::MatchResultListener* listener) const
requires(std::is_convertible_v<T*, ::xls::FunctionBase*>)
{
return MatchAndExplain(fb.get(), listener);
}

Expand Down Expand Up @@ -1313,10 +1314,11 @@ class ProcMatcher {
std::optional<::testing::Matcher<const std::string>> name)
: name_(std::move(name)) {}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
template <typename T>
bool MatchAndExplain(const T* fb,
::testing::MatchResultListener* listener) const {
::testing::MatchResultListener* listener) const
requires(std::is_convertible_v<T*, ::xls::FunctionBase*>)
{
if (fb == nullptr) {
return false;
}
Expand All @@ -1333,10 +1335,11 @@ class ProcMatcher {
return true;
}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
template <typename T>
bool MatchAndExplain(const std::unique_ptr<T>& fb,
::testing::MatchResultListener* listener) const {
::testing::MatchResultListener* listener) const
requires(std::is_convertible_v<T*, ::xls::FunctionBase*>)
{
return MatchAndExplain(fb.get(), listener);
}

Expand All @@ -1359,6 +1362,68 @@ inline ::testing::PolymorphicMatcher<ProcMatcher> Proc(
::xls::op_matchers::ProcMatcher(std::move(name)));
}

// Matcher for blocks. Supported forms:
//
// m::Block();
// m::Block(/*name=*/"foo");
// m::Block(/*name=*/HasSubstr("substr"));
//
class BlockMatcher {
public:
using is_gtest_matcher = void;

explicit BlockMatcher(
std::optional<::testing::Matcher<const std::string>> name)
: name_(std::move(name)) {}

template <typename T>
bool MatchAndExplain(const T* fb,
::testing::MatchResultListener* listener) const
requires(std::is_convertible_v<T*, ::xls::FunctionBase*>)
{
if (fb == nullptr) {
return false;
}
*listener << fb->name();
if (!fb->IsBlock()) {
*listener << " is not a block.";
return false;
}
// Now, match on FunctionBase.
if (!FunctionBase(name_).MatchAndExplain(fb, listener)) {
return false;
}

return true;
}

template <typename T>
bool MatchAndExplain(const std::unique_ptr<T>& fb,
::testing::MatchResultListener* listener) const
requires(std::is_convertible_v<T*, ::xls::FunctionBase*>)
{
return MatchAndExplain(fb.get(), listener);
}

void DescribeTo(::std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;

protected:
std::optional<::testing::Matcher<const std::string>> name_;
};

inline ::testing::PolymorphicMatcher<BlockMatcher> Block(
std::optional<std::string> name = std::nullopt) {
return testing::MakePolymorphicMatcher(
::xls::op_matchers::BlockMatcher(std::move(name)));
}

inline ::testing::PolymorphicMatcher<BlockMatcher> Block(
::testing::Matcher<const std::string> name) {
return testing::MakePolymorphicMatcher(
::xls::op_matchers::BlockMatcher(std::move(name)));
}

// Matcher for instances. Supported forms:
//
// m::Instantiation()
Expand Down
30 changes: 25 additions & 5 deletions xls/ir/ir_matcher_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,32 +535,52 @@ TEST(IrMatchersTest, FunctionBaseMatcher) {
BValue send_token = pb.Send(ch1, rcv_token, f_of_data);
XLS_ASSERT_OK(pb.Build(send_token, {}).status());

BlockBuilder bb("test_block", &p);
BValue a = bb.InputPort("x", p.GetBitsType(32));
BValue b = bb.InputPort("y", p.GetBitsType(32));
bb.OutputPort("out", bb.Add(a, b));
XLS_ASSERT_OK(bb.Build());

// Match FunctionBases.
EXPECT_THAT(
p.GetFunctionBases(),
UnorderedElementsAre(m::FunctionBase("f"), m::FunctionBase("test_proc")));
UnorderedElementsAre(m::FunctionBase("f"), m::FunctionBase("test_proc"),
m::FunctionBase("test_block")));
EXPECT_THAT(p.GetFunctionBases(),
UnorderedElementsAre(m::FunctionBase(HasSubstr("f")),
m::FunctionBase(HasSubstr("test"))));
m::FunctionBase(HasSubstr("test_pr")),
m::FunctionBase(HasSubstr("test_b"))));
EXPECT_THAT(p.GetFunctionBases(),
::testing::Not(Contains(m::FunctionBase("foobar"))));

// Match Function and Proc.
// Match Function, Proc and Block.
EXPECT_THAT(p.GetFunctionBases(),
UnorderedElementsAre(m::Function("f"), m::Proc("test_proc")));
UnorderedElementsAre(m::Function("f"), m::Proc("test_proc"),
m::Block("test_block")));
EXPECT_THAT(p.GetFunctionBases(),
UnorderedElementsAre(m::Function(HasSubstr("f")),
m::Proc(HasSubstr("test"))));
m::Proc(HasSubstr("test_p")),
m::Block(HasSubstr("test_b"))));
EXPECT_THAT(p.GetFunctionBases(),
::testing::Not(Contains(m::Function("test_proc"))));
EXPECT_THAT(p.GetFunctionBases(),
Not(Contains(m::Function(HasSubstr("proc")))));
EXPECT_THAT(p.GetFunctionBases(),
Not(Contains(m::Function(HasSubstr("block")))));
EXPECT_THAT(p.GetFunctionBases(), ::testing::Not(Contains(m::Proc("f"))));
EXPECT_THAT(p.GetFunctionBases(),
::testing::Not(Contains(m::Proc(HasSubstr("f")))));
EXPECT_THAT(p.GetFunctionBases(),
::testing::Not(Contains(m::Proc(HasSubstr("block")))));
EXPECT_THAT(p.GetFunctionBases(), ::testing::Not(Contains(m::Block("f"))));
EXPECT_THAT(p.GetFunctionBases(),
::testing::Not(Contains(m::Block(HasSubstr("f")))));
EXPECT_THAT(p.GetFunctionBases(),
::testing::Not(Contains(m::Block(HasSubstr("proc")))));

EXPECT_THAT(p.procs(), UnorderedElementsAre(m::Proc("test_proc")));
EXPECT_THAT(p.functions(), UnorderedElementsAre(m::Function("f")));
EXPECT_THAT(p.blocks(), UnorderedElementsAre(m::Block("test_block")));
}

TEST(IrMatchersTest, MinDelayMatcher) {
Expand Down

0 comments on commit adb92be

Please sign in to comment.