diff --git a/src/Func.cpp b/src/Func.cpp index 8f46e7316531..00beef98c8ea 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -788,6 +788,17 @@ Func Stage::rfactor(vector> preserved) { vector &args = definition.args(); vector &values = definition.values(); + // Figure out which pure vars were used in this update definition. + std::set pure_vars_used; + internal_assert(args.size() == dim_vars.size()); + for (size_t i = 0; i < args.size(); i++) { + if (const Internal::Variable *var = args[i].as()) { + if (var->name == dim_vars[i].name()) { + pure_vars_used.insert(var->name); + } + } + } + // Check whether the operator is associative and determine the operator and // its identity for each value in the definition if it is a Tuple const auto &prover_result = prove_associativity(func_name, args, values); @@ -1012,16 +1023,20 @@ Func Stage::rfactor(vector> preserved) { // Determine the dims of the new update definition + // The new update definition needs all the pure vars of the Func, but the + // one we're rfactoring may not have used them all. Add any missing ones to + // the dims list. + // Add pure Vars from the original init definition to the dims list // if they are not already in the list for (const Var &v : dim_vars) { - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); }); - if (iter == dims.end()) { + if (!pure_vars_used.count(v.name())) { Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; + // Insert it just before Var::outermost dims.insert(dims.end() - 1, d); } } + // Then, we need to remove lifted RVars from the dims list for (const string &rv : rvars_removed) { remove(rv); @@ -1888,6 +1903,11 @@ Stage &Stage::reorder(const std::vector &vars) { dims_old.swap(dims); + // We're not allowed to reorder Var::outermost inwards (rfactor assumes it's + // the last one). + user_assert(dims.back().var == Var::outermost().name()) + << "Var::outermost() may not be reordered inside any other var.\n"; + return *this; } diff --git a/src/Generator.h b/src/Generator.h index 4d00a0fec574..5f95c586be99 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -2280,6 +2280,8 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, align_bounds) HALIDE_FORWARD_METHOD(Func, align_extent) HALIDE_FORWARD_METHOD(Func, align_storage) + HALIDE_FORWARD_METHOD(Func, always_partition) + HALIDE_FORWARD_METHOD(Func, always_partition_all) HALIDE_FORWARD_METHOD_CONST(Func, args) HALIDE_FORWARD_METHOD(Func, bound) HALIDE_FORWARD_METHOD(Func, bound_extent) @@ -2303,9 +2305,12 @@ class GeneratorOutputBase : public GIOBase { HALIDE_FORWARD_METHOD(Func, hexagon) HALIDE_FORWARD_METHOD(Func, in) HALIDE_FORWARD_METHOD(Func, memoize) + HALIDE_FORWARD_METHOD(Func, never_partition) + HALIDE_FORWARD_METHOD(Func, never_partition_all) HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions) HALIDE_FORWARD_METHOD_CONST(Func, outputs) HALIDE_FORWARD_METHOD(Func, parallel) + HALIDE_FORWARD_METHOD(Func, partition) HALIDE_FORWARD_METHOD(Func, prefetch) HALIDE_FORWARD_METHOD(Func, print_loop_nest) HALIDE_FORWARD_METHOD(Func, rename) diff --git a/src/Solve.cpp b/src/Solve.cpp index 22bd14e44412..b25719cff8c7 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -394,7 +394,7 @@ class SolveExpression : public IRMutator { if (a_uses_var && !b_uses_var) { const int64_t *ib = as_const_int(b); auto is_multiple_of_b = [&](const Expr &e) { - if (ib) { + if (ib && op->type.is_scalar()) { int64_t r = 0; return reduce_expr_modulo(e, *ib, &r) && r == 0; } else { @@ -1478,6 +1478,9 @@ void solve_test() { check_solve(min(x + y, x - z), x + min(y, 0 - z)); check_solve(max(x + y, x - z), x + max(y, 0 - z)); + check_solve((5 * Broadcast::make(x, 4) + y) / 5, + Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5)); + debug(0) << "Solve test passed\n"; } diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index a774335a07bf..78fe9e0cb757 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -202,6 +202,31 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/7890 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + RVar yryf; + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5, "rdom_r"); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + + Var yo, yi, xo, xi, u; + blurry.split(y, yo, yi, 2, TailStrategy::Auto); + local_sum.split(x, xo, xi, 4, TailStrategy::Auto); + local_sum.update(0).split(x, xo, xi, 1, TailStrategy::Auto); + local_sum.update(0).rfactor(r.x, u); + blurry.store_root(); + local_sum.compute_root(); + Pipeline p({blurry}); + auto buf = p.realize({32, 32}); + check_blur_output(buf, correct); + } + // https://github.com/halide/Halide/issues/8054 { ImageParam input(Float(32), 2, "input");