Skip to content

Commit

Permalink
Fix crash when using unroll_for! without index/acc type annotation.
Browse files Browse the repository at this point in the history
#1702

Fixes #1702

PiperOrigin-RevId: 694559263
  • Loading branch information
richmckeever authored and copybara-github committed Nov 8, 2024
1 parent 0af7633 commit 777cdd6
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 22 deletions.
17 changes: 17 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,23 @@ fn test() -> u32 {
ExpectIr(converted, TestName());
}

TEST(IrConverterTest, UnrollForWithoutIndexAccTypeAnnotation) {
const char* kProgram = R"(
proc SomeProc {
init { () }
config() { }
next(state: ()) {
unroll_for! (i, a) in u32:0..u32:4 {
a + i
}(u32:0);
}
})";
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertModuleForTest(kProgram, ConvertOptions{.emit_positions = false}));
ExpectIr(converted, TestName());
}

TEST(IrConverterTest, UnrollForNested) {
const char* kProgram = R"(
fn test() -> u32 {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package test_module

file_number 0 "test_module.x"

proc __test_module__SomeProc_0_next(__state: (), init={()}) {
a: bits[32] = literal(value=0, id=4)
literal.5: bits[32] = literal(value=0, id=5)
a__1: bits[32] = add(a, literal.5, id=6)
literal.7: bits[32] = literal(value=1, id=7)
a__2: bits[32] = add(a__1, literal.7, id=8)
literal.9: bits[32] = literal(value=2, id=9)
a__3: bits[32] = add(a__2, literal.9, id=10)
literal.11: bits[32] = literal(value=3, id=11)
__token: token = literal(value=token, id=1)
literal.3: bits[1] = literal(value=1, id=3)
add.12: bits[32] = add(a__3, literal.11, id=12)
tuple.13: () = tuple(id=13)
next (tuple.13)
}
65 changes: 43 additions & 22 deletions xls/dslx/type_system/deduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,24 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceLet(const Let* node,
return Type::MakeUnit();
}

// The types that need to be deduced for `for`-like loops (including
// `unroll_for!`).
struct ForLoopTypes {
// The type of the container the loop iterates through.
std::unique_ptr<Type> iterable_type;

// The element type of the container indicated by `iterable_type`.
std::unique_ptr<Type> iterable_element_type;

// The type of the loop accumulator (which is the same type as the
// init parameter "passed in" to the loop after its body).
std::unique_ptr<Type> accumulator_type;
};

// Deduces and type-checks the init and iterable expressions of a loop,
// returning the init type.
absl::StatusOr<std::unique_ptr<Type>> DeduceLoopInitAndIterable(
const ForLoopBase* node, DeduceCtx* ctx) {
absl::StatusOr<ForLoopTypes> DeduceForLoopTypes(const ForLoopBase* node,
DeduceCtx* ctx) {
// Type of the init value to the for loop (also the accumulator type).
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> init_type,
ctx->DeduceAndResolve(node->init()));
Expand All @@ -453,10 +467,11 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceLoopInitAndIterable(
node->iterable()->span(), iterable_type.get(),
"For loop iterable value is not an array.", ctx->file_table());
}
const Type& iterable_element_type = iterable_array_type->element_type();
std::unique_ptr<Type> iterable_element_type =
iterable_array_type->element_type().CloneToUnique();

std::vector<std::unique_ptr<Type>> target_annotated_type_elems;
target_annotated_type_elems.push_back(iterable_element_type.CloneToUnique());
target_annotated_type_elems.push_back(iterable_element_type->CloneToUnique());
target_annotated_type_elems.push_back(init_type->CloneToUnique());
auto target_annotated_type =
std::make_unique<TupleType>(std::move(target_annotated_type_elems));
Expand Down Expand Up @@ -492,7 +507,7 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceLoopInitAndIterable(
"and a type for the accumulator; got %d types.",
annotated_tuple_members.size()));
}
if (iterable_element_type != *annotated_tuple_members[0]) {
if (*iterable_element_type != *annotated_tuple_members[0]) {
return ctx->TypeMismatchError(
node->span(), node->type_annotation(), *annotated_type, nullptr,
*target_annotated_type,
Expand Down Expand Up @@ -520,7 +535,9 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceLoopInitAndIterable(
XLS_RETURN_IF_ERROR(
BindNames(bindings, *target_annotated_type, ctx, std::nullopt));

return init_type;
return ForLoopTypes{.iterable_type = std::move(iterable_type),
.iterable_element_type = std::move(iterable_element_type),
.accumulator_type = std::move(init_type)};
}

// Type-checks the body of a loop, whose type should match that of the init
Expand All @@ -543,33 +560,35 @@ absl::Status TypecheckLoopBody(const ForLoopBase* node, const Expr* actual_body,
absl::StatusOr<std::unique_ptr<Type>> DeduceFor(const For* node,
DeduceCtx* ctx) {
VLOG(5) << "DeduceFor: " << node->ToString();
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> init_type,
DeduceLoopInitAndIterable(node, ctx));
XLS_RETURN_IF_ERROR(TypecheckLoopBody(node, node->body(), *init_type, ctx));
return init_type;
XLS_ASSIGN_OR_RETURN(ForLoopTypes loop_types, DeduceForLoopTypes(node, ctx));
XLS_RETURN_IF_ERROR(
TypecheckLoopBody(node, node->body(), *loop_types.accumulator_type, ctx));
return std::move(loop_types.accumulator_type);
}

absl::StatusOr<std::unique_ptr<Type>> DeduceUnrollFor(const UnrollFor* node,
DeduceCtx* ctx) {
VLOG(5) << "DeduceUnrollFor: " << node->ToString();

XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> init_type,
DeduceLoopInitAndIterable(node, ctx));
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> iterable_type,
ctx->DeduceAndResolve(node->iterable()));
absl::StatusOr<InterpValue> iterable =
EvaluateConstexprValue(ctx, node->iterable(), iterable_type.get());
XLS_ASSIGN_OR_RETURN(ForLoopTypes loop_types, DeduceForLoopTypes(node, ctx));
absl::StatusOr<InterpValue> iterable = EvaluateConstexprValue(
ctx, node->iterable(), loop_types.iterable_type.get());
if (!iterable.ok() || !iterable->HasValues()) {
return absl::InvalidArgumentError(absl::StrCat(
"unroll_for! must use a constexpr iterable expression at: ",
node->iterable()->span().ToString(ctx->file_table())));
}
const auto* types =
dynamic_cast<const TupleTypeAnnotation*>(node->type_annotation());
CHECK(types);
CHECK_EQ(types->members().size(), 2);
TypeAnnotation* index_type_annot = types->members()[0];
TypeAnnotation* acc_type_annot = types->members()[1];
TypeAnnotation* index_type_annot = nullptr;
TypeAnnotation* acc_type_annot = nullptr;
if (types) {
// Deducing the `ForLoopTypes` should have errored gracefully if this was
// not the case.
CHECK_EQ(types->members().size(), 2);
index_type_annot = types->members()[0];
acc_type_annot = types->members()[1];
}
CHECK_EQ(node->names()->nodes().size(), 2);
const NameDefTree& index_name = *node->names()->nodes()[0];
std::optional<NameDef*> index_def;
Expand Down Expand Up @@ -616,6 +635,7 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceUnrollFor(const UnrollFor* node,
Number* index = node->owner()->Make<Number>(
node->iterable()->span(), element.ToString(/*humanize=*/true),
NumberKind::kOther, index_type_annot);
ctx->type_info()->SetItem(index, *loop_types.iterable_element_type);
ctx->type_info()->NoteConstExpr(index, element);
index_replacer = NameRefReplacer(*index_def, index);
}
Expand All @@ -639,8 +659,9 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceUnrollFor(const UnrollFor* node,
unrolled->SetParentNonLexical(node->parent());
ctx->type_info()->NoteUnrolledLoop(node, ctx->GetCurrentParametricEnv(),
unrolled);
XLS_RETURN_IF_ERROR(TypecheckLoopBody(node, unrolled, *init_type, ctx));
return init_type;
XLS_RETURN_IF_ERROR(
TypecheckLoopBody(node, unrolled, *loop_types.accumulator_type, ctx));
return std::move(loop_types.accumulator_type);
}

// Returns true if the cast-conversion from "from" to "to" is acceptable (i.e.
Expand Down
13 changes: 13 additions & 0 deletions xls/dslx/type_system/typecheck_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,19 @@ fn f(x: u32) -> (u32, u8) {
"for the accumulator.")));
}

TEST(TypecheckTest, UnrollForWithoutIndexAccTypeAnnotation) {
XLS_EXPECT_OK(Typecheck(R"(
proc SomeProc {
init { () }
config() { }
next(state: ()) {
unroll_for! (i, a) in u32:0..u32:4 {
a
}(u32:0);
}
})"));
}

TEST(TypecheckTest, UnrollForWithWrongResultType) {
EXPECT_THAT(Typecheck(R"(
fn f(x: u32) -> (u32, u8) {
Expand Down

0 comments on commit 777cdd6

Please sign in to comment.