Skip to content

Commit

Permalink
Fix type mismatch error messages.
Browse files Browse the repository at this point in the history
* Channel direction is now accounted for in collecting mismatch data for channel type. Array size is similarly accounted for in array type. Added a CHECK to ensure we satisfy this for implementations on all types.
* As such, channel and array types now properly report as mismatched elements within a parent type, e.g. in a return tuple, instead of "Mismatched element types: <empty>".
* Matching aggregate types within tuples/structs now format with comma delimiters, instead of `((uN[1], uN[1])(uN[1], uN[1])`, i.e. without a separator.

Fixes #1392.

PiperOrigin-RevId: 664961917
  • Loading branch information
mikex-oss authored and copybara-github committed Aug 19, 2024
1 parent 24d61a0 commit f685331
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 20 deletions.
2 changes: 2 additions & 0 deletions xls/dslx/type_system/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ cc_library(
deps = [
":type",
":zip_types",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -901,6 +902,7 @@ cc_test(
":type",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"//xls/dslx:channel_direction",
"@com_google_googletest//:gtest",
],
)
40 changes: 21 additions & 19 deletions xls/dslx/type_system/format_type_mismatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -94,6 +95,22 @@ class Callbacks : public ZipTypesCallbacks {
aggregates);
}

absl::Status NoteAggregateNext(const AggregatePair& aggregates) override {
return absl::visit(
Visitor{
[&](auto p) { return absl::OkStatus(); },
[&](std::pair<const TupleType*, const TupleType*>) {
AddMatchedBoth(", ");
return absl::OkStatus();
},
[&](std::pair<const StructType*, const StructType*> p) {
AddMatchedBoth(", ");
return absl::OkStatus();
},
},
aggregates);
}

