Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hoist_storage not handling condition correctly. #8123

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/StorageFlattening.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,23 +293,37 @@ class FlattenDimensions : public IRMutator {
stmt = LetStmt::make(op->name + ".buffer", builder.build(), stmt);
if (hoisted_storages_map.count(op->name) > 0) {
HoistedStorageData &hoisted_storage_data = hoisted_storages[hoisted_storages_map[op->name]];
vector<Expr> bounded_extents;
for (const auto &e : allocation_extents) {
Expr expanded_extent = e;

auto expand_and_bound = [&](Expr e) {
// Iterate from innermost outwards
for (auto it = hoisted_storages.rbegin(); it != hoisted_storages.rend(); it++) {
expanded_extent = expand_expr(expanded_extent, it->scope);
e = expand_expr(e, it->scope);
if (it->name == op->name) {
break;
}
}
expanded_extent = simplify(common_subexpression_elimination(expanded_extent));
Interval bounds = bounds_of_expr_in_scope(expanded_extent, hoisted_storage_data.loop_vars);
user_assert(bounds.max.defined()) << "Couldn't infer the upper bound for the storage size of " << op->name << ", consider using bound_storage.\n";
bounded_extents.push_back(bounds.max);

e = simplify(common_subexpression_elimination(e));
Interval bounds = bounds_of_expr_in_scope(e, hoisted_storage_data.loop_vars);
return bounds.max;
};

vector<Expr> bounded_extents;
for (const auto &e : allocation_extents) {
Expr expanded_extent = expand_and_bound(e);
user_assert(expanded_extent.defined() &&
!expanded_extent.same_as(Interval::pos_inf()))
<< "Couldn't infer the upper bound for the storage size of " << op->name << ", consider using bound_storage.\n";
bounded_extents.push_back(expanded_extent);
}

Expr expanded_condition = expand_and_bound(condition);
if (!expanded_condition.defined() ||
expanded_condition.same_as(Interval::pos_inf())) {
expanded_condition = const_true();
}

HoistedAllocationInfo hoisted_alloc(op->name, op->types[0], op->memory_type, bounded_extents, condition);
HoistedAllocationInfo hoisted_alloc(op->name, op->types[0], op->memory_type, bounded_extents, expanded_condition);

hoisted_storage_data.hoisted_allocations.push_back(hoisted_alloc);
} else {
Expand Down
26 changes: 25 additions & 1 deletion test/correctness/skip_stages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void check_counts(int a = 0, int b = 0, int c = 0, int d = 0) {
}

int main(int argc, char **argv) {
Var x;
Var x, y;
Param<bool> toggle1, toggle2;

{
Expand Down Expand Up @@ -201,6 +201,30 @@ int main(int argc, char **argv) {
check_counts(11);
}

{
// Check the interation with storage hoisting

// This Func may or may not be loaded, depending on y
Func maybe_loaded("maybe_loaded");
maybe_loaded(x, y) = x + y;

// This Func may or may not be used, depending on y
Func maybe_used("maybe_used");
maybe_used(x, y) = maybe_loaded(x, y);

Func output("output");
output(x, y) = select(y % 100 == 37, 0, maybe_used(x, y));

// The allocation condition depends on y, but the actual allocation
// happens at the root level.
maybe_loaded.compute_at(output, y).hoist_storage_root();
maybe_used.compute_at(output, y).hoist_storage_root();

// This will fail to compile with an undefined symbol if we haven't
// handled the condition correctly.
output.realize({100, 100});
}

printf("Success!\n");
return 0;
}
Loading