Skip to content

Commit

Permalink
Bound allocation extents for hoist_storage using loop variables one-b…
Browse files Browse the repository at this point in the history
…y-one (#8154)

* Bound allocation extents using loop variable one-by-one

* Use emplace_back
  • Loading branch information
vksnk authored Mar 14, 2024
1 parent 83616f2 commit f841a27
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/StorageFlattening.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class FlattenDimensions : public IRMutator {
struct HoistedStorageData {
string name;
vector<HoistedAllocationInfo> hoisted_allocations;
Scope<Interval> loop_vars;
vector<pair<string, Interval>> loop_vars;
Scope<Expr> scope;

HoistedStorageData(const string &n)
Expand Down Expand Up @@ -304,8 +304,17 @@ class FlattenDimensions : public IRMutator {
}

e = simplify(common_subexpression_elimination(e));
Interval bounds = bounds_of_expr_in_scope(e, hoisted_storage_data.loop_vars);
return bounds.max;
// Find bounds of expression using the intervals of the loop variables. The loop variables may depend on
// the other loop variables, so we just call bounds_of_expr_in_scope for each loop variable separately
// in a reverse order.
for (auto it = hoisted_storage_data.loop_vars.rbegin(); it != hoisted_storage_data.loop_vars.rend(); ++it) {
Scope<Interval> one_loop_var;
one_loop_var.push(it->first, it->second);
Interval bounds = bounds_of_expr_in_scope(e, one_loop_var);
e = bounds.max;
}

return e;
};

vector<Expr> bounded_extents;
Expand Down Expand Up @@ -533,14 +542,14 @@ class FlattenDimensions : public IRMutator {
expanded_min = simplify(expand_expr(expanded_min, it->scope));
expanded_extent = expand_expr(expanded_extent, it->scope);
Interval loop_bounds = Interval(expanded_min, simplify(expanded_min + expanded_extent - 1));
it->loop_vars.push(op->name, loop_bounds);
it->loop_vars.emplace_back(op->name, loop_bounds);
}

ScopedValue<bool> old_in_gpu(in_gpu, in_gpu || is_gpu(op->for_type));
Stmt stmt = IRMutator::visit(op);

for (auto &p : hoisted_storages) {
p.loop_vars.pop(op->name);
p.loop_vars.pop_back();
}

return stmt;
Expand Down
14 changes: 14 additions & 0 deletions test/correctness/hoist_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,20 @@ int main(int argc, char **argv) {
});
}

{
ImageParam input(UInt(8), 2);
Var x{"x"}, y{"y"}, yo{"yo"}, yi{"yi"};
Func f[3];
f[0] = BoundaryConditions::repeat_edge(input);
f[1](x, y) = ((f[0]((x / 2) + 2, (y / 2) + 2)) + (f[0](x + 1, y)));
f[2](x, y) = ((f[1](x * 2, (y * 2) + -2)) + (f[1](x + -1, y + -1)));
f[2].split(y, yo, yi, 16);
f[0].hoist_storage(f[2], yo).compute_at(f[1], x);
f[1].hoist_storage_root().compute_at(f[2], yi);

f[2].compile_jit();
}

printf("Success!\n");
return 0;
}

0 comments on commit f841a27

Please sign in to comment.