absl::Status NoteAggregateEnd(const AggregatePair& aggregates) override {
return absl::visit(
Visitor{
Expand Down Expand Up @@ -135,7 +152,6 @@ class Callbacks : public ZipTypesCallbacks {
BeforeType(lhs, lhs_parent, rhs, rhs_parent);
AddMatched(lhs.ToString(), &colorized_lhs_);
AddMatched(rhs.ToString(), &colorized_rhs_);
AfterType(lhs, lhs_parent, rhs, rhs_parent);
return absl::OkStatus();
}

Expand Down Expand Up @@ -166,7 +182,6 @@ class Callbacks : public ZipTypesCallbacks {
mismatches_.mismatches.push_back({&lhs, &rhs});
BeforeType(lhs, lhs_parent, rhs, rhs_parent);
AddMismatched(lhs.ToString(), rhs.ToString());
AfterType(lhs, lhs_parent, rhs, rhs_parent);
return absl::OkStatus();
}

Expand All @@ -189,23 +204,6 @@ class Callbacks : public ZipTypesCallbacks {
}
}

void AfterType(const Type& lhs, const Type* lhs_parent, const Type& rhs,
const Type* rhs_parent) {
if (lhs_parent == nullptr) {
return;
}
if (auto* parent_struct = dynamic_cast<const StructType*>(lhs_parent);
parent_struct != nullptr &&
parent_struct->IndexOf(lhs).value() + 1 != parent_struct->size()) {
AddMatchedBoth(", ");
}
if (auto* parent_tuple = dynamic_cast<const TupleType*>(lhs_parent);
parent_tuple != nullptr &&
parent_tuple->IndexOf(lhs).value() + 1 != parent_tuple->size()) {
AddMatchedBoth(", ");
}
}

void AddMismatched(std::string_view lhs, std::string_view rhs) {
absl::StrAppend(&colorized_lhs_, kAnsiRed, lhs, kAnsiReset);
absl::StrAppend(&colorized_rhs_, kAnsiRed, rhs, kAnsiReset);
Expand Down Expand Up @@ -238,6 +236,10 @@ absl::StatusOr<std::string> FormatTypeMismatch(const Type& lhs,

XLS_RETURN_IF_ERROR(ZipTypes(lhs, rhs, callbacks));

CHECK(!data.mismatches.empty())
<< "type mismatch info not constructed correctly for types "
<< lhs.GetDebugTypeName() << " vs. " << rhs.GetDebugTypeName();

std::vector<std::string> lines;

if (!data.tuple_missing.empty()) {
Expand Down
99 changes: 99 additions & 0 deletions xls/dslx/type_system/format_type_mismatch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "gtest/gtest.h"
#include "xls/common/status/matchers.h"
#include "xls/dslx/channel_direction.h"
#include "xls/dslx/type_system/type.h"

namespace xls::dslx {
Expand Down Expand Up @@ -54,6 +55,27 @@ TEST(FormatTypeMismatchTest, ElementInTuple) {
);
}

TEST(FormatTypeMismatchTest, NestedTuple) {
std::unique_ptr<TupleType> t0 = TupleType::Create2(
TupleType::Create2(BitsType::MakeU8(), BitsType::MakeU1()),
TupleType::Create2(BitsType::MakeU1(), BitsType::MakeU1()));
std::unique_ptr<TupleType> t1 = TupleType::Create2(
TupleType::Create2(BitsType::MakeU1(), BitsType::MakeU1()),
TupleType::Create2(BitsType::MakeU1(), BitsType::MakeU1()));

XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*t0, *t1));

EXPECT_EQ(got,
ANSI_RESET
"Mismatched elements " ANSI_BOLD "within" ANSI_UNBOLD
" type:\n" //
" uN[8]\n" //
"vs uN[1]\n" ANSI_BOLD "Overall" ANSI_UNBOLD " type mismatch:\n" //
ANSI_RESET " ((" ANSI_RED "uN[8]" ANSI_RESET
", uN[1]), (uN[1], uN[1]))\n"
"vs ((" ANSI_RED "uN[1]" ANSI_RESET ", uN[1]), (uN[1], uN[1]))");
}

TEST(FormatTypeMismatchTest, ElementTypeInArrayInTuple) {
auto t0 = TupleType::Create2(
BitsType::MakeU1(),
Expand All @@ -76,6 +98,28 @@ TEST(FormatTypeMismatchTest, ElementTypeInArrayInTuple) {
);
}

TEST(FormatTypeMismatchTest, MismatchedArraySizeInTuple) {
auto t0 = TupleType::Create2(
BitsType::MakeU1(),
std::make_unique<ArrayType>(BitsType::MakeU32(), TypeDim::CreateU32(4)));
auto t1 = TupleType::Create2(
BitsType::MakeU1(),
std::make_unique<ArrayType>(BitsType::MakeU32(), TypeDim::CreateU32(2)));

XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*t0, *t1));

EXPECT_EQ(got,
ANSI_RESET "Mismatched elements " ANSI_BOLD "within" ANSI_UNBOLD
" type:\n" //
" uN[32][4]\n" //
"vs uN[32][2]\n" ANSI_BOLD //
"Overall" ANSI_UNBOLD " type mismatch:\n" //
ANSI_RESET " (uN[1], " ANSI_RED "uN[32][4]" ANSI_RESET
")\n" //
"vs (uN[1], " ANSI_RED "uN[32][2]" ANSI_RESET ")" //
);
}

TEST(FormatTypeMismatchTest, TotallyDifferentTuples) {
auto t0 = TupleType::Create2(BitsType::MakeU8(), BitsType::MakeU32());
auto t1 = TupleType::Create2(BitsType::MakeU1(), BitsType::MakeU64());
Expand Down Expand Up @@ -111,5 +155,60 @@ TEST(FormatTypeMismatchTest, TuplesWithSharedPrefixDifferentLength) {
"vs (uN[1], uN[8], uN[32])");
}

TEST(FormatTypeMismatchTest, ChannelTypeMismatch) {
std::unique_ptr<ChannelType> ch0 =
std::make_unique<ChannelType>(BitsType::MakeU8(), ChannelDirection::kIn);
std::unique_ptr<ChannelType> ch1 =
std::make_unique<ChannelType>(BitsType::MakeU32(), ChannelDirection::kIn);

XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*ch0, *ch1));

EXPECT_EQ(got,
"Type mismatch:\n"
" chan(uN[8], dir=in)\n"
"vs chan(uN[32], dir=in)");
}

TEST(FormatTypeMismatchTest, ChannelTypeDirectionMismatch) {
std::unique_ptr<ChannelType> ch0 =
std::make_unique<ChannelType>(BitsType::MakeU8(), ChannelDirection::kIn);
std::unique_ptr<ChannelType> ch1 =
std::make_unique<ChannelType>(BitsType::MakeU8(), ChannelDirection::kOut);

XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*ch0, *ch1));

