Skip to content

Commit

Permalink
Cherry-pick some recent bug-fixes into 17.0.1 (#8107)
Browse files Browse the repository at this point in the history
* Fix rfactor adding too many pure loops (#8086)

When you rfactor an update definition, the new update definition must
use all the pure vars of the Func, even though the one you're rfactoring
may not have used them all.

We also want to preserve any scheduling already done to the pure vars,
so we want to preserve the dims list and splits list from the original
definition.

The code accounted for this by checking the dims list for any missing
pure vars and adding them at the end (just before Var::outermost()), but
this didn't account for the fact that they may no longer exist in the
dims list due to splits that didn't reuse the outer name. In these
circumstances we could end up with too many pure loops. E.g. if x has
been split into xo and xi, then the code was adding a loop for x even
though there were already loops for xo and xi, which of course produces
garbage output.

This PR instead just checks which pure vars are actually used in the
update definition up front, and then uses that to tell which ones should
be added.

Fixes #7890

* Forward the partition methods from generator outputs (#8090)

* Fix reduce_expr_modulo of vector in Solve.cpp (#8089)

* Fix reduce_expr_modulo of vector in Solve.cpp

* Fix test
  • Loading branch information
abadams authored Feb 19, 2024
1 parent 8f424e5 commit d15325e
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
26 changes: 23 additions & 3 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,17 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
vector<Expr> &args = definition.args();
vector<Expr> &values = definition.values();

// Figure out which pure vars were used in this update definition.
std::set<string> pure_vars_used;
internal_assert(args.size() == dim_vars.size());
for (size_t i = 0; i < args.size(); i++) {
if (const Internal::Variable *var = args[i].as<Variable>()) {
if (var->name == dim_vars[i].name()) {
pure_vars_used.insert(var->name);
}
}
}

// Check whether the operator is associative and determine the operator and
// its identity for each value in the definition if it is a Tuple
const auto &prover_result = prove_associativity(func_name, args, values);
Expand Down Expand Up @@ -1012,16 +1023,20 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {

// Determine the dims of the new update definition

// The new update definition needs all the pure vars of the Func, but the
// one we're rfactoring may not have used them all. Add any missing ones to
// the dims list.

// Add pure Vars from the original init definition to the dims list
// if they are not already in the list
for (const Var &v : dim_vars) {
const auto &iter = std::find_if(dims.begin(), dims.end(),
[&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
if (iter == dims.end()) {
if (!pure_vars_used.count(v.name())) {
Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto};
// Insert it just before Var::outermost
dims.insert(dims.end() - 1, d);
}
}

// Then, we need to remove lifted RVars from the dims list
for (const string &rv : rvars_removed) {
remove(rv);
Expand Down Expand Up @@ -1888,6 +1903,11 @@ Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {

dims_old.swap(dims);

// We're not allowed to reorder Var::outermost inwards (rfactor assumes it's
// the last one).
user_assert(dims.back().var == Var::outermost().name())
<< "Var::outermost() may not be reordered inside any other var.\n";

return *this;
}

Expand Down
5 changes: 5 additions & 0 deletions src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,8 @@ class GeneratorOutputBase : public GIOBase {
HALIDE_FORWARD_METHOD(Func, align_bounds)
HALIDE_FORWARD_METHOD(Func, align_extent)
HALIDE_FORWARD_METHOD(Func, align_storage)
HALIDE_FORWARD_METHOD(Func, always_partition)
HALIDE_FORWARD_METHOD(Func, always_partition_all)
HALIDE_FORWARD_METHOD_CONST(Func, args)
HALIDE_FORWARD_METHOD(Func, bound)
HALIDE_FORWARD_METHOD(Func, bound_extent)
Expand All @@ -2303,9 +2305,12 @@ class GeneratorOutputBase : public GIOBase {
HALIDE_FORWARD_METHOD(Func, hexagon)
HALIDE_FORWARD_METHOD(Func, in)
HALIDE_FORWARD_METHOD(Func, memoize)
HALIDE_FORWARD_METHOD(Func, never_partition)
HALIDE_FORWARD_METHOD(Func, never_partition_all)
HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
HALIDE_FORWARD_METHOD_CONST(Func, outputs)
HALIDE_FORWARD_METHOD(Func, parallel)
HALIDE_FORWARD_METHOD(Func, partition)
HALIDE_FORWARD_METHOD(Func, prefetch)
HALIDE_FORWARD_METHOD(Func, print_loop_nest)
HALIDE_FORWARD_METHOD(Func, rename)
Expand Down
5 changes: 4 additions & 1 deletion src/Solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class SolveExpression : public IRMutator {
if (a_uses_var && !b_uses_var) {
const int64_t *ib = as_const_int(b);
auto is_multiple_of_b = [&](const Expr &e) {
if (ib) {
if (ib && op->type.is_scalar()) {
int64_t r = 0;
return reduce_expr_modulo(e, *ib, &r) && r == 0;
} else {
Expand Down Expand Up @@ -1478,6 +1478,9 @@ void solve_test() {
check_solve(min(x + y, x - z), x + min(y, 0 - z));
check_solve(max(x + y, x - z), x + max(y, 0 - z));

check_solve((5 * Broadcast::make(x, 4) + y) / 5,
Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5));

debug(0) << "Solve test passed\n";
}

Expand Down
25 changes: 25 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,31 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/7890
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
RVar yryf;
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5, "rdom_r");
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);

Var yo, yi, xo, xi, u;
blurry.split(y, yo, yi, 2, TailStrategy::Auto);
local_sum.split(x, xo, xi, 4, TailStrategy::Auto);
local_sum.update(0).split(x, xo, xi, 1, TailStrategy::Auto);
local_sum.update(0).rfactor(r.x, u);
blurry.store_root();
local_sum.compute_root();
Pipeline p({blurry});
auto buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/8054
{
ImageParam input(Float(32), 2, "input");
Expand Down

0 comments on commit d15325e

Please sign in to comment.