From e7e6a3a56e99a4835934316275ee2bcc4ed5ea54 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 12 Feb 2024 10:10:00 -0800 Subject: [PATCH] Fix rfactor adding too many pure loops (#8086) When you rfactor an update definition, the new update definition must use all the pure vars of the Func, even though the one you're rfactoring may not have used them all. We also want to preserve any scheduling already done to the pure vars, so we want to preserve the dims list and splits list from the original definition. The code accounted for this by checking the dims list for any missing pure vars and adding them at the end (just before Var::outermost()), but this didn't account for the fact that they may no longer exist in the dims list due to splits that didn't reuse the outer name. In these circumstances we could end up with too many pure loops. E.g. if x has been split into xo and xi, then the code was adding a loop for x even though there were already loops for xo and xi, which of course produces garbage output. This PR instead just checks which pure vars are actually used in the update definition up front, and then uses that to tell which ones should be added. Fixes #7890 --- src/Func.cpp | 26 +++++++++++++++++++++++--- test/correctness/fuzz_schedule.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 8f46e7316531..00beef98c8ea 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -788,6 +788,17 @@ Func Stage::rfactor(vector> preserved) { vector &args = definition.args(); vector &values = definition.values(); + // Figure out which pure vars were used in this update definition. + std::set pure_vars_used; + internal_assert(args.size() == dim_vars.size()); + for (size_t i = 0; i < args.size(); i++) { + if (const Internal::Variable *var = args[i].as()) { + if (var->name == dim_vars[i].name()) { + pure_vars_used.insert(var->name); + } + } + } + // Check whether the operator is associative and determine the operator and // its identity for each value in the definition if it is a Tuple const auto &prover_result = prove_associativity(func_name, args, values); @@ -1012,16 +1023,20 @@ Func Stage::rfactor(vector> preserved) { // Determine the dims of the new update definition + // The new update definition needs all the pure vars of the Func, but the + // one we're rfactoring may not have used them all. Add any missing ones to + // the dims list. + // Add pure Vars from the original init definition to the dims list // if they are not already in the list for (const Var &v : dim_vars) { - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); }); - if (iter == dims.end()) { + if (!pure_vars_used.count(v.name())) { Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; + // Insert it just before Var::outermost dims.insert(dims.end() - 1, d); } } + // Then, we need to remove lifted RVars from the dims list for (const string &rv : rvars_removed) { remove(rv); @@ -1888,6 +1903,11 @@ Stage &Stage::reorder(const std::vector &vars) { dims_old.swap(dims); + // We're not allowed to reorder Var::outermost inwards (rfactor assumes it's + // the last one). + user_assert(dims.back().var == Var::outermost().name()) + << "Var::outermost() may not be reordered inside any other var.\n"; + return *this; } diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index a774335a07bf..78fe9e0cb757 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -202,6 +202,31 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/7890 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + RVar yryf; + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5, "rdom_r"); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + + Var yo, yi, xo, xi, u; + blurry.split(y, yo, yi, 2, TailStrategy::Auto); + local_sum.split(x, xo, xi, 4, TailStrategy::Auto); + local_sum.update(0).split(x, xo, xi, 1, TailStrategy::Auto); + local_sum.update(0).rfactor(r.x, u); + blurry.store_root(); + local_sum.compute_root(); + Pipeline p({blurry}); + auto buf = p.realize({32, 32}); + check_blur_output(buf, correct); + } + // https://github.com/halide/Halide/issues/8054 { ImageParam input(Float(32), 2, "input");