diff --git a/src/Simplify.cpp b/src/Simplify.cpp index bc0c0964cf81..29535d36255b 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -15,31 +15,29 @@ using std::pair; using std::string; using std::vector; -#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS) -int Simplify::debug_indent = 0; -#endif - Simplify::Simplify(bool r, const Scope *bi, const Scope *ai) : remove_dead_code(r) { // Only respect the constant bounds from the containing scope. for (auto iter = bi->cbegin(); iter != bi->cend(); ++iter) { - ExprInfo bounds; + ExprInfo info; if (const int64_t *i_min = as_const_int(iter.value().min)) { - bounds.min_defined = true; - bounds.min = *i_min; + info.bounds.min_defined = true; + info.bounds.min = *i_min; } if (const int64_t *i_max = as_const_int(iter.value().max)) { - bounds.max_defined = true; - bounds.max = *i_max; + info.bounds.max_defined = true; + info.bounds.max = *i_max; } if (const auto *a = ai->find(iter.name())) { - bounds.alignment = *a; + info.alignment = *a; } - if (bounds.min_defined || bounds.max_defined || bounds.alignment.modulus != 1) { - bounds_and_alignment_info.push(iter.name(), bounds); + if (info.bounds.min_defined || + info.bounds.max_defined || + info.alignment.modulus != 1) { + bounds_and_alignment_info.push(iter.name(), info); } } @@ -48,20 +46,20 @@ Simplify::Simplify(bool r, const Scope *bi, const Scope, bool> Simplify::mutate_with_changes(const std::vector &old_exprs, ExprInfo *bounds) { +std::pair, bool> Simplify::mutate_with_changes(const std::vector &old_exprs) { vector new_exprs(old_exprs.size()); bool changed = false; // Mutate the args for (size_t i = 0; i < old_exprs.size(); i++) { const Expr &old_e = old_exprs[i]; - Expr new_e = mutate(old_e, bounds); + Expr new_e = mutate(old_e, nullptr); if (!new_e.same_as(old_e)) { changed = true; } @@ -135,17 +133,17 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::ExprInfo i; if (v) { simplify->mutate(lt->b, &i); - if (i.min_defined) { + if (i.bounds.min_defined) { // !(v < i) - learn_lower_bound(v, i.min); + learn_lower_bound(v, i.bounds.min); } } v = lt->b.as(); if (v) { simplify->mutate(lt->a, &i); - if (i.max_defined) { + if (i.bounds.max_defined) { // !(i < v) - learn_upper_bound(v, i.max); + learn_upper_bound(v, i.bounds.max); } } } else if (const LE *le = fact.as()) { @@ -153,17 +151,17 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::ExprInfo i; if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(le->b, &i); - if (i.min_defined) { + if (i.bounds.min_defined) { // !(v <= i) - learn_lower_bound(v, i.min + 1); + learn_lower_bound(v, i.bounds.min + 1); } } v = le->b.as(); if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(le->a, &i); - if (i.max_defined) { + if (i.bounds.max_defined) { // !(i <= v) - learn_upper_bound(v, i.max - 1); + learn_upper_bound(v, i.bounds.max - 1); } } } else if (const Call *c = Call::as_tag(fact)) { @@ -185,8 +183,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) { ExprInfo b; - b.max_defined = true; - b.max = val; + b.bounds = ConstantInterval::bounded_above(val); if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { b.intersect(*info); } @@ -196,8 +193,7 @@ void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) { void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) { ExprInfo b; - b.min_defined = true; - b.min = val; + b.bounds = ConstantInterval::bounded_below(val); if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { b.intersect(*info); } @@ -267,17 +263,17 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo i; if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(lt->b, &i); - if (i.max_defined) { + if (i.bounds.max_defined) { // v < i - learn_upper_bound(v, i.max - 1); + learn_upper_bound(v, i.bounds.max - 1); } } v = lt->b.as(); if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(lt->a, &i); - if (i.min_defined) { + if (i.bounds.min_defined) { // i < v - learn_lower_bound(v, i.min + 1); + learn_lower_bound(v, i.bounds.min + 1); } } } else if (const LE *le = fact.as()) { @@ -285,17 +281,17 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo i; if (v) { simplify->mutate(le->b, &i); - if (i.max_defined) { + if (i.bounds.max_defined) { // v <= i - learn_upper_bound(v, i.max); + learn_upper_bound(v, i.bounds.max); } } v = le->b.as(); if (v) { simplify->mutate(le->a, &i); - if (i.min_defined) { + if (i.bounds.min_defined) { // i <= v - learn_lower_bound(v, i.min); + learn_lower_bound(v, i.bounds.min); } } } else if (const Call *c = Call::as_tag(fact)) { diff --git a/src/Simplify.h b/src/Simplify.h index b9335c0c3de9..61ca847d7a27 100644 --- a/src/Simplify.h +++ b/src/Simplify.h @@ -21,11 +21,13 @@ namespace Internal { * Exprs that should be assumed to be true. */ // @{ -Stmt simplify(const Stmt &, bool remove_dead_code = true, +Stmt simplify(const Stmt &, + bool remove_dead_code = true, const Scope &bounds = Scope::empty_scope(), const Scope &alignment = Scope::empty_scope(), const std::vector &assumptions = std::vector()); -Expr simplify(const Expr &, bool remove_dead_code = true, +Expr simplify(const Expr &, + bool remove_dead_code = true, const Scope &bounds = Scope::empty_scope(), const Scope &alignment = Scope::empty_scope(), const std::vector &assumptions = std::vector()); diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index fb9238dd9a6a..e4cccf131b5e 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -3,20 +3,16 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Add *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds && no_overflow_int(op->type)) { - bounds->min_defined = a_bounds.min_defined && - b_bounds.min_defined && - add_with_overflow(64, a_bounds.min, b_bounds.min, &(bounds->min)); - bounds->max_defined = a_bounds.max_defined && - b_bounds.max_defined && - add_with_overflow(64, a_bounds.max, b_bounds.max, &(bounds->max)); - bounds->alignment = a_bounds.alignment + b_bounds.alignment; - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Add *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + info->bounds = a_info.bounds + b_info.bounds; + info->alignment = a_info.alignment + b_info.alignment; + info->trim_bounds_using_alignment(); + info->cast_to(op->type); } if (may_simplify(op->type)) { @@ -24,7 +20,7 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } auto rewrite = IRMatcher::rewriter(IRMatcher::add(a, b), op->type); @@ -194,7 +190,7 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { rewrite(x + (y + (c0 - x)/c1)*c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) || false)))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_And.cpp b/src/Simplify_And.cpp index 35bbd5f7f747..a6f7e82c9095 100644 --- a/src/Simplify_And.cpp +++ b/src/Simplify_And.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const And *op, ExprInfo *bounds) { +Expr Simplify::visit(const And *op, ExprInfo *info) { if (falsehoods.count(op)) { return const_false(op->type.lanes()); } @@ -109,7 +109,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) { rewrite(x <= y && x <= z, x <= min(y, z)) || rewrite(y <= x && z <= x, max(y, z) <= x)) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (a.same_as(op->a) && diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 29bc75aa2bb2..db3fe526418c 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -49,7 +49,7 @@ Expr lift_elementwise_broadcasts(Type type, const std::string &name, std::vector } // namespace -Expr Simplify::visit(const Call *op, ExprInfo *bounds) { +Expr Simplify::visit(const Call *op, ExprInfo *info) { // Calls implicitly depend on host, dev, mins, and strides of the buffer referenced if (op->call_type == Call::Image || op->call_type == Call::Halide) { found_buffer_reference(op->name, op->args.size()); @@ -79,7 +79,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } uint64_t ua = 0; @@ -123,7 +123,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } const Type t = op->type; @@ -132,9 +132,9 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { std::string result_op = op->name; // If we know the sign of this shift, change it to an unsigned shift. - if (b_info.min_defined && b_info.min >= 0) { + if (b_info.bounds >= 0) { b = mutate(cast(b.type().with_code(halide_type_uint), b), nullptr); - } else if (b.type().is_int() && b_info.max_defined && b_info.max <= 0) { + } else if (b.type().is_int() && b_info.bounds <= 0) { result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); b = mutate(cast(b.type().with_code(halide_type_uint), -b), nullptr); } @@ -145,24 +145,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // LLVM shl and shr instructions produce poison for // shifts >= typesize, so we will follow suit in our simplifier. if (ub >= (uint64_t)(t.bits())) { - clear_bounds_info(bounds); + clear_expr_info(info); return make_signed_integer_overflow(t); } if (a.type().is_uint() || ub < ((uint64_t)t.bits() - 1)) { b = make_const(t, ((int64_t)1LL) << ub); if (result_op == Call::get_intrinsic_name(Call::shift_left)) { - return mutate(Mul::make(a, b), bounds); + return mutate(Mul::make(a, b), info); } else { - return mutate(Div::make(a, b), bounds); + return mutate(Div::make(a, b), info); } } else { // For signed types, (1 << (t.bits() - 1)) will overflow into the sign bit while // (-32768 >> (t.bits() - 1)) propagates the sign bit, making decomposition // into mul or div problematic, so just special-case them here. if (result_op == Call::get_intrinsic_name(Call::shift_left)) { - return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << ub), make_zero(t)), bounds); + return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << ub), make_zero(t)), info); } else { - return mutate(select(a < 0, make_const(t, -1), make_zero(t)), bounds); + return mutate(select(a < 0, make_const(t, -1), make_zero(t)), info); } } } @@ -173,7 +173,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { if (is_const_zero(sub->a)) { result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); b = sub->b; - return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), bounds); + return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), info); } } } @@ -190,7 +190,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia, ib = 0; @@ -227,7 +227,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia, ib; @@ -248,7 +248,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia; @@ -268,7 +268,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia, ib; @@ -286,12 +286,17 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } else if (op->is_intrinsic(Call::abs)) { // Constant evaluate abs(x). - ExprInfo a_bounds; - Expr a = mutate(op->args[0], &a_bounds); + ExprInfo a_info; + Expr a = mutate(op->args[0], &a_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); + } + + if (info) { + info->bounds = abs(a_info.bounds); + info->cast_to(op->type); } Type ta = a.type(); @@ -310,9 +315,9 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { fa = -fa; } return make_const(a.type(), fa); - } else if (a.type().is_int() && a_bounds.min_defined && a_bounds.min >= 0) { + } else if (a.type().is_int() && a_info.bounds >= 0) { return cast(op->type, a); - } else if (a.type().is_int() && a_bounds.max_defined && a_bounds.max <= 0) { + } else if (a.type().is_int() && a_info.bounds <= 0) { return cast(op->type, -a); } else if (a.same_as(op->args[0])) { return op; @@ -321,13 +326,13 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } else if (op->is_intrinsic(Call::absd)) { // Constant evaluate absd(a, b). - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->args[0], &a_bounds); - Expr b = mutate(op->args[1], &b_bounds); + ExprInfo a_info, b_info; + Expr a = mutate(op->args[0], &a_info); + Expr b = mutate(op->args[1], &b_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } Type ta = a.type(); @@ -355,14 +360,17 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } else if (op->is_intrinsic(Call::saturating_cast)) { internal_assert(op->args.size() == 1); - ExprInfo a_bounds; - Expr a = mutate(op->args[0], &a_bounds); + ExprInfo a_info; + Expr a = mutate(op->args[0], &a_info); - // TODO(rootjalex): We could be intelligent about using a_bounds to remove saturating_casts; + // In principle we could use constant bounds here to convert saturating + // casts to casts, but it's probably a bad idea. Saturating casts only + // show up if the user asks for them, and they're faster than a cast on + // some platforms. We should leave them be. if (is_const(a)) { a = lower_saturating_cast(op->type, a); - return mutate(a, bounds); + return mutate(a, info); } else if (!a.same_as(op->args[0])) { return saturating_cast(op->type, a); } else { @@ -424,7 +432,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { internal_assert(op->args.size() % 2 == 0); // Prefetch: {base, offset, extent0, stride0, ...} - auto [args, changed] = mutate_with_changes(op->args, nullptr); + auto [args, changed] = mutate_with_changes(op->args); // The {extent, stride} args in the prefetch call are sorted // based on the storage dimension in ascending order (i.e. innermost @@ -478,7 +486,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { { // Can assume the condition is true when evaluating the value. auto t = scoped_truth(cond); - result = mutate(op->args[1], bounds); + result = mutate(op->args[1], info); } if (is_const_one(cond)) { @@ -511,12 +519,8 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { const Broadcast *b_lower = lower.as(); const Broadcast *b_upper = upper.as(); - if (arg_info.min_defined && - arg_info.max_defined && - lower_info.max_defined && - upper_info.min_defined && - arg_info.min >= lower_info.max && - arg_info.max <= upper_info.min) { + if (arg_info.bounds >= lower_info.bounds && + arg_info.bounds <= upper_info.bounds) { return arg; } else if (b_arg && b_lower && b_upper) { // Move broadcasts outwards @@ -537,7 +541,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } else if (Call::as_tag(op)) { // The bounds of the result are the bounds of the arg internal_assert(op->args.size() == 1); - Expr arg = mutate(op->args[0], bounds); + Expr arg = mutate(op->args[0], info); if (arg.same_as(op->args[0])) { return op; } else { @@ -557,12 +561,12 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } if (is_const_one(cond)) { - return mutate(op->args[1], bounds); + return mutate(op->args[1], info); } else if (is_const_zero(cond)) { if (op->args.size() == 3) { - return mutate(op->args[2], bounds); + return mutate(op->args[2], info); } else { - return mutate(make_zero(op->type), bounds); + return mutate(make_zero(op->type), info); } } else { Expr true_value = mutate(op->args[1], nullptr); @@ -576,11 +580,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } in_unreachable = false; if (true_unreachable) { - if (false_value.defined()) { - return false_value; - } else { - return make_zero(op->type); - } + return false_value; } else if (false_unreachable) { return true_value; } @@ -602,21 +602,20 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { int num_values = (int)op->args.size() - 1; if (num_values == 1) { // Mux of a single value - return mutate(op->args[1], bounds); + return mutate(op->args[1], info); } ExprInfo index_info; Expr index = mutate(op->args[0], &index_info); // Check if the mux has statically resolved - if (index_info.min_defined && - index_info.max_defined && - index_info.min == index_info.max) { - if (index_info.min >= 0 && index_info.min < num_values) { + if (index_info.bounds.is_single_point()) { + int64_t v = index_info.bounds.min; + if (v >= 0 && v < num_values) { // In-range, return the (simplified) corresponding value. - return mutate(op->args[index_info.min + 1], bounds); + return mutate(op->args[v + 1], info); } else { // It's out-of-range, so return the last value. - return mutate(op->args.back(), bounds); + return mutate(op->args.back(), info); } } @@ -782,16 +781,16 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // There are other PureExterns we don't bother with (e.g. fast_inverse_f32)... // just fall thru and take the general case. - debug(2) << "Simplifier: unhandled PureExtern: " << op->name; + debug(2) << "Simplifier: unhandled PureExtern: " << op->name << "\n"; } else if (op->is_intrinsic(Call::signed_integer_overflow)) { - clear_bounds_info(bounds); + clear_expr_info(info); } else if (op->is_intrinsic(Call::concat_bits) && op->args.size() == 1) { - return mutate(op->args[0], bounds); + return mutate(op->args[0], info); } // No else: we want to fall thru from the PureExtern clause. { - auto [new_args, changed] = mutate_with_changes(op->args, nullptr); + auto [new_args, changed] = mutate_with_changes(op->args); if (!changed) { return op; } else { diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 4e689212aaa0..985707ce2cfb 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -1,33 +1,25 @@ #include "Simplify_Internal.h" +#include "IRPrinter.h" + namespace Halide { namespace Internal { -Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { - Expr value = mutate(op->value, bounds); +Expr Simplify::visit(const Cast *op, ExprInfo *info) { - if (bounds) { - // If either the min value or the max value can't be represented - // in the destination type, or the min/max value is undefined, - // the bounds need to be cleared (one-sided for no_overflow, - // both sides for overflow types). - if ((bounds->min_defined && !op->type.can_represent(bounds->min)) || !bounds->min_defined) { - bounds->min_defined = false; - if (!no_overflow(op->type)) { - // If the type overflows, this invalidates the max too. - bounds->max_defined = false; - } - } - if ((bounds->max_defined && !op->type.can_represent(bounds->max)) || !bounds->max_defined) { - if (!no_overflow(op->type)) { - // If the type overflows, this invalidates the min too. - bounds->min_defined = false; - } - bounds->max_defined = false; - } - if (!op->type.can_represent(bounds->alignment.modulus) || - !op->type.can_represent(bounds->alignment.remainder)) { - bounds->alignment = ModulusRemainder(); + ExprInfo value_info; + Expr value = mutate(op->value, &value_info); + + if (info) { + if (no_overflow(op->type) && !op->type.can_represent(value_info.bounds)) { + // If there's overflow in a no-overflow type (e.g. due to casting + // from a UInt(64) to an Int(32)), then forget everything we know + // about the Expr. The expression may or may not overflow. We don't + // know. + *info = ExprInfo{}; + } else { + *info = value_info; + info->cast_to(op->type); } } @@ -39,7 +31,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { int64_t i = 0; uint64_t u = 0; if (Call::as_intrinsic(value, {Call::signed_integer_overflow})) { - clear_bounds_info(bounds); + clear_expr_info(info); return make_signed_integer_overflow(op->type); } else if (value.type() == op->type) { return value; @@ -48,12 +40,13 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { std::isfinite(f)) { // float -> int // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(f)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(f)), info); } else if (op->type.is_uint() && const_float(value, &f) && std::isfinite(f)) { // float -> uint - return make_const(op->type, safe_numeric_cast(f)); + // Recursively call mutate just to set the bounds + return mutate(make_const(op->type, safe_numeric_cast(f)), info); } else if (op->type.is_float() && const_float(value, &f)) { // float -> float @@ -62,7 +55,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { const_int(value, &i)) { // int -> int // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, i), bounds); + return mutate(make_const(op->type, i), info); } else if (op->type.is_uint() && const_int(value, &i)) { // int -> uint @@ -70,19 +63,19 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { } else if (op->type.is_float() && const_int(value, &i)) { // int -> float - return make_const(op->type, safe_numeric_cast(i)); + return mutate(make_const(op->type, safe_numeric_cast(i)), info); } else if (op->type.is_int() && const_uint(value, &u) && op->type.bits() < value.type().bits()) { // uint -> int narrowing // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(u)), info); } else if (op->type.is_int() && const_uint(value, &u) && op->type.bits() == value.type().bits()) { // uint -> int reinterpret // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(u)), info); } else if (op->type.is_int() && const_uint(value, &u) && op->type.bits() > value.type().bits()) { @@ -90,14 +83,14 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { if (op->type.can_represent(u) || op->type.bits() < 32) { // If the type can represent the value or overflow is well-defined. // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(u)), info); } else { return make_signed_integer_overflow(op->type); } } else if (op->type.is_uint() && const_uint(value, &u)) { // uint -> uint - return make_const(op->type, u); + return mutate(make_const(op->type, u), info); } else if (op->type.is_float() && const_uint(value, &u)) { // uint -> float @@ -108,7 +101,18 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { // If this is a cast of a cast of the same type, where the // outer cast is narrower, the inner cast can be // eliminated. - return mutate(Cast::make(op->type, cast->value), bounds); + return mutate(Cast::make(op->type, cast->value), info); + } else if (cast && + op->type.is_int_or_uint() && + cast->type.is_int() && + cast->value.type().is_int() && + op->type.bits() >= cast->type.bits() && + cast->type.bits() >= cast->value.type().bits()) { + // Casting from a signed type always sign-extends, so widening + // partway to a signed type and the rest of the way to some other + // integer type is the same as just widening to that integer type + // directly. + return mutate(Cast::make(op->type, cast->value), info); } else if (cast && (op->type.is_int() || op->type.is_uint()) && (cast->type.is_int() || cast->type.is_uint()) && @@ -119,10 +123,10 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { // inner cast's argument, the inner cast can be // eliminated. The inner cast is either a sign extend // or a zero extend, and the outer cast truncates the extended bits - return mutate(Cast::make(op->type, cast->value), bounds); + return mutate(Cast::make(op->type, cast->value), info); } else if (broadcast_value) { // cast(broadcast(x)) -> broadcast(cast(x)) - return mutate(Broadcast::make(Cast::make(op->type.with_lanes(broadcast_value->value.type().lanes()), broadcast_value->value), broadcast_value->lanes), bounds); + return mutate(Broadcast::make(Cast::make(op->type.with_lanes(broadcast_value->value.type().lanes()), broadcast_value->value), broadcast_value->lanes), info); } else if (ramp_value && op->type.element_of() == Int(64) && op->value.type().element_of() == Int(32)) { @@ -132,7 +136,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { Cast::make(op->type.with_lanes(ramp_value->stride.type().lanes()), ramp_value->stride), ramp_value->lanes), - bounds); + info); } } diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index 49f98837404c..92487eddecc2 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -3,112 +3,51 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Div *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds && no_overflow_int(op->type)) { - bounds->min = INT64_MAX; - bounds->max = INT64_MIN; - - // Enumerate all possible values for the min and max and take the extreme values. - if (a_bounds.min_defined && b_bounds.min_defined && b_bounds.min != 0) { - int64_t v = div_imp(a_bounds.min, b_bounds.min); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - if (a_bounds.min_defined && b_bounds.max_defined && b_bounds.max != 0) { - int64_t v = div_imp(a_bounds.min, b_bounds.max); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - if (a_bounds.max_defined && b_bounds.max_defined && b_bounds.max != 0) { - int64_t v = div_imp(a_bounds.max, b_bounds.max); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - if (a_bounds.max_defined && b_bounds.min_defined && b_bounds.min != 0) { - int64_t v = div_imp(a_bounds.max, b_bounds.min); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - const bool b_positive = b_bounds.min_defined && b_bounds.min > 0; - const bool b_negative = b_bounds.max_defined && b_bounds.max < 0; - - if ((b_positive && !b_bounds.max_defined) || - (b_negative && !b_bounds.min_defined)) { - // Take limit as b -> +/- infinity - int64_t v = 0; - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - bounds->min_defined = ((a_bounds.min_defined && b_positive) || - (a_bounds.max_defined && b_negative)); - bounds->max_defined = ((a_bounds.max_defined && b_positive) || - (a_bounds.min_defined && b_negative)); - - // That's as far as we can get knowing the sign of the - // denominator. For bounded numerators, we additionally know - // that div can't make anything larger in magnitude, so we can - // take the intersection with that. - if (a_bounds.max_defined && a_bounds.min_defined) { - int64_t v = std::max(a_bounds.max, -a_bounds.min); - if (bounds->min_defined) { - bounds->min = std::max(bounds->min, -v); - } else { - bounds->min = -v; - } - if (bounds->max_defined) { - bounds->max = std::min(bounds->max, v); - } else { - bounds->max = v; - } - bounds->min_defined = bounds->max_defined = true; - } - - // Bounded numerator divided by constantish - // denominator can sometimes collapse things to a - // constant at this point - if (bounds->min_defined && - bounds->max_defined && - bounds->max == bounds->min) { - if (op->type.can_represent(bounds->min)) { - return make_const(op->type, bounds->min); - } else { - // Even though this is 'no-overflow-int', if the result - // we calculate can't fit into the destination type, - // we're better off returning an overflow condition than - // a known-wrong value. (Note that no_overflow_int() should - // only be true for signed integers.) - internal_assert(op->type.is_int()); - clear_bounds_info(bounds); - return make_signed_integer_overflow(op->type); +Expr Simplify::visit(const Div *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + if (op->type.is_int_or_uint()) { + // ConstantInterval division is integer division, so we can't use + // this code path for floats. + info->bounds = a_info.bounds / b_info.bounds; + info->alignment = a_info.alignment / b_info.alignment; + info->trim_bounds_using_alignment(); + info->cast_to(op->type); + + // Bounded numerator divided by constantish bounded denominator can + // sometimes collapse things to a constant at this point. This + // mostly happens when the denominator is a constant and the + // numerator span is small (e.g. [23, 29]/10 = 2), but there are + // also cases with a bounded denominator (e.g. [5, 7]/[4, 5] = 1). + if (info->bounds.is_single_point()) { + if (op->type.can_represent(info->bounds.min)) { + return make_const(op->type, info->bounds.min); + } else { + // Even though this is 'no-overflow-int', if the result + // we calculate can't fit into the destination type, + // we're better off returning an overflow condition than + // a known-wrong value. (Note that no_overflow_int() should + // only be true for signed integers.) + internal_assert(no_overflow_int(op->type)); + clear_expr_info(info); + return make_signed_integer_overflow(op->type); + } } + } else { + // TODO: Tracking constant integer bounds of floating point values + // isn't so useful right now, but if we want integer bounds for + // floating point division later, here's the place to put it. + clear_expr_info(info); } - // Code downstream can use min/max in calculated-but-unused arithmetic - // that can lead to UB (and thus, flaky failures under ASAN/UBSAN) - // if we leave them set to INT64_MAX/INT64_MIN; normalize to zero to avoid this. - if (!bounds->min_defined) { - bounds->min = 0; - } - if (!bounds->max_defined) { - bounds->max = 0; - } - bounds->alignment = a_bounds.alignment / b_bounds.alignment; - bounds->trim_bounds_using_alignment(); } bool denominator_non_zero = (no_overflow_int(op->type) && - ((b_bounds.min_defined && b_bounds.min > 0) || - (b_bounds.max_defined && b_bounds.max < 0) || - (b_bounds.alignment.remainder != 0))); + (!b_info.bounds.contains(0) || + b_info.alignment.remainder != 0)); if (may_simplify(op->type)) { @@ -126,8 +65,8 @@ Expr Simplify::visit(const Div *op, ExprInfo *bounds) { return rewrite.result; } - int a_mod = a_bounds.alignment.modulus; - int a_rem = a_bounds.alignment.remainder; + int a_mod = a_info.alignment.modulus; + int a_rem = a_info.alignment.remainder; // clang-format off if (EVAL_IN_LAMBDA @@ -272,7 +211,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *bounds) { c2 > 0 && c0 % c2 == 0) || // A very specific pattern that comes up in bounds in upsampling code. rewrite((x % 2 + c0) / 2, x % 2 + fold(c0 / 2), c0 % 2 == 1))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_EQ.cpp b/src/Simplify_EQ.cpp index 13b49a90886c..97c32814e03d 100644 --- a/src/Simplify_EQ.cpp +++ b/src/Simplify_EQ.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { +Expr Simplify::visit(const EQ *op, ExprInfo *info) { if (truths.count(op)) { return const_true(op->type.lanes()); } else if (falsehoods.count(op)) { @@ -31,7 +31,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { if (rewrite(x == 1, x)) { return rewrite.result; } else if (rewrite(x == 0, !x)) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } else if (rewrite(x == x, const_true(lanes))) { return rewrite.result; } else if (a.same_as(op->a) && b.same_as(op->b)) { @@ -41,8 +41,8 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { } } - ExprInfo delta_bounds; - Expr delta = mutate(op->a - op->b, &delta_bounds); + ExprInfo delta_info; + Expr delta = mutate(op->a - op->b, &delta_info); const int lanes = op->type.lanes(); // If the delta is 0, then it's just x == x @@ -51,16 +51,12 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { } // Attempt to disprove using bounds analysis - if (delta_bounds.min_defined && delta_bounds.min > 0) { - return const_false(lanes); - } - - if (delta_bounds.max_defined && delta_bounds.max < 0) { + if (!delta_info.bounds.contains(0)) { return const_false(lanes); } // Attempt to disprove using modulus remainder analysis - if (delta_bounds.alignment.remainder != 0) { + if (delta_info.alignment.remainder != 0) { return const_false(lanes); } @@ -109,7 +105,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { rewrite(min(x, 0) == 0, 0 <= x) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (rewrite(c0 == 0, fold(c0 == 0)) || @@ -121,7 +117,9 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { const EQ *eq = rewrite.result.as(); if (eq && eq->a.same_as(op->a) && - eq->b.same_as(op->b)) { + equal(eq->b, op->b)) { + // Note we don't use same_as for b, because the shuffling of the RHS + // to the LHS and back might mutate it and then mutate it back. return op; } else { return rewrite.result; @@ -134,7 +132,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { } // ne redirects to not eq -Expr Simplify::visit(const NE *op, ExprInfo *bounds) { +Expr Simplify::visit(const NE *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -145,7 +143,7 @@ Expr Simplify::visit(const NE *op, ExprInfo *bounds) { } } - Expr mutated = mutate(Not::make(EQ::make(op->a, op->b)), bounds); + Expr mutated = mutate(Not::make(EQ::make(op->a, op->b)), info); if (const NE *ne = mutated.as()) { if (ne->a.same_as(op->a) && ne->b.same_as(op->b)) { return op; diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index b5fcc96ac0cd..02f19ae13a6a 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -7,49 +7,48 @@ namespace Internal { // Miscellaneous expression visitors that are too small to bother putting in their own files -Expr Simplify::visit(const IntImm *op, ExprInfo *bounds) { - if (bounds && no_overflow_int(op->type)) { - bounds->min_defined = bounds->max_defined = true; - bounds->min = bounds->max = op->value; - bounds->alignment.remainder = op->value; - bounds->alignment.modulus = 0; +Expr Simplify::visit(const IntImm *op, ExprInfo *info) { + if (info) { + info->bounds = ConstantInterval::single_point(op->value); + info->alignment = ModulusRemainder(0, op->value); + info->cast_to(op->type); } else { - clear_bounds_info(bounds); + clear_expr_info(info); } return op; } -Expr Simplify::visit(const UIntImm *op, ExprInfo *bounds) { - if (bounds && Int(64).can_represent(op->value)) { - bounds->min_defined = bounds->max_defined = true; - bounds->min = bounds->max = (int64_t)(op->value); - bounds->alignment.remainder = op->value; - bounds->alignment.modulus = 0; +Expr Simplify::visit(const UIntImm *op, ExprInfo *info) { + if (info && Int(64).can_represent(op->value)) { + int64_t v = (int64_t)(op->value); + info->bounds = ConstantInterval::single_point(v); + info->alignment = ModulusRemainder(0, v); + info->cast_to(op->type); } else { - clear_bounds_info(bounds); + clear_expr_info(info); } return op; } -Expr Simplify::visit(const FloatImm *op, ExprInfo *bounds) { - clear_bounds_info(bounds); +Expr Simplify::visit(const FloatImm *op, ExprInfo *info) { + clear_expr_info(info); return op; } -Expr Simplify::visit(const StringImm *op, ExprInfo *bounds) { - clear_bounds_info(bounds); +Expr Simplify::visit(const StringImm *op, ExprInfo *info) { + clear_expr_info(info); return op; } -Expr Simplify::visit(const Broadcast *op, ExprInfo *bounds) { - Expr value = mutate(op->value, bounds); +Expr Simplify::visit(const Broadcast *op, ExprInfo *info) { + Expr value = mutate(op->value, info); const int lanes = op->lanes; auto rewrite = IRMatcher::rewriter(IRMatcher::broadcast(value, lanes), op->type); if (rewrite(broadcast(broadcast(x, c0), lanes), broadcast(x, c0 * lanes)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (value.same_as(op->value)) { @@ -59,8 +58,8 @@ Expr Simplify::visit(const Broadcast *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { - Expr value = mutate(op->value, bounds); +Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { + Expr value = mutate(op->value, info); const int lanes = op->type.lanes(); const int arg_lanes = op->value.type().lanes(); @@ -69,32 +68,22 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { return value; } - if (bounds && op->type.is_int()) { + if (info && op->type.is_int()) { switch (op->op) { case VectorReduce::Add: // Alignment of result is the alignment of the arg. Bounds // of the result can grow according to the reduction // factor. - if (bounds->min_defined) { - bounds->min *= factor; - } - if (bounds->max_defined) { - bounds->max *= factor; - } + info->bounds = cast(op->type, info->bounds * factor); break; case VectorReduce::SaturatingAdd: - if (bounds->min_defined) { - bounds->min = saturating_mul(bounds->min, factor); - } - if (bounds->max_defined) { - bounds->max = saturating_mul(bounds->max, factor); - } + info->bounds = saturating_cast(op->type, info->bounds * factor); break; case VectorReduce::Mul: // Don't try to infer anything about bounds. Leave the // alignment unchanged even though we could theoretically // upgrade it. - bounds->min_defined = bounds->max_defined = false; + info->bounds = ConstantInterval{}; break; case VectorReduce::Min: case VectorReduce::Max: @@ -104,8 +93,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { case VectorReduce::Or: // For integer types this is a bitwise operator. Don't try // to infer anything for now. - bounds->min_defined = bounds->max_defined = false; - bounds->alignment = ModulusRemainder{}; + info->bounds = ConstantInterval{}; + info->alignment = ModulusRemainder{}; break; } } @@ -134,7 +123,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) || rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -148,7 +137,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) || rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -162,7 +151,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) || rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -183,7 +172,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), x <= y + min(z * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -205,7 +194,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), x <= y + max(z * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -220,33 +209,35 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { +Expr Simplify::visit(const Variable *op, ExprInfo *info) { if (const ExprInfo *b = bounds_and_alignment_info.find(op->name)) { - if (bounds) { - *bounds = *b; + if (info) { + *info = *b; } - if (b->min_defined && b->max_defined && b->min == b->max) { - return make_const(op->type, b->min); + if (b->bounds.is_single_point()) { + return make_const(op->type, b->bounds.min); } + } else if (info && !no_overflow_int(op->type)) { + info->bounds = ConstantInterval::bounds_of_type(op->type); } - if (auto *info = var_info.shallow_find(op->name)) { + if (auto *v_info = var_info.shallow_find(op->name)) { // if replacement is defined, we should substitute it in (unless // it's a var that has been hidden by a nested scope). - if (info->replacement.defined()) { - internal_assert(info->replacement.type() == op->type) + if (v_info->replacement.defined()) { + internal_assert(v_info->replacement.type() == op->type) << "Cannot replace variable " << op->name << " of type " << op->type - << " with expression of type " << info->replacement.type() << "\n"; - info->new_uses++; + << " with expression of type " << v_info->replacement.type() << "\n"; + v_info->new_uses++; // We want to remutate the replacement, because we may be // injecting it into a context where it is known to be a // constant (e.g. due to an if). - return mutate(info->replacement, bounds); + return mutate(v_info->replacement, info); } else { // This expression was not something deemed // substitutable - no replacement is defined. - info->old_uses++; + v_info->old_uses++; return op; } } else { @@ -256,29 +247,28 @@ Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const Ramp *op, ExprInfo *bounds) { - ExprInfo base_bounds, stride_bounds; - Expr base = mutate(op->base, &base_bounds); - Expr stride = mutate(op->stride, &stride_bounds); +Expr Simplify::visit(const Ramp *op, ExprInfo *info) { + ExprInfo base_info, stride_info; + Expr base = mutate(op->base, &base_info); + Expr stride = mutate(op->stride, &stride_info); const int lanes = op->lanes; - if (bounds && no_overflow_int(op->type)) { - bounds->min_defined = base_bounds.min_defined && stride_bounds.min_defined; - bounds->max_defined = base_bounds.max_defined && stride_bounds.max_defined; - bounds->min = std::min(base_bounds.min, base_bounds.min + (lanes - 1) * stride_bounds.min); - bounds->max = std::max(base_bounds.max, base_bounds.max + (lanes - 1) * stride_bounds.max); + if (info) { + info->bounds = base_info.bounds + stride_info.bounds * ConstantInterval(0, lanes - 1); // A ramp lane is b + l * s. Expanding b into mb * x + rb and s into ms * y + rs, we get: // mb * x + rb + l * (ms * y + rs) // = mb * x + ms * l * y + rs * l + rb // = gcd(rs, ms, mb) * z + rb - int64_t m = stride_bounds.alignment.modulus; - m = gcd(m, stride_bounds.alignment.remainder); - m = gcd(m, base_bounds.alignment.modulus); - int64_t r = base_bounds.alignment.remainder; + int64_t m = stride_info.alignment.modulus; + m = gcd(m, stride_info.alignment.remainder); + m = gcd(m, base_info.alignment.modulus); + int64_t r = base_info.alignment.remainder; if (m != 0) { - r = mod_imp(base_bounds.alignment.remainder, m); + r = mod_imp(base_info.alignment.remainder, m); } - bounds->alignment = {m, r}; + info->alignment = {m, r}; + info->trim_bounds_using_alignment(); + info->cast_to(op->type); } // A somewhat torturous way to check if the stride is zero, @@ -303,9 +293,13 @@ Expr Simplify::visit(const Ramp *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const Load *op, ExprInfo *bounds) { +Expr Simplify::visit(const Load *op, ExprInfo *info) { found_buffer_reference(op->name); + if (info) { + info->bounds = ConstantInterval::bounds_of_type(op->type); + } + Expr predicate = mutate(op->predicate, nullptr); ExprInfo index_info; @@ -319,17 +313,11 @@ Expr Simplify::visit(const Load *op, ExprInfo *bounds) { if (is_const_one(op->predicate)) { string alloc_extent_name = op->name + ".total_extent_bytes"; if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { - if (index_info.max_defined && index_info.max < 0) { + if (index_info.bounds < 0 || + index_info.bounds * op->type.bytes() > alloc_info->bounds) { in_unreachable = true; return unreachable(op->type); } - if (alloc_info->max_defined && index_info.min_defined) { - int index_min_bytes = index_info.min * op->type.bytes(); - if (index_min_bytes > alloc_info->max) { - in_unreachable = true; - return unreachable(op->type); - } - } } } diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 92f012926091..19666cc77294 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -7,7 +7,9 @@ * exported in Halide.h. */ #include "Bounds.h" +#include "ConstantInterval.h" #include "IRMatch.h" +#include "IRPrinter.h" #include "IRVisitor.h" #include "Scope.h" @@ -28,17 +30,6 @@ namespace Halide { namespace Internal { -inline int64_t saturating_mul(int64_t a, int64_t b) { - int64_t result; - if (mul_with_overflow(64, a, b, &result)) { - return result; - } else if ((a > 0) == (b > 0)) { - return INT64_MAX; - } else { - return INT64_MIN; - } -} - class Simplify : public VariadicVisitor { using Super = VariadicVisitor; @@ -47,80 +38,109 @@ class Simplify : public VariadicVisitor { struct ExprInfo { // We track constant integer bounds when they exist - // TODO: Use ConstantInterval? - int64_t min = 0, max = 0; - bool min_defined = false, max_defined = false; + ConstantInterval bounds; // And the alignment of integer variables ModulusRemainder alignment; void trim_bounds_using_alignment() { if (alignment.modulus == 0) { - min_defined = max_defined = true; - min = max = alignment.remainder; + bounds = ConstantInterval::single_point(alignment.remainder); } else if (alignment.modulus > 1) { - if (min_defined) { + if (bounds.min_defined) { int64_t adjustment; - bool no_overflow = sub_with_overflow(64, alignment.remainder, mod_imp(min, alignment.modulus), &adjustment); + bool no_overflow = sub_with_overflow(64, alignment.remainder, mod_imp(bounds.min, alignment.modulus), &adjustment); adjustment = mod_imp(adjustment, alignment.modulus); int64_t new_min; - no_overflow &= add_with_overflow(64, min, adjustment, &new_min); + no_overflow &= add_with_overflow(64, bounds.min, adjustment, &new_min); if (no_overflow) { - min = new_min; + bounds.min = new_min; } } - if (max_defined) { + if (bounds.max_defined) { int64_t adjustment; - bool no_overflow = sub_with_overflow(64, mod_imp(max, alignment.modulus), alignment.remainder, &adjustment); + bool no_overflow = sub_with_overflow(64, mod_imp(bounds.max, alignment.modulus), alignment.remainder, &adjustment); adjustment = mod_imp(adjustment, alignment.modulus); int64_t new_max; - no_overflow &= sub_with_overflow(64, max, adjustment, &new_max); + no_overflow &= sub_with_overflow(64, bounds.max, adjustment, &new_max); if (no_overflow) { - max = new_max; + bounds.max = new_max; } } } - if (min_defined && max_defined && min == max) { + if (bounds.is_single_point()) { alignment.modulus = 0; - alignment.remainder = min; + alignment.remainder = bounds.min; + } + + if (bounds.is_bounded() && bounds.min > bounds.max) { + // Impossible, we must be in unreachable code. TODO: surface + // this to the simplify instance's in_unreachable flag. + bounds.max = bounds.min; } } - // Mix in existing knowledge about this Expr - void intersect(const ExprInfo &other) { - if (min_defined && other.min_defined) { - min = std::max(min, other.min); - } else if (other.min_defined) { - min_defined = true; - min = other.min; + void cast_to(Type t) { + if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) { + return; } - if (max_defined && other.max_defined) { - max = std::min(max, other.max); - } else if (other.max_defined) { - max_defined = true; - max = other.max; + // We've just done some infinite-integer operation on a bounded + // integer type, and we need to project the bounds and alignment + // back in-range. + + if (!t.can_represent(bounds)) { + if (t.bits() >= 64) { + // Just preserve any power-of-two factor in the modulus. When + // alignment.modulus == 0, the value is some positive constant + // representable as any 64-bit integer type, so there's no + // wraparound. + if (alignment.modulus > 0) { + // This masks off all bits except for the lowest set one, + // giving the largest power-of-two factor of a number. + alignment.modulus &= -alignment.modulus; + alignment.remainder = mod_imp(alignment.remainder, alignment.modulus); + } + } else { + // A narrowing integer cast that could possibly overflow adds + // some unknown multiple of 2^bits + alignment = alignment + ModulusRemainder(((int64_t)1 << t.bits()), 0); + } } - alignment = ModulusRemainder::intersect(alignment, other.alignment); + // Truncate the bounds to the new type. + bounds.cast_to(t); + } + // Mix in existing knowledge about this Expr + void intersect(const ExprInfo &other) { + if (bounds < other.bounds || other.bounds < bounds) { + // Impossible. We must be in unreachable code. TODO: It might + // be nice to surface this to the simplify instance's + // in_unreachable flag, but we'd have to be sure that it's going + // to be caught at the right place. + return; + } + bounds = ConstantInterval::make_intersection(bounds, other.bounds); + alignment = ModulusRemainder::intersect(alignment, other.alignment); trim_bounds_using_alignment(); } }; HALIDE_ALWAYS_INLINE - void clear_bounds_info(ExprInfo *b) { + void clear_expr_info(ExprInfo *b) { if (b) { *b = ExprInfo{}; } } #if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS) - static int debug_indent; + int debug_indent = 0; #endif #if LOG_EXPR_MUTATIONS Expr mutate(const Expr &e, ExprInfo *b) { + internal_assert(debug_indent >= 0); const std::string spaces(debug_indent, ' '); debug(1) << spaces << "Simplifying Expr: " << e << "\n"; debug_indent++; @@ -130,6 +150,19 @@ class Simplify : public VariadicVisitor { debug(1) << spaces << "Before: " << e << "\n" << spaces << "After: " << new_e << "\n"; + if (b) { + debug(1) + << spaces << "Bounds: " << b->bounds << " " << b->alignment << "\n"; + if (const int64_t *i = as_const_int(new_e)) { + internal_assert(b->bounds.contains(*i)) << e << "\n" + << new_e << "\n" + << b->bounds; + } else if (const uint64_t *i = as_const_uint(new_e)) { + internal_assert(b->bounds.contains(*i)) << e << "\n" + << new_e << "\n" + << b->bounds; + } + } } internal_assert(e.type() == new_e.type()); return new_e; @@ -298,45 +331,45 @@ class Simplify : public VariadicVisitor { Stmt mutate_let_body(const Stmt &s, ExprInfo *) { return mutate(s); } - Expr mutate_let_body(const Expr &e, ExprInfo *bounds) { - return mutate(e, bounds); + Expr mutate_let_body(const Expr &e, ExprInfo *info) { + return mutate(e, info); } template - Body simplify_let(const T *op, ExprInfo *bounds); - - Expr visit(const IntImm *op, ExprInfo *bounds); - Expr visit(const UIntImm *op, ExprInfo *bounds); - Expr visit(const FloatImm *op, ExprInfo *bounds); - Expr visit(const StringImm *op, ExprInfo *bounds); - Expr visit(const Broadcast *op, ExprInfo *bounds); - Expr visit(const Cast *op, ExprInfo *bounds); - Expr visit(const Reinterpret *op, ExprInfo *bounds); - Expr visit(const Variable *op, ExprInfo *bounds); - Expr visit(const Add *op, ExprInfo *bounds); - Expr visit(const Sub *op, ExprInfo *bounds); - Expr visit(const Mul *op, ExprInfo *bounds); - Expr visit(const Div *op, ExprInfo *bounds); - Expr visit(const Mod *op, ExprInfo *bounds); - Expr visit(const Min *op, ExprInfo *bounds); - Expr visit(const Max *op, ExprInfo *bounds); - Expr visit(const EQ *op, ExprInfo *bounds); - Expr visit(const NE *op, ExprInfo *bounds); - Expr visit(const LT *op, ExprInfo *bounds); - Expr visit(const LE *op, ExprInfo *bounds); - Expr visit(const GT *op, ExprInfo *bounds); - Expr visit(const GE *op, ExprInfo *bounds); - Expr visit(const And *op, ExprInfo *bounds); - Expr visit(const Or *op, ExprInfo *bounds); - Expr visit(const Not *op, ExprInfo *bounds); - Expr visit(const Select *op, ExprInfo *bounds); - Expr visit(const Ramp *op, ExprInfo *bounds); + Body simplify_let(const T *op, ExprInfo *info); + + Expr visit(const IntImm *op, ExprInfo *info); + Expr visit(const UIntImm *op, ExprInfo *info); + Expr visit(const FloatImm *op, ExprInfo *info); + Expr visit(const StringImm *op, ExprInfo *info); + Expr visit(const Broadcast *op, ExprInfo *info); + Expr visit(const Cast *op, ExprInfo *info); + Expr visit(const Reinterpret *op, ExprInfo *info); + Expr visit(const Variable *op, ExprInfo *info); + Expr visit(const Add *op, ExprInfo *info); + Expr visit(const Sub *op, ExprInfo *info); + Expr visit(const Mul *op, ExprInfo *info); + Expr visit(const Div *op, ExprInfo *info); + Expr visit(const Mod *op, ExprInfo *info); + Expr visit(const Min *op, ExprInfo *info); + Expr visit(const Max *op, ExprInfo *info); + Expr visit(const EQ *op, ExprInfo *info); + Expr visit(const NE *op, ExprInfo *info); + Expr visit(const LT *op, ExprInfo *info); + Expr visit(const LE *op, ExprInfo *info); + Expr visit(const GT *op, ExprInfo *info); + Expr visit(const GE *op, ExprInfo *info); + Expr visit(const And *op, ExprInfo *info); + Expr visit(const Or *op, ExprInfo *info); + Expr visit(const Not *op, ExprInfo *info); + Expr visit(const Select *op, ExprInfo *info); + Expr visit(const Ramp *op, ExprInfo *info); Stmt visit(const IfThenElse *op); - Expr visit(const Load *op, ExprInfo *bounds); - Expr visit(const Call *op, ExprInfo *bounds); - Expr visit(const Shuffle *op, ExprInfo *bounds); - Expr visit(const VectorReduce *op, ExprInfo *bounds); - Expr visit(const Let *op, ExprInfo *bounds); + Expr visit(const Load *op, ExprInfo *info); + Expr visit(const Call *op, ExprInfo *info); + Expr visit(const Shuffle *op, ExprInfo *info); + Expr visit(const VectorReduce *op, ExprInfo *info); + Expr visit(const Let *op, ExprInfo *info); Stmt visit(const LetStmt *op); Stmt visit(const AssertStmt *op); Stmt visit(const For *op); @@ -354,7 +387,7 @@ class Simplify : public VariadicVisitor { Stmt visit(const Atomic *op); Stmt visit(const HoistedStorage *op); - std::pair, bool> mutate_with_changes(const std::vector &old_exprs, ExprInfo *bounds); + std::pair, bool> mutate_with_changes(const std::vector &old_exprs); }; } // namespace Internal diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 58c6d4d27ab3..c9ac45c349d7 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -3,10 +3,10 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const LT *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); +Expr Simplify::visit(const LT *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); const int lanes = op->type.lanes(); Type ty = a.type(); @@ -20,11 +20,9 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { if (may_simplify(ty)) { // Prove or disprove using bounds analysis - if (a_bounds.max_defined && b_bounds.min_defined && a_bounds.max < b_bounds.min) { + if (a_info.bounds < b_info.bounds) { return const_true(lanes); - } - - if (a_bounds.min_defined && b_bounds.max_defined && a_bounds.min >= b_bounds.max) { + } else if (a_info.bounds >= b_info.bounds) { return const_false(lanes); } @@ -499,7 +497,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { c1 * (lanes - 1) < c0 && c1 * (lanes - 1) >= 0) ))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } @@ -512,7 +510,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { } // The other comparison operators redirect to the less-than operator -Expr Simplify::visit(const LE *op, ExprInfo *bounds) { +Expr Simplify::visit(const LE *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -523,7 +521,7 @@ Expr Simplify::visit(const LE *op, ExprInfo *bounds) { } } - Expr mutated = mutate(!(op->b < op->a), bounds); + Expr mutated = mutate(!(op->b < op->a), info); if (const LE *le = mutated.as()) { if (le->a.same_as(op->a) && le->b.same_as(op->b)) { return op; @@ -532,7 +530,7 @@ Expr Simplify::visit(const LE *op, ExprInfo *bounds) { return mutated; } -Expr Simplify::visit(const GT *op, ExprInfo *bounds) { +Expr Simplify::visit(const GT *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -543,10 +541,10 @@ Expr Simplify::visit(const GT *op, ExprInfo *bounds) { } } - return mutate(op->b < op->a, bounds); + return mutate(op->b < op->a, info); } -Expr Simplify::visit(const GE *op, ExprInfo *bounds) { +Expr Simplify::visit(const GE *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -557,7 +555,7 @@ Expr Simplify::visit(const GE *op, ExprInfo *bounds) { } } - return mutate(!(op->a < op->b), bounds); + return mutate(!(op->a < op->b), info); } } // namespace Internal diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 342281fa6639..13fbd575d75f 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -61,7 +61,7 @@ void find_var_uses(StmtOrExpr x, std::unordered_set &unused_vars) { } // namespace template -Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { +Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { // Lets are often deeply nested. Get the intermediate state off // the call stack where it could overflow onto an explicit stack. @@ -89,8 +89,8 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { // If the value is trivial, make a note of it in the scope so // we can subs it in later - ExprInfo value_bounds; - f.value = mutate(op->value, &value_bounds); + ExprInfo value_info; + f.value = mutate(op->value, &value_info); // Iteratively peel off certain operations from the let value and push them inside. f.new_value = f.value; @@ -222,21 +222,24 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { var_info.push(op->name, info); // Before we enter the body, track the alignment info - if (f.new_value.defined() && no_overflow_scalar_int(f.new_value.type())) { // Remutate new_value to get updated bounds - ExprInfo new_value_bounds; - f.new_value = mutate(f.new_value, &new_value_bounds); - if (new_value_bounds.min_defined || new_value_bounds.max_defined || new_value_bounds.alignment.modulus != 1) { + ExprInfo new_value_info; + f.new_value = mutate(f.new_value, &new_value_info); + if (new_value_info.bounds.min_defined || + new_value_info.bounds.max_defined || + new_value_info.alignment.modulus != 1) { // There is some useful information - bounds_and_alignment_info.push(f.new_name, new_value_bounds); + bounds_and_alignment_info.push(f.new_name, new_value_info); f.new_value_bounds_tracked = true; } } if (no_overflow_scalar_int(f.value.type())) { - if (value_bounds.min_defined || value_bounds.max_defined || value_bounds.alignment.modulus != 1) { - bounds_and_alignment_info.push(op->name, value_bounds); + if (value_info.bounds.min_defined || + value_info.bounds.max_defined || + value_info.alignment.modulus != 1) { + bounds_and_alignment_info.push(op->name, value_info); f.value_bounds_tracked = true; } } @@ -245,7 +248,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { op = result.template as(); } - result = mutate_let_body(result, bounds); + result = mutate_let_body(result, info); // TODO: var_info and unused_vars are pretty redundant; however, at the time // of writing, both cover cases that the other does not: @@ -310,8 +313,8 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { return result; } -Expr Simplify::visit(const Let *op, ExprInfo *bounds) { - return simplify_let(op, bounds); +Expr Simplify::visit(const Let *op, ExprInfo *info) { + return simplify_let(op, info); } Stmt Simplify::visit(const LetStmt *op) { diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index 1a79aef962fa..6f3ecc1999f7 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -3,44 +3,33 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Max *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds) { - bounds->min_defined = a_bounds.min_defined || b_bounds.min_defined; - bounds->max_defined = a_bounds.max_defined && b_bounds.max_defined; - bounds->max = std::max(a_bounds.max, b_bounds.max); - if (a_bounds.min_defined && b_bounds.min_defined) { - bounds->min = std::max(a_bounds.min, b_bounds.min); - } else if (a_bounds.min_defined) { - bounds->min = a_bounds.min; - } else { - bounds->min = b_bounds.min; - } - bounds->alignment = ModulusRemainder::unify(a_bounds.alignment, b_bounds.alignment); - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Max *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + info->bounds = max(a_info.bounds, b_info.bounds); + info->alignment = ModulusRemainder::unify(a_info.alignment, b_info.alignment); + info->trim_bounds_using_alignment(); } - // Early out when the bounds tells us one side or the other is smaller - if (a_bounds.max_defined && b_bounds.min_defined && a_bounds.max <= b_bounds.min) { - if (const Call *call = b.as()) { + auto strip_likely = [](const Expr &e) { + if (const Call *call = e.as()) { if (call->is_intrinsic(Call::likely) || call->is_intrinsic(Call::likely_if_innermost)) { return call->args[0]; } } - return b; + return e; + }; + + // Early out when the bounds tells us one side or the other is smaller + if (a_info.bounds <= b_info.bounds) { + return strip_likely(b); } - if (b_bounds.max_defined && a_bounds.min_defined && b_bounds.max <= a_bounds.min) { - if (const Call *call = a.as()) { - if (call->is_intrinsic(Call::likely) || - call->is_intrinsic(Call::likely_if_innermost)) { - return call->args[0]; - } - } - return a; + if (b_info.bounds <= a_info.bounds) { + return strip_likely(a); } if (may_simplify(op->type)) { @@ -48,7 +37,7 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } int lanes = op->type.lanes(); @@ -301,7 +290,7 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { rewrite(max(c0 - x, c1), c0 - min(x, fold(c0 - c1))))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Min.cpp b/src/Simplify_Min.cpp index 214ed09374d3..41e455174351 100644 --- a/src/Simplify_Min.cpp +++ b/src/Simplify_Min.cpp @@ -3,44 +3,34 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Min *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds) { - bounds->min_defined = a_bounds.min_defined && b_bounds.min_defined; - bounds->max_defined = a_bounds.max_defined || b_bounds.max_defined; - bounds->min = std::min(a_bounds.min, b_bounds.min); - if (a_bounds.max_defined && b_bounds.max_defined) { - bounds->max = std::min(a_bounds.max, b_bounds.max); - } else if (a_bounds.max_defined) { - bounds->max = a_bounds.max; - } else { - bounds->max = b_bounds.max; - } - bounds->alignment = ModulusRemainder::unify(a_bounds.alignment, b_bounds.alignment); - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Min *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + info->bounds = min(a_info.bounds, b_info.bounds); + info->alignment = ModulusRemainder::unify(a_info.alignment, b_info.alignment); + info->trim_bounds_using_alignment(); } // Early out when the bounds tells us one side or the other is smaller - if (a_bounds.max_defined && b_bounds.min_defined && a_bounds.max <= b_bounds.min) { - if (const Call *call = a.as()) { + auto strip_likely = [](const Expr &e) { + if (const Call *call = e.as()) { if (call->is_intrinsic(Call::likely) || call->is_intrinsic(Call::likely_if_innermost)) { return call->args[0]; } } - return a; + return e; + }; + + // Early out when the bounds tells us one side or the other is smaller + if (a_info.bounds >= b_info.bounds) { + return strip_likely(b); } - if (b_bounds.max_defined && a_bounds.min_defined && b_bounds.max <= a_bounds.min) { - if (const Call *call = b.as()) { - if (call->is_intrinsic(Call::likely) || - call->is_intrinsic(Call::likely_if_innermost)) { - return call->args[0]; - } - } - return b; + if (b_info.bounds >= a_info.bounds) { + return strip_likely(a); } if (may_simplify(op->type)) { @@ -48,7 +38,7 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } int lanes = op->type.lanes(); @@ -312,7 +302,7 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { false )))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Mod.cpp b/src/Simplify_Mod.cpp index fcd4021b759f..dbfcfcb14f81 100644 --- a/src/Simplify_Mod.cpp +++ b/src/Simplify_Mod.cpp @@ -3,60 +3,33 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Mod *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); +Expr Simplify::visit(const Mod *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); // We always combine bounds here, even if not requested, because // we can use them to simplify down to a constant if the bounds // are tight enough. - ExprInfo mod_bounds; - - if (no_overflow_int(op->type)) { - // The result is at least zero. - mod_bounds.min_defined = true; - mod_bounds.min = 0; - - // Mod by produces a result between 0 - // and max(0, abs(modulus) - 1). However, if b is unbounded in - // either direction, abs(modulus) could be arbitrarily - // large. - if (b_bounds.max_defined && b_bounds.min_defined) { - mod_bounds.max_defined = true; - mod_bounds.max = 0; // When b == 0 - mod_bounds.max = std::max(mod_bounds.max, b_bounds.max - 1); // When b > 0 - mod_bounds.max = std::max(mod_bounds.max, -1 - b_bounds.min); // When b < 0 - } - - // If a is positive, mod can't make it larger - if (a_bounds.min_defined && a_bounds.min >= 0 && a_bounds.max_defined) { - if (mod_bounds.max_defined) { - mod_bounds.max = std::min(mod_bounds.max, a_bounds.max); - } else { - mod_bounds.max_defined = true; - mod_bounds.max = a_bounds.max; - } - } - - mod_bounds.alignment = a_bounds.alignment % b_bounds.alignment; - mod_bounds.trim_bounds_using_alignment(); - if (bounds) { - *bounds = mod_bounds; - } + ExprInfo mod_info; + if (op->type.is_int_or_uint()) { + mod_info.bounds = a_info.bounds % b_info.bounds; + mod_info.alignment = a_info.alignment % b_info.alignment; + mod_info.trim_bounds_using_alignment(); + // Modulo can't overflow, so no mod_info.cast_to(op->type) + } + // TODO: Modulo bounds for floating-point modulo + if (info) { + *info = mod_info; } if (may_simplify(op->type)) { - if (a_bounds.min_defined && a_bounds.min >= 0 && - a_bounds.max_defined && b_bounds.min_defined && a_bounds.max < b_bounds.min) { - if (bounds) { - *bounds = a_bounds; - } + if (a_info.bounds >= 0 && a_info.bounds < b_info.bounds) { return a; } - if (mod_bounds.min_defined && mod_bounds.max_defined && mod_bounds.min == mod_bounds.max) { - return make_const(op->type, mod_bounds.min); + if (mod_info.bounds.is_single_point()) { + return make_const(op->type, mod_info.bounds.min); } int lanes = op->type.lanes(); @@ -94,7 +67,7 @@ Expr Simplify::visit(const Mod *op, ExprInfo *bounds) { rewrite(ramp(x + c0, c2, c3) % broadcast(c1, c3), ramp(x + fold(c0 % c1), fold(c2 % c1), c3) % c1, c1 > 0 && (c0 >= c1 || c0 < 0)) || rewrite(ramp(x * c0 + y, c2, c3) % broadcast(c1, c3), ramp(y, fold(c2 % c1), c3) % c1, c0 % c1 == 0) || rewrite(ramp(y + x * c0, c2, c3) % broadcast(c1, c3), ramp(y, fold(c2 % c1), c3) % c1, c0 % c1 == 0))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Mul.cpp b/src/Simplify_Mul.cpp index 881d09112f7d..446f420c6c91 100644 --- a/src/Simplify_Mul.cpp +++ b/src/Simplify_Mul.cpp @@ -3,49 +3,16 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds && no_overflow_int(op->type)) { - bool a_positive = a_bounds.min_defined && a_bounds.min > 0; - bool b_positive = b_bounds.min_defined && b_bounds.min > 0; - bool a_bounded = a_bounds.min_defined && a_bounds.max_defined; - bool b_bounded = b_bounds.min_defined && b_bounds.max_defined; - - if (a_bounded && b_bounded) { - bounds->min_defined = bounds->max_defined = true; - int64_t v1 = saturating_mul(a_bounds.min, b_bounds.min); - int64_t v2 = saturating_mul(a_bounds.min, b_bounds.max); - int64_t v3 = saturating_mul(a_bounds.max, b_bounds.min); - int64_t v4 = saturating_mul(a_bounds.max, b_bounds.max); - bounds->min = std::min(std::min(v1, v2), std::min(v3, v4)); - bounds->max = std::max(std::max(v1, v2), std::max(v3, v4)); - } else if ((a_bounds.max_defined && b_bounded && b_positive) || - (b_bounds.max_defined && a_bounded && a_positive)) { - bounds->max_defined = true; - bounds->max = saturating_mul(a_bounds.max, b_bounds.max); - } else if ((a_bounds.min_defined && b_bounded && b_positive) || - (b_bounds.min_defined && a_bounded && a_positive)) { - bounds->min_defined = true; - bounds->min = saturating_mul(a_bounds.min, b_bounds.min); - } - - if (bounds->max_defined && bounds->max == INT64_MAX) { - // Assume it saturated to avoid overflow. This gives up a - // single representable value at the top end of the range - // to represent infinity. - bounds->max_defined = false; - bounds->max = 0; - } - if (bounds->min_defined && bounds->min == INT64_MIN) { - bounds->min_defined = false; - bounds->min = 0; - } - - bounds->alignment = a_bounds.alignment * b_bounds.alignment; - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Mul *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + info->bounds = a_info.bounds * b_info.bounds; + info->alignment = a_info.alignment * b_info.alignment; + info->trim_bounds_using_alignment(); + info->cast_to(op->type); } if (may_simplify(op->type)) { @@ -53,7 +20,7 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } auto rewrite = IRMatcher::rewriter(IRMatcher::mul(a, b), op->type); @@ -103,7 +70,7 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { rewrite(slice(x, c0, c1, c2) * (z * slice(y, c0, c1, c2)), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } } diff --git a/src/Simplify_Not.cpp b/src/Simplify_Not.cpp index 70b4b234ddef..47b74661fd2c 100644 --- a/src/Simplify_Not.cpp +++ b/src/Simplify_Not.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Not *op, ExprInfo *bounds) { +Expr Simplify::visit(const Not *op, ExprInfo *info) { Expr a = mutate(op->a, nullptr); auto rewrite = IRMatcher::rewriter(IRMatcher::not_op(a), op->type); @@ -25,7 +25,7 @@ Expr Simplify::visit(const Not *op, ExprInfo *bounds) { rewrite(!(x && !y), !x || y) || rewrite(!(x || !y), !x && y) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (a.same_as(op->a)) { diff --git a/src/Simplify_Or.cpp b/src/Simplify_Or.cpp index 274d66435ffb..083af6d5bc88 100644 --- a/src/Simplify_Or.cpp +++ b/src/Simplify_Or.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Or *op, ExprInfo *bounds) { +Expr Simplify::visit(const Or *op, ExprInfo *info) { if (truths.count(op)) { return const_true(op->type.lanes()); } @@ -101,7 +101,7 @@ Expr Simplify::visit(const Or *op, ExprInfo *bounds) { rewrite(x <= y || x <= z, x <= max(y, z)) || rewrite(y <= x || z <= x, min(y, z) <= x)) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (a.same_as(op->a) && diff --git a/src/Simplify_Reinterpret.cpp b/src/Simplify_Reinterpret.cpp index d5a8c1361fbe..51289aac9b87 100644 --- a/src/Simplify_Reinterpret.cpp +++ b/src/Simplify_Reinterpret.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Reinterpret *op, ExprInfo *bounds) { +Expr Simplify::visit(const Reinterpret *op, ExprInfo *info) { Expr a = mutate(op->value, nullptr); int64_t ia; @@ -19,7 +19,7 @@ Expr Simplify::visit(const Reinterpret *op, ExprInfo *bounds) { return make_const(op->type, (int64_t)ua); } else if (const Reinterpret *as_r = a.as()) { // Fold double-reinterprets. - return mutate(reinterpret(op->type, as_r->value), bounds); + return mutate(reinterpret(op->type, as_r->value), info); } else if ((op->type.bits() == a.type().bits()) && op->type.is_int_or_uint() && a.type().is_int_or_uint()) { diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index 0233be61724d..63be8d64718e 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -3,20 +3,17 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Select *op, ExprInfo *bounds) { +Expr Simplify::visit(const Select *op, ExprInfo *info) { - ExprInfo t_bounds, f_bounds; + ExprInfo t_info, f_info; Expr condition = mutate(op->condition, nullptr); - Expr true_value = mutate(op->true_value, &t_bounds); - Expr false_value = mutate(op->false_value, &f_bounds); - - if (bounds) { - bounds->min_defined = t_bounds.min_defined && f_bounds.min_defined; - bounds->max_defined = t_bounds.max_defined && f_bounds.max_defined; - bounds->min = std::min(t_bounds.min, f_bounds.min); - bounds->max = std::max(t_bounds.max, f_bounds.max); - bounds->alignment = ModulusRemainder::unify(t_bounds.alignment, f_bounds.alignment); - bounds->trim_bounds_using_alignment(); + Expr true_value = mutate(op->true_value, &t_info); + Expr false_value = mutate(op->false_value, &f_info); + + if (info) { + info->bounds = ConstantInterval::make_union(t_info.bounds, f_info.bounds); + info->alignment = ModulusRemainder::unify(t_info.alignment, f_info.alignment); + info->trim_bounds_using_alignment(); } if (may_simplify(op->type)) { @@ -230,7 +227,7 @@ Expr Simplify::visit(const Select *op, ExprInfo *bounds) { rewrite(select(x, y, true), !x || y) || rewrite(select(x, false, y), !x && y) || rewrite(select(x, true, y), x || y))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 7da4f6699ab7..348289ab0c83 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -7,7 +7,7 @@ namespace Internal { using std::vector; -Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { +Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { if (op->is_extract_element()) { int index = op->indices[0]; internal_assert(index >= 0); @@ -18,7 +18,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { // the same shuffle back. break; } else { - return extract_lane(mutate(vector, bounds), index); + return extract_lane(mutate(vector, info), index); } } index -= vector.type().lanes(); @@ -29,20 +29,17 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { vector new_vectors; bool changed = false; for (const Expr &vector : op->vectors) { - ExprInfo v_bounds; - Expr new_vector = mutate(vector, &v_bounds); + ExprInfo v_info; + Expr new_vector = mutate(vector, &v_info); if (!vector.same_as(new_vector)) { changed = true; } - if (bounds) { + if (info) { if (new_vectors.empty()) { - *bounds = v_bounds; + *info = v_info; } else { - bounds->min_defined &= v_bounds.min_defined; - bounds->max_defined &= v_bounds.max_defined; - bounds->min = std::min(bounds->min, v_bounds.min); - bounds->max = std::max(bounds->max, v_bounds.max); - bounds->alignment = ModulusRemainder::unify(bounds->alignment, v_bounds.alignment); + info->bounds = ConstantInterval::make_union(info->bounds, v_info.bounds); + info->alignment = ModulusRemainder::unify(info->alignment, v_info.alignment); } } new_vectors.push_back(new_vector); @@ -141,7 +138,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { } } if (can_collapse) { - return mutate(Ramp::make(r->base, r->stride / terms, r->lanes * terms), bounds); + return mutate(Ramp::make(r->base, r->stride / terms, r->lanes * terms), info); } } @@ -272,7 +269,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { if (cast->type.bits() > cast->value.type().bits()) { return mutate(Cast::make(cast->type.with_lanes(op->type.lanes()), Shuffle::make({cast->value}, op->indices)), - bounds); + info); } } } diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index f6cb81345961..3645ebbf4369 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -203,12 +203,12 @@ Stmt Simplify::visit(const AssertStmt *op) { } Stmt Simplify::visit(const For *op) { - ExprInfo min_bounds, extent_bounds; - Expr new_min = mutate(op->min, &min_bounds); + ExprInfo min_info, extent_info; + Expr new_min = mutate(op->min, &min_info); if (in_unreachable) { return Evaluate::make(new_min); } - Expr new_extent = mutate(op->extent, &extent_bounds); + Expr new_extent = mutate(op->extent, &extent_info); if (in_unreachable) { return Evaluate::make(new_extent); } @@ -217,34 +217,43 @@ Stmt Simplify::visit(const For *op) { (in_vector_loop || op->for_type == ForType::Vectorized)); - bool bounds_tracked = false; - if (min_bounds.min_defined || (min_bounds.max_defined && extent_bounds.max_defined)) { - min_bounds.max += extent_bounds.max - 1; - min_bounds.max_defined &= extent_bounds.max_defined; - min_bounds.alignment = ModulusRemainder{}; - bounds_tracked = true; - bounds_and_alignment_info.push(op->name, min_bounds); + Expr extent_positive = mutate(0 < new_extent, nullptr); + if (is_const_zero(extent_positive)) { + // This loop never runs + return Evaluate::make(0); } + ExprInfo loop_var_info; + // Deduce bounds for the loop var that are true for any code than runs + // inside the loop body. Code in the inner loop only runs if the extent is + // at least one, so we can throw a max around the extent bounds. + + loop_var_info.bounds = + ConstantInterval::make_union(min_info.bounds, + min_info.bounds + max(extent_info.bounds, 1) - 1); Stmt new_body; { + ScopedBinding bind_if((loop_var_info.bounds.max_defined || + loop_var_info.bounds.min_defined), + bounds_and_alignment_info, + op->name, + loop_var_info); + // If we're in the loop, the extent must be greater than 0. - ScopedFact fact = scoped_truth(0 < new_extent); + ScopedFact fact = scoped_truth(extent_positive); new_body = mutate(op->body); } + if (in_unreachable) { - if (extent_bounds.min_defined && extent_bounds.min >= 1) { - // If we know the loop executes once, the code that runs this loop is unreachable. - return new_body; - } - in_unreachable = false; + // We found that the body of this loop is unreachable when recursively + // mutating it, so we can remove the loop. Additionally, if we know the + // extent is greater than zero, then the code *outside* the loop must be + // unreachable too, because if it weren't, it'd run the unreachable body + // at least once. + in_unreachable = extent_info.bounds > 0; return Evaluate::make(0); } - if (bounds_tracked) { - bounds_and_alignment_info.pop(op->name); - } - if (const Acquire *acquire = new_body.as()) { if (is_no_op(acquire->body)) { // Rewrite iterated no-op acquires as a single acquire. @@ -254,14 +263,14 @@ Stmt Simplify::visit(const For *op) { if (is_no_op(new_body)) { return new_body; - } else if (extent_bounds.max_defined && - extent_bounds.max <= 0) { + } else if (extent_info.bounds <= 0) { return Evaluate::make(0); - } else if (extent_bounds.max_defined && - extent_bounds.max <= 1 && + } else if (extent_info.bounds <= 1 && op->device_api == DeviceAPI::None) { + // Loop body runs at most once Stmt s = LetStmt::make(op->name, new_min, new_body); - if (extent_bounds.min < 1) { + if (extent_info.bounds.contains(0)) { + // Loop body might not run at all s = IfThenElse::make(0 < new_extent, s); } return mutate(s); @@ -280,8 +289,8 @@ Stmt Simplify::visit(const Provide *op) { found_buffer_reference(op->name, op->args.size()); // Mutate the args - auto [new_args, changed_args] = mutate_with_changes(op->args, nullptr); - auto [new_values, changed_values] = mutate_with_changes(op->values, nullptr); + auto [new_args, changed_args] = mutate_with_changes(op->args); + auto [new_values, changed_values] = mutate_with_changes(op->values); Expr new_predicate = mutate(op->predicate, nullptr); if (!(changed_args || changed_values) && new_predicate.same_as(op->predicate)) { @@ -307,17 +316,11 @@ Stmt Simplify::visit(const Store *op) { string alloc_extent_name = op->name + ".total_extent_bytes"; if (is_const_one(op->predicate)) { if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { - if (index_info.max_defined && index_info.max < 0) { + if (index_info.bounds < 0 || + index_info.bounds * op->value.type().bytes() > alloc_info->bounds) { in_unreachable = true; return Evaluate::make(unreachable()); } - if (alloc_info->max_defined && index_info.min_defined) { - int index_min_bytes = index_info.min * op->value.type().bytes(); - if (index_min_bytes > alloc_info->max) { - in_unreachable = true; - return Evaluate::make(unreachable()); - } - } } } @@ -356,33 +359,14 @@ Stmt Simplify::visit(const Allocate *op) { std::vector new_extents; bool all_extents_unmodified = true; ExprInfo total_extent_info; - total_extent_info.min_defined = true; - total_extent_info.max_defined = true; - total_extent_info.min = 1; - total_extent_info.max = 1; + total_extent_info.bounds = ConstantInterval::single_point(op->type.bytes()); for (size_t i = 0; i < op->extents.size(); i++) { ExprInfo extent_info; new_extents.push_back(mutate(op->extents[i], &extent_info)); all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); - if (extent_info.min_defined) { - total_extent_info.min *= extent_info.min; - } else { - total_extent_info.min_defined = false; - } - if (extent_info.max_defined) { - total_extent_info.max *= extent_info.max; - } else { - total_extent_info.max_defined = false; - } - } - if (total_extent_info.min_defined) { - total_extent_info.min *= op->type.bytes(); - total_extent_info.min -= 1; - } - if (total_extent_info.max_defined) { - total_extent_info.max *= op->type.bytes(); - total_extent_info.max -= 1; + total_extent_info.bounds *= extent_info.bounds; } + total_extent_info.bounds -= 1; ScopedBinding b(bounds_and_alignment_info, op->name + ".total_extent_bytes", total_extent_info); diff --git a/src/Simplify_Sub.cpp b/src/Simplify_Sub.cpp index f3a06ca28949..cf21205f13d1 100644 --- a/src/Simplify_Sub.cpp +++ b/src/Simplify_Sub.cpp @@ -3,23 +3,19 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Sub *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); +Expr Simplify::visit(const Sub *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); - if (bounds && no_overflow_int(op->type)) { + if (info) { // Doesn't account for correlated a, b, so any // cancellation rule that exploits that should always // remutate to recalculate the bounds. - bounds->min_defined = a_bounds.min_defined && - b_bounds.max_defined && - sub_with_overflow(64, a_bounds.min, b_bounds.max, &(bounds->min)); - bounds->max_defined = a_bounds.max_defined && - b_bounds.min_defined && - sub_with_overflow(64, a_bounds.max, b_bounds.min, &(bounds->max)); - bounds->alignment = a_bounds.alignment - b_bounds.alignment; - bounds->trim_bounds_using_alignment(); + info->bounds = a_info.bounds - b_info.bounds; + info->alignment = a_info.alignment - b_info.alignment; + info->trim_bounds_using_alignment(); + info->cast_to(op->type); } if (may_simplify(op->type)) { @@ -446,7 +442,7 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) { rewrite((min(z, x*c0 + y) + w) / c1 - x*c2, (min(z - x*c0, y) + w) / c1, c0 == c1 * c2) || false)))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } } // clang-format on diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index ae4a6776ac72..4246ba807220 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -122,6 +122,7 @@ tests(GROUPS correctness fused_where_inner_extent_is_zero.cpp fuzz_float_stores.cpp fuzz_schedule.cpp + fuzz_simplify.cpp gameoflife.cpp gather.cpp gpu_allocation_cache.cpp diff --git a/test/correctness/fuse.cpp b/test/correctness/fuse.cpp index 87ebcba3dbc4..d644e6fb741e 100644 --- a/test/correctness/fuse.cpp +++ b/test/correctness/fuse.cpp @@ -72,7 +72,7 @@ int main(int argc, char **argv) { Var xy("xy"); f.compute_root() .fuse(x, y, xy) - .vectorize(xy, 16); + .vectorize(xy, 16, TailStrategy::RoundUp); f.add_custom_lowering_pass(new CheckForMod); f.compile_jit(); diff --git a/test/fuzz/simplify.cpp b/test/correctness/fuzz_simplify.cpp similarity index 53% rename from test/fuzz/simplify.cpp rename to test/correctness/fuzz_simplify.cpp index ae4f8ee46a11..fd09316a5887 100644 --- a/test/fuzz/simplify.cpp +++ b/test/correctness/fuzz_simplify.cpp @@ -1,8 +1,6 @@ #include "Halide.h" -#include "fuzz_helpers.h" #include #include -#include #include #include #include @@ -25,21 +23,35 @@ std::string fuzz_var(int i) { return std::string(1, 'a' + i); } -Expr random_var(FuzzedDataProvider &fdp) { - int fuzz_count = fdp.ConsumeIntegralInRange(0, fuzz_var_count - 1); - return Variable::make(Int(0), fuzz_var(fuzz_count)); +Expr random_var(std::mt19937 &rng, Type t) { + int fuzz_count = rng() % (fuzz_var_count - 1); + return cast(t, Variable::make(Int(32), fuzz_var(fuzz_count))); } -Type random_type(FuzzedDataProvider &fdp, int width) { - Type t = fdp.PickValueInArray(fuzz_types); +template +T random_choice(std::mt19937 &rng, const T (&choices)[N]) { + return choices[rng() % N]; +} + +template +T random_choice(std::mt19937 &rng, const std::vector &choices) { + return choices[rng() % choices.size()]; +} + +template +T random_choice(std::mt19937 &rng, const std::array &choices) { + return choices[rng() % N]; +} +Type random_type(std::mt19937 &rng, int width) { + Type t = random_choice(rng, fuzz_types); if (width > 1) { t = t.with_lanes(width); } return t; } -int get_random_divisor(FuzzedDataProvider &fdp, Type t) { +int get_random_divisor(std::mt19937 &rng, Type t) { std::vector divisors = {t.lanes()}; for (int dd = 2; dd < t.lanes(); dd++) { if (t.lanes() % dd == 0) { @@ -47,43 +59,42 @@ int get_random_divisor(FuzzedDataProvider &fdp, Type t) { } } - return pick_value_in_vector(fdp, divisors); + return random_choice(rng, divisors); } -Expr random_leaf(FuzzedDataProvider &fdp, Type t, bool overflow_undef = false, bool imm_only = false) { +Expr random_leaf(std::mt19937 &rng, Type t, bool overflow_undef = false, bool imm_only = false) { if (t.is_int() && t.bits() == 32) { overflow_undef = true; } if (t.is_scalar()) { - if (!imm_only && fdp.ConsumeBool()) { - auto v1 = random_var(fdp); - return cast(t, v1); + if (!imm_only && (rng() & 1)) { + return random_var(rng, t); } else { if (overflow_undef) { // For Int(32), we don't care about correctness during // overflow, so just use numbers that are unlikely to // overflow. - return cast(t, fdp.ConsumeIntegralInRange(-128, 127)); + return cast(t, (int32_t)((int8_t)(rng() & 255))); } else { - return cast(t, fdp.ConsumeIntegral()); + return cast(t, (int32_t)(rng())); } } } else { - int lanes = get_random_divisor(fdp, t); - if (fdp.ConsumeBool()) { - auto e1 = random_leaf(fdp, t.with_lanes(t.lanes() / lanes), overflow_undef); - auto e2 = random_leaf(fdp, t.with_lanes(t.lanes() / lanes), overflow_undef); + int lanes = get_random_divisor(rng, t); + if (rng() & 1) { + auto e1 = random_leaf(rng, t.with_lanes(t.lanes() / lanes), overflow_undef); + auto e2 = random_leaf(rng, t.with_lanes(t.lanes() / lanes), overflow_undef); return Ramp::make(e1, e2, lanes); } else { - auto e1 = random_leaf(fdp, t.with_lanes(t.lanes() / lanes), overflow_undef); + auto e1 = random_leaf(rng, t.with_lanes(t.lanes() / lanes), overflow_undef); return Broadcast::make(e1, lanes); } } } -Expr random_expr(FuzzedDataProvider &fdp, Type t, int depth, bool overflow_undef = false); +Expr random_expr(std::mt19937 &rng, Type t, int depth, bool overflow_undef = false); -Expr random_condition(FuzzedDataProvider &fdp, Type t, int depth, bool maybe_scalar) { +Expr random_condition(std::mt19937 &rng, Type t, int depth, bool maybe_scalar) { static make_bin_op_fn make_bin_op[] = { EQ::make, NE::make, @@ -93,13 +104,13 @@ Expr random_condition(FuzzedDataProvider &fdp, Type t, int depth, bool maybe_sca GE::make, }; - if (maybe_scalar && fdp.ConsumeBool()) { + if (maybe_scalar && (rng() & 1)) { t = t.element_of(); } - Expr a = random_expr(fdp, t, depth); - Expr b = random_expr(fdp, t, depth); - return fdp.PickValueInArray(make_bin_op)(a, b); + Expr a = random_expr(rng, t, depth); + Expr b = random_expr(rng, t, depth); + return random_choice(rng, make_bin_op)(a, b); } Expr make_absd(Expr a, Expr b) { @@ -108,67 +119,67 @@ Expr make_absd(Expr a, Expr b) { return cast(a.type(), absd(a, b)); } -Expr random_expr(FuzzedDataProvider &fdp, Type t, int depth, bool overflow_undef) { +Expr random_expr(std::mt19937 &rng, Type t, int depth, bool overflow_undef) { if (t.is_int() && t.bits() == 32) { overflow_undef = true; } if (depth-- <= 0) { - return random_leaf(fdp, t, overflow_undef); + return random_leaf(rng, t, overflow_undef); } std::function operations[] = { [&]() { - return random_leaf(fdp, t); + return random_leaf(rng, t); }, [&]() { - auto c = random_condition(fdp, t, depth, true); - auto e1 = random_expr(fdp, t, depth, overflow_undef); - auto e2 = random_expr(fdp, t, depth, overflow_undef); + auto c = random_condition(rng, t, depth, true); + auto e1 = random_expr(rng, t, depth, overflow_undef); + auto e2 = random_expr(rng, t, depth, overflow_undef); return Select::make(c, e1, e2); }, [&]() { if (t.lanes() != 1) { - int lanes = get_random_divisor(fdp, t); - auto e1 = random_expr(fdp, t.with_lanes(t.lanes() / lanes), depth, overflow_undef); + int lanes = get_random_divisor(rng, t); + auto e1 = random_expr(rng, t.with_lanes(t.lanes() / lanes), depth, overflow_undef); return Broadcast::make(e1, lanes); } - return random_expr(fdp, t, depth, overflow_undef); + return random_expr(rng, t, depth, overflow_undef); }, [&]() { if (t.lanes() != 1) { - int lanes = get_random_divisor(fdp, t); - auto e1 = random_expr(fdp, t.with_lanes(t.lanes() / lanes), depth, overflow_undef); - auto e2 = random_expr(fdp, t.with_lanes(t.lanes() / lanes), depth, overflow_undef); + int lanes = get_random_divisor(rng, t); + auto e1 = random_expr(rng, t.with_lanes(t.lanes() / lanes), depth, overflow_undef); + auto e2 = random_expr(rng, t.with_lanes(t.lanes() / lanes), depth, overflow_undef); return Ramp::make(e1, e2, lanes); } - return random_expr(fdp, t, depth, overflow_undef); + return random_expr(rng, t, depth, overflow_undef); }, [&]() { if (t.is_bool()) { - auto e1 = random_expr(fdp, t, depth); + auto e1 = random_expr(rng, t, depth); return Not::make(e1); } - return random_expr(fdp, t, depth, overflow_undef); + return random_expr(rng, t, depth, overflow_undef); }, [&]() { // When generating boolean expressions, maybe throw in a condition on non-bool types. if (t.is_bool()) { - return random_condition(fdp, random_type(fdp, t.lanes()), depth, false); + return random_condition(rng, random_type(rng, t.lanes()), depth, false); } - return random_expr(fdp, t, depth, overflow_undef); + return random_expr(rng, t, depth, overflow_undef); }, [&]() { // Get a random type that isn't t or int32 (int32 can overflow and we don't care about that). - // Note also that the FuzzedDataProvider doesn't actually promise to return a random distribution -- + // Note also that the std::mt19937 doesn't actually promise to return a random distribution -- // it can (e.g.) decide to just return 0 for all data, forever -- so this loop has no guarantee // of eventually finding a different type. To remedy this, we'll just put a limit on the retries. int count = 0; Type subtype; do { - subtype = random_type(fdp, t.lanes()); + subtype = random_type(rng, t.lanes()); } while (++count < 10 && (subtype == t || (subtype.is_int() && subtype.bits() == 32))); - auto e1 = random_expr(fdp, subtype, depth, overflow_undef); + auto e1 = random_expr(rng, subtype, depth, overflow_undef); return Cast::make(t, e1); }, [&]() { @@ -184,9 +195,9 @@ Expr random_expr(FuzzedDataProvider &fdp, Type t, int depth, bool overflow_undef make_absd, }; - Expr a = random_expr(fdp, t, depth, overflow_undef); - Expr b = random_expr(fdp, t, depth, overflow_undef); - return fdp.PickValueInArray(make_bin_op)(a, b); + Expr a = random_expr(rng, t, depth, overflow_undef); + Expr b = random_expr(rng, t, depth, overflow_undef); + return random_choice(rng, make_bin_op)(a, b); }, [&]() { static make_bin_op_fn make_bin_op[] = { @@ -196,14 +207,14 @@ Expr random_expr(FuzzedDataProvider &fdp, Type t, int depth, bool overflow_undef // Boolean operations -- both sides must be cast to booleans, // and then we must cast the result back to 't'. - Expr a = random_expr(fdp, t, depth, overflow_undef); - Expr b = random_expr(fdp, t, depth, overflow_undef); + Expr a = random_expr(rng, t, depth, overflow_undef); + Expr b = random_expr(rng, t, depth, overflow_undef); Type bool_with_lanes = Bool(t.lanes()); a = cast(bool_with_lanes, a); b = cast(bool_with_lanes, b); - return cast(t, fdp.PickValueInArray(make_bin_op)(a, b)); + return cast(t, random_choice(rng, make_bin_op)(a, b)); }}; - return fdp.PickValueInArray(operations)(); + return random_choice(rng, operations)(); } bool test_simplification(Expr a, Expr b, Type t, const map &vars) { @@ -240,7 +251,7 @@ bool test_simplification(Expr a, Expr b, Type t, const map &vars) return true; } -bool test_expression(FuzzedDataProvider &fdp, Expr test, int samples) { +bool test_expression(std::mt19937 &rng, Expr test, int samples) { Expr simplified = simplify(test); map vars; @@ -254,7 +265,7 @@ bool test_expression(FuzzedDataProvider &fdp, Expr test, int samples) { // Don't let the random leaf depend on v itself. size_t iterations = 0; do { - v->second = random_leaf(fdp, test.type().element_of(), true); + v->second = random_leaf(rng, Int(32), true); iterations++; } while (expr_uses_var(v->second, v->first) && iterations < kMaxLeafIterations); } @@ -266,96 +277,62 @@ bool test_expression(FuzzedDataProvider &fdp, Expr test, int samples) { return true; } -// These are here to enable copy of failed output expressions -// and pasting them into the test for debugging; they are commented out -// to avoid "unused function" warnings in some build environments. -#if 0 -Expr ramp(Expr b, Expr s, int w) { - return Ramp::make(b, s, w); -} -Expr x1(Expr x) { - return Broadcast::make(x, 2); -} -Expr x2(Expr x) { - return Broadcast::make(x, 2); -} -Expr x3(Expr x) { - return Broadcast::make(x, 3); -} -Expr x4(Expr x) { - return Broadcast::make(x, 4); -} -Expr x6(Expr x) { - return Broadcast::make(x, 6); -} -Expr x8(Expr x) { - return Broadcast::make(x, 8); -} -Expr uint1(Expr x) { - return Cast::make(UInt(1), x); -} -Expr uint8(Expr x) { - return Cast::make(UInt(8), x); -} -Expr uint16(Expr x) { - return Cast::make(UInt(16), x); -} -Expr uint32(Expr x) { - return Cast::make(UInt(32), x); -} -Expr int8(Expr x) { - return Cast::make(Int(8), x); -} -Expr int16(Expr x) { - return Cast::make(Int(16), x); -} -Expr int32(Expr x) { - return Cast::make(Int(32), x); -} -Expr uint1x2(Expr x) { - return Cast::make(UInt(1).with_lanes(2), x); -} -Expr uint8x2(Expr x) { - return Cast::make(UInt(8).with_lanes(2), x); -} -Expr uint16x2(Expr x) { - return Cast::make(UInt(16).with_lanes(2), x); -} -Expr uint32x2(Expr x) { - return Cast::make(UInt(32).with_lanes(2), x); -} -Expr int8x2(Expr x) { - return Cast::make(Int(8).with_lanes(2), x); -} -Expr int16x2(Expr x) { - return Cast::make(Int(16).with_lanes(2), x); -} -Expr int32x2(Expr x) { - return Cast::make(Int(32).with_lanes(2), x); -} -#endif - -Expr a(Variable::make(Int(0), fuzz_var(0))); -Expr b(Variable::make(Int(0), fuzz_var(1))); -Expr c(Variable::make(Int(0), fuzz_var(2))); -Expr d(Variable::make(Int(0), fuzz_var(3))); -Expr e(Variable::make(Int(0), fuzz_var(4))); - } // namespace -extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { +int main(int argc, char **argv) { // Depth of the randomly generated expression trees. const int depth = 5; // Number of samples to test the generated expressions for. const int samples = 3; - FuzzedDataProvider fdp(data, size); + std::mt19937 seed_generator{(uint32_t)time(NULL)}; + + for (int i = 0; i < ((argc == 1) ? 10000 : 1); i++) { + uint32_t seed = seed_generator(); + if (argc > 1) { + seed = atoi(argv[1]); + } + // Print the seed on every iteration so that if the simplifier crashes + // (rather than the check failing), we can reproduce. + printf("Seed: %d\n", seed); + std::mt19937 rng{seed}; + std::array vector_widths = {1, 2, 3, 4, 6, 8}; + int width = random_choice(rng, vector_widths); + Type VT = random_type(rng, width); + // Generate a random expr... + Expr test = random_expr(rng, VT, depth); + if (!test_expression(rng, test, samples)) { + + // Failure. Find the minimal subexpression that failed. + printf("Testing subexpressions...\n"); + class TestSubexpressions : public IRMutator { + std::mt19937 &rng; + bool found_failure = false; + + public: + using IRMutator::mutate; + Expr mutate(const Expr &e) override { + // We know there's a failure here somewhere, so test + // subexpressions more aggressively. + IRMutator::mutate(e); + if (e.type().bits() && !found_failure) { + const int samples = 100; + found_failure = !test_expression(rng, e, samples); + } + return e; + } + + TestSubexpressions(std::mt19937 &rng) + : rng(rng) { + } + } tester(rng); + tester.mutate(test); + + printf("Failed with seed %d\n", seed); + return 1; + } + } - std::array vector_widths = {1, 2, 3, 4, 6, 8}; - int width = fdp.PickValueInArray(vector_widths); - Type VT = random_type(fdp, width); - // Generate a random expr... - Expr test = random_expr(fdp, VT, depth); - assert(test_expression(fdp, test, samples)); + printf("Success!\n"); return 0; } diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 6f497531da94..6f51d65f59a6 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -2124,7 +2124,9 @@ void check_invariant() { Expr w = Variable::make(t, "w"); check_inv(x + y); check_inv(x - y); - check_inv(x % y); + if (t != UInt(1)) { + check_inv(x % y); + } check_inv(x * y); check_inv(x / y); check_inv(min(x, y)); @@ -2214,7 +2216,7 @@ int main(int argc, char **argv) { // This expression used to cause infinite recursion. check(Broadcast::make(-16, 2) < (ramp(Cast::make(UInt(16), 7), Cast::make(UInt(16), 11), 2) - Broadcast::make(1, 2)), - Broadcast::make(-15, 2) < (ramp(make_const(UInt(16), 7), make_const(UInt(16), 11), 2))); + Broadcast::make(make_const(UInt(1), 1), 2)); { // Verify that integer types passed to min() and max() are coerced to match diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index 4cd4000cb72c..18bdcaf1d42e 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -2,7 +2,6 @@ tests(GROUPS fuzz SOURCES bounds.cpp cse.cpp - simplify.cpp # By default, the libfuzzer harness runs with a timeout of 1200 seconds. # Let's dial that back: # - Do 1000 fuzz runs for each test. @@ -26,7 +25,7 @@ tests(GROUPS fuzz set(LIB_FUZZING_ENGINE "$ENV{LIB_FUZZING_ENGINE}" CACHE STRING "Compiler flags necessary to link the fuzzing engine of choice e.g. libfuzzer, afl etc.") -foreach(fuzzer "fuzz_bounds" "fuzz_cse" "fuzz_simplify") +foreach(fuzzer "fuzz_bounds" "fuzz_cse") target_link_libraries(${fuzzer} PRIVATE Halide::Halide) # Allow OSS-fuzz to manage flags directly