EXPECT_EQ(got,
"Type mismatch:\n"
" chan(uN[8], dir=in)\n"
"vs chan(uN[8], dir=out)");
}

TEST(FormatTypeMismatchTest, TupleOfChannelTypesElementMismatch) {
std::unique_ptr<ChannelType> ch0 =
std::make_unique<ChannelType>(BitsType::MakeU8(), ChannelDirection::kIn);
std::unique_ptr<ChannelType> ch1 = std::make_unique<ChannelType>(
BitsType::MakeU32(), ChannelDirection::kOut);

std::unique_ptr<TupleType> t0 = TupleType::Create3(
ch0->CloneToUnique(), ch0->CloneToUnique(), ch1->CloneToUnique());
std::unique_ptr<TupleType> t1 = TupleType::Create3(
ch0->CloneToUnique(), ch1->CloneToUnique(), ch1->CloneToUnique());

XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*t0, *t1));

EXPECT_EQ(
got,
ANSI_RESET "Mismatched elements " ANSI_BOLD "within" ANSI_UNBOLD
" type:\n" //
" chan(uN[8], dir=in)\n" //
"vs chan(uN[32], dir=out)\n" ANSI_BOLD "Overall" ANSI_UNBOLD
" type mismatch:\n" //
ANSI_RESET " (chan(uN[8]), " ANSI_RED "chan(uN[8], dir=in)" ANSI_RESET
", chan(uN[32]))\n" //
"vs (chan(uN[8]), " ANSI_RED "chan(uN[32], dir=out)" ANSI_RESET
", chan(uN[32]))" //
);
}

} // namespace
} // namespace xls::dslx
11 changes: 11 additions & 0 deletions xls/dslx/type_system/zip_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class ZipTypeVisitor : public TypeVisitor {
}
absl::Status HandleArray(const ArrayType& lhs) override {
if (auto* rhs = dynamic_cast<const ArrayType*>(&rhs_)) {
if (lhs.size() != rhs->size()) {
return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_);
}
AggregatePair aggregates = std::make_pair(&lhs, rhs);
XLS_RETURN_IF_ERROR(callbacks_.NoteAggregateStart(aggregates));
const Type& lhs_elem = lhs.element_type();
Expand All @@ -88,6 +91,11 @@ class ZipTypeVisitor : public TypeVisitor {
}
absl::Status HandleChannel(const ChannelType& lhs) override {
if (auto* rhs = dynamic_cast<const ChannelType*>(&rhs_)) {
// If channel directions don't match, capture the full channel strings.
if (lhs.direction() != rhs->direction()) {
return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_);
}

AggregatePair aggregates = std::make_pair(&lhs, rhs);
XLS_RETURN_IF_ERROR(callbacks_.NoteAggregateStart(aggregates));
XLS_RETURN_IF_ERROR(
Expand Down Expand Up @@ -136,6 +144,9 @@ class ZipTypeVisitor : public TypeVisitor {
const Type& rhs_elem = rhs.GetMemberType(i);
XLS_RETURN_IF_ERROR(
ZipTypesWithParents(lhs_elem, rhs_elem, &lhs, &rhs, callbacks_));
if (i + 1 != lhs.size()) {
XLS_RETURN_IF_ERROR(callbacks_.NoteAggregateNext(aggregates));
}
}
XLS_RETURN_IF_ERROR(callbacks_.NoteAggregateEnd(aggregates));
return absl::OkStatus();
Expand Down
5 changes: 4 additions & 1 deletion xls/dslx/type_system/zip_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ class ZipTypesCallbacks {
virtual ~ZipTypesCallbacks() = default;

// These are called if the same aggregate type is present on the lhs and rhs,
// to note we're entering/leaving it.
// to note we're traversing it.
virtual absl::Status NoteAggregateStart(const AggregatePair& aggregates) = 0;
virtual absl::Status NoteAggregateNext(const AggregatePair& aggregates) {
return absl::OkStatus();
};
virtual absl::Status NoteAggregateEnd(const AggregatePair& aggregates) = 0;

// Called when there is a leaf type (non aggregate) where the types are
Expand Down

0 comments on commit f685331

Please sign in to comment.