From 7d80f8b865e3e3cff6f666977cd5972bfcece534 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 14 Mar 2024 14:37:38 -0700 Subject: [PATCH 01/33] Fix horrifying bug in lossless_cast of a subtract --- src/IROperator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 0f318f777561..d27d10126278 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -502,7 +502,7 @@ Expr lossless_cast(Type t, Expr e) { Expr a = lossless_cast(t.narrow(), sub->a); Expr b = lossless_cast(t.narrow(), sub->b); if (a.defined() && b.defined()) { - return cast(t, a) + cast(t, b); + return cast(t, a) - cast(t, b); } else { return Expr(); } From 9c33c94fdd2ead5420374945b6297ff9dc383b5b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 18 Mar 2024 15:43:41 -0700 Subject: [PATCH 02/33] Use constant integer intervals to analyze safety for lossless_cast TODO: - Dedup the constant integer code with the same code in the simplifier. - Move constant interval arithmetic operations out of the class. - Make the ConstantInterval part of the return type of lossless_cast (and turn it into an inner helper) so that it isn't constantly recomputed. --- src/Expr.cpp | 4 +- src/IROperator.cpp | 269 ++++++++++++++++-------- src/Interval.cpp | 327 +++++++++++++++++++++++++++++ src/Interval.h | 48 ++++- src/Simplify_Call.cpp | 2 +- test/correctness/lossless_cast.cpp | 231 +++++++++++++++++++- 6 files changed, 781 insertions(+), 100 deletions(-) diff --git a/src/Expr.cpp b/src/Expr.cpp index c3a7deb483aa..d73bd72660fa 100644 --- a/src/Expr.cpp +++ b/src/Expr.cpp @@ -8,7 +8,7 @@ const IntImm *IntImm::make(Type t, int64_t value) { internal_assert(t.is_int() && t.is_scalar()) << "IntImm must be a scalar Int\n"; internal_assert(t.bits() >= 1 && t.bits() <= 64) - << "IntImm must have between 1 and 64 bits\n"; + << "IntImm must have between 1 and 64 bits: " << t << "\n"; // Normalize the value by dropping the high bits. // Since left-shift of negative value is UB in C++, cast to uint64 first; @@ -28,7 +28,7 @@ const UIntImm *UIntImm::make(Type t, uint64_t value) { internal_assert(t.is_uint() && t.is_scalar()) << "UIntImm must be a scalar UInt\n"; internal_assert(t.bits() >= 1 && t.bits() <= 64) - << "UIntImm must have between 1 and 64 bits\n"; + << "UIntImm must have between 1 and 64 bits " << t << "\n"; // Normalize the value by dropping the high bits value <<= (64 - t.bits()); diff --git a/src/IROperator.cpp b/src/IROperator.cpp index d27d10126278..d27885a80d03 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -12,6 +12,7 @@ #include "IRMutator.h" #include "IROperator.h" #include "IRPrinter.h" +#include "Interval.h" #include "Util.h" #include "Var.h" @@ -434,125 +435,175 @@ Expr const_false(int w) { return make_zero(UInt(1, w)); } +namespace { + +ConstantInterval constant_integer_bounds(const Expr &e) { + auto ret = [&]() { + // Compute the bounds of each IR node from the bounds of its args. Math + // on ConstantInterval is in terms of infinite integers, so any op that + // can overflow needs to cast the resulting interval back to the output + // type. + if (const UIntImm *op = e.as()) { + if (Int(64).can_represent(op->value)) { + return ConstantInterval::single_point((int64_t)(op->value)); + } else { + return ConstantInterval::everything(); + } + } else if (const IntImm *op = e.as()) { + return ConstantInterval::single_point(op->value); + } else if (const Add *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) + constant_integer_bounds(op->b)); + } else if (const Sub *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) - constant_integer_bounds(op->b)); + } else if (const Mul *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) * constant_integer_bounds(op->b)); + } else if (const Div *op = e.as
()) { + // Can overflow when dividing type.min() by -1 + return cast(op->type, constant_integer_bounds(op->a) / constant_integer_bounds(op->b)); + } else if (const Min *op = e.as()) { + return min(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); + } else if (const Max *op = e.as()) { + return max(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); + } else if (const Cast *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->value)); + } else if (const Broadcast *op = e.as()) { + return constant_integer_bounds(op->value); + } else if (const VectorReduce *op = e.as()) { + int f = op->value.type().lanes() / op->type.lanes(); + ConstantInterval factor(f, f); + ConstantInterval arg_bounds = constant_integer_bounds(op->value); + switch (op->op) { + case VectorReduce::Add: + return cast(op->type, arg_bounds * factor); + case VectorReduce::SaturatingAdd: + return saturating_cast(op->type, arg_bounds * factor); + case VectorReduce::Min: + case VectorReduce::Max: + case VectorReduce::And: + case VectorReduce::Or: + return arg_bounds; + default:; + } + } else if (const Shuffle *op = e.as()) { + ConstantInterval arg_bounds = constant_integer_bounds(op->vectors[0]); + for (size_t i = 1; i < op->vectors.size(); i++) { + arg_bounds.include(constant_integer_bounds(op->vectors[i])); + } + return arg_bounds; + } else if (const Call *op = e.as()) { + // For all intrinsics that can't possibly overflow, we don't need the + // final cast. + if (op->is_intrinsic(Call::abs)) { + return abs(constant_integer_bounds(op->args[0])); + } else if (op->is_intrinsic(Call::absd)) { + return abs(constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::count_leading_zeros) || + op->is_intrinsic(Call::count_trailing_zeros)) { + // Conservatively just say it's the potential number of zeros in the type. + return ConstantInterval(0, op->args[0].type().bits()); + } else if (op->is_intrinsic(Call::halving_add)) { + return (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1])) / + ConstantInterval(2, 2); + } else if (op->is_intrinsic(Call::halving_sub)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1])) / + ConstantInterval(2, 2)); + } else if (op->is_intrinsic(Call::rounding_halving_add)) { + return (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]) + + ConstantInterval(1, 1)) / + ConstantInterval(2, 2); + } else if (op->is_intrinsic(Call::saturating_add)) { + return saturating_cast(op->type, + (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::saturating_sub)) { + return saturating_cast(op->type, + (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widening_add)) { + return constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widening_sub)) { + // widening ops can't overflow ... + return constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widening_mul)) { + return constant_integer_bounds(op->args[0]) * + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widen_right_add)) { + // but the widen_right versions can overflow + return cast(op->type, (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widen_right_sub)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widen_right_mul)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) * + constant_integer_bounds(op->args[1]))); + } + // We could include the various shifting intrinsics here too, but we'd + // have to check for the sign on the second argument. + } + + return ConstantInterval::bounds_of_type(e.type()); + }(); + + // debug(0) << e << " -> " << ret.min_defined << " " << ret.min << " " << ret.max_defined << " " << ret.max << "\n"; + + if (ret.min_defined) { + internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && + (!ret.max_defined || e.type().can_represent(ret.max))) + << "Expr: " << e << "\n" + << " min_defined = " << ret.min_defined << "\n" + << " max_defined = " << ret.max_defined << "\n" + << " min = " << ret.min << "\n" + << " max = " << ret.max << "\n"; + } + + return ret; +} +} // namespace + Expr lossless_cast(Type t, Expr e) { if (!e.defined() || t == e.type()) { return e; } else if (t.can_represent(e.type())) { return cast(t, std::move(e)); - } - - if (const Cast *c = e.as()) { + } else if (const Cast *c = e.as()) { if (c->type.can_represent(c->value.type())) { - // We can recurse into widening casts. return lossless_cast(t, c->value); } else { return Expr(); } - } - - if (const Broadcast *b = e.as()) { + } else if (const Broadcast *b = e.as()) { Expr v = lossless_cast(t.element_of(), b->value); if (v.defined()) { return Broadcast::make(v, b->lanes); } else { return Expr(); } - } - - if (const IntImm *i = e.as()) { + } else if (const IntImm *i = e.as()) { if (t.can_represent(i->value)) { return make_const(t, i->value); } else { return Expr(); } - } - - if (const UIntImm *i = e.as()) { + } else if (const UIntImm *i = e.as()) { if (t.can_represent(i->value)) { return make_const(t, i->value); } else { return Expr(); } - } - - if (const FloatImm *f = e.as()) { + } else if (const FloatImm *f = e.as()) { if (t.can_represent(f->value)) { return make_const(t, f->value); } else { return Expr(); } - } - - if (t.is_int_or_uint() && t.bits() >= 16) { - if (const Add *add = e.as()) { - // If we can losslessly narrow the args even more - // aggressively, we're good. - // E.g. lossless_cast(uint16, (uint32)(some_u8) + 37) - // = (uint16)(some_u8) + 37 - Expr a = lossless_cast(t.narrow(), add->a); - Expr b = lossless_cast(t.narrow(), add->b); - if (a.defined() && b.defined()) { - return cast(t, a) + cast(t, b); - } else { - return Expr(); - } - } - - if (const Sub *sub = e.as()) { - Expr a = lossless_cast(t.narrow(), sub->a); - Expr b = lossless_cast(t.narrow(), sub->b); - if (a.defined() && b.defined()) { - return cast(t, a) - cast(t, b); - } else { - return Expr(); - } - } - - if (const Mul *mul = e.as()) { - Expr a = lossless_cast(t.narrow(), mul->a); - Expr b = lossless_cast(t.narrow(), mul->b); - if (a.defined() && b.defined()) { - return cast(t, a) * cast(t, b); - } else { - return Expr(); - } - } - - if (const VectorReduce *reduce = e.as()) { - const int factor = reduce->value.type().lanes() / reduce->type.lanes(); - switch (reduce->op) { - case VectorReduce::Add: - // A horizontal add requires one extra bit per factor - // of two in the reduction factor. E.g. a reduction of - // 8 vector lanes down to 2 requires 2 extra bits in - // the output. We only deal with power-of-two types - // though, so just make sure the reduction factor - // isn't so large that it will more than double the - // number of bits required. - if (factor < (1 << (t.bits() / 2))) { - Type narrower = reduce->value.type().with_bits(t.bits() / 2); - Expr val = lossless_cast(narrower, reduce->value); - if (val.defined()) { - val = cast(narrower.with_bits(t.bits()), val); - return VectorReduce::make(reduce->op, val, reduce->type.lanes()); - } - } - break; - case VectorReduce::Max: - case VectorReduce::Min: { - Expr val = lossless_cast(t, reduce->value); - if (val.defined()) { - return VectorReduce::make(reduce->op, val, reduce->type.lanes()); - } - break; - } - default: - break; - } - } - } - - if (const Shuffle *shuf = e.as()) { + } else if (const Shuffle *shuf = e.as()) { std::vector vecs; for (const auto &vec : shuf->vectors) { vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec)); @@ -561,6 +612,48 @@ Expr lossless_cast(Type t, Expr e) { } } return Shuffle::make(vecs, shuf->indices); + } else if (t.is_int_or_uint()) { + // We'll just throw a cast around something, if the bounds are small + // enough. + ConstantInterval ci = constant_integer_bounds(e); + if (ci.is_bounded() && + t.can_represent(ci.max) && + t.can_represent(ci.min)) { + + // There are certain IR nodes where if the result is expressible + // using some type, and the args are expressible using that type, + // then the operation can just be done in that type. + if (const Add *op = e.as()) { + Expr a = lossless_cast(t, op->a); + Expr b = lossless_cast(t, op->b); + if (a.defined() && b.defined()) { + return a + b; + } + } else if (const Sub *op = e.as()) { + Expr a = lossless_cast(t, op->a); + Expr b = lossless_cast(t, op->b); + if (a.defined() && b.defined()) { + return a - b; + } + } else if (const Mul *op = e.as()) { + Expr a = lossless_cast(t, op->a); + Expr b = lossless_cast(t, op->b); + if (a.defined() && b.defined()) { + return a * b; + } + } else if (const VectorReduce *op = e.as()) { + if (op->op == VectorReduce::Add || + op->op == VectorReduce::Min || + op->op == VectorReduce::Max) { + Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value); + if (v.defined()) { + return VectorReduce::make(op->op, v, op->type.lanes()); + } + } + } + + return cast(t, e); + } } return Expr(); diff --git a/src/Interval.cpp b/src/Interval.cpp index 10550f7ed48b..acdb4562f030 100644 --- a/src/Interval.cpp +++ b/src/Interval.cpp @@ -3,6 +3,8 @@ #include "IRMatch.h" #include "IROperator.h" +using namespace Halide::Internal; + namespace Halide { namespace Internal { @@ -237,11 +239,336 @@ void ConstantInterval::include(int64_t x) { } } +bool ConstantInterval::contains(int64_t x) const { + return !((min_defined && x < min) || + (max_defined && x > max)); +} + ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const ConstantInterval &b) { ConstantInterval result = a; result.include(b); return result; } +// TODO: These were taken directly from the simplifier, so change the simplifier +// to use these instead of duplicating the code. +void ConstantInterval::operator+=(const ConstantInterval &other) { + min_defined = min_defined && + other.min_defined && + add_with_overflow(64, min, other.min, &min); + max_defined = max_defined && + other.max_defined && + add_with_overflow(64, max, other.max, &max); +} + +void ConstantInterval::operator-=(const ConstantInterval &other) { + min_defined = min_defined && + other.max_defined && + sub_with_overflow(64, min, other.max, &min); + max_defined = max_defined && + other.min_defined && + sub_with_overflow(64, max, other.min, &max); +} + +void ConstantInterval::operator*=(const ConstantInterval &other) { + ConstantInterval result; + + // Compute a possible extreme value of the product, setting the min/max + // defined flags if it's unbounded. + auto saturating_mul = [&](int64_t a, int64_t b) -> int64_t { + int64_t c; + if (mul_with_overflow(64, a, b, &c)) { + return c; + } else if ((a > 0) == (b > 0)) { + result.max_defined = false; + return INT64_MAX; + } else { + result.min_defined = false; + return INT64_MIN; + } + }; + + bool positive = min_defined && min > 0; + bool other_positive = other.min_defined && other.min > 0; + bool bounded = min_defined && max_defined; + bool other_bounded = other.min_defined && other.max_defined; + + if (bounded && other_bounded) { + // Both are bounded + result.min_defined = result.max_defined = true; + int64_t v1 = saturating_mul(min, other.min); + int64_t v2 = saturating_mul(min, other.max); + int64_t v3 = saturating_mul(max, other.min); + int64_t v4 = saturating_mul(max, other.max); + if (result.min_defined) { + result.min = std::min(std::min(v1, v2), std::min(v3, v4)); + } else { + result.min = 0; + } + if (result.max_defined) { + result.max = std::max(std::max(v1, v2), std::max(v3, v4)); + } else { + result.max = 0; + } + } else if ((max_defined && other_bounded && other_positive) || + (other.max_defined && bounded && positive)) { + // One side has a max, and the other side is bounded and positive + // (e.g. a constant). + result.max = saturating_mul(max, other.max); + if (!result.max_defined) { + result.max = 0; + } + } else if ((min_defined && other_bounded && other_positive) || + (other.min_defined && bounded && positive)) { + // One side has a min, and the other side is bounded and positive + // (e.g. a constant). + min = saturating_mul(min, other.min); + if (!result.min_defined) { + result.min = 0; + } + } + // TODO: what about the above two cases, but for multiplication by bounded + // and negative intervals? + + *this = result; +} + +void ConstantInterval::operator/=(const ConstantInterval &other) { + ConstantInterval result; + + result.min = INT64_MAX; + result.max = INT64_MIN; + + // Enumerate all possible values for the min and max and take the extreme values. + if (min_defined && other.min_defined && other.min != 0) { + int64_t v = div_imp(min, other.min); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + if (min_defined && other.max_defined && other.max != 0) { + int64_t v = div_imp(min, other.max); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + if (max_defined && other.max_defined && other.max != 0) { + int64_t v = div_imp(max, other.max); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + if (max_defined && other.min_defined && other.min != 0) { + int64_t v = div_imp(max, other.min); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + // Define an int64_t zero just to pacify std::min and std::max + constexpr int64_t zero = 0; + + const bool other_positive = other.min_defined && other.min > 0; + const bool other_negative = other.max_defined && other.max < 0; + if ((other_positive && !other.max_defined) || + (other_negative && !other.min_defined)) { + // Take limit as other -> +/- infinity + result.min = std::min(result.min, zero); + result.max = std::max(result.max, zero); + } + + bool bounded_numerator = min_defined && max_defined; + + result.min_defined = ((min_defined && other_positive) || + (max_defined && other_negative)); + result.max_defined = ((max_defined && other_positive) || + (min_defined && other_negative)); + + // That's as far as we can get knowing the sign of the + // denominator. For bounded numerators, we additionally know + // that div can't make anything larger in magnitude, so we can + // take the intersection with that. + if (bounded_numerator && min != INT64_MIN) { + int64_t magnitude = std::max(max, -min); + if (result.min_defined) { + result.min = std::max(result.min, -magnitude); + } else { + result.min = -magnitude; + } + if (result.max_defined) { + result.max = std::min(result.max, magnitude); + } else { + result.max = magnitude; + } + result.min_defined = result.max_defined = true; + } + + // Finally we can provide a bound if the numerator and denominator are + // non-positive or non-negative. + bool numerator_non_negative = min_defined && min >= 0; + bool denominator_non_negative = other.min_defined && other.min >= 0; + bool numerator_non_positive = max_defined && max <= 0; + bool denominator_non_positive = other.max_defined && other.max <= 0; + if ((numerator_non_negative && denominator_non_negative) || + (numerator_non_positive && denominator_non_positive)) { + if (result.min_defined) { + result.min = std::max(result.min, zero); + } else { + result.min_defined = true; + result.min = 0; + } + } + if ((numerator_non_negative && denominator_non_positive) || + (numerator_non_positive && denominator_non_negative)) { + if (result.max_defined) { + result.max = std::min(result.max, zero); + } else { + result.max_defined = true; + result.max = 0; + } + } + + // Normalize the values if it's undefined + if (!result.min_defined) { + result.min = 0; + } + if (!result.max_defined) { + result.max = 0; + } + + *this = result; +} + +void ConstantInterval::cast_to(Type t) { + if (!(max_defined && t.can_represent(max) && + min_defined && t.can_represent(min))) { + // We have potential overflow or underflow, return the entire bounds of + // the type. + ConstantInterval type_bounds; + if (t.is_int()) { + if (t.bits() <= 64) { + type_bounds.min_defined = type_bounds.max_defined = true; + type_bounds.min = ((int64_t)(-1)) << (t.bits() - 1); + type_bounds.max = ~type_bounds.min; + } + } else if (t.is_uint()) { + type_bounds.min_defined = true; + type_bounds.min = 0; + if (t.bits() < 64) { + type_bounds.max_defined = true; + type_bounds.max = (((int64_t)(1)) << t.bits()) - 1; + } + } + // If it's not int or uint, we're setting this to a default-constructed + // ConstantInterval, which is everything. + *this = type_bounds; + } +} + +ConstantInterval ConstantInterval::bounds_of_type(Type t) { + return cast(t, ConstantInterval::everything()); +} + +ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result += b; + return result; +} + +ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result -= b; + return result; +} + +ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result /= b; + return result; +} + +ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result *= b; + return result; +} + +ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + if (a.min_defined && b.min_defined && b.min < a.min) { + result.min = b.min; + } + if (a.max_defined && b.max_defined && b.max < a.max) { + result.max = b.max; + } + return result; +} + +ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + if (a.min_defined && b.min_defined && b.min > a.min) { + result.min = b.min; + } + if (a.max_defined && b.max_defined && b.max > a.max) { + result.max = b.max; + } + return result; +} + +ConstantInterval abs(const ConstantInterval &a) { + ConstantInterval result; + if (a.min_defined && a.max_defined && a.min != INT64_MIN) { + result.max_defined = true; + result.max = std::max(-a.min, a.max); + } + result.min_defined = true; + if (a.min_defined && a.min > 0) { + result.min = a.min; + } else { + result.min = 0; + } + + return result; +} + } // namespace Internal + +ConstantInterval cast(Type t, const ConstantInterval &a) { + ConstantInterval result = a; + result.cast_to(t); + return result; +} + +ConstantInterval saturating_cast(Type t, const ConstantInterval &a) { + ConstantInterval b = ConstantInterval::bounds_of_type(t); + + if (b.max_defined && a.min_defined && a.min > b.max) { + return ConstantInterval(b.max, b.max); + } else if (b.min_defined && a.max_defined && a.max < b.min) { + return ConstantInterval(b.min, b.min); + } + + ConstantInterval result = a; + result.max_defined = a.max_defined || b.max_defined; + if (a.max_defined) { + if (b.max_defined) { + result.max = std::min(a.max, b.max); + } else { + result.max = a.max; + } + } else if (b.max_defined) { + result.max = b.max; + } + result.min_defined = a.min_defined || b.min_defined; + if (a.min_defined) { + if (b.min_defined) { + result.min = std::max(a.min, b.min); + } else { + result.min = a.min; + } + } else if (b.min_defined) { + result.min = b.min; + } + return result; +} + } // namespace Halide diff --git a/src/Interval.h b/src/Interval.h index 1d90d4a29b55..6fbb8b81c0e1 100644 --- a/src/Interval.h +++ b/src/Interval.h @@ -87,7 +87,7 @@ struct Interval { /** Construct the smallest interval containing two intervals. */ static Interval make_union(const Interval &a, const Interval &b); - /** Construct the largest interval contained within two intervals. */ + /** Construct the largest interval contained within two other intervals. */ static Interval make_intersection(const Interval &a, const Interval &b); /** An eagerly-simplifying max of two Exprs that respects infinities. */ @@ -110,8 +110,8 @@ struct Interval { static Expr neg_inf_noinline(); }; -/** A class to represent ranges of integers. Can be unbounded above or below, but - * they cannot be empty. */ +/** A class to represent ranges of integers. Can be unbounded above or below, + * but they cannot be empty. */ struct ConstantInterval { /** The lower and upper bound of the interval. They are included * in the interval. */ @@ -158,6 +158,9 @@ struct ConstantInterval { /** Expand the interval to include a point */ void include(int64_t x); + /** Test if the interval contains a particular value */ + bool contains(int64_t x) const; + /** Construct the smallest interval containing two intervals. */ static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b); @@ -165,9 +168,48 @@ struct ConstantInterval { * compare two map for equality in order to * cache computations. */ bool operator==(const ConstantInterval &other) const; + + /** In-place versions of the arithmetic operators below. */ + // @{ + void operator+=(const ConstantInterval &other); + void operator-=(const ConstantInterval &other); + void operator*=(const ConstantInterval &other); + void operator/=(const ConstantInterval &other); + // @} + + /** Track what happens if a constant integer interval is forced to fit into + * a concrete integer type. */ + void cast_to(Type t); + + /** Get constant integer bounds on a type. */ + static ConstantInterval bounds_of_type(Type); }; +/** Arithmetic operators on ConstantIntervals. The resulting interval contains + * all possible values of the operator applied to any two elements of the + * argument intervals. Note that these operator on unbounded integers. If you + * are applying this to concrete small integer types, you will need to manually + * cast the constant interval back to the desired type to model the effect of + * overflow. */ +// @{ +ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval abs(const ConstantInterval &a); +// @} } // namespace Internal + +/** Cast operators for ConstantIntervals. These ones have to live out in + * Halide::, to avoid C++ name lookup confusion with the Halide::cast variants + * that take Exprs. */ +// @{ +Internal::ConstantInterval cast(Type t, const Internal::ConstantInterval &a); +Internal::ConstantInterval saturating_cast(Type t, const Internal::ConstantInterval &a); +// @} + } // namespace Halide #endif diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 33d11ccb8d06..66d12f0efc4f 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -778,7 +778,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // There are other PureExterns we don't bother with (e.g. fast_inverse_f32)... // just fall thru and take the general case. - debug(2) << "Simplifier: unhandled PureExtern: " << op->name; + debug(2) << "Simplifier: unhandled PureExtern: " << op->name << "\n"; } else if (op->is_intrinsic(Call::signed_integer_overflow)) { clear_bounds_info(bounds); } else if (op->is_intrinsic(Call::concat_bits) && op->args.size() == 1) { diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index abdbaa9502c3..ffd5cf008716 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -7,8 +7,8 @@ int check_lossless_cast(const Type &t, const Expr &in, const Expr &correct) { Expr result = lossless_cast(t, in); if (!equal(result, correct)) { std::cout << "Incorrect lossless_cast result:\nlossless_cast(" - << t << ", " << in << ") gave: " << result - << " but expected was: " << correct << "\n"; + << t << ", " << in << ") gave:\n " << result + << " but expected was:\n " << correct << "\n"; return 1; } return 0; @@ -19,9 +19,11 @@ int lossless_cast_test() { Type u8 = UInt(8); Type u16 = UInt(16); Type u32 = UInt(32); + // Type u64 = UInt(64); Type i8 = Int(8); Type i16 = Int(16); Type i32 = Int(32); + Type i64 = Int(64); Type u8x = UInt(8, 4); Type u16x = UInt(16, 4); Type u32x = UInt(32, 4); @@ -52,14 +54,231 @@ int lossless_cast_test() { e = VectorReduce::make(VectorReduce::Add, cast(u32x, var_u8x), 1); res |= check_lossless_cast(u16, e, VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1)); - return res; + e = cast(u32, var_u8) - 16; + res |= check_lossless_cast(u16, e, Expr()); + + e = cast(u32, var_u8) + 16; + res |= check_lossless_cast(u16, e, cast(u16, var_u8) + 16); + + e = 16 - cast(u32, var_u8); + res |= check_lossless_cast(u16, e, Expr()); + + e = 16 + cast(u32, var_u8); + res |= check_lossless_cast(u16, e, 16 + cast(u16, var_u8)); + + // Check one where the target type is unsigned but there's a signed addition + // (that can't overflow) + e = cast(i64, cast(u16, var_u8) + cast(i32, 17)); + res |= check_lossless_cast(u32, e, cast(u32, cast(u16, var_u8)) + cast(u32, 17)); + + // Check one where the target type is unsigned but there's a signed subtract + // (that can overflow). It's not safe to enter the i16 sub + e = cast(i64, cast(i16, 10) - cast(i16, 17)); + res |= check_lossless_cast(u32, e, Expr()); + + e = cast(i64, 1024) * cast(i64, 1024) * cast(i64, 1024); + res |= check_lossless_cast(i32, e, (cast(i32, 1024) * 1024) * 1024); + + return 0; + + // return res; } -int main() { +constexpr int size = 1024; +Buffer buf_u8(size, "buf_u8"); +Buffer buf_i8(size, "buf_i8"); +Var x{"x"}; + +Expr random_expr(std::mt19937 &rng) { + std::vector exprs; + // Add some atoms + exprs.push_back(cast((uint8_t)rng())); + exprs.push_back(cast((int8_t)rng())); + exprs.push_back(cast((uint8_t)rng())); + exprs.push_back(cast((int8_t)rng())); + exprs.push_back(buf_u8(x)); + exprs.push_back(buf_i8(x)); + + // Make random combinations of them + while (true) { + Expr e; + int i1 = rng() % exprs.size(); + int i2 = rng() % exprs.size(); + int op = rng() % 7; + Expr e1 = exprs[i1]; + Expr e2 = cast(e1.type(), exprs[i2]); + bool may_widen = e1.type().bits() < 64; + switch (op) { + case 0: + if (may_widen) { + e = cast(e1.type().widen(), e1); + } + break; + case 1: + if (may_widen) { + e = cast(Int(e1.type().bits() * 2), e1); + } + break; + case 2: + e = e1 + e2; + break; + case 3: + e = e1 - e2; + break; + case 4: + e = e1 * e2; + break; + case 5: + e = e1 / e2; + break; + case 6: + switch (rng() % 10) { + case 0: + if (may_widen) { + e = widening_add(e1, e2); + } + break; + case 1: + if (may_widen) { + e = widening_sub(e1, e2); + } + break; + case 2: + if (may_widen) { + e = widening_mul(e1, e2); + } + break; + case 3: + e = halving_add(e1, e2); + break; + case 4: + e = rounding_halving_add(e1, e2); + break; + case 5: + e = halving_sub(e1, e2); + break; + case 6: + e = saturating_add(e1, e2); + break; + case 7: + e = saturating_sub(e1, e2); + break; + case 8: + e = count_leading_zeros(e1); + break; + case 9: + e = count_trailing_zeros(e1); + break; + } + } + + if (!e.defined()) { + continue; + } + + // Stop when we get to 64 bits, but probably don't stop on a widening + // cast, because that'll just get trivially stripped. + if (e.type().bits() == 64 && (op > 1 || ((rng() & 7) == 0))) { + return e; + } + + exprs.push_back(e); + } +} + +class CheckForIntOverflow : public IRMutator { + using IRMutator::visit; + + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::signed_integer_overflow)) { + found_overflow = true; + return make_zero(op->type); + } else { + return IRMutator::visit(op); + } + } + +public: + bool found_overflow = false; +}; + +bool found_error = false; + +int test_one(uint32_t seed) { + std::mt19937 rng{seed}; + + buf_u8.fill(rng); + buf_i8.fill(rng); + + Expr e1 = random_expr(rng); + Type target; + std::vector target_types = {UInt(32), Int(32), UInt(16), Int(16)}; + target = target_types[rng() % target_types.size()]; + Expr e2 = lossless_cast(target, e1); + if (!e2.defined()) { + return 0; + } + + Func f; + f(x) = {cast(e1), cast(e2)}; + f.vectorize(x, 4, TailStrategy::RoundUp); + + // std::cout << e1 << " to " << target << "\n -> " << e2 << "\n -> " << simplify(e2) << "\n"; + // std::cout << "\n\n\n--------------------\n\n\n"; + Buffer out1(size), out2(size); + Pipeline p(f); + CheckForIntOverflow checker; + p.add_custom_lowering_pass(&checker, nullptr); + p.realize({out1, out2}); + + if (checker.found_overflow) { + // We don't do anything in the expression generator to avoid signed + // integer overflow, so just skip anything with signed integer overflow. + return 0; + } + + for (int x = 0; x < size; x++) { + if (out1(x) != out2(x)) { + std::cout + << "seed = " << seed << "\n" + << "x = " << x << "\n" + << "buf_u8 = " << (int)buf_u8(x) << "\n" + << "buf_i8 = " << (int)buf_i8(x) << "\n" + << "out1 = " << out1(x) << "\n" + << "out2 = " << out2(x) << "\n" + << "Original: " << e1 << "\n" + << "Lossless cast: " << e2 << "\n"; + return 1; + } + } + + return 0; +} + +int fuzz_test(uint32_t root_seed) { + std::mt19937 seed_generator(root_seed); + + std::cout << "Fuzz testing with root seed " << root_seed << "\n"; + for (int i = 0; i < 1000; i++) { + if (test_one(seed_generator())) { + return 1; + } + } + return 0; +} + +int main(int argc, char **argv) { + if (argc == 2) { + return test_one(atoi(argv[1])); + } if (lossless_cast_test()) { - printf("lossless_cast test failed!\n"); + std::cout << "lossless_cast test failed!\n"; + return 1; + } + if (fuzz_test(time(NULL))) { + std::cout << "lossless_cast fuzz test failed!\n"; return 1; } - printf("Success!\n"); + std::cout << "Success!\n"; return 0; } From e0f9f8e28fe1484c0774537eb9123dbb10ea43af Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 21 Mar 2024 10:27:52 -0700 Subject: [PATCH 03/33] Fix ARM and HVX instruction selection Also added more TODOs --- src/Bounds.cpp | 133 +++++++++++++++++++++++ src/Bounds.h | 3 + src/CodeGen_ARM.cpp | 64 +++++------ src/FindIntrinsics.cpp | 18 +++- src/HexagonOptimize.cpp | 15 +-- src/IRMatch.h | 46 ++++++++ src/IROperator.cpp | 144 ++----------------------- src/Interval.cpp | 10 +- src/Interval.h | 3 + test/correctness/simd_op_check_arm.cpp | 23 ++-- 10 files changed, 260 insertions(+), 199 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 16fd69f3e8fb..1bb7271babdd 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1854,6 +1854,139 @@ Interval bounds_of_expr_in_scope(const Expr &expr, const Scope &scope, return bounds_of_expr_in_scope_with_indent(expr, scope, fb, const_bound, 0); } +// TODO: This is not the best place for this code. Also it should be a visitor. Maybe ConstantBounds.cpp.h +ConstantInterval constant_integer_bounds(const Expr &e) { + internal_assert(e.defined()); + + auto ret = [&]() { + // Compute the bounds of each IR node from the bounds of its args. Math + // on ConstantInterval is in terms of infinite integers, so any op that + // can overflow needs to cast the resulting interval back to the output + // type. + if (const UIntImm *op = e.as()) { + if (Int(64).can_represent(op->value)) { + return ConstantInterval::single_point((int64_t)(op->value)); + } else { + return ConstantInterval::everything(); + } + } else if (const IntImm *op = e.as()) { + return ConstantInterval::single_point(op->value); + } else if (const Add *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) + constant_integer_bounds(op->b)); + } else if (const Sub *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) - constant_integer_bounds(op->b)); + } else if (const Mul *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) * constant_integer_bounds(op->b)); + } else if (const Div *op = e.as
()) { + // Can overflow when dividing type.min() by -1 + return cast(op->type, constant_integer_bounds(op->a) / constant_integer_bounds(op->b)); + } else if (const Min *op = e.as()) { + return min(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); + } else if (const Max *op = e.as()) { + return max(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); + } else if (const Cast *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->value)); + } else if (const Broadcast *op = e.as()) { + return constant_integer_bounds(op->value); + } else if (const VectorReduce *op = e.as()) { + int f = op->value.type().lanes() / op->type.lanes(); + ConstantInterval factor(f, f); + ConstantInterval arg_bounds = constant_integer_bounds(op->value); + switch (op->op) { + case VectorReduce::Add: + return cast(op->type, arg_bounds * factor); + case VectorReduce::SaturatingAdd: + return saturating_cast(op->type, arg_bounds * factor); + case VectorReduce::Min: + case VectorReduce::Max: + case VectorReduce::And: + case VectorReduce::Or: + return arg_bounds; + default:; + } + } else if (const Shuffle *op = e.as()) { + ConstantInterval arg_bounds = constant_integer_bounds(op->vectors[0]); + for (size_t i = 1; i < op->vectors.size(); i++) { + arg_bounds.include(constant_integer_bounds(op->vectors[i])); + } + return arg_bounds; + } else if (const Call *op = e.as()) { + // For all intrinsics that can't possibly overflow, we don't need the + // final cast. + if (op->is_intrinsic(Call::abs)) { + return abs(constant_integer_bounds(op->args[0])); + } else if (op->is_intrinsic(Call::absd)) { + return abs(constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::count_leading_zeros) || + op->is_intrinsic(Call::count_trailing_zeros)) { + // Conservatively just say it's the potential number of zeros in the type. + return ConstantInterval(0, op->args[0].type().bits()); + } else if (op->is_intrinsic(Call::halving_add)) { + return (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1])) / + ConstantInterval(2, 2); + } else if (op->is_intrinsic(Call::halving_sub)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1])) / + ConstantInterval(2, 2)); + } else if (op->is_intrinsic(Call::rounding_halving_add)) { + return (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]) + + ConstantInterval(1, 1)) / + ConstantInterval(2, 2); + } else if (op->is_intrinsic(Call::saturating_add)) { + return saturating_cast(op->type, + (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::saturating_sub)) { + return saturating_cast(op->type, + (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widening_add)) { + return constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widening_sub)) { + // widening ops can't overflow ... + return constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widening_mul)) { + return constant_integer_bounds(op->args[0]) * + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widen_right_add)) { + // but the widen_right versions can overflow + return cast(op->type, (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widen_right_sub)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widen_right_mul)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) * + constant_integer_bounds(op->args[1]))); + } + // We could include the various shifting intrinsics here too, but we'd + // have to check for the sign on the second argument. + // TODO: widening_shift_left is important + } + + return ConstantInterval::bounds_of_type(e.type()); + }(); + + // debug(0) << e << " -> " << ret.min_defined << " " << ret.min << " " << ret.max_defined << " " << ret.max << "\n"; + + if (ret.min_defined) { + internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && + (!ret.max_defined || e.type().can_represent(ret.max))) + << "Expr: " << e << "\n" + << " min_defined = " << ret.min_defined << "\n" + << " max_defined = " << ret.max_defined << "\n" + << " min = " << ret.min << "\n" + << " max = " << ret.max << "\n"; + } + + return ret; +} + Region region_union(const Region &a, const Region &b) { internal_assert(a.size() == b.size()) << "Mismatched dimensionality in region union\n"; Region result; diff --git a/src/Bounds.h b/src/Bounds.h index bafa42ecda1a..a0655bcba636 100644 --- a/src/Bounds.h +++ b/src/Bounds.h @@ -48,6 +48,9 @@ Expr find_constant_bound(const Expr &e, Direction d, * +/-inf. */ Interval find_constant_bounds(const Expr &e, const Scope &scope); +// TODO: comment +ConstantInterval constant_integer_bounds(const Expr &e); + /** Represents the bounds of a region of arbitrary dimension. Zero * dimensions corresponds to a scalar region. */ struct Box { diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 7852532183bf..a3a836305371 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -942,50 +942,42 @@ void CodeGen_ARM::visit(const Add *op) { Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc"); Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc"); - // clang-format off + Expr ma_i8 = widening_mul(a_i8, ac_i8); + Expr mb_i8 = widening_mul(b_i8, bc_i8); + Expr mc_i8 = widening_mul(c_i8, cc_i8); + Expr md_i8 = widening_mul(d_i8, dc_i8); + + Expr ma_u8 = widening_mul(a_u8, ac_u8); + Expr mb_u8 = widening_mul(b_u8, bc_u8); + Expr mc_u8 = widening_mul(c_u8, cc_u8); + Expr md_u8 = widening_mul(d_u8, dc_u8); + static const Pattern patterns[] = { - // If we had better normalization, we could drastically reduce the number of patterns here. // Signed variants. - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"}, - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)}, - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, - {init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, - // Signed variants (associative). - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"}, - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)}, - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, - {init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, + {(init_i32 + widening_add(ma_i8, mb_i8)) + widening_add(mc_i8, md_i8), "dot_product"}, + {init_i32 + (widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8)), "dot_product"}, + {widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8), "dot_product"}, + // Unsigned variants. - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"}, - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)}, - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, - {init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, - // Unsigned variants (associative). - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"}, - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)}, - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, - {init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, + {(init_u32 + widening_add(ma_u8, mb_u8)) + widening_add(mc_u8, md_u8), "dot_product"}, + {init_u32 + (widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8)), "dot_product"}, + {widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8), "dot_product"}, }; - // clang-format on std::map matches; for (const Pattern &p : patterns) { if (expr_match(p.pattern, op, matches)) { - Expr init = matches["init"]; - Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]}); - // Coefficients can be 1 if not in the pattern. - Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes())); - // This hideous code pattern implements fetching a - // default value if the map doesn't contain a key. - Expr _ac = matches.try_emplace("ac", one).first->second; - Expr _bc = matches.try_emplace("bc", one).first->second; - Expr _cc = matches.try_emplace("cc", one).first->second; - Expr _dc = matches.try_emplace("dc", one).first->second; - Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc}); + Expr init; + auto it = matches.find("init"); + if (it == matches.end()) { + init = make_zero(op->type); + } else { + init = it->second; + } + Expr values = Shuffle::make_interleave({matches["a"], matches["b"], + matches["c"], matches["d"]}); + Expr coeffs = Shuffle::make_interleave({matches["ac"], matches["bc"], + matches["cc"], matches["dc"]}); value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs}); if (value) { return; diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index d453d0134c29..e8e8b10b2f09 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -77,6 +77,8 @@ bool no_overflow(Type t) { return t.is_float() || no_overflow_int(t); } +// TODO: Can I delete this now and just rely on lossless cast? + // If there's a widening add or subtract in the first e.type().bits() / 2 - 1 // levels down a tree of adds or subtracts, we know there's enough headroom for // another add without overflow. For example, it is safe to add to @@ -810,6 +812,12 @@ class FindIntrinsics : public IRMutator { // We only care about integers, this should be trivially true. is_x_same_int_or_uint) || + // widening_add(x + widen(y), widen(z)) -> widening_add(x, widening_add(y, z)) + rewrite(widening_add(widen_right_add(x, y), widen(z)), + widening_add(x, widening_add(y, z))) || + rewrite(widening_add(widen(z), widen_right_add(x, y)), + widening_add(x, widening_add(y, z))) || + // Saturating patterns. rewrite(saturating_cast(op->type, widening_add(x, y)), saturating_add(x, y), @@ -908,13 +916,16 @@ class FindIntrinsics : public IRMutator { } // TODO: do we want versions of widen_right_add here? - if (op->is_intrinsic(Call::shift_right) || op->is_intrinsic(Call::shift_left)) { + if (op->is_intrinsic(Call::shift_right) || + op->is_intrinsic(Call::shift_left)) { // Try to turn this into a widening shift. internal_assert(op->args.size() == 2); Expr a_narrow = lossless_narrow(op->args[0]); Expr b_narrow = lossless_narrow(op->args[1]); if (a_narrow.defined() && b_narrow.defined()) { - Expr result = op->is_intrinsic(Call::shift_left) ? widening_shift_left(a_narrow, b_narrow) : widening_shift_right(a_narrow, b_narrow); + Expr result = op->is_intrinsic(Call::shift_left) ? + widening_shift_left(a_narrow, b_narrow) : + widening_shift_right(a_narrow, b_narrow); if (result.type() != op->type) { result = Cast::make(op->type, result); } @@ -928,7 +939,8 @@ class FindIntrinsics : public IRMutator { } } - if (op->is_intrinsic(Call::rounding_shift_left) || op->is_intrinsic(Call::rounding_shift_right)) { + if (op->is_intrinsic(Call::rounding_shift_left) || + op->is_intrinsic(Call::rounding_shift_right)) { // Try to turn this into a widening shift. internal_assert(op->args.size() == 2); Expr a_narrow = lossless_narrow(op->args[0]); diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index f11fa3348399..7aabdb699b83 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -375,23 +375,24 @@ Expr unbroadcast_lossless_cast(Type ty, Expr x) { // expressions where we pretend the op to be multiplied by 1. int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count, vector &mpys, Expr &rest) { + if ((int)mpys.size() >= max_mpy_count) { rest = rest.defined() ? Add::make(rest, op) : op; return 0; } // If the add is also widening, remove the cast. + Expr stripped = op; int mpy_bits = std::max(a_ty.bits(), b_ty.bits()) * 2; - Expr maybe_mul = op; if (op.type().bits() == mpy_bits * 2) { if (const Cast *cast = op.as()) { if (cast->value.type().bits() == mpy_bits) { - maybe_mul = cast->value; + stripped = cast->value; } } } - maybe_mul = as_mul(maybe_mul); + Expr maybe_mul = as_mul(stripped); if (maybe_mul.defined()) { const Mul *mul = maybe_mul.as(); Expr a = unbroadcast_lossless_cast(a_ty, mul->a); @@ -408,17 +409,17 @@ int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count, return 1; } } - } else if (const Add *add = op.as()) { + } else if (const Add *add = stripped.as()) { int mpy_count = 0; mpy_count += find_mpy_ops(add->a, a_ty, b_ty, max_mpy_count, mpys, rest); mpy_count += find_mpy_ops(add->b, a_ty, b_ty, max_mpy_count, mpys, rest); return mpy_count; - } else if (const Call *add = Call::as_intrinsic(op, {Call::widening_add})) { + } else if (const Call *add = Call::as_intrinsic(stripped, {Call::widening_add})) { int mpy_count = 0; mpy_count += find_mpy_ops(cast(op.type(), add->args[0]), a_ty, b_ty, max_mpy_count, mpys, rest); mpy_count += find_mpy_ops(cast(op.type(), add->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest); return mpy_count; - } else if (const Call *wadd = Call::as_intrinsic(op, {Call::widen_right_add})) { + } else if (const Call *wadd = Call::as_intrinsic(stripped, {Call::widen_right_add})) { int mpy_count = 0; mpy_count += find_mpy_ops(wadd->args[0], a_ty, b_ty, max_mpy_count, mpys, rest); mpy_count += find_mpy_ops(cast(op.type(), wadd->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest); @@ -2262,4 +2263,4 @@ Stmt optimize_hexagon_instructions(Stmt s, const Target &t) { } } // namespace Internal -} // namespace Halide \ No newline at end of file +} // namespace Halide diff --git a/src/IRMatch.h b/src/IRMatch.h index a203fec51199..cf9312db0075 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -2080,6 +2080,52 @@ HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp +struct WidenOp { + struct pattern_tag {}; + A a; + + constexpr static uint32_t binds = bindings::mask; + + constexpr static IRNodeType min_node_type = IRNodeType::Cast; + constexpr static IRNodeType max_node_type = IRNodeType::Cast; + constexpr static bool canonical = A::canonical; + + template + HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { + if (e.node_type != Cast::_node_type) { + return false; + } + const Cast &op = (const Cast &)e; + return (e.type == op.value.type().widen() && + a.template match(*op.value.get(), state)); + } + template + HALIDE_ALWAYS_INLINE bool match(const WidenOp &op, MatcherState &state) const noexcept { + return a.template match(unwrap(op.a), state); + } + + HALIDE_ALWAYS_INLINE + Expr make(MatcherState &state, halide_type_t type_hint) const { + Expr e = a.make(state, {}); + return cast(e.type().widen(), std::move(e)); + } + + constexpr static bool foldable = false; +}; + +template +std::ostream &operator<<(std::ostream &s, const WidenOp &op) { + s << "widen(" << op.a << ")"; + return s; +} + +template +HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp { + assert_is_lvalue_if_expr(); + return {pattern_arg(a)}; +} + template struct SliceOp { struct pattern_tag {}; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index d27885a80d03..b27754d20a81 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -5,6 +5,7 @@ #include #include +#include "Bounds.h" #include "CSE.h" #include "Debug.h" #include "Func.h" @@ -435,138 +436,6 @@ Expr const_false(int w) { return make_zero(UInt(1, w)); } -namespace { - -ConstantInterval constant_integer_bounds(const Expr &e) { - auto ret = [&]() { - // Compute the bounds of each IR node from the bounds of its args. Math - // on ConstantInterval is in terms of infinite integers, so any op that - // can overflow needs to cast the resulting interval back to the output - // type. - if (const UIntImm *op = e.as()) { - if (Int(64).can_represent(op->value)) { - return ConstantInterval::single_point((int64_t)(op->value)); - } else { - return ConstantInterval::everything(); - } - } else if (const IntImm *op = e.as()) { - return ConstantInterval::single_point(op->value); - } else if (const Add *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->a) + constant_integer_bounds(op->b)); - } else if (const Sub *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->a) - constant_integer_bounds(op->b)); - } else if (const Mul *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->a) * constant_integer_bounds(op->b)); - } else if (const Div *op = e.as
()) { - // Can overflow when dividing type.min() by -1 - return cast(op->type, constant_integer_bounds(op->a) / constant_integer_bounds(op->b)); - } else if (const Min *op = e.as()) { - return min(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); - } else if (const Max *op = e.as()) { - return max(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); - } else if (const Cast *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->value)); - } else if (const Broadcast *op = e.as()) { - return constant_integer_bounds(op->value); - } else if (const VectorReduce *op = e.as()) { - int f = op->value.type().lanes() / op->type.lanes(); - ConstantInterval factor(f, f); - ConstantInterval arg_bounds = constant_integer_bounds(op->value); - switch (op->op) { - case VectorReduce::Add: - return cast(op->type, arg_bounds * factor); - case VectorReduce::SaturatingAdd: - return saturating_cast(op->type, arg_bounds * factor); - case VectorReduce::Min: - case VectorReduce::Max: - case VectorReduce::And: - case VectorReduce::Or: - return arg_bounds; - default:; - } - } else if (const Shuffle *op = e.as()) { - ConstantInterval arg_bounds = constant_integer_bounds(op->vectors[0]); - for (size_t i = 1; i < op->vectors.size(); i++) { - arg_bounds.include(constant_integer_bounds(op->vectors[i])); - } - return arg_bounds; - } else if (const Call *op = e.as()) { - // For all intrinsics that can't possibly overflow, we don't need the - // final cast. - if (op->is_intrinsic(Call::abs)) { - return abs(constant_integer_bounds(op->args[0])); - } else if (op->is_intrinsic(Call::absd)) { - return abs(constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1])); - } else if (op->is_intrinsic(Call::count_leading_zeros) || - op->is_intrinsic(Call::count_trailing_zeros)) { - // Conservatively just say it's the potential number of zeros in the type. - return ConstantInterval(0, op->args[0].type().bits()); - } else if (op->is_intrinsic(Call::halving_add)) { - return (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1])) / - ConstantInterval(2, 2); - } else if (op->is_intrinsic(Call::halving_sub)) { - return cast(op->type, (constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1])) / - ConstantInterval(2, 2)); - } else if (op->is_intrinsic(Call::rounding_halving_add)) { - return (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]) + - ConstantInterval(1, 1)) / - ConstantInterval(2, 2); - } else if (op->is_intrinsic(Call::saturating_add)) { - return saturating_cast(op->type, - (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::saturating_sub)) { - return saturating_cast(op->type, - (constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::widening_add)) { - return constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]); - } else if (op->is_intrinsic(Call::widening_sub)) { - // widening ops can't overflow ... - return constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1]); - } else if (op->is_intrinsic(Call::widening_mul)) { - return constant_integer_bounds(op->args[0]) * - constant_integer_bounds(op->args[1]); - } else if (op->is_intrinsic(Call::widen_right_add)) { - // but the widen_right versions can overflow - return cast(op->type, (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::widen_right_sub)) { - return cast(op->type, (constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::widen_right_mul)) { - return cast(op->type, (constant_integer_bounds(op->args[0]) * - constant_integer_bounds(op->args[1]))); - } - // We could include the various shifting intrinsics here too, but we'd - // have to check for the sign on the second argument. - } - - return ConstantInterval::bounds_of_type(e.type()); - }(); - - // debug(0) << e << " -> " << ret.min_defined << " " << ret.min << " " << ret.max_defined << " " << ret.max << "\n"; - - if (ret.min_defined) { - internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && - (!ret.max_defined || e.type().can_represent(ret.max))) - << "Expr: " << e << "\n" - << " min_defined = " << ret.min_defined << "\n" - << " max_defined = " << ret.max_defined << "\n" - << " min = " << ret.min << "\n" - << " max = " << ret.max << "\n"; - } - - return ret; -} -} // namespace - Expr lossless_cast(Type t, Expr e) { if (!e.defined() || t == e.type()) { return e; @@ -616,10 +485,7 @@ Expr lossless_cast(Type t, Expr e) { // We'll just throw a cast around something, if the bounds are small // enough. ConstantInterval ci = constant_integer_bounds(e); - if (ci.is_bounded() && - t.can_represent(ci.max) && - t.can_represent(ci.min)) { - + if (ci.within(t)) { // There are certain IR nodes where if the result is expressible // using some type, and the args are expressible using that type, // then the operation can just be done in that type. @@ -641,6 +507,12 @@ Expr lossless_cast(Type t, Expr e) { if (a.defined() && b.defined()) { return a * b; } + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add})) { + Expr a = lossless_cast(t, op->args[0]); + Expr b = lossless_cast(t, op->args[1]); + if (a.defined() && b.defined()) { + return a + b; + } } else if (const VectorReduce *op = e.as()) { if (op->op == VectorReduce::Add || op->op == VectorReduce::Min || diff --git a/src/Interval.cpp b/src/Interval.cpp index acdb4562f030..7ac52a036463 100644 --- a/src/Interval.cpp +++ b/src/Interval.cpp @@ -244,6 +244,13 @@ bool ConstantInterval::contains(int64_t x) const { (max_defined && x > max)); } +bool ConstantInterval::within(Type t) const { + return min_defined && + max_defined && + t.can_represent(min) && + t.can_represent(max); +} + ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const ConstantInterval &b) { ConstantInterval result = a; result.include(b); @@ -439,8 +446,7 @@ void ConstantInterval::operator/=(const ConstantInterval &other) { } void ConstantInterval::cast_to(Type t) { - if (!(max_defined && t.can_represent(max) && - min_defined && t.can_represent(min))) { + if (!within(t)) { // We have potential overflow or underflow, return the entire bounds of // the type. ConstantInterval type_bounds; diff --git a/src/Interval.h b/src/Interval.h index 6fbb8b81c0e1..af4942d5b2d7 100644 --- a/src/Interval.h +++ b/src/Interval.h @@ -161,6 +161,9 @@ struct ConstantInterval { /** Test if the interval contains a particular value */ bool contains(int64_t x) const; + /** Test if the interval lies with a particular type. */ + bool within(Type t) const; + /** Construct the smallest interval containing two intervals. */ static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b); diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index e8762a6ea2d8..7f879fb1f09c 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -554,7 +554,7 @@ class SimdOpCheckARM : public SimdOpCheckTest { // use the forms with an accumulator check(arm32 ? "vpadal.s8" : "sadalp", 16, sum_(i16(in_i8(f * x + r)))); check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(i16(in_u8(f * x + r)))); - check(arm32 ? "vpadal.u8" : "uadalp*", 16, sum_(u16(in_u8(f * x + r)))); + check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(u16(in_u8(f * x + r)))); check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i16(f * x + r)))); check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u16(f * x + r)))); @@ -588,17 +588,10 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(u32(in_u8(f * x + r)))); if (!arm32) { - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12); - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4)); - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) + i32(i8_4) * 12); - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) + i32(i8_3) * 9 + i32(i8_4) * 12); - check("sdot", 8, i32_1 + i32(i8_1) + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12); - - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12); - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4)); - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) + u32(u8_4) * 12); - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) + u32(u8_3) * 9 + u32(u8_4) * 12); - check("udot", 8, u32_1 + u32(u8_1) + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12); + check("udot", 8, u32(u8_1) * 200 + u32(u8_2) * 201 + u32(u8_3) * 202 + u32(u8_4) * 203); + // For signed, mapping the pattern above to sdot + // is a wash, because we can add more products + // of i8s together before they overflow an i16. } } else { check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r)))); @@ -614,15 +607,15 @@ class SimdOpCheckARM : public SimdOpCheckTest { // signed, because the intermediate type is u16 if (target.has_feature(Target::ARMDotProd)) { check(arm32 ? "vpadal.s16" : "sdot", 8, sum_(i32(in_i8(f * x + r)))); - check(arm32 ? "vpadal.u16" : "udot", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.s16" : "udot", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpadal.u16" : "udot", 8, sum_(u32(in_u8(f * x + r)))); } else { check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i8(f * x + r)))); - check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(u32(in_u8(f * x + r)))); } check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_i16(f * x + r)))); - check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(i64(in_u16(f * x + r)))); + check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_u16(f * x + r)))); check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(u64(in_u16(f * x + r)))); } From 214f0fd2eeda806138eb6e53d3cecced9d538393 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 22 Mar 2024 11:00:17 -0700 Subject: [PATCH 04/33] Using constant_integer_bounds to strengthen FindIntrinsics In particular, we can do better instruction selection for pmulhrsw --- src/Bounds.cpp | 27 ++++++-- src/CodeGen_X86.cpp | 8 ++- src/FindIntrinsics.cpp | 96 +++++++++++++++----------- src/Interval.cpp | 50 ++++++++++++++ src/Interval.h | 5 ++ test/correctness/simd_op_check_x86.cpp | 11 +++ 6 files changed, 151 insertions(+), 46 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 1bb7271babdd..ab4aaf388516 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1963,18 +1963,37 @@ ConstantInterval constant_integer_bounds(const Expr &e) { } else if (op->is_intrinsic(Call::widen_right_mul)) { return cast(op->type, (constant_integer_bounds(op->args[0]) * constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::shift_right)) { + return cast(op->type, constant_integer_bounds(op->args[0]) >> constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::shift_left)) { + return cast(op->type, constant_integer_bounds(op->args[0]) << constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::rounding_shift_right)) { + ConstantInterval ca = constant_integer_bounds(op->args[0]); + ConstantInterval cb = constant_integer_bounds(op->args[1]); + ConstantInterval rounding_term; + if (cb.min_defined && cb.min > 0) { + auto rounding_term = ConstantInterval(1, 1) << (cb - ConstantInterval(1, 1)); + // rounding shift right with a positive RHS can't overflow, + // so no cast required. + return (ca + rounding_term) >> cb; + } else if (cb.max_defined && cb.max <= 0) { + return cast(op->type, ca << (-cb)); + } else { + auto rounding_term = ConstantInterval(0, 1) << max(cb - ConstantInterval(1, 1), ConstantInterval(0, 0)); + return cast(op->type, (ca + rounding_term) >> cb); + } } - // We could include the various shifting intrinsics here too, but we'd - // have to check for the sign on the second argument. + // TODO: more intrinsics // TODO: widening_shift_left is important } return ConstantInterval::bounds_of_type(e.type()); }(); - // debug(0) << e << " -> " << ret.min_defined << " " << ret.min << " " << ret.max_defined << " " << ret.max << "\n"; + debug(0) << "constant_integer_bounds(" << e << ") =\n " + << ret.min_defined << " " << ret.min << " " << ret.max_defined << " " << ret.max << "\n"; - if (ret.min_defined) { + if (true) { internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && (!ret.max_defined || e.type().can_represent(ret.max))) << "Expr: " << e << "\n" diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index b0df27af0f2f..47225d3e21ad 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,3 +1,4 @@ +#include "Bounds.h" #include "CodeGen_Internal.h" #include "CodeGen_Posix.h" #include "ConciseCasts.h" @@ -700,7 +701,12 @@ void CodeGen_X86::visit(const Call *op) { // Handle edge case of possible overflow. // See https://github.com/halide/Halide/pull/7129/files#r1008331426 // On AVX512 (and with enough lanes) we can use a mask register. - if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) { + ConstantInterval ca = constant_integer_bounds(a); + ConstantInterval cb = constant_integer_bounds(b); + if (!ca.contains(-32768) || !cb.contains(-32768)) { + // Overflow isn't possible + pmulhrs.accept(this); + } else if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) { Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs); expr.accept(this); } else { diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index e8e8b10b2f09..34328380ca00 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1,4 +1,5 @@ #include "FindIntrinsics.h" +#include "Bounds.h" #include "CSE.h" #include "CodeGen_Internal.h" #include "ConciseCasts.h" @@ -550,6 +551,12 @@ class FindIntrinsics : public IRMutator { } } + // Do we need to worry about this cast overflowing? + ConstantInterval value_bounds = constant_integer_bounds(value); + debug(0) << "Bounds of " << Expr(op) << " are " << value_bounds.min << " " << value_bounds.min_defined << " " << value_bounds.max << " " << value_bounds.max_defined << "\n"; + bool no_overflow = (op->type.can_represent(op->value.type()) || + value_bounds.within(op->type)); + if (op->type.is_int() || op->type.is_uint()) { Expr lower = cast(value.type(), op->type.min()); Expr upper = cast(value.type(), op->type.max()); @@ -567,7 +574,7 @@ class FindIntrinsics : public IRMutator { auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y)); - auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2); + // auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2); if ( // Saturating patterns rewrite(max(min(widening_add(x, y), upper), lower), @@ -669,32 +676,16 @@ class FindIntrinsics : public IRMutator { rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && x_y_same_sign && c0 >= bits - 1) || - rewrite(shift_right(widening_mul(x, y), c0), - mul_shift_right(x, y, cast(unsigned_type, c0)), - is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) || - - rewrite(rounding_shift_right(widening_mul(x, y), c0), - rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), - is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) || - - // We can also match on smaller shifts if one of the args is - // narrow. We don't do this for signed (yet), because the - // saturation issue is tricky. - rewrite(shift_right(widening_mul(x, cast(op->type, y)), c0), - mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || - - rewrite(rounding_shift_right(widening_mul(x, cast(op->type, y)), c0), - rounding_mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + // We can also match whenever the cast can't overflow, so + // questions of saturation are irrelevant. + (no_overflow && + (rewrite(shift_right(widening_mul(x, y), c0), + mul_shift_right(x, y, cast(unsigned_type, c0)), + is_x_same_int_or_uint && x_y_same_sign && c0 >= 0) || - rewrite(shift_right(widening_mul(cast(op->type, y), x), c0), - mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || - - rewrite(rounding_shift_right(widening_mul(cast(op->type, y), x), c0), - rounding_mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + rewrite(rounding_shift_right(widening_mul(x, y), c0), + rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), + is_x_same_int_or_uint && x_y_same_sign && c0 >= 0))) || // Halving subtract patterns rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1), @@ -1497,27 +1488,50 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) // one of the operands and the denominator by a constant. We only do this // if it isn't already full precision. This avoids infinite loops despite // "lowering" this to another mul_shift_right operation. - if (can_prove(q < full_q)) { - Expr missing_q = full_q - q; - internal_assert(missing_q.type().bits() == b.type().bits()); - Expr new_b = simplify(b << missing_q); - if (is_const(new_b) && can_prove(new_b >> missing_q == b)) { - return rounding_mul_shift_right(a, new_b, full_q); - } - Expr new_a = simplify(a << missing_q); - if (is_const(new_a) && can_prove(new_a >> missing_q == a)) { - return rounding_mul_shift_right(new_a, b, full_q); + ConstantInterval cq = constant_integer_bounds(q); + debug(0) << " cq = " << cq.min << " " << cq.min_defined << " " << cq.max << " " << cq.max_defined << "\n"; + if (cq.is_single_point() && cq.max >= 0 && cq.max < full_q) { + int missing_q = full_q - (int)cq.max; + + // Try to scale up the args by factors of two without overflowing + int a_shift = 0, b_shift = 0; + ConstantInterval ca = constant_integer_bounds(a); + debug(0) << " ca = " << ca.min << " " << ca.min_defined << " " << ca.max << " " << ca.max_defined << "\n"; + do { + ConstantInterval bigger = ca * ConstantInterval::single_point(2); + if (bigger.within(a.type()) && a_shift + b_shift < missing_q) { + ca = bigger; + a_shift++; + continue; + } + } while (false); + ConstantInterval cb = constant_integer_bounds(b); + debug(0) << " cb = " << cb.min << " " << cb.min_defined << " " << cb.max << " " << cb.max_defined << "\n"; + do { + ConstantInterval bigger = cb * ConstantInterval::single_point(2); + if (bigger.within(b.type()) && b_shift + b_shift < missing_q) { + cb = bigger; + b_shift++; + continue; + } + } while (false); + + debug(0) << "a_shift = " << a_shift << " b_shift = " << b_shift << " full_q = " << full_q << "\n"; + if (a_shift + b_shift == missing_q) { + return rounding_mul_shift_right(simplify(a << a_shift), simplify(b << b_shift), full_q); } } // If all else fails, just widen, shift, and narrow. - Expr result = rounding_shift_right(widening_mul(a, b), q); - if (!can_prove(q >= a.type().bits())) { - result = saturating_narrow(result); + Expr wide_result = rounding_shift_right(widening_mul(a, b), q); + Expr narrowed = lossless_cast(a.type(), wide_result); + if (narrowed.defined()) { + debug(0) << " losslessly narrowed to " << narrowed << "\n"; + return narrowed; } else { - result = narrow(result); + debug(0) << " returning saturating_narrow(" << wide_result << ")\n"; + return saturating_narrow(wide_result); } - return result; } Expr lower_intrinsic(const Call *op) { diff --git a/src/Interval.cpp b/src/Interval.cpp index 7ac52a036463..6b15d77e2377 100644 --- a/src/Interval.cpp +++ b/src/Interval.cpp @@ -470,6 +470,19 @@ void ConstantInterval::cast_to(Type t) { } } +ConstantInterval ConstantInterval::operator-() const { + ConstantInterval result; + if (min_defined && min != INT64_MIN) { + result.max_defined = true; + result.max = -min; + } + if (max_defined) { + result.min_defined = true; + result.min = -max; + } + return result; +} + ConstantInterval ConstantInterval::bounds_of_type(Type t) { return cast(t, ConstantInterval::everything()); } @@ -536,6 +549,43 @@ ConstantInterval abs(const ConstantInterval &a) { return result; } +ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b) { + // Try to map this to a multiplication and a division + ConstantInterval mul, div; + constexpr int64_t one = 1; + if (b.min_defined) { + if (b.min >= 0 && b.min < 63) { + mul.min = one << b.min; + mul.min_defined = true; + div.max = one; + div.max_defined = true; + } else if (b.min > -63 && b.min <= 0) { + mul.min = one; + mul.min_defined = true; + div.max = one << (-b.min); + div.max_defined = true; + } + } + if (b.max_defined) { + if (b.max >= 0 && b.max < 63) { + mul.max = one << b.max; + mul.max_defined = true; + div.min = one; + div.min_defined = true; + } else if (b.max > -63 && b.max <= 0) { + mul.max = one; + mul.max_defined = true; + div.min = one << (-b.max); + div.min_defined = true; + } + } + return (a * mul) / div; +} + +ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b) { + return a << (-b); +} + } // namespace Internal ConstantInterval cast(Type t, const ConstantInterval &a) { diff --git a/src/Interval.h b/src/Interval.h index af4942d5b2d7..7fab81e461ba 100644 --- a/src/Interval.h +++ b/src/Interval.h @@ -180,6 +180,9 @@ struct ConstantInterval { void operator/=(const ConstantInterval &other); // @} + /** Negate an interval. */ + ConstantInterval operator-() const; + /** Track what happens if a constant integer interval is forced to fit into * a concrete integer type. */ void cast_to(Type t); @@ -202,6 +205,8 @@ ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b); ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b); ConstantInterval abs(const ConstantInterval &a); +ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b); // @} } // namespace Internal diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 8286bc68f9e6..4a81dfbdf926 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -253,6 +253,17 @@ class SimdOpCheckX86 : public SimdOpCheckTest { for (int w = 2; w <= 4; w++) { check("pmulhrsw", 4 * w, i16((i32(i16_1) * i32(i16_2) + 16384) >> 15)); check("pmulhrsw", 4 * w, i16_sat((i32(i16_1) * i32(i16_2) + 16384) >> 15)); + // Should be able to use the non-saturating form of pmulhrsw, + // because the second arg can't be -32768, so the i16_sat + // doesn't actually need to saturate. + check("pmulhrsw", 4 * w, i16_sat((i32(i16_1) * i32(i16_2 / 2) + 16384) >> 15)); + + // Should be able to use pmulhrsw despite the shift being too + // small, because there are enough bits of headroom to shift + // left one of the args: + check("pmulhrsw", 4 * w, i16_sat((i32(i16_1) * i32(i16_2 / 2) + 8192) >> 14)); + check("pmulhrsw", 4 * w, i16((i32(i16_1) * i32(i16_2 / 3) + 8192) >> 14)); + check("pabsb", 8 * w, abs(i8_1)); check("pabsw", 4 * w, abs(i16_1)); check("pabsd", 2 * w, abs(i32_1)); From 67855a578cddc417a1428bd162690644e6d87e71 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 25 Mar 2024 08:44:24 -0700 Subject: [PATCH 05/33] Move new classes to new files Also fix up Monotonic.cpp --- Makefile | 4 + src/Bounds.cpp | 152 ------------ src/Bounds.h | 3 - src/CMakeLists.txt | 4 + src/CodeGen_X86.cpp | 4 +- src/ConstantBounds.cpp | 161 +++++++++++++ src/ConstantInterval.cpp | 508 +++++++++++++++++++++++++++++++++++++++ src/FindIntrinsics.cpp | 9 +- src/IROperator.cpp | 4 +- src/IRPrinter.cpp | 31 +++ src/IRPrinter.h | 10 +- src/Interval.cpp | 467 ----------------------------------- src/Interval.h | 107 --------- src/Monotonic.cpp | 246 +++++-------------- src/Monotonic.h | 5 +- src/Type.cpp | 5 + src/Type.h | 8 + 17 files changed, 802 insertions(+), 926 deletions(-) create mode 100644 src/ConstantBounds.cpp create mode 100644 src/ConstantInterval.cpp diff --git a/Makefile b/Makefile index 17e8a80e1ca4..440b307a920e 100644 --- a/Makefile +++ b/Makefile @@ -477,6 +477,8 @@ SOURCE_FILES = \ CodeGen_WebGPU_Dev.cpp \ CodeGen_X86.cpp \ CompilerLogger.cpp \ + ConstantBounds.cpp \ + ConstantInterval.cpp \ CPlusPlusMangle.cpp \ CSE.cpp \ Debug.cpp \ @@ -671,6 +673,8 @@ HEADER_FILES = \ CompilerLogger.h \ ConciseCasts.h \ CPlusPlusMangle.h \ + ConstantBounds.h \ + ConstantInterval.h \ CSE.h \ Debug.h \ DebugArguments.h \ diff --git a/src/Bounds.cpp b/src/Bounds.cpp index ab4aaf388516..16fd69f3e8fb 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1854,158 +1854,6 @@ Interval bounds_of_expr_in_scope(const Expr &expr, const Scope &scope, return bounds_of_expr_in_scope_with_indent(expr, scope, fb, const_bound, 0); } -// TODO: This is not the best place for this code. Also it should be a visitor. Maybe ConstantBounds.cpp.h -ConstantInterval constant_integer_bounds(const Expr &e) { - internal_assert(e.defined()); - - auto ret = [&]() { - // Compute the bounds of each IR node from the bounds of its args. Math - // on ConstantInterval is in terms of infinite integers, so any op that - // can overflow needs to cast the resulting interval back to the output - // type. - if (const UIntImm *op = e.as()) { - if (Int(64).can_represent(op->value)) { - return ConstantInterval::single_point((int64_t)(op->value)); - } else { - return ConstantInterval::everything(); - } - } else if (const IntImm *op = e.as()) { - return ConstantInterval::single_point(op->value); - } else if (const Add *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->a) + constant_integer_bounds(op->b)); - } else if (const Sub *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->a) - constant_integer_bounds(op->b)); - } else if (const Mul *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->a) * constant_integer_bounds(op->b)); - } else if (const Div *op = e.as
()) { - // Can overflow when dividing type.min() by -1 - return cast(op->type, constant_integer_bounds(op->a) / constant_integer_bounds(op->b)); - } else if (const Min *op = e.as()) { - return min(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); - } else if (const Max *op = e.as()) { - return max(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); - } else if (const Cast *op = e.as()) { - return cast(op->type, constant_integer_bounds(op->value)); - } else if (const Broadcast *op = e.as()) { - return constant_integer_bounds(op->value); - } else if (const VectorReduce *op = e.as()) { - int f = op->value.type().lanes() / op->type.lanes(); - ConstantInterval factor(f, f); - ConstantInterval arg_bounds = constant_integer_bounds(op->value); - switch (op->op) { - case VectorReduce::Add: - return cast(op->type, arg_bounds * factor); - case VectorReduce::SaturatingAdd: - return saturating_cast(op->type, arg_bounds * factor); - case VectorReduce::Min: - case VectorReduce::Max: - case VectorReduce::And: - case VectorReduce::Or: - return arg_bounds; - default:; - } - } else if (const Shuffle *op = e.as()) { - ConstantInterval arg_bounds = constant_integer_bounds(op->vectors[0]); - for (size_t i = 1; i < op->vectors.size(); i++) { - arg_bounds.include(constant_integer_bounds(op->vectors[i])); - } - return arg_bounds; - } else if (const Call *op = e.as()) { - // For all intrinsics that can't possibly overflow, we don't need the - // final cast. - if (op->is_intrinsic(Call::abs)) { - return abs(constant_integer_bounds(op->args[0])); - } else if (op->is_intrinsic(Call::absd)) { - return abs(constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1])); - } else if (op->is_intrinsic(Call::count_leading_zeros) || - op->is_intrinsic(Call::count_trailing_zeros)) { - // Conservatively just say it's the potential number of zeros in the type. - return ConstantInterval(0, op->args[0].type().bits()); - } else if (op->is_intrinsic(Call::halving_add)) { - return (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1])) / - ConstantInterval(2, 2); - } else if (op->is_intrinsic(Call::halving_sub)) { - return cast(op->type, (constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1])) / - ConstantInterval(2, 2)); - } else if (op->is_intrinsic(Call::rounding_halving_add)) { - return (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]) + - ConstantInterval(1, 1)) / - ConstantInterval(2, 2); - } else if (op->is_intrinsic(Call::saturating_add)) { - return saturating_cast(op->type, - (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::saturating_sub)) { - return saturating_cast(op->type, - (constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::widening_add)) { - return constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]); - } else if (op->is_intrinsic(Call::widening_sub)) { - // widening ops can't overflow ... - return constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1]); - } else if (op->is_intrinsic(Call::widening_mul)) { - return constant_integer_bounds(op->args[0]) * - constant_integer_bounds(op->args[1]); - } else if (op->is_intrinsic(Call::widen_right_add)) { - // but the widen_right versions can overflow - return cast(op->type, (constant_integer_bounds(op->args[0]) + - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::widen_right_sub)) { - return cast(op->type, (constant_integer_bounds(op->args[0]) - - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::widen_right_mul)) { - return cast(op->type, (constant_integer_bounds(op->args[0]) * - constant_integer_bounds(op->args[1]))); - } else if (op->is_intrinsic(Call::shift_right)) { - return cast(op->type, constant_integer_bounds(op->args[0]) >> constant_integer_bounds(op->args[1])); - } else if (op->is_intrinsic(Call::shift_left)) { - return cast(op->type, constant_integer_bounds(op->args[0]) << constant_integer_bounds(op->args[1])); - } else if (op->is_intrinsic(Call::rounding_shift_right)) { - ConstantInterval ca = constant_integer_bounds(op->args[0]); - ConstantInterval cb = constant_integer_bounds(op->args[1]); - ConstantInterval rounding_term; - if (cb.min_defined && cb.min > 0) { - auto rounding_term = ConstantInterval(1, 1) << (cb - ConstantInterval(1, 1)); - // rounding shift right with a positive RHS can't overflow, - // so no cast required. - return (ca + rounding_term) >> cb; - } else if (cb.max_defined && cb.max <= 0) { - return cast(op->type, ca << (-cb)); - } else { - auto rounding_term = ConstantInterval(0, 1) << max(cb - ConstantInterval(1, 1), ConstantInterval(0, 0)); - return cast(op->type, (ca + rounding_term) >> cb); - } - } - // TODO: more intrinsics - // TODO: widening_shift_left is important - } - - return ConstantInterval::bounds_of_type(e.type()); - }(); - - debug(0) << "constant_integer_bounds(" << e << ") =\n " - << ret.min_defined << " " << ret.min << " " << ret.max_defined << " " << ret.max << "\n"; - - if (true) { - internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && - (!ret.max_defined || e.type().can_represent(ret.max))) - << "Expr: " << e << "\n" - << " min_defined = " << ret.min_defined << "\n" - << " max_defined = " << ret.max_defined << "\n" - << " min = " << ret.min << "\n" - << " max = " << ret.max << "\n"; - } - - return ret; -} - Region region_union(const Region &a, const Region &b) { internal_assert(a.size() == b.size()) << "Mismatched dimensionality in region union\n"; Region result; diff --git a/src/Bounds.h b/src/Bounds.h index a0655bcba636..bafa42ecda1a 100644 --- a/src/Bounds.h +++ b/src/Bounds.h @@ -48,9 +48,6 @@ Expr find_constant_bound(const Expr &e, Direction d, * +/-inf. */ Interval find_constant_bounds(const Expr &e, const Scope &scope); -// TODO: comment -ConstantInterval constant_integer_bounds(const Expr &e); - /** Represents the bounds of a region of arbitrary dimension. Zero * dimensions corresponds to a scalar region. */ struct Box { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 557574f284c4..2f410244d2b0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -45,6 +45,8 @@ set(HEADER_FILES CompilerLogger.h ConciseCasts.h CPlusPlusMangle.h + ConstantBounds.h + ConstantInterval.h CSE.h Debug.h DebugArguments.h @@ -219,6 +221,8 @@ set(SOURCE_FILES CodeGen_X86.cpp CompilerLogger.cpp CPlusPlusMangle.cpp + ConstantBounds.cpp + ConstantInterval.cpp CSE.cpp Debug.cpp DebugArguments.cpp diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 47225d3e21ad..3b40375342d5 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,7 +1,7 @@ -#include "Bounds.h" #include "CodeGen_Internal.h" #include "CodeGen_Posix.h" #include "ConciseCasts.h" +#include "ConstantBounds.h" #include "Debug.h" #include "IRMatch.h" #include "IRMutator.h" @@ -539,7 +539,7 @@ void CodeGen_X86::visit(const Cast *op) { // clang-format off static Pattern patterns[] = { - // This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't + // This isn't rounding_mul_shift_right(i16, i16, 15) because it doesn't // saturate the result. {"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))}, diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp new file mode 100644 index 000000000000..071fc63e5365 --- /dev/null +++ b/src/ConstantBounds.cpp @@ -0,0 +1,161 @@ +#include "ConstantBounds.h" +#include "IR.h" +#include "IROperator.h" +#include "IRPrinter.h" + +namespace Halide { +namespace Internal { + +ConstantInterval constant_integer_bounds(const Expr &e, const Scope &scope) { + internal_assert(e.defined()); + + auto ret = [&]() { + // Compute the bounds of each IR node from the bounds of its args. Math + // on ConstantInterval is in terms of infinite integers, so any op that + // can overflow needs to cast the resulting interval back to the output + // type. + if (const UIntImm *op = e.as()) { + if (Int(64).can_represent(op->value)) { + return ConstantInterval::single_point((int64_t)(op->value)); + } else { + return ConstantInterval::everything(); + } + } else if (const IntImm *op = e.as()) { + return ConstantInterval::single_point(op->value); + } else if (const Variable *op = e.as()) { + if (const auto *in = scope.find(op->name)) { + return *in; + } + } else if (const Add *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) + constant_integer_bounds(op->b)); + } else if (const Sub *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) - constant_integer_bounds(op->b)); + } else if (const Mul *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) * constant_integer_bounds(op->b)); + } else if (const Div *op = e.as
()) { + // Can overflow when dividing type.min() by -1 + return cast(op->type, constant_integer_bounds(op->a) / constant_integer_bounds(op->b)); + } else if (const Min *op = e.as()) { + return min(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); + } else if (const Max *op = e.as()) { + return max(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); + } else if (const Cast *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->value)); + } else if (const Broadcast *op = e.as()) { + return constant_integer_bounds(op->value); + } else if (const VectorReduce *op = e.as()) { + int f = op->value.type().lanes() / op->type.lanes(); + ConstantInterval factor(f, f); + ConstantInterval arg_bounds = constant_integer_bounds(op->value); + switch (op->op) { + case VectorReduce::Add: + return cast(op->type, arg_bounds * factor); + case VectorReduce::SaturatingAdd: + return saturating_cast(op->type, arg_bounds * factor); + case VectorReduce::Min: + case VectorReduce::Max: + case VectorReduce::And: + case VectorReduce::Or: + return arg_bounds; + default:; + } + } else if (const Shuffle *op = e.as()) { + ConstantInterval arg_bounds = constant_integer_bounds(op->vectors[0]); + for (size_t i = 1; i < op->vectors.size(); i++) { + arg_bounds.include(constant_integer_bounds(op->vectors[i])); + } + return arg_bounds; + } else if (const Call *op = e.as()) { + // For all intrinsics that can't possibly overflow, we don't need the + // final cast. + if (op->is_intrinsic(Call::abs)) { + return abs(constant_integer_bounds(op->args[0])); + } else if (op->is_intrinsic(Call::absd)) { + return abs(constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::count_leading_zeros) || + op->is_intrinsic(Call::count_trailing_zeros)) { + // Conservatively just say it's the potential number of zeros in the type. + return ConstantInterval(0, op->args[0].type().bits()); + } else if (op->is_intrinsic(Call::halving_add)) { + return (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1])) / + ConstantInterval(2, 2); + } else if (op->is_intrinsic(Call::halving_sub)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1])) / + ConstantInterval(2, 2)); + } else if (op->is_intrinsic(Call::rounding_halving_add)) { + return (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]) + + ConstantInterval(1, 1)) / + ConstantInterval(2, 2); + } else if (op->is_intrinsic(Call::saturating_add)) { + return saturating_cast(op->type, + (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::saturating_sub)) { + return saturating_cast(op->type, + (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widening_add)) { + return constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widening_sub)) { + // widening ops can't overflow ... + return constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widening_mul)) { + return constant_integer_bounds(op->args[0]) * + constant_integer_bounds(op->args[1]); + } else if (op->is_intrinsic(Call::widen_right_add)) { + // but the widen_right versions can overflow + return cast(op->type, (constant_integer_bounds(op->args[0]) + + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widen_right_sub)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) - + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::widen_right_mul)) { + return cast(op->type, (constant_integer_bounds(op->args[0]) * + constant_integer_bounds(op->args[1]))); + } else if (op->is_intrinsic(Call::shift_right)) { + return cast(op->type, constant_integer_bounds(op->args[0]) >> constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::shift_left)) { + return cast(op->type, constant_integer_bounds(op->args[0]) << constant_integer_bounds(op->args[1])); + } else if (op->is_intrinsic(Call::rounding_shift_right)) { + ConstantInterval ca = constant_integer_bounds(op->args[0]); + ConstantInterval cb = constant_integer_bounds(op->args[1]); + ConstantInterval rounding_term; + if (cb.min_defined && cb.min > 0) { + auto rounding_term = ConstantInterval(1, 1) << (cb - ConstantInterval(1, 1)); + // rounding shift right with a positive RHS can't overflow, + // so no cast required. + return (ca + rounding_term) >> cb; + } else if (cb.max_defined && cb.max <= 0) { + return cast(op->type, ca << (-cb)); + } else { + auto rounding_term = ConstantInterval(0, 1) << max(cb - ConstantInterval(1, 1), ConstantInterval(0, 0)); + return cast(op->type, (ca + rounding_term) >> cb); + } + } + // TODO: more intrinsics + // TODO: widening_shift_left is important + } + + return ConstantInterval::bounds_of_type(e.type()); + }(); + + debug(0) << "constant_integer_bounds(" << e << ") =\n " << ret << "\n"; + + if (true) { + internal_assert((!ret.has_lower_bound() || e.type().can_represent(ret.min)) && + (!ret.has_upper_bound() || e.type().can_represent(ret.max))) + << "constant_bounds returned defined bounds that are not representable in " + << "the type of the Expr passed in.\n Expr: " << e << "\n Bounds: " << ret; + } + + return ret; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp new file mode 100644 index 000000000000..03fbb692f540 --- /dev/null +++ b/src/ConstantInterval.cpp @@ -0,0 +1,508 @@ +#include "ConstantInterval.h" + +#include "Error.h" +#include "IROperator.h" + +namespace Halide { +namespace Internal { + +ConstantInterval::ConstantInterval() = default; + +ConstantInterval::ConstantInterval(int64_t min, int64_t max) + : min(min), max(max), min_defined(true), max_defined(true) { + internal_assert(min <= max); +} + +ConstantInterval ConstantInterval::everything() { + return ConstantInterval(); +} + +ConstantInterval ConstantInterval::single_point(int64_t x) { + return ConstantInterval(x, x); +} + +ConstantInterval ConstantInterval::bounded_below(int64_t min) { + ConstantInterval result(min, min); + result.max_defined = false; + return result; +} + +ConstantInterval ConstantInterval::bounded_above(int64_t max) { + ConstantInterval result(max, max); + result.min_defined = false; + return result; +} + +bool ConstantInterval::is_everything() const { + return !min_defined && !max_defined; +} + +bool ConstantInterval::is_single_point() const { + return min_defined && max_defined && min == max; +} + +bool ConstantInterval::is_single_point(int64_t x) const { + return min_defined && max_defined && min == x && max == x; +} + +bool ConstantInterval::has_upper_bound() const { + return max_defined; +} + +bool ConstantInterval::has_lower_bound() const { + return min_defined; +} + +bool ConstantInterval::is_bounded() const { + return has_upper_bound() && has_lower_bound(); +} + +bool ConstantInterval::operator==(const ConstantInterval &other) const { + if (min_defined != other.min_defined || max_defined != other.max_defined) { + return false; + } + return (!min_defined || min == other.min) && (!max_defined || max == other.max); +} + +void ConstantInterval::include(const ConstantInterval &i) { + if (max_defined && i.max_defined) { + max = std::max(max, i.max); + } else { + max_defined = false; + } + if (min_defined && i.min_defined) { + min = std::min(min, i.min); + } else { + min_defined = false; + } +} + +void ConstantInterval::include(int64_t x) { + if (max_defined) { + max = std::max(max, x); + } + if (min_defined) { + min = std::min(min, x); + } +} + +bool ConstantInterval::contains(int64_t x) const { + return !((min_defined && x < min) || + (max_defined && x > max)); +} + +ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result.include(b); + return result; +} + +// TODO: These were taken directly from the simplifier, so change the simplifier +// to use these instead of duplicating the code. +void ConstantInterval::operator+=(const ConstantInterval &other) { + min_defined = min_defined && + other.min_defined && + add_with_overflow(64, min, other.min, &min); + max_defined = max_defined && + other.max_defined && + add_with_overflow(64, max, other.max, &max); +} + +void ConstantInterval::operator-=(const ConstantInterval &other) { + min_defined = min_defined && + other.max_defined && + sub_with_overflow(64, min, other.max, &min); + max_defined = max_defined && + other.min_defined && + sub_with_overflow(64, max, other.min, &max); +} + +void ConstantInterval::operator*=(const ConstantInterval &other) { + ConstantInterval result; + + // Compute a possible extreme value of the product, setting the min/max + // defined flags if it's unbounded. + auto saturating_mul = [&](int64_t a, int64_t b) -> int64_t { + int64_t c; + if (mul_with_overflow(64, a, b, &c)) { + return c; + } else if ((a > 0) == (b > 0)) { + result.max_defined = false; + return INT64_MAX; + } else { + result.min_defined = false; + return INT64_MIN; + } + }; + + bool positive = min_defined && min > 0; + bool other_positive = other.min_defined && other.min > 0; + bool bounded = min_defined && max_defined; + bool other_bounded = other.min_defined && other.max_defined; + + if (bounded && other_bounded) { + // Both are bounded + result.min_defined = result.max_defined = true; + int64_t v1 = saturating_mul(min, other.min); + int64_t v2 = saturating_mul(min, other.max); + int64_t v3 = saturating_mul(max, other.min); + int64_t v4 = saturating_mul(max, other.max); + if (result.min_defined) { + result.min = std::min(std::min(v1, v2), std::min(v3, v4)); + } else { + result.min = 0; + } + if (result.max_defined) { + result.max = std::max(std::max(v1, v2), std::max(v3, v4)); + } else { + result.max = 0; + } + } else if ((max_defined && other_bounded && other_positive) || + (other.max_defined && bounded && positive)) { + // One side has a max, and the other side is bounded and positive + // (e.g. a constant). + result.max = saturating_mul(max, other.max); + if (!result.max_defined) { + result.max = 0; + } + } else if ((min_defined && other_bounded && other_positive) || + (other.min_defined && bounded && positive)) { + // One side has a min, and the other side is bounded and positive + // (e.g. a constant). + min = saturating_mul(min, other.min); + if (!result.min_defined) { + result.min = 0; + } + } + // TODO: what about the above two cases, but for multiplication by bounded + // and negative intervals? + + *this = result; +} + +void ConstantInterval::operator/=(const ConstantInterval &other) { + ConstantInterval result; + + result.min = INT64_MAX; + result.max = INT64_MIN; + + // Enumerate all possible values for the min and max and take the extreme values. + if (min_defined && other.min_defined && other.min != 0) { + int64_t v = div_imp(min, other.min); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + if (min_defined && other.max_defined && other.max != 0) { + int64_t v = div_imp(min, other.max); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + if (max_defined && other.max_defined && other.max != 0) { + int64_t v = div_imp(max, other.max); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + if (max_defined && other.min_defined && other.min != 0) { + int64_t v = div_imp(max, other.min); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + } + + // Define an int64_t zero just to pacify std::min and std::max + constexpr int64_t zero = 0; + + const bool other_positive = other.min_defined && other.min > 0; + const bool other_negative = other.max_defined && other.max < 0; + if ((other_positive && !other.max_defined) || + (other_negative && !other.min_defined)) { + // Take limit as other -> +/- infinity + result.min = std::min(result.min, zero); + result.max = std::max(result.max, zero); + } + + bool bounded_numerator = min_defined && max_defined; + + result.min_defined = ((min_defined && other_positive) || + (max_defined && other_negative)); + result.max_defined = ((max_defined && other_positive) || + (min_defined && other_negative)); + + // That's as far as we can get knowing the sign of the + // denominator. For bounded numerators, we additionally know + // that div can't make anything larger in magnitude, so we can + // take the intersection with that. + if (bounded_numerator && min != INT64_MIN) { + int64_t magnitude = std::max(max, -min); + if (result.min_defined) { + result.min = std::max(result.min, -magnitude); + } else { + result.min = -magnitude; + } + if (result.max_defined) { + result.max = std::min(result.max, magnitude); + } else { + result.max = magnitude; + } + result.min_defined = result.max_defined = true; + } + + // Finally we can provide a bound if the numerator and denominator are + // non-positive or non-negative. + bool numerator_non_negative = min_defined && min >= 0; + bool denominator_non_negative = other.min_defined && other.min >= 0; + bool numerator_non_positive = max_defined && max <= 0; + bool denominator_non_positive = other.max_defined && other.max <= 0; + if ((numerator_non_negative && denominator_non_negative) || + (numerator_non_positive && denominator_non_positive)) { + if (result.min_defined) { + result.min = std::max(result.min, zero); + } else { + result.min_defined = true; + result.min = 0; + } + } + if ((numerator_non_negative && denominator_non_positive) || + (numerator_non_positive && denominator_non_negative)) { + if (result.max_defined) { + result.max = std::min(result.max, zero); + } else { + result.max_defined = true; + result.max = 0; + } + } + + // Normalize the values if it's undefined + if (!result.min_defined) { + result.min = 0; + } + if (!result.max_defined) { + result.max = 0; + } + + *this = result; +} + +void ConstantInterval::operator+=(int64_t x) { + // TODO: Optimize this + *this += ConstantInterval(x, x); +} + +void ConstantInterval::operator-=(int64_t x) { + // TODO: Optimize this + *this -= ConstantInterval(x, x); +} + +void ConstantInterval::operator*=(int64_t x) { + // TODO: Optimize this + *this *= ConstantInterval(x, x); +} + +void ConstantInterval::operator/=(int64_t x) { + // TODO: Optimize this + *this /= ConstantInterval(x, x); +} + +void ConstantInterval::cast_to(const Type &t) { + if (!t.can_represent(*this)) { + // We have potential overflow or underflow, return the entire bounds of + // the type. + ConstantInterval type_bounds; + if (t.is_int()) { + if (t.bits() <= 64) { + type_bounds.min_defined = type_bounds.max_defined = true; + type_bounds.min = ((int64_t)(-1)) << (t.bits() - 1); + type_bounds.max = ~type_bounds.min; + } + } else if (t.is_uint()) { + type_bounds.min_defined = true; + type_bounds.min = 0; + if (t.bits() < 64) { + type_bounds.max_defined = true; + type_bounds.max = (((int64_t)(1)) << t.bits()) - 1; + } + } + // If it's not int or uint, we're setting this to a default-constructed + // ConstantInterval, which is everything. + *this = type_bounds; + } +} + +ConstantInterval ConstantInterval::operator-() const { + ConstantInterval result; + if (min_defined && min != INT64_MIN) { + result.max_defined = true; + result.max = -min; + } + if (max_defined) { + result.min_defined = true; + result.min = -max; + } + return result; +} + +ConstantInterval ConstantInterval::bounds_of_type(Type t) { + return cast(t, ConstantInterval::everything()); +} + +ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result += b; + return result; +} + +ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result -= b; + return result; +} + +ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result /= b; + return result; +} + +ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result *= b; + return result; +} + +ConstantInterval operator+(const ConstantInterval &a, int64_t b) { + return a + ConstantInterval(b, b); +} + +ConstantInterval operator-(const ConstantInterval &a, int64_t b) { + return a - ConstantInterval(b, b); +} + +ConstantInterval operator/(const ConstantInterval &a, int64_t b) { + return a / ConstantInterval(b, b); +} + +ConstantInterval operator*(const ConstantInterval &a, int64_t b) { + return a * ConstantInterval(b, b); +} + +ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + if (a.min_defined && b.min_defined && b.min < a.min) { + result.min = b.min; + } + if (a.max_defined && b.max_defined && b.max < a.max) { + result.max = b.max; + } + return result; +} + +ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + if (a.min_defined && b.min_defined && b.min > a.min) { + result.min = b.min; + } + if (a.max_defined && b.max_defined && b.max > a.max) { + result.max = b.max; + } + return result; +} + +ConstantInterval abs(const ConstantInterval &a) { + ConstantInterval result; + if (a.min_defined && a.max_defined && a.min != INT64_MIN) { + result.max_defined = true; + result.max = std::max(-a.min, a.max); + } + result.min_defined = true; + if (a.min_defined && a.min > 0) { + result.min = a.min; + } else { + result.min = 0; + } + + return result; +} + +ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b) { + // Try to map this to a multiplication and a division + ConstantInterval mul, div; + constexpr int64_t one = 1; + if (b.min_defined) { + if (b.min >= 0 && b.min < 63) { + mul.min = one << b.min; + mul.min_defined = true; + div.max = one; + div.max_defined = true; + } else if (b.min > -63 && b.min <= 0) { + mul.min = one; + mul.min_defined = true; + div.max = one << (-b.min); + div.max_defined = true; + } + } + if (b.max_defined) { + if (b.max >= 0 && b.max < 63) { + mul.max = one << b.max; + mul.max_defined = true; + div.min = one; + div.min_defined = true; + } else if (b.max > -63 && b.max <= 0) { + mul.max = one; + mul.max_defined = true; + div.min = one << (-b.max); + div.min_defined = true; + } + } + return (a * mul) / div; +} + +ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b) { + return a << (-b); +} + +} // namespace Internal + +using namespace Internal; + +ConstantInterval cast(Type t, const ConstantInterval &a) { + ConstantInterval result = a; + result.cast_to(t); + return result; +} + +ConstantInterval saturating_cast(Type t, const ConstantInterval &a) { + ConstantInterval b = ConstantInterval::bounds_of_type(t); + + if (b.max_defined && a.min_defined && a.min > b.max) { + return ConstantInterval(b.max, b.max); + } else if (b.min_defined && a.max_defined && a.max < b.min) { + return ConstantInterval(b.min, b.min); + } + + ConstantInterval result = a; + result.max_defined = a.max_defined || b.max_defined; + if (a.max_defined) { + if (b.max_defined) { + result.max = std::min(a.max, b.max); + } else { + result.max = a.max; + } + } else if (b.max_defined) { + result.max = b.max; + } + result.min_defined = a.min_defined || b.min_defined; + if (a.min_defined) { + if (b.min_defined) { + result.min = std::max(a.min, b.min); + } else { + result.min = a.min; + } + } else if (b.min_defined) { + result.min = b.min; + } + return result; +} + +} // namespace Halide diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 34328380ca00..0008746d81aa 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1,8 +1,8 @@ #include "FindIntrinsics.h" -#include "Bounds.h" #include "CSE.h" #include "CodeGen_Internal.h" #include "ConciseCasts.h" +#include "ConstantBounds.h" #include "IRMatch.h" #include "IRMutator.h" #include "Simplify.h" @@ -554,8 +554,9 @@ class FindIntrinsics : public IRMutator { // Do we need to worry about this cast overflowing? ConstantInterval value_bounds = constant_integer_bounds(value); debug(0) << "Bounds of " << Expr(op) << " are " << value_bounds.min << " " << value_bounds.min_defined << " " << value_bounds.max << " " << value_bounds.max_defined << "\n"; + bool no_overflow = (op->type.can_represent(op->value.type()) || - value_bounds.within(op->type)); + op->type.can_represent(value_bounds)); if (op->type.is_int() || op->type.is_uint()) { Expr lower = cast(value.type(), op->type.min()); @@ -1499,7 +1500,7 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) debug(0) << " ca = " << ca.min << " " << ca.min_defined << " " << ca.max << " " << ca.max_defined << "\n"; do { ConstantInterval bigger = ca * ConstantInterval::single_point(2); - if (bigger.within(a.type()) && a_shift + b_shift < missing_q) { + if (a.type().can_represent(bigger) && a_shift + b_shift < missing_q) { ca = bigger; a_shift++; continue; @@ -1509,7 +1510,7 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) debug(0) << " cb = " << cb.min << " " << cb.min_defined << " " << cb.max << " " << cb.max_defined << "\n"; do { ConstantInterval bigger = cb * ConstantInterval::single_point(2); - if (bigger.within(b.type()) && b_shift + b_shift < missing_q) { + if (b.type().can_represent(bigger) && b_shift + b_shift < missing_q) { cb = bigger; b_shift++; continue; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index b27754d20a81..b857eb4947b8 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -5,8 +5,8 @@ #include #include -#include "Bounds.h" #include "CSE.h" +#include "ConstantBounds.h" #include "Debug.h" #include "Func.h" #include "IREquality.h" @@ -485,7 +485,7 @@ Expr lossless_cast(Type t, Expr e) { // We'll just throw a cast around something, if the bounds are small // enough. ConstantInterval ci = constant_integer_bounds(e); - if (ci.within(t)) { + if (t.can_represent(ci)) { // There are certain IR nodes where if the result is expressible // using some type, and the args are expressible using that type, // then the operation can just be done in that type. diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index a186be1874d7..5d4c43304eb0 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -6,7 +6,9 @@ #include "AssociativeOpsTable.h" #include "Associativity.h" #include "Closure.h" +#include "ConstantInterval.h" #include "IROperator.h" +#include "Interval.h" #include "Module.h" #include "Target.h" #include "Util.h" @@ -446,6 +448,35 @@ std::ostream &operator<<(std::ostream &out, const Closure &c) { return out; } +namespace { +template +void emit_interval(std::ostream &out, const T &in) { + out << "["; + if (in.has_lower_bound()) { + out << in.min; + } else { + out << "-inf"; + } + out << ", "; + if (in.has_upper_bound()) { + out << in.max; + } else { + out << "inf"; + } + out << "]"; +} +} // namespace + +std::ostream &operator<<(std::ostream &out, const Interval &c) { + emit_interval(out, c); + return out; +} + +std::ostream &operator<<(std::ostream &out, const ConstantInterval &c) { + emit_interval(out, c); + return out; +} + IRPrinter::IRPrinter(ostream &s) : stream(s) { s.setf(std::ios::fixed, std::ios::floatfield); diff --git a/src/IRPrinter.h b/src/IRPrinter.h index 849e50b816f4..161960077b93 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -58,6 +58,8 @@ namespace Internal { struct AssociativePattern; struct AssociativeOp; class Closure; +struct Interval; +struct ConstantInterval; /** Emit a halide associative pattern on an output stream (such as std::cout) * in a human-readable form */ @@ -90,9 +92,15 @@ std::ostream &operator<<(std::ostream &stream, const LinkageType &); /** Emit a halide dimension type in human-readable format */ std::ostream &operator<<(std::ostream &stream, const DimType &); -/** Emit a Closure in human-readable format */ +/** Emit a Closure in human-readable form */ std::ostream &operator<<(std::ostream &out, const Closure &c); +/** Emit an Interval in human-readable form */ +std::ostream &operator<<(std::ostream &out, const Interval &c); + +/** Emit a ConstantInterval in human-readable form */ +std::ostream &operator<<(std::ostream &out, const ConstantInterval &c); + struct Indentation { int indent; }; diff --git a/src/Interval.cpp b/src/Interval.cpp index 6b15d77e2377..7d0cc41d44b9 100644 --- a/src/Interval.cpp +++ b/src/Interval.cpp @@ -159,472 +159,5 @@ Expr Interval::neg_inf_noinline() { return Interval::neg_inf_expr; } -ConstantInterval::ConstantInterval() = default; - -ConstantInterval::ConstantInterval(int64_t min, int64_t max) - : min(min), max(max), min_defined(true), max_defined(true) { - internal_assert(min <= max); -} - -ConstantInterval ConstantInterval::everything() { - return ConstantInterval(); -} - -ConstantInterval ConstantInterval::single_point(int64_t x) { - return ConstantInterval(x, x); -} - -ConstantInterval ConstantInterval::bounded_below(int64_t min) { - ConstantInterval result(min, min); - result.max_defined = false; - return result; -} - -ConstantInterval ConstantInterval::bounded_above(int64_t max) { - ConstantInterval result(max, max); - result.min_defined = false; - return result; -} - -bool ConstantInterval::is_everything() const { - return !min_defined && !max_defined; -} - -bool ConstantInterval::is_single_point() const { - return min_defined && max_defined && min == max; -} - -bool ConstantInterval::is_single_point(int64_t x) const { - return min_defined && max_defined && min == x && max == x; -} - -bool ConstantInterval::has_upper_bound() const { - return max_defined; -} - -bool ConstantInterval::has_lower_bound() const { - return min_defined; -} - -bool ConstantInterval::is_bounded() const { - return has_upper_bound() && has_lower_bound(); -} - -bool ConstantInterval::operator==(const ConstantInterval &other) const { - if (min_defined != other.min_defined || max_defined != other.max_defined) { - return false; - } - return (!min_defined || min == other.min) && (!max_defined || max == other.max); -} - -void ConstantInterval::include(const ConstantInterval &i) { - if (max_defined && i.max_defined) { - max = std::max(max, i.max); - } else { - max_defined = false; - } - if (min_defined && i.min_defined) { - min = std::min(min, i.min); - } else { - min_defined = false; - } -} - -void ConstantInterval::include(int64_t x) { - if (max_defined) { - max = std::max(max, x); - } - if (min_defined) { - min = std::min(min, x); - } -} - -bool ConstantInterval::contains(int64_t x) const { - return !((min_defined && x < min) || - (max_defined && x > max)); -} - -bool ConstantInterval::within(Type t) const { - return min_defined && - max_defined && - t.can_represent(min) && - t.can_represent(max); -} - -ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result.include(b); - return result; -} - -// TODO: These were taken directly from the simplifier, so change the simplifier -// to use these instead of duplicating the code. -void ConstantInterval::operator+=(const ConstantInterval &other) { - min_defined = min_defined && - other.min_defined && - add_with_overflow(64, min, other.min, &min); - max_defined = max_defined && - other.max_defined && - add_with_overflow(64, max, other.max, &max); -} - -void ConstantInterval::operator-=(const ConstantInterval &other) { - min_defined = min_defined && - other.max_defined && - sub_with_overflow(64, min, other.max, &min); - max_defined = max_defined && - other.min_defined && - sub_with_overflow(64, max, other.min, &max); -} - -void ConstantInterval::operator*=(const ConstantInterval &other) { - ConstantInterval result; - - // Compute a possible extreme value of the product, setting the min/max - // defined flags if it's unbounded. - auto saturating_mul = [&](int64_t a, int64_t b) -> int64_t { - int64_t c; - if (mul_with_overflow(64, a, b, &c)) { - return c; - } else if ((a > 0) == (b > 0)) { - result.max_defined = false; - return INT64_MAX; - } else { - result.min_defined = false; - return INT64_MIN; - } - }; - - bool positive = min_defined && min > 0; - bool other_positive = other.min_defined && other.min > 0; - bool bounded = min_defined && max_defined; - bool other_bounded = other.min_defined && other.max_defined; - - if (bounded && other_bounded) { - // Both are bounded - result.min_defined = result.max_defined = true; - int64_t v1 = saturating_mul(min, other.min); - int64_t v2 = saturating_mul(min, other.max); - int64_t v3 = saturating_mul(max, other.min); - int64_t v4 = saturating_mul(max, other.max); - if (result.min_defined) { - result.min = std::min(std::min(v1, v2), std::min(v3, v4)); - } else { - result.min = 0; - } - if (result.max_defined) { - result.max = std::max(std::max(v1, v2), std::max(v3, v4)); - } else { - result.max = 0; - } - } else if ((max_defined && other_bounded && other_positive) || - (other.max_defined && bounded && positive)) { - // One side has a max, and the other side is bounded and positive - // (e.g. a constant). - result.max = saturating_mul(max, other.max); - if (!result.max_defined) { - result.max = 0; - } - } else if ((min_defined && other_bounded && other_positive) || - (other.min_defined && bounded && positive)) { - // One side has a min, and the other side is bounded and positive - // (e.g. a constant). - min = saturating_mul(min, other.min); - if (!result.min_defined) { - result.min = 0; - } - } - // TODO: what about the above two cases, but for multiplication by bounded - // and negative intervals? - - *this = result; -} - -void ConstantInterval::operator/=(const ConstantInterval &other) { - ConstantInterval result; - - result.min = INT64_MAX; - result.max = INT64_MIN; - - // Enumerate all possible values for the min and max and take the extreme values. - if (min_defined && other.min_defined && other.min != 0) { - int64_t v = div_imp(min, other.min); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - if (min_defined && other.max_defined && other.max != 0) { - int64_t v = div_imp(min, other.max); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - if (max_defined && other.max_defined && other.max != 0) { - int64_t v = div_imp(max, other.max); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - if (max_defined && other.min_defined && other.min != 0) { - int64_t v = div_imp(max, other.min); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - // Define an int64_t zero just to pacify std::min and std::max - constexpr int64_t zero = 0; - - const bool other_positive = other.min_defined && other.min > 0; - const bool other_negative = other.max_defined && other.max < 0; - if ((other_positive && !other.max_defined) || - (other_negative && !other.min_defined)) { - // Take limit as other -> +/- infinity - result.min = std::min(result.min, zero); - result.max = std::max(result.max, zero); - } - - bool bounded_numerator = min_defined && max_defined; - - result.min_defined = ((min_defined && other_positive) || - (max_defined && other_negative)); - result.max_defined = ((max_defined && other_positive) || - (min_defined && other_negative)); - - // That's as far as we can get knowing the sign of the - // denominator. For bounded numerators, we additionally know - // that div can't make anything larger in magnitude, so we can - // take the intersection with that. - if (bounded_numerator && min != INT64_MIN) { - int64_t magnitude = std::max(max, -min); - if (result.min_defined) { - result.min = std::max(result.min, -magnitude); - } else { - result.min = -magnitude; - } - if (result.max_defined) { - result.max = std::min(result.max, magnitude); - } else { - result.max = magnitude; - } - result.min_defined = result.max_defined = true; - } - - // Finally we can provide a bound if the numerator and denominator are - // non-positive or non-negative. - bool numerator_non_negative = min_defined && min >= 0; - bool denominator_non_negative = other.min_defined && other.min >= 0; - bool numerator_non_positive = max_defined && max <= 0; - bool denominator_non_positive = other.max_defined && other.max <= 0; - if ((numerator_non_negative && denominator_non_negative) || - (numerator_non_positive && denominator_non_positive)) { - if (result.min_defined) { - result.min = std::max(result.min, zero); - } else { - result.min_defined = true; - result.min = 0; - } - } - if ((numerator_non_negative && denominator_non_positive) || - (numerator_non_positive && denominator_non_negative)) { - if (result.max_defined) { - result.max = std::min(result.max, zero); - } else { - result.max_defined = true; - result.max = 0; - } - } - - // Normalize the values if it's undefined - if (!result.min_defined) { - result.min = 0; - } - if (!result.max_defined) { - result.max = 0; - } - - *this = result; -} - -void ConstantInterval::cast_to(Type t) { - if (!within(t)) { - // We have potential overflow or underflow, return the entire bounds of - // the type. - ConstantInterval type_bounds; - if (t.is_int()) { - if (t.bits() <= 64) { - type_bounds.min_defined = type_bounds.max_defined = true; - type_bounds.min = ((int64_t)(-1)) << (t.bits() - 1); - type_bounds.max = ~type_bounds.min; - } - } else if (t.is_uint()) { - type_bounds.min_defined = true; - type_bounds.min = 0; - if (t.bits() < 64) { - type_bounds.max_defined = true; - type_bounds.max = (((int64_t)(1)) << t.bits()) - 1; - } - } - // If it's not int or uint, we're setting this to a default-constructed - // ConstantInterval, which is everything. - *this = type_bounds; - } -} - -ConstantInterval ConstantInterval::operator-() const { - ConstantInterval result; - if (min_defined && min != INT64_MIN) { - result.max_defined = true; - result.max = -min; - } - if (max_defined) { - result.min_defined = true; - result.min = -max; - } - return result; -} - -ConstantInterval ConstantInterval::bounds_of_type(Type t) { - return cast(t, ConstantInterval::everything()); -} - -ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result += b; - return result; -} - -ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result -= b; - return result; -} - -ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result /= b; - return result; -} - -ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result *= b; - return result; -} - -ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - if (a.min_defined && b.min_defined && b.min < a.min) { - result.min = b.min; - } - if (a.max_defined && b.max_defined && b.max < a.max) { - result.max = b.max; - } - return result; -} - -ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - if (a.min_defined && b.min_defined && b.min > a.min) { - result.min = b.min; - } - if (a.max_defined && b.max_defined && b.max > a.max) { - result.max = b.max; - } - return result; -} - -ConstantInterval abs(const ConstantInterval &a) { - ConstantInterval result; - if (a.min_defined && a.max_defined && a.min != INT64_MIN) { - result.max_defined = true; - result.max = std::max(-a.min, a.max); - } - result.min_defined = true; - if (a.min_defined && a.min > 0) { - result.min = a.min; - } else { - result.min = 0; - } - - return result; -} - -ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b) { - // Try to map this to a multiplication and a division - ConstantInterval mul, div; - constexpr int64_t one = 1; - if (b.min_defined) { - if (b.min >= 0 && b.min < 63) { - mul.min = one << b.min; - mul.min_defined = true; - div.max = one; - div.max_defined = true; - } else if (b.min > -63 && b.min <= 0) { - mul.min = one; - mul.min_defined = true; - div.max = one << (-b.min); - div.max_defined = true; - } - } - if (b.max_defined) { - if (b.max >= 0 && b.max < 63) { - mul.max = one << b.max; - mul.max_defined = true; - div.min = one; - div.min_defined = true; - } else if (b.max > -63 && b.max <= 0) { - mul.max = one; - mul.max_defined = true; - div.min = one << (-b.max); - div.min_defined = true; - } - } - return (a * mul) / div; -} - -ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b) { - return a << (-b); -} - } // namespace Internal - -ConstantInterval cast(Type t, const ConstantInterval &a) { - ConstantInterval result = a; - result.cast_to(t); - return result; -} - -ConstantInterval saturating_cast(Type t, const ConstantInterval &a) { - ConstantInterval b = ConstantInterval::bounds_of_type(t); - - if (b.max_defined && a.min_defined && a.min > b.max) { - return ConstantInterval(b.max, b.max); - } else if (b.min_defined && a.max_defined && a.max < b.min) { - return ConstantInterval(b.min, b.min); - } - - ConstantInterval result = a; - result.max_defined = a.max_defined || b.max_defined; - if (a.max_defined) { - if (b.max_defined) { - result.max = std::min(a.max, b.max); - } else { - result.max = a.max; - } - } else if (b.max_defined) { - result.max = b.max; - } - result.min_defined = a.min_defined || b.min_defined; - if (a.min_defined) { - if (b.min_defined) { - result.min = std::max(a.min, b.min); - } else { - result.min = a.min; - } - } else if (b.min_defined) { - result.min = b.min; - } - return result; -} - } // namespace Halide diff --git a/src/Interval.h b/src/Interval.h index 7fab81e461ba..ccd27341f167 100644 --- a/src/Interval.h +++ b/src/Interval.h @@ -110,114 +110,7 @@ struct Interval { static Expr neg_inf_noinline(); }; -/** A class to represent ranges of integers. Can be unbounded above or below, - * but they cannot be empty. */ -struct ConstantInterval { - /** The lower and upper bound of the interval. They are included - * in the interval. */ - int64_t min = 0, max = 0; - bool min_defined = false, max_defined = false; - - /* A default-constructed Interval is everything */ - ConstantInterval(); - - /** Construct an interval from a lower and upper bound. */ - ConstantInterval(int64_t min, int64_t max); - - /** The interval representing everything. */ - static ConstantInterval everything(); - - /** Construct an interval representing a single point. */ - static ConstantInterval single_point(int64_t x); - - /** Construct intervals bounded above or below. */ - static ConstantInterval bounded_below(int64_t min); - static ConstantInterval bounded_above(int64_t max); - - /** Is the interval the entire range */ - bool is_everything() const; - - /** Is the interval just a single value (min == max) */ - bool is_single_point() const; - - /** Is the interval a particular single value */ - bool is_single_point(int64_t x) const; - - /** Does the interval have a finite least upper bound */ - bool has_upper_bound() const; - - /** Does the interval have a finite greatest lower bound */ - bool has_lower_bound() const; - - /** Does the interval have a finite upper and lower bound */ - bool is_bounded() const; - - /** Expand the interval to include another Interval */ - void include(const ConstantInterval &i); - - /** Expand the interval to include a point */ - void include(int64_t x); - - /** Test if the interval contains a particular value */ - bool contains(int64_t x) const; - - /** Test if the interval lies with a particular type. */ - bool within(Type t) const; - - /** Construct the smallest interval containing two intervals. */ - static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b); - - /** Equivalent to same_as. Exists so that the autoscheduler can - * compare two map for equality in order to - * cache computations. */ - bool operator==(const ConstantInterval &other) const; - - /** In-place versions of the arithmetic operators below. */ - // @{ - void operator+=(const ConstantInterval &other); - void operator-=(const ConstantInterval &other); - void operator*=(const ConstantInterval &other); - void operator/=(const ConstantInterval &other); - // @} - - /** Negate an interval. */ - ConstantInterval operator-() const; - - /** Track what happens if a constant integer interval is forced to fit into - * a concrete integer type. */ - void cast_to(Type t); - - /** Get constant integer bounds on a type. */ - static ConstantInterval bounds_of_type(Type); -}; - -/** Arithmetic operators on ConstantIntervals. The resulting interval contains - * all possible values of the operator applied to any two elements of the - * argument intervals. Note that these operator on unbounded integers. If you - * are applying this to concrete small integer types, you will need to manually - * cast the constant interval back to the desired type to model the effect of - * overflow. */ -// @{ -ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval abs(const ConstantInterval &a); -ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b); -ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b); -// @} } // namespace Internal - -/** Cast operators for ConstantIntervals. These ones have to live out in - * Halide::, to avoid C++ name lookup confusion with the Halide::cast variants - * that take Exprs. */ -// @{ -Internal::ConstantInterval cast(Type t, const Internal::ConstantInterval &a); -Internal::ConstantInterval saturating_cast(Type t, const Internal::ConstantInterval &a); -// @} - } // namespace Halide #endif diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index fee151f00a22..e09358075dae 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -1,6 +1,7 @@ #include "Monotonic.h" -#include "Bounds.h" +#include "ConstantBounds.h" #include "IROperator.h" +#include "IRPrinter.h" #include "IRVisitor.h" #include "Scope.h" #include "Simplify.h" @@ -35,7 +36,7 @@ const int64_t *as_const_int_or_uint(const Expr &e) { if (const int64_t *i = as_const_int(e)) { return i; } else if (const uint64_t *u = as_const_uint(e)) { - if (*u <= (uint64_t)std::numeric_limits::max()) { + if ((*u >> 63) == 0) { return (const int64_t *)u; } } @@ -46,20 +47,12 @@ bool is_constant(const ConstantInterval &a) { return a.is_single_point(0); } -bool may_be_negative(const ConstantInterval &a) { - return !a.has_lower_bound() || a.min < 0; -} - -bool may_be_positive(const ConstantInterval &a) { - return !a.has_upper_bound() || a.max > 0; -} - bool is_monotonic_increasing(const ConstantInterval &a) { - return !may_be_negative(a); + return a.has_lower_bound() && a.min >= 0; } bool is_monotonic_decreasing(const ConstantInterval &a) { - return !may_be_positive(a); + return a.has_upper_bound() && a.max <= 0; } ConstantInterval to_interval(Monotonic m) { @@ -98,143 +91,11 @@ ConstantInterval unify(const ConstantInterval &a, int64_t b) { return result; } -// Helpers for doing arithmetic on ConstantIntervals that avoid generating -// expressions of pos_inf/neg_inf. -ConstantInterval add(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result; - result.min_defined = a.has_lower_bound() && b.has_lower_bound(); - result.max_defined = a.has_upper_bound() && b.has_upper_bound(); - if (result.has_lower_bound()) { - result.min_defined = add_with_overflow(64, a.min, b.min, &result.min); - } - if (result.has_upper_bound()) { - result.max_defined = add_with_overflow(64, a.max, b.max, &result.max); - } - return result; -} - -ConstantInterval add(const ConstantInterval &a, int64_t b) { - return add(a, ConstantInterval(b, b)); -} - -ConstantInterval negate(const ConstantInterval &r) { - ConstantInterval result; - result.min_defined = r.has_upper_bound(); - if (result.min_defined) { - result.min_defined = sub_with_overflow(64, 0, r.max, &result.min); - } - result.max_defined = r.has_lower_bound(); - if (result.max_defined) { - result.max_defined = sub_with_overflow(64, 0, r.min, &result.max); - } - return result; -} - -ConstantInterval sub(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result; - result.min_defined = a.has_lower_bound() && b.has_lower_bound(); - result.max_defined = a.has_upper_bound() && b.has_upper_bound(); - if (result.has_lower_bound()) { - result.min_defined = sub_with_overflow(64, a.min, b.max, &result.min); - } - if (result.has_upper_bound()) { - result.max_defined = sub_with_overflow(64, a.max, b.min, &result.max); - } - return result; -} - -ConstantInterval sub(const ConstantInterval &a, int64_t b) { - return sub(a, ConstantInterval(b, b)); -} - -ConstantInterval multiply(const ConstantInterval &a, int64_t b) { - ConstantInterval result(a); - if (b < 0) { - result = negate(result); - b = -b; - } - if (result.has_lower_bound()) { - result.min *= b; - } - if (result.has_upper_bound()) { - result.max *= b; - } - return result; -} - -ConstantInterval multiply(const ConstantInterval &a, const Expr &b) { - if (const int64_t *bi = as_const_int_or_uint(b)) { - return multiply(a, *bi); - } - return ConstantInterval::everything(); -} - -ConstantInterval multiply(const ConstantInterval &a, const ConstantInterval &b) { - int64_t bounds[4]; - int64_t *bounds_begin = &bounds[0]; - int64_t *bounds_end = &bounds[0]; - bool no_overflow = true; - if (a.has_lower_bound() && b.has_lower_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.min, b.min, bounds_end++); - } - if (a.has_lower_bound() && b.has_upper_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.min, b.max, bounds_end++); - } - if (a.has_upper_bound() && b.has_lower_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.max, b.min, bounds_end++); - } - if (a.has_upper_bound() && b.has_upper_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.max, b.max, bounds_end++); - } - if (no_overflow && (bounds_begin != bounds_end)) { - ConstantInterval result = { - *std::min_element(bounds_begin, bounds_end), - *std::max_element(bounds_begin, bounds_end), - }; - // There *must* be a better way than this... Even - // cutting half the cases with swapping isn't that much help. - if (!a.has_lower_bound()) { - if (may_be_negative(b)) result.max_defined = false; // NOLINT - if (may_be_positive(b)) result.min_defined = false; // NOLINT - } - if (!a.has_upper_bound()) { - if (may_be_negative(b)) result.min_defined = false; // NOLINT - if (may_be_positive(b)) result.max_defined = false; // NOLINT - } - if (!b.has_lower_bound()) { - if (may_be_negative(a)) result.max_defined = false; // NOLINT - if (may_be_positive(a)) result.min_defined = false; // NOLINT - } - if (!b.has_upper_bound()) { - if (may_be_negative(a)) result.min_defined = false; // NOLINT - if (may_be_positive(a)) result.max_defined = false; // NOLINT - } - return result; - } else { - return ConstantInterval::everything(); - } -} - -ConstantInterval divide(const ConstantInterval &a, int64_t b) { - ConstantInterval result(a); - if (b < 0) { - result = negate(result); - b = -b; - } - if (result.has_lower_bound()) { - result.min = div_imp(result.min, b); - } - if (result.has_upper_bound()) { - result.max = div_imp(result.max - 1, b) + 1; - } - return result; -} - class DerivativeBounds : public IRVisitor { const string &var; - Scope scope; - Scope bounds; + // Bounds on the derivatives and values of variables in scope. + Scope derivative_bounds, value_bounds; void visit(const IntImm *) override { result = ConstantInterval::single_point(0); @@ -280,7 +141,7 @@ class DerivativeBounds : public IRVisitor { void visit(const Variable *op) override { if (op->name == var) { result = ConstantInterval::single_point(1); - } else if (const auto *r = scope.find(op->name)) { + } else if (const auto *r = derivative_bounds.find(op->name)) { result = *r; } else { result = ConstantInterval::single_point(0); @@ -291,16 +152,14 @@ class DerivativeBounds : public IRVisitor { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = add(ra, rb); + result += ra; } void visit(const Sub *op) override { - op->a.accept(this); - ConstantInterval ra = result; op->b.accept(this); ConstantInterval rb = result; - result = sub(ra, rb); + op->a.accept(this); + result -= rb; } void visit(const Mul *op) override { @@ -313,9 +172,9 @@ class DerivativeBounds : public IRVisitor { // This is essentially the product rule: a*rb + b*ra // but only implemented for the case where a or b is constant. if (const int64_t *b = as_const_int_or_uint(op->b)) { - result = multiply(ra, *b); + result = ra * (*b); } else if (const int64_t *a = as_const_int_or_uint(op->a)) { - result = multiply(rb, *a); + result = rb * (*a); } else { result = ConstantInterval::everything(); } @@ -326,20 +185,37 @@ class DerivativeBounds : public IRVisitor { void visit(const Div *op) override { if (op->type.is_scalar()) { - op->a.accept(this); - ConstantInterval ra = result; - if (const int64_t *b = as_const_int_or_uint(op->b)) { - result = divide(ra, *b); - } else { - result = ConstantInterval::everything(); + op->a.accept(this); + // We don't just want to divide by b. For the min we want to + // take floor division, and for the max we want to use ceil + // division. + if (*b == 0) { + result = ConstantInterval(0, 0); + } else { + if (result.has_lower_bound()) { + result.min = div_imp(result.min, *b); + } + if (result.has_upper_bound()) { + if (result.max != INT64_MIN) { + result.max = div_imp(result.max - 1, *b) + 1; + } else { + result.max_defined = false; + result.max = 0; + } + } + if (*b < 0) { + result = -result; + } + } + return; } - } else { - result = ConstantInterval::everything(); } + result = ConstantInterval::everything(); } void visit(const Mod *op) override { + // TODO result = ConstantInterval::everything(); } @@ -387,7 +263,7 @@ class DerivativeBounds : public IRVisitor { ConstantInterval ra = result; b.accept(this); ConstantInterval rb = result; - result = unify(negate(ra), rb); + result = unify(-ra, rb); // If the result is bounded, limit it to [-1, 1]. The largest // difference possible is flipping from true to false or false // to true. @@ -433,14 +309,20 @@ class DerivativeBounds : public IRVisitor { void visit(const Not *op) override { op->a.accept(this); - result = negate(result); + result = -result; } void visit(const Select *op) override { - // The result is the unified bounds, added to the "bump" that happens when switching from true to false. + // The result is the unified bounds, added to the "bump" that happens + // when switching from true to false. if (op->type.is_scalar()) { op->condition.accept(this); ConstantInterval rcond = result; + // rcond is: + // [ 0 0] if the condition does not depend on the variable + // [-1, 0] if it changes once from true to false + // [ 0 1] if it changes once from false to true + // [-1, 1] if it could change in either direction op->true_value.accept(this); ConstantInterval ra = result; @@ -450,19 +332,11 @@ class DerivativeBounds : public IRVisitor { // If the condition is not constant, we hit a "bump" when the condition changes value. if (!is_constant(rcond)) { - // TODO: How to handle unsigned values? - Expr delta = simplify(op->true_value - op->false_value); - - Interval delta_bounds = find_constant_bounds(delta, bounds); - // TODO: Maybe we can do something with one-sided intervals? - if (delta_bounds.is_bounded()) { - ConstantInterval delta_low = multiply(rcond, delta_bounds.min); - ConstantInterval delta_high = multiply(rcond, delta_bounds.max); - result = add(result, ConstantInterval::make_union(delta_low, delta_high)); - } else { - // The bump is unbounded. - result = ConstantInterval::everything(); - } + // It's very important to have stripped likelies here, or the + // simplification might not cancel things that it should. + Expr bump = simplify(op->true_value - op->false_value); + ConstantInterval bump_bounds = constant_integer_bounds(bump, value_bounds); + result += rcond * bump_bounds; } } else { result = ConstantInterval::everything(); @@ -493,10 +367,9 @@ class DerivativeBounds : public IRVisitor { return; } - if (op->is_intrinsic(Call::unsafe_promise_clamped) || - op->is_intrinsic(Call::promise_clamped) || - op->is_intrinsic(Call::saturating_cast)) { + if (op->is_intrinsic(Call::saturating_cast)) { op->args[0].accept(this); + result.include(0); return; } @@ -526,14 +399,14 @@ class DerivativeBounds : public IRVisitor { void visit(const Let *op) override { op->value.accept(this); - ScopedBinding bounds_binding(bounds, op->name, find_constant_bounds(op->value, bounds)); - + ScopedBinding vb_binding(value_bounds, op->name, + constant_integer_bounds(op->value, value_bounds)); if (is_constant(result)) { // No point pushing it if it's constant w.r.t the var, // because unknown variables are treated as constant. op->body.accept(this); } else { - ScopedBinding scope_binding(scope, op->name, result); + ScopedBinding db_binding(derivative_bounds, op->name, result); op->body.accept(this); } } @@ -554,7 +427,7 @@ class DerivativeBounds : public IRVisitor { switch (op->op) { case VectorReduce::Add: case VectorReduce::SaturatingAdd: - result = multiply(result, op->value.type().lanes() / op->type.lanes()); + result *= op->value.type().lanes() / op->type.lanes(); break; case VectorReduce::Min: case VectorReduce::Max: @@ -643,7 +516,7 @@ class DerivativeBounds : public IRVisitor { DerivativeBounds(const std::string &v, const Scope &parent) : var(v), result(ConstantInterval::everything()) { - scope.set_containing_scope(&parent); + derivative_bounds.set_containing_scope(&parent); } }; @@ -655,6 +528,7 @@ ConstantInterval derivative_bounds(const Expr &e, const std::string &var, const } DerivativeBounds m(var, scope); remove_likelies(remove_promises(e)).accept(&m); + debug(0) << "Derivative bounds of " << e << " w.r.t. " << var << ": " << m.result << "\n"; return m.result; } diff --git a/src/Monotonic.h b/src/Monotonic.h index 3d7946a13ed7..4fe6b3aed57f 100644 --- a/src/Monotonic.h +++ b/src/Monotonic.h @@ -8,13 +8,14 @@ #include #include -#include "Interval.h" +#include "ConstantBounds.h" #include "Scope.h" namespace Halide { namespace Internal { -/** Find the bounds of the derivative of an expression. */ +/** Find the bounds of the derivative of an expression. The scope gives the + * bounds on the derivatives of any variables found. */ ConstantInterval derivative_bounds(const Expr &e, const std::string &var, const Scope &scope = Scope::empty_scope()); diff --git a/src/Type.cpp b/src/Type.cpp index 64414fa04eca..1cd95e0a6b01 100644 --- a/src/Type.cpp +++ b/src/Type.cpp @@ -1,3 +1,4 @@ +#include "ConstantBounds.h" #include "IR.h" #include #include @@ -126,6 +127,10 @@ bool Type::can_represent(Type other) const { } } +bool Type::can_represent(const Internal::ConstantInterval &in) const { + return in.is_bounded() && can_represent(in.min) && can_represent(in.max); +} + bool Type::can_represent(int64_t x) const { if (is_int()) { return x >= min_int(bits()) && x <= max_int(bits()); diff --git a/src/Type.h b/src/Type.h index af5447350810..e6cbb554dd77 100644 --- a/src/Type.h +++ b/src/Type.h @@ -266,6 +266,10 @@ struct halide_handle_traits { namespace Halide { +namespace Internal { +struct ConstantInterval; +} + struct Expr; /** Types in the halide type system. They can be ints, unsigned ints, @@ -501,6 +505,10 @@ struct Type { /** Can this type represent all values of another type? */ bool can_represent(Type other) const; + /** Can this type represent exactly all integer values of some constant + * integer range? */ + bool can_represent(const Internal::ConstantInterval &in) const; + /** Can this type represent a particular constant? */ // @{ bool can_represent(double x) const; From bee38ce5c724701dff0445018bcd352f40ca5419 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 25 Mar 2024 14:37:54 -0700 Subject: [PATCH 06/33] Make the simplifier use ConstantInterval --- src/ConstantInterval.cpp | 121 +++++++++++++++++++++++++--- src/IRPrinter.cpp | 6 ++ src/IRPrinter.h | 4 + src/Simplify.cpp | 66 +++++++-------- src/Simplify.h | 9 ++- src/Simplify_Add.cpp | 27 +++---- src/Simplify_And.cpp | 4 +- src/Simplify_Call.cpp | 95 ++++++++++------------ src/Simplify_Cast.cpp | 60 +++++++------- src/Simplify_Div.cpp | 109 +++++-------------------- src/Simplify_EQ.cpp | 22 +++-- src/Simplify_Exprs.cpp | 151 ++++++++++++++++------------------- src/Simplify_Internal.h | 129 ++++++++++++------------------ src/Simplify_LT.cpp | 31 +++---- src/Simplify_Let.cpp | 29 ++++--- src/Simplify_Max.cpp | 57 ++++++------- src/Simplify_Min.cpp | 52 +++++------- src/Simplify_Mod.cpp | 56 ++++--------- src/Simplify_Mul.cpp | 56 +++---------- src/Simplify_Not.cpp | 4 +- src/Simplify_Or.cpp | 4 +- src/Simplify_Reinterpret.cpp | 4 +- src/Simplify_Select.cpp | 23 +++--- src/Simplify_Shuffle.cpp | 23 +++--- src/Simplify_Stmts.cpp | 71 ++++++---------- src/Simplify_Sub.cpp | 23 +++--- 26 files changed, 562 insertions(+), 674 deletions(-) diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp index 03fbb692f540..7be859e63265 100644 --- a/src/ConstantInterval.cpp +++ b/src/ConstantInterval.cpp @@ -97,6 +97,36 @@ ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const C return result; } +ConstantInterval ConstantInterval::make_intersection(const ConstantInterval &a, + const ConstantInterval &b) { + ConstantInterval result; + if (a.min_defined) { + if (b.min_defined) { + result.min = std::max(a.min, b.min); + } else { + result.min = a.min; + } + result.min_defined = true; + } else { + result.min_defined = b.min_defined; + result.min = b.min; + } + if (a.max_defined) { + if (b.max_defined) { + result.max = std::min(a.max, b.max); + } else { + result.max = a.max; + } + result.max_defined = true; + } else { + result.max_defined = b.max_defined; + result.max = b.max; + } + internal_assert(!result.is_bounded() || result.min <= result.max) + << "Empty ConstantInterval constructed in make_intersection"; + return result; +} + // TODO: These were taken directly from the simplifier, so change the simplifier // to use these instead of duplicating the code. void ConstantInterval::operator+=(const ConstantInterval &other) { @@ -305,6 +335,27 @@ void ConstantInterval::operator/=(int64_t x) { *this /= ConstantInterval(x, x); } +bool operator<=(const ConstantInterval &a, const ConstantInterval &b) { + return a.max_defined && b.min_defined && a.max <= b.min; +} +bool operator<(const ConstantInterval &a, const ConstantInterval &b) { + return a.max_defined && b.min_defined && a.max < b.min; +} + +bool operator<=(const ConstantInterval &a, int64_t b) { + return a.max_defined && a.max <= b; +} +bool operator<(const ConstantInterval &a, int64_t b) { + return a.max_defined && a.max < b; +} + +bool operator<=(int64_t a, const ConstantInterval &b) { + return b.min_defined && a <= b.min; +} +bool operator<(int64_t a, const ConstantInterval &b) { + return b.min_defined && a < b.min; +} + void ConstantInterval::cast_to(const Type &t) { if (!t.can_represent(*this)) { // We have potential overflow or underflow, return the entire bounds of @@ -371,6 +422,42 @@ ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) return result; } +ConstantInterval operator%(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + + // Maybe the mod won't actually do anything + if (a >= 0 && a < b) { + return a; + } + + // The result is at least zero. + result.min_defined = true; + result.min = 0; + + // Mod by produces a result between 0 + // and max(0, abs(modulus) - 1). However, if b is unbounded in + // either direction, abs(modulus) could be arbitrarily + // large. + if (b.is_bounded()) { + result.max_defined = true; + result.max = 0; // When b == 0 + result.max = std::max(result.max, b.max - 1); // When b > 0 + result.max = std::max(result.max, -1 - b.min); // When b < 0 + } + + // If a is positive, mod can't make it larger + if (a.is_bounded() && a.min >= 0) { + if (result.max_defined) { + result.max = std::min(result.max, a.max); + } else { + result.max_defined = true; + result.max = a.max; + } + } + + return result; +} + ConstantInterval operator+(const ConstantInterval &a, int64_t b) { return a + ConstantInterval(b, b); } @@ -387,24 +474,40 @@ ConstantInterval operator*(const ConstantInterval &a, int64_t b) { return a * ConstantInterval(b, b); } +ConstantInterval operator%(const ConstantInterval &a, int64_t b) { + return a * ConstantInterval(b, b); +} + ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - if (a.min_defined && b.min_defined && b.min < a.min) { - result.min = b.min; - } - if (a.max_defined && b.max_defined && b.max < a.max) { + ConstantInterval result; + result.max_defined = a.max_defined || b.max_defined; + result.min_defined = a.min_defined && b.min_defined; + if (a.max_defined && b.max_defined) { + result.max = std::min(a.max, b.max); + } else if (a.max_defined) { + result.max = a.max; + } else if (b.max_defined) { result.max = b.max; } + if (a.min_defined && b.min_defined) { + result.min = std::min(a.min, b.min); + } return result; } ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - if (a.min_defined && b.min_defined && b.min > a.min) { + ConstantInterval result; + result.min_defined = a.min_defined || b.min_defined; + result.max_defined = a.max_defined && b.max_defined; + if (a.min_defined && b.min_defined) { + result.min = std::max(a.min, b.min); + } else if (a.min_defined) { + result.min = a.min; + } else if (b.min_defined) { result.min = b.min; } - if (a.max_defined && b.max_defined && b.max > a.max) { - result.max = b.max; + if (a.max_defined && b.max_defined) { + result.max = std::max(a.max, b.max); } return result; } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 5d4c43304eb0..7719eccfc489 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -10,6 +10,7 @@ #include "IROperator.h" #include "Interval.h" #include "Module.h" +#include "ModulusRemainder.h" #include "Target.h" #include "Util.h" @@ -477,6 +478,11 @@ std::ostream &operator<<(std::ostream &out, const ConstantInterval &c) { return out; } +std::ostream &operator<<(std::ostream &out, const ModulusRemainder &c) { + out << "(mod: " << c.modulus << " rem: " << c.remainder << ")"; + return out; +} + IRPrinter::IRPrinter(ostream &s) : stream(s) { s.setf(std::ios::fixed, std::ios::floatfield); diff --git a/src/IRPrinter.h b/src/IRPrinter.h index 161960077b93..6addbbd7c771 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -60,6 +60,7 @@ struct AssociativeOp; class Closure; struct Interval; struct ConstantInterval; +struct ModulusRemainder; /** Emit a halide associative pattern on an output stream (such as std::cout) * in a human-readable form */ @@ -101,6 +102,9 @@ std::ostream &operator<<(std::ostream &out, const Interval &c); /** Emit a ConstantInterval in human-readable form */ std::ostream &operator<<(std::ostream &out, const ConstantInterval &c); +/** Emit a ModulusRemainder in human-readable form */ +std::ostream &operator<<(std::ostream &out, const ModulusRemainder &c); + struct Indentation { int indent; }; diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 61cf7886cb70..c73b9b43f4e6 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -24,22 +24,24 @@ Simplify::Simplify(bool r, const Scope *bi, const Scopecbegin(); iter != bi->cend(); ++iter) { - ExprInfo bounds; + ExprInfo info; if (const int64_t *i_min = as_const_int(iter.value().min)) { - bounds.min_defined = true; - bounds.min = *i_min; + info.bounds.min_defined = true; + info.bounds.min = *i_min; } if (const int64_t *i_max = as_const_int(iter.value().max)) { - bounds.max_defined = true; - bounds.max = *i_max; + info.bounds.max_defined = true; + info.bounds.max = *i_max; } if (const auto *a = ai->find(iter.name())) { - bounds.alignment = *a; + info.alignment = *a; } - if (bounds.min_defined || bounds.max_defined || bounds.alignment.modulus != 1) { - bounds_and_alignment_info.push(iter.name(), bounds); + if (info.bounds.has_lower_bound() || + info.bounds.has_upper_bound() || + info.alignment.modulus != 1) { + bounds_and_alignment_info.push(iter.name(), info); } } @@ -48,20 +50,20 @@ Simplify::Simplify(bool r, const Scope *bi, const Scope, bool> Simplify::mutate_with_changes(const std::vector &old_exprs, ExprInfo *bounds) { +std::pair, bool> Simplify::mutate_with_changes(const std::vector &old_exprs) { vector new_exprs(old_exprs.size()); bool changed = false; // Mutate the args for (size_t i = 0; i < old_exprs.size(); i++) { const Expr &old_e = old_exprs[i]; - Expr new_e = mutate(old_e, bounds); + Expr new_e = mutate(old_e, nullptr); if (!new_e.same_as(old_e)) { changed = true; } @@ -135,17 +137,17 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::ExprInfo i; if (v) { simplify->mutate(lt->b, &i); - if (i.min_defined) { + if (i.bounds.has_lower_bound()) { // !(v < i) - learn_lower_bound(v, i.min); + learn_lower_bound(v, i.bounds.min); } } v = lt->b.as(); if (v) { simplify->mutate(lt->a, &i); - if (i.max_defined) { + if (i.bounds.has_upper_bound()) { // !(i < v) - learn_upper_bound(v, i.max); + learn_upper_bound(v, i.bounds.max); } } } else if (const LE *le = fact.as()) { @@ -153,17 +155,17 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::ExprInfo i; if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(le->b, &i); - if (i.min_defined) { + if (i.bounds.has_lower_bound()) { // !(v <= i) - learn_lower_bound(v, i.min + 1); + learn_lower_bound(v, i.bounds.min + 1); } } v = le->b.as(); if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(le->a, &i); - if (i.max_defined) { + if (i.bounds.has_upper_bound()) { // !(i <= v) - learn_upper_bound(v, i.max - 1); + learn_upper_bound(v, i.bounds.max - 1); } } } else if (const Call *c = Call::as_tag(fact)) { @@ -185,8 +187,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) { ExprInfo b; - b.max_defined = true; - b.max = val; + b.bounds = ConstantInterval::bounded_above(val); if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { b.intersect(*info); } @@ -196,8 +197,7 @@ void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) { void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) { ExprInfo b; - b.min_defined = true; - b.min = val; + b.bounds = ConstantInterval::bounded_below(val); if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { b.intersect(*info); } @@ -267,17 +267,17 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo i; if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(lt->b, &i); - if (i.max_defined) { + if (i.bounds.has_upper_bound()) { // v < i - learn_upper_bound(v, i.max - 1); + learn_upper_bound(v, i.bounds.max - 1); } } v = lt->b.as(); if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(lt->a, &i); - if (i.min_defined) { + if (i.bounds.has_lower_bound()) { // i < v - learn_lower_bound(v, i.min + 1); + learn_lower_bound(v, i.bounds.min + 1); } } } else if (const LE *le = fact.as()) { @@ -285,17 +285,17 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo i; if (v) { simplify->mutate(le->b, &i); - if (i.max_defined) { + if (i.bounds.has_upper_bound()) { // v <= i - learn_upper_bound(v, i.max); + learn_upper_bound(v, i.bounds.max); } } v = le->b.as(); if (v) { simplify->mutate(le->a, &i); - if (i.min_defined) { + if (i.bounds.has_lower_bound()) { // i <= v - learn_lower_bound(v, i.min); + learn_lower_bound(v, i.bounds.min); } } } else if (const Call *c = Call::as_tag(fact)) { diff --git a/src/Simplify.h b/src/Simplify.h index b9335c0c3de9..d3d7dda7701d 100644 --- a/src/Simplify.h +++ b/src/Simplify.h @@ -21,11 +21,16 @@ namespace Internal { * Exprs that should be assumed to be true. */ // @{ -Stmt simplify(const Stmt &, bool remove_dead_code = true, + +// TODO: Change the interface to accept a scope of ConstantInterval + +Stmt simplify(const Stmt &, + bool remove_dead_code = true, const Scope &bounds = Scope::empty_scope(), const Scope &alignment = Scope::empty_scope(), const std::vector &assumptions = std::vector()); -Expr simplify(const Expr &, bool remove_dead_code = true, +Expr simplify(const Expr &, + bool remove_dead_code = true, const Scope &bounds = Scope::empty_scope(), const Scope &alignment = Scope::empty_scope(), const std::vector &assumptions = std::vector()); diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index fb9238dd9a6a..4efc7e4b9fcb 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -3,20 +3,15 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Add *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds && no_overflow_int(op->type)) { - bounds->min_defined = a_bounds.min_defined && - b_bounds.min_defined && - add_with_overflow(64, a_bounds.min, b_bounds.min, &(bounds->min)); - bounds->max_defined = a_bounds.max_defined && - b_bounds.max_defined && - add_with_overflow(64, a_bounds.max, b_bounds.max, &(bounds->max)); - bounds->alignment = a_bounds.alignment + b_bounds.alignment; - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Add *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info && no_overflow_int(op->type)) { + info->bounds = a_info.bounds + b_info.bounds; + info->alignment = a_info.alignment + b_info.alignment; + info->trim_bounds_using_alignment(); } if (may_simplify(op->type)) { @@ -24,7 +19,7 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } auto rewrite = IRMatcher::rewriter(IRMatcher::add(a, b), op->type); @@ -194,7 +189,7 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { rewrite(x + (y + (c0 - x)/c1)*c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) || false)))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_And.cpp b/src/Simplify_And.cpp index 35bbd5f7f747..a6f7e82c9095 100644 --- a/src/Simplify_And.cpp +++ b/src/Simplify_And.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const And *op, ExprInfo *bounds) { +Expr Simplify::visit(const And *op, ExprInfo *info) { if (falsehoods.count(op)) { return const_false(op->type.lanes()); } @@ -109,7 +109,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) { rewrite(x <= y && x <= z, x <= min(y, z)) || rewrite(y <= x && z <= x, max(y, z) <= x)) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (a.same_as(op->a) && diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 66d12f0efc4f..609a156f9aea 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -49,7 +49,7 @@ Expr lift_elementwise_broadcasts(Type type, const std::string &name, std::vector } // namespace -Expr Simplify::visit(const Call *op, ExprInfo *bounds) { +Expr Simplify::visit(const Call *op, ExprInfo *info) { // Calls implicitly depend on host, dev, mins, and strides of the buffer referenced if (op->call_type == Call::Image || op->call_type == Call::Halide) { found_buffer_reference(op->name, op->args.size()); @@ -79,7 +79,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } uint64_t ua = 0; @@ -123,7 +123,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } const Type t = op->type; @@ -132,9 +132,9 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { std::string result_op = op->name; // If we know the sign of this shift, change it to an unsigned shift. - if (b_info.min_defined && b_info.min >= 0) { + if (b_info.bounds >= 0) { b = mutate(cast(b.type().with_code(halide_type_uint), b), nullptr); - } else if (b.type().is_int() && b_info.max_defined && b_info.max <= 0) { + } else if (b.type().is_int() && b_info.bounds <= 0) { result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); b = mutate(cast(b.type().with_code(halide_type_uint), -b), nullptr); } @@ -145,24 +145,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // LLVM shl and shr instructions produce poison for // shifts >= typesize, so we will follow suit in our simplifier. if (ub >= (uint64_t)(t.bits())) { - clear_bounds_info(bounds); + clear_bounds_info(info); return make_signed_integer_overflow(t); } if (a.type().is_uint() || ub < ((uint64_t)t.bits() - 1)) { b = make_const(t, ((int64_t)1LL) << ub); if (result_op == Call::get_intrinsic_name(Call::shift_left)) { - return mutate(Mul::make(a, b), bounds); + return mutate(Mul::make(a, b), info); } else { - return mutate(Div::make(a, b), bounds); + return mutate(Div::make(a, b), info); } } else { // For signed types, (1 << (t.bits() - 1)) will overflow into the sign bit while // (-32768 >> (t.bits() - 1)) propagates the sign bit, making decomposition // into mul or div problematic, so just special-case them here. if (result_op == Call::get_intrinsic_name(Call::shift_left)) { - return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << ub), make_zero(t)), bounds); + return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << ub), make_zero(t)), info); } else { - return mutate(select(a < 0, make_const(t, -1), make_zero(t)), bounds); + return mutate(select(a < 0, make_const(t, -1), make_zero(t)), info); } } } @@ -173,7 +173,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { if (is_const_zero(sub->a)) { result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); b = sub->b; - return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), bounds); + return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), info); } } } @@ -190,7 +190,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia, ib = 0; @@ -227,7 +227,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia, ib; @@ -248,7 +248,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia; @@ -268,7 +268,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } int64_t ia, ib; @@ -286,12 +286,12 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } else if (op->is_intrinsic(Call::abs)) { // Constant evaluate abs(x). - ExprInfo a_bounds; - Expr a = mutate(op->args[0], &a_bounds); + ExprInfo a_info; + Expr a = mutate(op->args[0], &a_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } Type ta = a.type(); @@ -310,9 +310,9 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { fa = -fa; } return make_const(a.type(), fa); - } else if (a.type().is_int() && a_bounds.min_defined && a_bounds.min >= 0) { + } else if (a.type().is_int() && a_info.bounds >= 0) { return cast(op->type, a); - } else if (a.type().is_int() && a_bounds.max_defined && a_bounds.max <= 0) { + } else if (a.type().is_int() && a_info.bounds <= 0) { return cast(op->type, -a); } else if (a.same_as(op->args[0])) { return op; @@ -321,13 +321,13 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } else if (op->is_intrinsic(Call::absd)) { // Constant evaluate absd(a, b). - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->args[0], &a_bounds); - Expr b = mutate(op->args[1], &b_bounds); + ExprInfo a_info, b_info; + Expr a = mutate(op->args[0], &a_info); + Expr b = mutate(op->args[1], &b_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { - return mutate(unbroadcast, bounds); + return mutate(unbroadcast, info); } Type ta = a.type(); @@ -355,14 +355,12 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } else if (op->is_intrinsic(Call::saturating_cast)) { internal_assert(op->args.size() == 1); - ExprInfo a_bounds; - Expr a = mutate(op->args[0], &a_bounds); - - // TODO(rootjalex): We could be intelligent about using a_bounds to remove saturating_casts; + ExprInfo a_info; + Expr a = mutate(op->args[0], &a_info); if (is_const(a)) { a = lower_saturating_cast(op->type, a); - return mutate(a, bounds); + return mutate(a, info); } else if (!a.same_as(op->args[0])) { return saturating_cast(op->type, a); } else { @@ -424,7 +422,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { internal_assert(op->args.size() % 2 == 0); // Prefetch: {base, offset, extent0, stride0, ...} - auto [args, changed] = mutate_with_changes(op->args, nullptr); + auto [args, changed] = mutate_with_changes(op->args); // The {extent, stride} args in the prefetch call are sorted // based on the storage dimension in ascending order (i.e. innermost @@ -478,7 +476,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { { // Can assume the condition is true when evaluating the value. auto t = scoped_truth(cond); - result = mutate(op->args[1], bounds); + result = mutate(op->args[1], info); } if (is_const_one(cond)) { @@ -511,12 +509,8 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { const Broadcast *b_lower = lower.as(); const Broadcast *b_upper = upper.as(); - if (arg_info.min_defined && - arg_info.max_defined && - lower_info.max_defined && - upper_info.min_defined && - arg_info.min >= lower_info.max && - arg_info.max <= upper_info.min) { + if (arg_info.bounds >= lower_info.bounds && + arg_info.bounds <= upper_info.bounds) { return arg; } else if (b_arg && b_lower && b_upper) { // Move broadcasts outwards @@ -537,7 +531,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } else if (Call::as_tag(op)) { // The bounds of the result are the bounds of the arg internal_assert(op->args.size() == 1); - Expr arg = mutate(op->args[0], bounds); + Expr arg = mutate(op->args[0], info); if (arg.same_as(op->args[0])) { return op; } else { @@ -557,12 +551,12 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } if (is_const_one(cond)) { - return mutate(op->args[1], bounds); + return mutate(op->args[1], info); } else if (is_const_zero(cond)) { if (op->args.size() == 3) { - return mutate(op->args[2], bounds); + return mutate(op->args[2], info); } else { - return mutate(make_zero(op->type), bounds); + return mutate(make_zero(op->type), info); } } else { Expr true_value = mutate(op->args[1], nullptr); @@ -598,21 +592,20 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { int num_values = (int)op->args.size() - 1; if (num_values == 1) { // Mux of a single value - return mutate(op->args[1], bounds); + return mutate(op->args[1], info); } ExprInfo index_info; Expr index = mutate(op->args[0], &index_info); // Check if the mux has statically resolved - if (index_info.min_defined && - index_info.max_defined && - index_info.min == index_info.max) { - if (index_info.min >= 0 && index_info.min < num_values) { + if (index_info.bounds.is_single_point()) { + int64_t v = index_info.bounds.min; + if (v >= 0 && v < num_values) { // In-range, return the (simplified) corresponding value. - return mutate(op->args[index_info.min + 1], bounds); + return mutate(op->args[v + 1], info); } else { // It's out-of-range, so return the last value. - return mutate(op->args.back(), bounds); + return mutate(op->args.back(), info); } } @@ -780,14 +773,14 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // just fall thru and take the general case. debug(2) << "Simplifier: unhandled PureExtern: " << op->name << "\n"; } else if (op->is_intrinsic(Call::signed_integer_overflow)) { - clear_bounds_info(bounds); + clear_bounds_info(info); } else if (op->is_intrinsic(Call::concat_bits) && op->args.size() == 1) { - return mutate(op->args[0], bounds); + return mutate(op->args[0], info); } // No else: we want to fall thru from the PureExtern clause. { - auto [new_args, changed] = mutate_with_changes(op->args, nullptr); + auto [new_args, changed] = mutate_with_changes(op->args); if (!changed) { return op; } else { diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 4e689212aaa0..f64c55f68640 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -3,31 +3,29 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { - Expr value = mutate(op->value, bounds); +Expr Simplify::visit(const Cast *op, ExprInfo *info) { + Expr value = mutate(op->value, info); - if (bounds) { - // If either the min value or the max value can't be represented - // in the destination type, or the min/max value is undefined, - // the bounds need to be cleared (one-sided for no_overflow, - // both sides for overflow types). - if ((bounds->min_defined && !op->type.can_represent(bounds->min)) || !bounds->min_defined) { - bounds->min_defined = false; - if (!no_overflow(op->type)) { - // If the type overflows, this invalidates the max too. - bounds->max_defined = false; + if (info) { + if (no_overflow(op->type)) { + // If there's overflow in a no-overflow type (e.g. due to casting + // from a UInt(64) to an Int(32), then set the corresponding bound + // to infinity. + if (info->bounds.has_upper_bound() && !op->type.can_represent(info->bounds.max)) { + info->bounds.max_defined = false; + info->bounds.max = 0; } - } - if ((bounds->max_defined && !op->type.can_represent(bounds->max)) || !bounds->max_defined) { - if (!no_overflow(op->type)) { - // If the type overflows, this invalidates the min too. - bounds->min_defined = false; + if (info->bounds.has_lower_bound() && !op->type.can_represent(info->bounds.min)) { + info->bounds.min_defined = false; + info->bounds.min = 0; } - bounds->max_defined = false; + } else { + info->bounds.cast_to(op->type); } - if (!op->type.can_represent(bounds->alignment.modulus) || - !op->type.can_represent(bounds->alignment.remainder)) { - bounds->alignment = ModulusRemainder(); + + if (!op->type.can_represent(info->alignment.modulus) || + !op->type.can_represent(info->alignment.remainder)) { + info->alignment = ModulusRemainder(); } } @@ -39,7 +37,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { int64_t i = 0; uint64_t u = 0; if (Call::as_intrinsic(value, {Call::signed_integer_overflow})) { - clear_bounds_info(bounds); + clear_bounds_info(info); return make_signed_integer_overflow(op->type); } else if (value.type() == op->type) { return value; @@ -48,7 +46,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { std::isfinite(f)) { // float -> int // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(f)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(f)), info); } else if (op->type.is_uint() && const_float(value, &f) && std::isfinite(f)) { @@ -62,7 +60,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { const_int(value, &i)) { // int -> int // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, i), bounds); + return mutate(make_const(op->type, i), info); } else if (op->type.is_uint() && const_int(value, &i)) { // int -> uint @@ -76,13 +74,13 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { op->type.bits() < value.type().bits()) { // uint -> int narrowing // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(u)), info); } else if (op->type.is_int() && const_uint(value, &u) && op->type.bits() == value.type().bits()) { // uint -> int reinterpret // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(u)), info); } else if (op->type.is_int() && const_uint(value, &u) && op->type.bits() > value.type().bits()) { @@ -90,7 +88,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { if (op->type.can_represent(u) || op->type.bits() < 32) { // If the type can represent the value or overflow is well-defined. // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), bounds); + return mutate(make_const(op->type, safe_numeric_cast(u)), info); } else { return make_signed_integer_overflow(op->type); } @@ -108,7 +106,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { // If this is a cast of a cast of the same type, where the // outer cast is narrower, the inner cast can be // eliminated. - return mutate(Cast::make(op->type, cast->value), bounds); + return mutate(Cast::make(op->type, cast->value), info); } else if (cast && (op->type.is_int() || op->type.is_uint()) && (cast->type.is_int() || cast->type.is_uint()) && @@ -119,10 +117,10 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { // inner cast's argument, the inner cast can be // eliminated. The inner cast is either a sign extend // or a zero extend, and the outer cast truncates the extended bits - return mutate(Cast::make(op->type, cast->value), bounds); + return mutate(Cast::make(op->type, cast->value), info); } else if (broadcast_value) { // cast(broadcast(x)) -> broadcast(cast(x)) - return mutate(Broadcast::make(Cast::make(op->type.with_lanes(broadcast_value->value.type().lanes()), broadcast_value->value), broadcast_value->lanes), bounds); + return mutate(Broadcast::make(Cast::make(op->type.with_lanes(broadcast_value->value.type().lanes()), broadcast_value->value), broadcast_value->lanes), info); } else if (ramp_value && op->type.element_of() == Int(64) && op->value.type().element_of() == Int(32)) { @@ -132,7 +130,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *bounds) { Cast::make(op->type.with_lanes(ramp_value->stride.type().lanes()), ramp_value->stride), ramp_value->lanes), - bounds); + info); } } diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index 49f98837404c..fecd381545cc 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -3,83 +3,25 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Div *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds && no_overflow_int(op->type)) { - bounds->min = INT64_MAX; - bounds->max = INT64_MIN; - - // Enumerate all possible values for the min and max and take the extreme values. - if (a_bounds.min_defined && b_bounds.min_defined && b_bounds.min != 0) { - int64_t v = div_imp(a_bounds.min, b_bounds.min); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - if (a_bounds.min_defined && b_bounds.max_defined && b_bounds.max != 0) { - int64_t v = div_imp(a_bounds.min, b_bounds.max); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - if (a_bounds.max_defined && b_bounds.max_defined && b_bounds.max != 0) { - int64_t v = div_imp(a_bounds.max, b_bounds.max); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - if (a_bounds.max_defined && b_bounds.min_defined && b_bounds.min != 0) { - int64_t v = div_imp(a_bounds.max, b_bounds.min); - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } - - const bool b_positive = b_bounds.min_defined && b_bounds.min > 0; - const bool b_negative = b_bounds.max_defined && b_bounds.max < 0; +Expr Simplify::visit(const Div *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); - if ((b_positive && !b_bounds.max_defined) || - (b_negative && !b_bounds.min_defined)) { - // Take limit as b -> +/- infinity - int64_t v = 0; - bounds->min = std::min(bounds->min, v); - bounds->max = std::max(bounds->max, v); - } + if (info && no_overflow_int(op->type)) { + info->bounds = a_info.bounds / b_info.bounds; + info->alignment = a_info.alignment / b_info.alignment; + info->trim_bounds_using_alignment(); - bounds->min_defined = ((a_bounds.min_defined && b_positive) || - (a_bounds.max_defined && b_negative)); - bounds->max_defined = ((a_bounds.max_defined && b_positive) || - (a_bounds.min_defined && b_negative)); - - // That's as far as we can get knowing the sign of the - // denominator. For bounded numerators, we additionally know - // that div can't make anything larger in magnitude, so we can - // take the intersection with that. - if (a_bounds.max_defined && a_bounds.min_defined) { - int64_t v = std::max(a_bounds.max, -a_bounds.min); - if (bounds->min_defined) { - bounds->min = std::max(bounds->min, -v); - } else { - bounds->min = -v; - } - if (bounds->max_defined) { - bounds->max = std::min(bounds->max, v); - } else { - bounds->max = v; - } - bounds->min_defined = bounds->max_defined = true; - } + // TODO: add test case which resolves to a scalar, but only after + // trimming using the alignment. // Bounded numerator divided by constantish // denominator can sometimes collapse things to a // constant at this point - if (bounds->min_defined && - bounds->max_defined && - bounds->max == bounds->min) { - if (op->type.can_represent(bounds->min)) { - return make_const(op->type, bounds->min); + if (info->bounds.is_single_point()) { + if (op->type.can_represent(info->bounds.min)) { + return make_const(op->type, info->bounds.min); } else { // Even though this is 'no-overflow-int', if the result // we calculate can't fit into the destination type, @@ -87,28 +29,17 @@ Expr Simplify::visit(const Div *op, ExprInfo *bounds) { // a known-wrong value. (Note that no_overflow_int() should // only be true for signed integers.) internal_assert(op->type.is_int()); - clear_bounds_info(bounds); + clear_bounds_info(info); return make_signed_integer_overflow(op->type); } } - // Code downstream can use min/max in calculated-but-unused arithmetic - // that can lead to UB (and thus, flaky failures under ASAN/UBSAN) - // if we leave them set to INT64_MAX/INT64_MIN; normalize to zero to avoid this. - if (!bounds->min_defined) { - bounds->min = 0; - } - if (!bounds->max_defined) { - bounds->max = 0; - } - bounds->alignment = a_bounds.alignment / b_bounds.alignment; - bounds->trim_bounds_using_alignment(); } bool denominator_non_zero = (no_overflow_int(op->type) && - ((b_bounds.min_defined && b_bounds.min > 0) || - (b_bounds.max_defined && b_bounds.max < 0) || - (b_bounds.alignment.remainder != 0))); + (b_info.bounds < 0 || + b_info.bounds > 0 || + b_info.alignment.remainder != 0)); if (may_simplify(op->type)) { @@ -126,8 +57,8 @@ Expr Simplify::visit(const Div *op, ExprInfo *bounds) { return rewrite.result; } - int a_mod = a_bounds.alignment.modulus; - int a_rem = a_bounds.alignment.remainder; + int a_mod = a_info.alignment.modulus; + int a_rem = a_info.alignment.remainder; // clang-format off if (EVAL_IN_LAMBDA @@ -272,7 +203,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *bounds) { c2 > 0 && c0 % c2 == 0) || // A very specific pattern that comes up in bounds in upsampling code. rewrite((x % 2 + c0) / 2, x % 2 + fold(c0 / 2), c0 % 2 == 1))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_EQ.cpp b/src/Simplify_EQ.cpp index 13b49a90886c..4f9b539e5269 100644 --- a/src/Simplify_EQ.cpp +++ b/src/Simplify_EQ.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { +Expr Simplify::visit(const EQ *op, ExprInfo *info) { if (truths.count(op)) { return const_true(op->type.lanes()); } else if (falsehoods.count(op)) { @@ -31,7 +31,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { if (rewrite(x == 1, x)) { return rewrite.result; } else if (rewrite(x == 0, !x)) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } else if (rewrite(x == x, const_true(lanes))) { return rewrite.result; } else if (a.same_as(op->a) && b.same_as(op->b)) { @@ -41,8 +41,8 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { } } - ExprInfo delta_bounds; - Expr delta = mutate(op->a - op->b, &delta_bounds); + ExprInfo delta_info; + Expr delta = mutate(op->a - op->b, &delta_info); const int lanes = op->type.lanes(); // If the delta is 0, then it's just x == x @@ -51,16 +51,12 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { } // Attempt to disprove using bounds analysis - if (delta_bounds.min_defined && delta_bounds.min > 0) { - return const_false(lanes); - } - - if (delta_bounds.max_defined && delta_bounds.max < 0) { + if (!delta_info.bounds.contains(0)) { return const_false(lanes); } // Attempt to disprove using modulus remainder analysis - if (delta_bounds.alignment.remainder != 0) { + if (delta_info.alignment.remainder != 0) { return const_false(lanes); } @@ -109,7 +105,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { rewrite(min(x, 0) == 0, 0 <= x) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (rewrite(c0 == 0, fold(c0 == 0)) || @@ -134,7 +130,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *bounds) { } // ne redirects to not eq -Expr Simplify::visit(const NE *op, ExprInfo *bounds) { +Expr Simplify::visit(const NE *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -145,7 +141,7 @@ Expr Simplify::visit(const NE *op, ExprInfo *bounds) { } } - Expr mutated = mutate(Not::make(EQ::make(op->a, op->b)), bounds); + Expr mutated = mutate(Not::make(EQ::make(op->a, op->b)), info); if (const NE *ne = mutated.as()) { if (ne->a.same_as(op->a) && ne->b.same_as(op->b)) { return op; diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index b5fcc96ac0cd..d8678ddc5736 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -7,49 +7,46 @@ namespace Internal { // Miscellaneous expression visitors that are too small to bother putting in their own files -Expr Simplify::visit(const IntImm *op, ExprInfo *bounds) { - if (bounds && no_overflow_int(op->type)) { - bounds->min_defined = bounds->max_defined = true; - bounds->min = bounds->max = op->value; - bounds->alignment.remainder = op->value; - bounds->alignment.modulus = 0; +Expr Simplify::visit(const IntImm *op, ExprInfo *info) { + if (info && no_overflow_int(op->type)) { + info->bounds = ConstantInterval::single_point(op->value); + info->alignment = ModulusRemainder(0, op->value); } else { - clear_bounds_info(bounds); + clear_bounds_info(info); } return op; } -Expr Simplify::visit(const UIntImm *op, ExprInfo *bounds) { - if (bounds && Int(64).can_represent(op->value)) { - bounds->min_defined = bounds->max_defined = true; - bounds->min = bounds->max = (int64_t)(op->value); - bounds->alignment.remainder = op->value; - bounds->alignment.modulus = 0; +Expr Simplify::visit(const UIntImm *op, ExprInfo *info) { + if (info && Int(64).can_represent(op->value)) { + int64_t v = (int64_t)(op->value); + info->bounds = ConstantInterval::single_point(v); + info->alignment = ModulusRemainder(0, v); } else { - clear_bounds_info(bounds); + clear_bounds_info(info); } return op; } -Expr Simplify::visit(const FloatImm *op, ExprInfo *bounds) { - clear_bounds_info(bounds); +Expr Simplify::visit(const FloatImm *op, ExprInfo *info) { + clear_bounds_info(info); return op; } -Expr Simplify::visit(const StringImm *op, ExprInfo *bounds) { - clear_bounds_info(bounds); +Expr Simplify::visit(const StringImm *op, ExprInfo *info) { + clear_bounds_info(info); return op; } -Expr Simplify::visit(const Broadcast *op, ExprInfo *bounds) { - Expr value = mutate(op->value, bounds); +Expr Simplify::visit(const Broadcast *op, ExprInfo *info) { + Expr value = mutate(op->value, info); const int lanes = op->lanes; auto rewrite = IRMatcher::rewriter(IRMatcher::broadcast(value, lanes), op->type); if (rewrite(broadcast(broadcast(x, c0), lanes), broadcast(x, c0 * lanes)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (value.same_as(op->value)) { @@ -59,8 +56,8 @@ Expr Simplify::visit(const Broadcast *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { - Expr value = mutate(op->value, bounds); +Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { + Expr value = mutate(op->value, info); const int lanes = op->type.lanes(); const int arg_lanes = op->value.type().lanes(); @@ -69,32 +66,22 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { return value; } - if (bounds && op->type.is_int()) { + if (info && op->type.is_int()) { switch (op->op) { case VectorReduce::Add: // Alignment of result is the alignment of the arg. Bounds // of the result can grow according to the reduction // factor. - if (bounds->min_defined) { - bounds->min *= factor; - } - if (bounds->max_defined) { - bounds->max *= factor; - } + info->bounds = cast(op->type, info->bounds * factor); break; case VectorReduce::SaturatingAdd: - if (bounds->min_defined) { - bounds->min = saturating_mul(bounds->min, factor); - } - if (bounds->max_defined) { - bounds->max = saturating_mul(bounds->max, factor); - } + info->bounds = saturating_cast(op->type, info->bounds * factor); break; case VectorReduce::Mul: // Don't try to infer anything about bounds. Leave the // alignment unchanged even though we could theoretically // upgrade it. - bounds->min_defined = bounds->max_defined = false; + info->bounds = ConstantInterval{}; break; case VectorReduce::Min: case VectorReduce::Max: @@ -104,8 +91,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { case VectorReduce::Or: // For integer types this is a bitwise operator. Don't try // to infer anything for now. - bounds->min_defined = bounds->max_defined = false; - bounds->alignment = ModulusRemainder{}; + info->bounds = ConstantInterval{}; + info->alignment = ModulusRemainder{}; break; } } @@ -134,7 +121,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) || rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -148,7 +135,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) || rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -162,7 +149,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) || rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -183,7 +170,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), x <= y + min(z * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -205,7 +192,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), x <= y + max(z * (arg_lanes - 1), 0)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } break; } @@ -220,33 +207,42 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { +Expr Simplify::visit(const Variable *op, ExprInfo *info) { if (const ExprInfo *b = bounds_and_alignment_info.find(op->name)) { - if (bounds) { - *bounds = *b; + if (info) { + *info = *b; } - if (b->min_defined && b->max_defined && b->min == b->max) { - return make_const(op->type, b->min); + if (b->bounds.is_single_point()) { + if (info) { + debug(0) << "Var is single point: " << op->name << ": " << info->bounds << "\n"; + } + return make_const(op->type, b->bounds.min); } + } else if (info && !no_overflow_int(op->type)) { + info->bounds = ConstantInterval::bounds_of_type(op->type); } - if (auto *info = var_info.shallow_find(op->name)) { + if (info) { + debug(0) << "Bounds of var: " << op->name << ": " << info->bounds << "\n"; + } + + if (auto *v_info = var_info.shallow_find(op->name)) { // if replacement is defined, we should substitute it in (unless // it's a var that has been hidden by a nested scope). - if (info->replacement.defined()) { - internal_assert(info->replacement.type() == op->type) + if (v_info->replacement.defined()) { + internal_assert(v_info->replacement.type() == op->type) << "Cannot replace variable " << op->name << " of type " << op->type - << " with expression of type " << info->replacement.type() << "\n"; - info->new_uses++; + << " with expression of type " << v_info->replacement.type() << "\n"; + v_info->new_uses++; // We want to remutate the replacement, because we may be // injecting it into a context where it is known to be a // constant (e.g. due to an if). - return mutate(info->replacement, bounds); + return mutate(v_info->replacement, info); } else { // This expression was not something deemed // substitutable - no replacement is defined. - info->old_uses++; + v_info->old_uses++; return op; } } else { @@ -256,29 +252,26 @@ Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const Ramp *op, ExprInfo *bounds) { - ExprInfo base_bounds, stride_bounds; - Expr base = mutate(op->base, &base_bounds); - Expr stride = mutate(op->stride, &stride_bounds); +Expr Simplify::visit(const Ramp *op, ExprInfo *info) { + ExprInfo base_info, stride_info; + Expr base = mutate(op->base, &base_info); + Expr stride = mutate(op->stride, &stride_info); const int lanes = op->lanes; - if (bounds && no_overflow_int(op->type)) { - bounds->min_defined = base_bounds.min_defined && stride_bounds.min_defined; - bounds->max_defined = base_bounds.max_defined && stride_bounds.max_defined; - bounds->min = std::min(base_bounds.min, base_bounds.min + (lanes - 1) * stride_bounds.min); - bounds->max = std::max(base_bounds.max, base_bounds.max + (lanes - 1) * stride_bounds.max); + if (info && no_overflow_int(op->type)) { + info->bounds = base_info.bounds + stride_info.bounds * ConstantInterval(0, lanes - 1); // A ramp lane is b + l * s. Expanding b into mb * x + rb and s into ms * y + rs, we get: // mb * x + rb + l * (ms * y + rs) // = mb * x + ms * l * y + rs * l + rb // = gcd(rs, ms, mb) * z + rb - int64_t m = stride_bounds.alignment.modulus; - m = gcd(m, stride_bounds.alignment.remainder); - m = gcd(m, base_bounds.alignment.modulus); - int64_t r = base_bounds.alignment.remainder; + int64_t m = stride_info.alignment.modulus; + m = gcd(m, stride_info.alignment.remainder); + m = gcd(m, base_info.alignment.modulus); + int64_t r = base_info.alignment.remainder; if (m != 0) { - r = mod_imp(base_bounds.alignment.remainder, m); + r = mod_imp(base_info.alignment.remainder, m); } - bounds->alignment = {m, r}; + info->alignment = {m, r}; } // A somewhat torturous way to check if the stride is zero, @@ -303,9 +296,13 @@ Expr Simplify::visit(const Ramp *op, ExprInfo *bounds) { } } -Expr Simplify::visit(const Load *op, ExprInfo *bounds) { +Expr Simplify::visit(const Load *op, ExprInfo *info) { found_buffer_reference(op->name); + if (info) { + info->bounds = ConstantInterval::bounds_of_type(op->type); + } + Expr predicate = mutate(op->predicate, nullptr); ExprInfo index_info; @@ -319,17 +316,11 @@ Expr Simplify::visit(const Load *op, ExprInfo *bounds) { if (is_const_one(op->predicate)) { string alloc_extent_name = op->name + ".total_extent_bytes"; if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { - if (index_info.max_defined && index_info.max < 0) { + if (index_info.bounds < 0 || + index_info.bounds * op->type.bytes() > alloc_info->bounds) { in_unreachable = true; return unreachable(op->type); } - if (alloc_info->max_defined && index_info.min_defined) { - int index_min_bytes = index_info.min * op->type.bytes(); - if (index_min_bytes > alloc_info->max) { - in_unreachable = true; - return unreachable(op->type); - } - } } } diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index a59a4250cf2b..48d06c717794 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -7,10 +7,13 @@ * exported in Halide.h. */ #include "Bounds.h" +#include "ConstantInterval.h" #include "IRMatch.h" #include "IRVisitor.h" #include "Scope.h" +#include "IRPrinter.h" + // Because this file is only included by the simplify methods and // doesn't go into Halide.h, we're free to use any old names for our // macros. @@ -28,17 +31,6 @@ namespace Halide { namespace Internal { -inline int64_t saturating_mul(int64_t a, int64_t b) { - int64_t result; - if (mul_with_overflow(64, a, b, &result)) { - return result; - } else if ((a > 0) == (b > 0)) { - return INT64_MAX; - } else { - return INT64_MIN; - } -} - class Simplify : public VariadicVisitor { using Super = VariadicVisitor; @@ -47,63 +39,46 @@ class Simplify : public VariadicVisitor { struct ExprInfo { // We track constant integer bounds when they exist - // TODO: Use ConstantInterval? - int64_t min = 0, max = 0; - bool min_defined = false, max_defined = false; + ConstantInterval bounds; // And the alignment of integer variables ModulusRemainder alignment; void trim_bounds_using_alignment() { if (alignment.modulus == 0) { - min_defined = max_defined = true; - min = max = alignment.remainder; + bounds = ConstantInterval::single_point(alignment.remainder); } else if (alignment.modulus > 1) { - if (min_defined) { + if (bounds.has_lower_bound()) { int64_t adjustment; - bool no_overflow = sub_with_overflow(64, alignment.remainder, mod_imp(min, alignment.modulus), &adjustment); + bool no_overflow = sub_with_overflow(64, alignment.remainder, mod_imp(bounds.min, alignment.modulus), &adjustment); adjustment = mod_imp(adjustment, alignment.modulus); int64_t new_min; - no_overflow &= add_with_overflow(64, min, adjustment, &new_min); + no_overflow &= add_with_overflow(64, bounds.min, adjustment, &new_min); if (no_overflow) { - min = new_min; + bounds.min = new_min; } } - if (max_defined) { + if (bounds.has_upper_bound()) { int64_t adjustment; - bool no_overflow = sub_with_overflow(64, mod_imp(max, alignment.modulus), alignment.remainder, &adjustment); + bool no_overflow = sub_with_overflow(64, mod_imp(bounds.max, alignment.modulus), alignment.remainder, &adjustment); adjustment = mod_imp(adjustment, alignment.modulus); int64_t new_max; - no_overflow &= sub_with_overflow(64, max, adjustment, &new_max); + no_overflow &= sub_with_overflow(64, bounds.max, adjustment, &new_max); if (no_overflow) { - max = new_max; + bounds.max = new_max; } } } - if (min_defined && max_defined && min == max) { + if (bounds.is_single_point()) { alignment.modulus = 0; - alignment.remainder = min; + alignment.remainder = bounds.min; } } // Mix in existing knowledge about this Expr void intersect(const ExprInfo &other) { - if (min_defined && other.min_defined) { - min = std::max(min, other.min); - } else if (other.min_defined) { - min_defined = true; - min = other.min; - } - - if (max_defined && other.max_defined) { - max = std::min(max, other.max); - } else if (other.max_defined) { - max_defined = true; - max = other.max; - } - + bounds = ConstantInterval::make_intersection(bounds, other.bounds); alignment = ModulusRemainder::intersect(alignment, other.alignment); - trim_bounds_using_alignment(); } }; @@ -298,45 +273,45 @@ class Simplify : public VariadicVisitor { Stmt mutate_let_body(const Stmt &s, ExprInfo *) { return mutate(s); } - Expr mutate_let_body(const Expr &e, ExprInfo *bounds) { - return mutate(e, bounds); + Expr mutate_let_body(const Expr &e, ExprInfo *info) { + return mutate(e, info); } template - Body simplify_let(const T *op, ExprInfo *bounds); - - Expr visit(const IntImm *op, ExprInfo *bounds); - Expr visit(const UIntImm *op, ExprInfo *bounds); - Expr visit(const FloatImm *op, ExprInfo *bounds); - Expr visit(const StringImm *op, ExprInfo *bounds); - Expr visit(const Broadcast *op, ExprInfo *bounds); - Expr visit(const Cast *op, ExprInfo *bounds); - Expr visit(const Reinterpret *op, ExprInfo *bounds); - Expr visit(const Variable *op, ExprInfo *bounds); - Expr visit(const Add *op, ExprInfo *bounds); - Expr visit(const Sub *op, ExprInfo *bounds); - Expr visit(const Mul *op, ExprInfo *bounds); - Expr visit(const Div *op, ExprInfo *bounds); - Expr visit(const Mod *op, ExprInfo *bounds); - Expr visit(const Min *op, ExprInfo *bounds); - Expr visit(const Max *op, ExprInfo *bounds); - Expr visit(const EQ *op, ExprInfo *bounds); - Expr visit(const NE *op, ExprInfo *bounds); - Expr visit(const LT *op, ExprInfo *bounds); - Expr visit(const LE *op, ExprInfo *bounds); - Expr visit(const GT *op, ExprInfo *bounds); - Expr visit(const GE *op, ExprInfo *bounds); - Expr visit(const And *op, ExprInfo *bounds); - Expr visit(const Or *op, ExprInfo *bounds); - Expr visit(const Not *op, ExprInfo *bounds); - Expr visit(const Select *op, ExprInfo *bounds); - Expr visit(const Ramp *op, ExprInfo *bounds); + Body simplify_let(const T *op, ExprInfo *info); + + Expr visit(const IntImm *op, ExprInfo *info); + Expr visit(const UIntImm *op, ExprInfo *info); + Expr visit(const FloatImm *op, ExprInfo *info); + Expr visit(const StringImm *op, ExprInfo *info); + Expr visit(const Broadcast *op, ExprInfo *info); + Expr visit(const Cast *op, ExprInfo *info); + Expr visit(const Reinterpret *op, ExprInfo *info); + Expr visit(const Variable *op, ExprInfo *info); + Expr visit(const Add *op, ExprInfo *info); + Expr visit(const Sub *op, ExprInfo *info); + Expr visit(const Mul *op, ExprInfo *info); + Expr visit(const Div *op, ExprInfo *info); + Expr visit(const Mod *op, ExprInfo *info); + Expr visit(const Min *op, ExprInfo *info); + Expr visit(const Max *op, ExprInfo *info); + Expr visit(const EQ *op, ExprInfo *info); + Expr visit(const NE *op, ExprInfo *info); + Expr visit(const LT *op, ExprInfo *info); + Expr visit(const LE *op, ExprInfo *info); + Expr visit(const GT *op, ExprInfo *info); + Expr visit(const GE *op, ExprInfo *info); + Expr visit(const And *op, ExprInfo *info); + Expr visit(const Or *op, ExprInfo *info); + Expr visit(const Not *op, ExprInfo *info); + Expr visit(const Select *op, ExprInfo *info); + Expr visit(const Ramp *op, ExprInfo *info); Stmt visit(const IfThenElse *op); - Expr visit(const Load *op, ExprInfo *bounds); - Expr visit(const Call *op, ExprInfo *bounds); - Expr visit(const Shuffle *op, ExprInfo *bounds); - Expr visit(const VectorReduce *op, ExprInfo *bounds); - Expr visit(const Let *op, ExprInfo *bounds); + Expr visit(const Load *op, ExprInfo *info); + Expr visit(const Call *op, ExprInfo *info); + Expr visit(const Shuffle *op, ExprInfo *info); + Expr visit(const VectorReduce *op, ExprInfo *info); + Expr visit(const Let *op, ExprInfo *info); Stmt visit(const LetStmt *op); Stmt visit(const AssertStmt *op); Stmt visit(const For *op); @@ -354,7 +329,7 @@ class Simplify : public VariadicVisitor { Stmt visit(const Atomic *op); Stmt visit(const HoistedStorage *op); - std::pair, bool> mutate_with_changes(const std::vector &old_exprs, ExprInfo *bounds); + std::pair, bool> mutate_with_changes(const std::vector &old_exprs); }; } // namespace Internal diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 58c6d4d27ab3..1069602e24a1 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -3,10 +3,10 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const LT *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); +Expr Simplify::visit(const LT *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); const int lanes = op->type.lanes(); Type ty = a.type(); @@ -20,11 +20,12 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { if (may_simplify(ty)) { // Prove or disprove using bounds analysis - if (a_bounds.max_defined && b_bounds.min_defined && a_bounds.max < b_bounds.min) { + debug(0) << "ELEPHANT: " << Expr(op) << ": " << a_info.bounds << ", " << b_info.bounds << "\n"; + if (a_info.bounds < b_info.bounds) { + debug(0) << "... true\n"; return const_true(lanes); - } - - if (a_bounds.min_defined && b_bounds.max_defined && a_bounds.min >= b_bounds.max) { + } else if (a_info.bounds >= b_info.bounds) { + debug(0) << "... false\n"; return const_false(lanes); } @@ -499,7 +500,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { c1 * (lanes - 1) < c0 && c1 * (lanes - 1) >= 0) ))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } @@ -512,7 +513,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *bounds) { } // The other comparison operators redirect to the less-than operator -Expr Simplify::visit(const LE *op, ExprInfo *bounds) { +Expr Simplify::visit(const LE *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -523,7 +524,7 @@ Expr Simplify::visit(const LE *op, ExprInfo *bounds) { } } - Expr mutated = mutate(!(op->b < op->a), bounds); + Expr mutated = mutate(!(op->b < op->a), info); if (const LE *le = mutated.as()) { if (le->a.same_as(op->a) && le->b.same_as(op->b)) { return op; @@ -532,7 +533,7 @@ Expr Simplify::visit(const LE *op, ExprInfo *bounds) { return mutated; } -Expr Simplify::visit(const GT *op, ExprInfo *bounds) { +Expr Simplify::visit(const GT *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -543,10 +544,10 @@ Expr Simplify::visit(const GT *op, ExprInfo *bounds) { } } - return mutate(op->b < op->a, bounds); + return mutate(op->b < op->a, info); } -Expr Simplify::visit(const GE *op, ExprInfo *bounds) { +Expr Simplify::visit(const GE *op, ExprInfo *info) { if (!may_simplify(op->a.type())) { Expr a = mutate(op->a, nullptr); Expr b = mutate(op->b, nullptr); @@ -557,7 +558,7 @@ Expr Simplify::visit(const GE *op, ExprInfo *bounds) { } } - return mutate(!(op->a < op->b), bounds); + return mutate(!(op->a < op->b), info); } } // namespace Internal diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 4f1862abf6ac..481f37be87a4 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -43,7 +43,7 @@ void count_var_uses(StmtOrExpr x, std::map &var_uses) { } // namespace template -Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { +Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { // Lets are often deeply nested. Get the intermediate state off // the call stack where it could overflow onto an explicit stack. @@ -70,8 +70,8 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { // If the value is trivial, make a note of it in the scope so // we can subs it in later - ExprInfo value_bounds; - f.value = mutate(op->value, &value_bounds); + ExprInfo value_info; + f.value = mutate(op->value, &value_info); // Iteratively peel off certain operations from the let value and push them inside. f.new_value = f.value; @@ -201,21 +201,24 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { var_info.push(op->name, info); // Before we enter the body, track the alignment info - if (f.new_value.defined() && no_overflow_scalar_int(f.new_value.type())) { // Remutate new_value to get updated bounds - ExprInfo new_value_bounds; - f.new_value = mutate(f.new_value, &new_value_bounds); - if (new_value_bounds.min_defined || new_value_bounds.max_defined || new_value_bounds.alignment.modulus != 1) { + ExprInfo new_value_info; + f.new_value = mutate(f.new_value, &new_value_info); + if (new_value_info.bounds.has_lower_bound() || + new_value_info.bounds.has_upper_bound() || + new_value_info.alignment.modulus != 1) { // There is some useful information - bounds_and_alignment_info.push(f.new_name, new_value_bounds); + bounds_and_alignment_info.push(f.new_name, new_value_info); f.new_value_bounds_tracked = true; } } if (no_overflow_scalar_int(f.value.type())) { - if (value_bounds.min_defined || value_bounds.max_defined || value_bounds.alignment.modulus != 1) { - bounds_and_alignment_info.push(op->name, value_bounds); + if (value_info.bounds.has_lower_bound() || + value_info.bounds.has_upper_bound() || + value_info.alignment.modulus != 1) { + bounds_and_alignment_info.push(op->name, value_info); f.value_bounds_tracked = true; } } @@ -224,7 +227,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { op = result.template as(); } - result = mutate_let_body(result, bounds); + result = mutate_let_body(result, info); // TODO: var_info and vars_used are pretty redundant; however, at the time // of writing, both cover cases that the other does not: @@ -271,8 +274,8 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { return result; } -Expr Simplify::visit(const Let *op, ExprInfo *bounds) { - return simplify_let(op, bounds); +Expr Simplify::visit(const Let *op, ExprInfo *info) { + return simplify_let(op, info); } Stmt Simplify::visit(const LetStmt *op) { diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index 1a79aef962fa..59bb3b6313e5 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -3,44 +3,37 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Max *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds) { - bounds->min_defined = a_bounds.min_defined || b_bounds.min_defined; - bounds->max_defined = a_bounds.max_defined && b_bounds.max_defined; - bounds->max = std::max(a_bounds.max, b_bounds.max); - if (a_bounds.min_defined && b_bounds.min_defined) { - bounds->min = std::max(a_bounds.min, b_bounds.min); - } else if (a_bounds.min_defined) { - bounds->min = a_bounds.min; - } else { - bounds->min = b_bounds.min; - } - bounds->alignment = ModulusRemainder::unify(a_bounds.alignment, b_bounds.alignment); - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Max *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + info->bounds = max(a_info.bounds, b_info.bounds); + info->alignment = ModulusRemainder::unify(a_info.alignment, b_info.alignment); + info->trim_bounds_using_alignment(); } - // Early out when the bounds tells us one side or the other is smaller - if (a_bounds.max_defined && b_bounds.min_defined && a_bounds.max <= b_bounds.min) { - if (const Call *call = b.as()) { + auto strip_likely = [](const Expr &e) { + if (const Call *call = e.as()) { if (call->is_intrinsic(Call::likely) || call->is_intrinsic(Call::likely_if_innermost)) { return call->args[0]; } } - return b; + return e; + }; + + if (info) { + debug(0) << "Bounds of max: " << Expr(op) << ": " << a_info.bounds << ", " << b_info.bounds << ", " << info->bounds << "\n"; } - if (b_bounds.max_defined && a_bounds.min_defined && b_bounds.max <= a_bounds.min) { - if (const Call *call = a.as()) { - if (call->is_intrinsic(Call::likely) || - call->is_intrinsic(Call::likely_if_innermost)) { - return call->args[0]; - } - } - return a; + + // Early out when the bounds tells us one side or the other is smaller + if (a_info.bounds <= b_info.bounds) { + return strip_likely(b); + } + if (b_info.bounds <= a_info.bounds) { + return strip_likely(a); } if (may_simplify(op->type)) { @@ -48,7 +41,7 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } int lanes = op->type.lanes(); @@ -301,7 +294,7 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { rewrite(max(c0 - x, c1), c0 - min(x, fold(c0 - c1))))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Min.cpp b/src/Simplify_Min.cpp index 214ed09374d3..41e455174351 100644 --- a/src/Simplify_Min.cpp +++ b/src/Simplify_Min.cpp @@ -3,44 +3,34 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Min *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds) { - bounds->min_defined = a_bounds.min_defined && b_bounds.min_defined; - bounds->max_defined = a_bounds.max_defined || b_bounds.max_defined; - bounds->min = std::min(a_bounds.min, b_bounds.min); - if (a_bounds.max_defined && b_bounds.max_defined) { - bounds->max = std::min(a_bounds.max, b_bounds.max); - } else if (a_bounds.max_defined) { - bounds->max = a_bounds.max; - } else { - bounds->max = b_bounds.max; - } - bounds->alignment = ModulusRemainder::unify(a_bounds.alignment, b_bounds.alignment); - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Min *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info) { + info->bounds = min(a_info.bounds, b_info.bounds); + info->alignment = ModulusRemainder::unify(a_info.alignment, b_info.alignment); + info->trim_bounds_using_alignment(); } // Early out when the bounds tells us one side or the other is smaller - if (a_bounds.max_defined && b_bounds.min_defined && a_bounds.max <= b_bounds.min) { - if (const Call *call = a.as()) { + auto strip_likely = [](const Expr &e) { + if (const Call *call = e.as()) { if (call->is_intrinsic(Call::likely) || call->is_intrinsic(Call::likely_if_innermost)) { return call->args[0]; } } - return a; + return e; + }; + + // Early out when the bounds tells us one side or the other is smaller + if (a_info.bounds >= b_info.bounds) { + return strip_likely(b); } - if (b_bounds.max_defined && a_bounds.min_defined && b_bounds.max <= a_bounds.min) { - if (const Call *call = b.as()) { - if (call->is_intrinsic(Call::likely) || - call->is_intrinsic(Call::likely_if_innermost)) { - return call->args[0]; - } - } - return b; + if (b_info.bounds >= a_info.bounds) { + return strip_likely(a); } if (may_simplify(op->type)) { @@ -48,7 +38,7 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } int lanes = op->type.lanes(); @@ -312,7 +302,7 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { false )))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Mod.cpp b/src/Simplify_Mod.cpp index fcd4021b759f..6ba14847a3a2 100644 --- a/src/Simplify_Mod.cpp +++ b/src/Simplify_Mod.cpp @@ -3,60 +3,32 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Mod *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); +Expr Simplify::visit(const Mod *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); // We always combine bounds here, even if not requested, because // we can use them to simplify down to a constant if the bounds // are tight enough. - ExprInfo mod_bounds; - + ExprInfo mod_info; if (no_overflow_int(op->type)) { - // The result is at least zero. - mod_bounds.min_defined = true; - mod_bounds.min = 0; - - // Mod by produces a result between 0 - // and max(0, abs(modulus) - 1). However, if b is unbounded in - // either direction, abs(modulus) could be arbitrarily - // large. - if (b_bounds.max_defined && b_bounds.min_defined) { - mod_bounds.max_defined = true; - mod_bounds.max = 0; // When b == 0 - mod_bounds.max = std::max(mod_bounds.max, b_bounds.max - 1); // When b > 0 - mod_bounds.max = std::max(mod_bounds.max, -1 - b_bounds.min); // When b < 0 - } - - // If a is positive, mod can't make it larger - if (a_bounds.min_defined && a_bounds.min >= 0 && a_bounds.max_defined) { - if (mod_bounds.max_defined) { - mod_bounds.max = std::min(mod_bounds.max, a_bounds.max); - } else { - mod_bounds.max_defined = true; - mod_bounds.max = a_bounds.max; - } - } + mod_info.bounds = a_info.bounds % b_info.bounds; + mod_info.alignment = a_info.alignment % b_info.alignment; + mod_info.trim_bounds_using_alignment(); - mod_bounds.alignment = a_bounds.alignment % b_bounds.alignment; - mod_bounds.trim_bounds_using_alignment(); - if (bounds) { - *bounds = mod_bounds; + if (info) { + *info = mod_info; } } if (may_simplify(op->type)) { - if (a_bounds.min_defined && a_bounds.min >= 0 && - a_bounds.max_defined && b_bounds.min_defined && a_bounds.max < b_bounds.min) { - if (bounds) { - *bounds = a_bounds; - } + if (a_info.bounds >= 0 && a_info.bounds < b_info.bounds) { return a; } - if (mod_bounds.min_defined && mod_bounds.max_defined && mod_bounds.min == mod_bounds.max) { - return make_const(op->type, mod_bounds.min); + if (mod_info.bounds.is_single_point()) { + return make_const(op->type, mod_info.bounds.min); } int lanes = op->type.lanes(); @@ -94,7 +66,7 @@ Expr Simplify::visit(const Mod *op, ExprInfo *bounds) { rewrite(ramp(x + c0, c2, c3) % broadcast(c1, c3), ramp(x + fold(c0 % c1), fold(c2 % c1), c3) % c1, c1 > 0 && (c0 >= c1 || c0 < 0)) || rewrite(ramp(x * c0 + y, c2, c3) % broadcast(c1, c3), ramp(y, fold(c2 % c1), c3) % c1, c0 % c1 == 0) || rewrite(ramp(y + x * c0, c2, c3) % broadcast(c1, c3), ramp(y, fold(c2 % c1), c3) % c1, c0 % c1 == 0))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Mul.cpp b/src/Simplify_Mul.cpp index 881d09112f7d..e333de5c566b 100644 --- a/src/Simplify_Mul.cpp +++ b/src/Simplify_Mul.cpp @@ -3,49 +3,15 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); - - if (bounds && no_overflow_int(op->type)) { - bool a_positive = a_bounds.min_defined && a_bounds.min > 0; - bool b_positive = b_bounds.min_defined && b_bounds.min > 0; - bool a_bounded = a_bounds.min_defined && a_bounds.max_defined; - bool b_bounded = b_bounds.min_defined && b_bounds.max_defined; - - if (a_bounded && b_bounded) { - bounds->min_defined = bounds->max_defined = true; - int64_t v1 = saturating_mul(a_bounds.min, b_bounds.min); - int64_t v2 = saturating_mul(a_bounds.min, b_bounds.max); - int64_t v3 = saturating_mul(a_bounds.max, b_bounds.min); - int64_t v4 = saturating_mul(a_bounds.max, b_bounds.max); - bounds->min = std::min(std::min(v1, v2), std::min(v3, v4)); - bounds->max = std::max(std::max(v1, v2), std::max(v3, v4)); - } else if ((a_bounds.max_defined && b_bounded && b_positive) || - (b_bounds.max_defined && a_bounded && a_positive)) { - bounds->max_defined = true; - bounds->max = saturating_mul(a_bounds.max, b_bounds.max); - } else if ((a_bounds.min_defined && b_bounded && b_positive) || - (b_bounds.min_defined && a_bounded && a_positive)) { - bounds->min_defined = true; - bounds->min = saturating_mul(a_bounds.min, b_bounds.min); - } - - if (bounds->max_defined && bounds->max == INT64_MAX) { - // Assume it saturated to avoid overflow. This gives up a - // single representable value at the top end of the range - // to represent infinity. - bounds->max_defined = false; - bounds->max = 0; - } - if (bounds->min_defined && bounds->min == INT64_MIN) { - bounds->min_defined = false; - bounds->min = 0; - } - - bounds->alignment = a_bounds.alignment * b_bounds.alignment; - bounds->trim_bounds_using_alignment(); +Expr Simplify::visit(const Mul *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); + + if (info && no_overflow_int(op->type)) { + info->bounds = a_info.bounds * b_info.bounds; + info->alignment = a_info.alignment * b_info.alignment; + info->trim_bounds_using_alignment(); } if (may_simplify(op->type)) { @@ -53,7 +19,7 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { // Order commutative operations by node type if (should_commute(a, b)) { std::swap(a, b); - std::swap(a_bounds, b_bounds); + std::swap(a_info, b_info); } auto rewrite = IRMatcher::rewriter(IRMatcher::mul(a, b), op->type); @@ -103,7 +69,7 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { rewrite(slice(x, c0, c1, c2) * (z * slice(y, c0, c1, c2)), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } } diff --git a/src/Simplify_Not.cpp b/src/Simplify_Not.cpp index 70b4b234ddef..47b74661fd2c 100644 --- a/src/Simplify_Not.cpp +++ b/src/Simplify_Not.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Not *op, ExprInfo *bounds) { +Expr Simplify::visit(const Not *op, ExprInfo *info) { Expr a = mutate(op->a, nullptr); auto rewrite = IRMatcher::rewriter(IRMatcher::not_op(a), op->type); @@ -25,7 +25,7 @@ Expr Simplify::visit(const Not *op, ExprInfo *bounds) { rewrite(!(x && !y), !x || y) || rewrite(!(x || !y), !x && y) || false) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (a.same_as(op->a)) { diff --git a/src/Simplify_Or.cpp b/src/Simplify_Or.cpp index 274d66435ffb..083af6d5bc88 100644 --- a/src/Simplify_Or.cpp +++ b/src/Simplify_Or.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Or *op, ExprInfo *bounds) { +Expr Simplify::visit(const Or *op, ExprInfo *info) { if (truths.count(op)) { return const_true(op->type.lanes()); } @@ -101,7 +101,7 @@ Expr Simplify::visit(const Or *op, ExprInfo *bounds) { rewrite(x <= y || x <= z, x <= max(y, z)) || rewrite(y <= x || z <= x, min(y, z) <= x)) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } if (a.same_as(op->a) && diff --git a/src/Simplify_Reinterpret.cpp b/src/Simplify_Reinterpret.cpp index d5a8c1361fbe..51289aac9b87 100644 --- a/src/Simplify_Reinterpret.cpp +++ b/src/Simplify_Reinterpret.cpp @@ -3,7 +3,7 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Reinterpret *op, ExprInfo *bounds) { +Expr Simplify::visit(const Reinterpret *op, ExprInfo *info) { Expr a = mutate(op->value, nullptr); int64_t ia; @@ -19,7 +19,7 @@ Expr Simplify::visit(const Reinterpret *op, ExprInfo *bounds) { return make_const(op->type, (int64_t)ua); } else if (const Reinterpret *as_r = a.as()) { // Fold double-reinterprets. - return mutate(reinterpret(op->type, as_r->value), bounds); + return mutate(reinterpret(op->type, as_r->value), info); } else if ((op->type.bits() == a.type().bits()) && op->type.is_int_or_uint() && a.type().is_int_or_uint()) { diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index 0233be61724d..63be8d64718e 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -3,20 +3,17 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Select *op, ExprInfo *bounds) { +Expr Simplify::visit(const Select *op, ExprInfo *info) { - ExprInfo t_bounds, f_bounds; + ExprInfo t_info, f_info; Expr condition = mutate(op->condition, nullptr); - Expr true_value = mutate(op->true_value, &t_bounds); - Expr false_value = mutate(op->false_value, &f_bounds); - - if (bounds) { - bounds->min_defined = t_bounds.min_defined && f_bounds.min_defined; - bounds->max_defined = t_bounds.max_defined && f_bounds.max_defined; - bounds->min = std::min(t_bounds.min, f_bounds.min); - bounds->max = std::max(t_bounds.max, f_bounds.max); - bounds->alignment = ModulusRemainder::unify(t_bounds.alignment, f_bounds.alignment); - bounds->trim_bounds_using_alignment(); + Expr true_value = mutate(op->true_value, &t_info); + Expr false_value = mutate(op->false_value, &f_info); + + if (info) { + info->bounds = ConstantInterval::make_union(t_info.bounds, f_info.bounds); + info->alignment = ModulusRemainder::unify(t_info.alignment, f_info.alignment); + info->trim_bounds_using_alignment(); } if (may_simplify(op->type)) { @@ -230,7 +227,7 @@ Expr Simplify::visit(const Select *op, ExprInfo *bounds) { rewrite(select(x, y, true), !x || y) || rewrite(select(x, false, y), !x && y) || rewrite(select(x, true, y), x || y))))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } // clang-format on } diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 7da4f6699ab7..348289ab0c83 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -7,7 +7,7 @@ namespace Internal { using std::vector; -Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { +Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { if (op->is_extract_element()) { int index = op->indices[0]; internal_assert(index >= 0); @@ -18,7 +18,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { // the same shuffle back. break; } else { - return extract_lane(mutate(vector, bounds), index); + return extract_lane(mutate(vector, info), index); } } index -= vector.type().lanes(); @@ -29,20 +29,17 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { vector new_vectors; bool changed = false; for (const Expr &vector : op->vectors) { - ExprInfo v_bounds; - Expr new_vector = mutate(vector, &v_bounds); + ExprInfo v_info; + Expr new_vector = mutate(vector, &v_info); if (!vector.same_as(new_vector)) { changed = true; } - if (bounds) { + if (info) { if (new_vectors.empty()) { - *bounds = v_bounds; + *info = v_info; } else { - bounds->min_defined &= v_bounds.min_defined; - bounds->max_defined &= v_bounds.max_defined; - bounds->min = std::min(bounds->min, v_bounds.min); - bounds->max = std::max(bounds->max, v_bounds.max); - bounds->alignment = ModulusRemainder::unify(bounds->alignment, v_bounds.alignment); + info->bounds = ConstantInterval::make_union(info->bounds, v_info.bounds); + info->alignment = ModulusRemainder::unify(info->alignment, v_info.alignment); } } new_vectors.push_back(new_vector); @@ -141,7 +138,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { } } if (can_collapse) { - return mutate(Ramp::make(r->base, r->stride / terms, r->lanes * terms), bounds); + return mutate(Ramp::make(r->base, r->stride / terms, r->lanes * terms), info); } } @@ -272,7 +269,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { if (cast->type.bits() > cast->value.type().bits()) { return mutate(Cast::make(cast->type.with_lanes(op->type.lanes()), Shuffle::make({cast->value}, op->indices)), - bounds); + info); } } } diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index f6cb81345961..fac58dcb847f 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -203,12 +203,12 @@ Stmt Simplify::visit(const AssertStmt *op) { } Stmt Simplify::visit(const For *op) { - ExprInfo min_bounds, extent_bounds; - Expr new_min = mutate(op->min, &min_bounds); + ExprInfo min_info, extent_info; + Expr new_min = mutate(op->min, &min_info); if (in_unreachable) { return Evaluate::make(new_min); } - Expr new_extent = mutate(op->extent, &extent_bounds); + Expr new_extent = mutate(op->extent, &extent_info); if (in_unreachable) { return Evaluate::make(new_extent); } @@ -218,12 +218,13 @@ Stmt Simplify::visit(const For *op) { op->for_type == ForType::Vectorized)); bool bounds_tracked = false; - if (min_bounds.min_defined || (min_bounds.max_defined && extent_bounds.max_defined)) { - min_bounds.max += extent_bounds.max - 1; - min_bounds.max_defined &= extent_bounds.max_defined; - min_bounds.alignment = ModulusRemainder{}; + ExprInfo loop_var_info; + loop_var_info.bounds = ConstantInterval::make_union(min_info.bounds, + min_info.bounds + extent_info.bounds - 1); + if (loop_var_info.bounds.has_upper_bound() || + loop_var_info.bounds.has_lower_bound()) { bounds_tracked = true; - bounds_and_alignment_info.push(op->name, min_bounds); + bounds_and_alignment_info.push(op->name, loop_var_info); } Stmt new_body; @@ -233,8 +234,9 @@ Stmt Simplify::visit(const For *op) { new_body = mutate(op->body); } if (in_unreachable) { - if (extent_bounds.min_defined && extent_bounds.min >= 1) { - // If we know the loop executes once, the code that runs this loop is unreachable. + if (extent_info.bounds > 0) { + // If we know the loop executes at least once, the code that runs + // this loop is unreachable. return new_body; } in_unreachable = false; @@ -254,14 +256,14 @@ Stmt Simplify::visit(const For *op) { if (is_no_op(new_body)) { return new_body; - } else if (extent_bounds.max_defined && - extent_bounds.max <= 0) { + } else if (extent_info.bounds <= 0) { return Evaluate::make(0); - } else if (extent_bounds.max_defined && - extent_bounds.max <= 1 && + } else if (extent_info.bounds <= 1 && op->device_api == DeviceAPI::None) { + // Loop body runs at most once Stmt s = LetStmt::make(op->name, new_min, new_body); - if (extent_bounds.min < 1) { + if (extent_info.bounds.contains(0)) { + // Loop body might not run at all s = IfThenElse::make(0 < new_extent, s); } return mutate(s); @@ -280,8 +282,8 @@ Stmt Simplify::visit(const Provide *op) { found_buffer_reference(op->name, op->args.size()); // Mutate the args - auto [new_args, changed_args] = mutate_with_changes(op->args, nullptr); - auto [new_values, changed_values] = mutate_with_changes(op->values, nullptr); + auto [new_args, changed_args] = mutate_with_changes(op->args); + auto [new_values, changed_values] = mutate_with_changes(op->values); Expr new_predicate = mutate(op->predicate, nullptr); if (!(changed_args || changed_values) && new_predicate.same_as(op->predicate)) { @@ -307,17 +309,11 @@ Stmt Simplify::visit(const Store *op) { string alloc_extent_name = op->name + ".total_extent_bytes"; if (is_const_one(op->predicate)) { if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { - if (index_info.max_defined && index_info.max < 0) { + if (index_info.bounds < 0 || + index_info.bounds * op->value.type().bytes() > alloc_info->bounds) { in_unreachable = true; return Evaluate::make(unreachable()); } - if (alloc_info->max_defined && index_info.min_defined) { - int index_min_bytes = index_info.min * op->value.type().bytes(); - if (index_min_bytes > alloc_info->max) { - in_unreachable = true; - return Evaluate::make(unreachable()); - } - } } } @@ -356,33 +352,14 @@ Stmt Simplify::visit(const Allocate *op) { std::vector new_extents; bool all_extents_unmodified = true; ExprInfo total_extent_info; - total_extent_info.min_defined = true; - total_extent_info.max_defined = true; - total_extent_info.min = 1; - total_extent_info.max = 1; + total_extent_info.bounds = ConstantInterval::single_point(op->type.bytes()); for (size_t i = 0; i < op->extents.size(); i++) { ExprInfo extent_info; new_extents.push_back(mutate(op->extents[i], &extent_info)); all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); - if (extent_info.min_defined) { - total_extent_info.min *= extent_info.min; - } else { - total_extent_info.min_defined = false; - } - if (extent_info.max_defined) { - total_extent_info.max *= extent_info.max; - } else { - total_extent_info.max_defined = false; - } - } - if (total_extent_info.min_defined) { - total_extent_info.min *= op->type.bytes(); - total_extent_info.min -= 1; - } - if (total_extent_info.max_defined) { - total_extent_info.max *= op->type.bytes(); - total_extent_info.max -= 1; + total_extent_info.bounds *= extent_info.bounds; } + total_extent_info.bounds -= 1; ScopedBinding b(bounds_and_alignment_info, op->name + ".total_extent_bytes", total_extent_info); diff --git a/src/Simplify_Sub.cpp b/src/Simplify_Sub.cpp index 1ab53b2dea90..d566f42cad9c 100644 --- a/src/Simplify_Sub.cpp +++ b/src/Simplify_Sub.cpp @@ -3,23 +3,18 @@ namespace Halide { namespace Internal { -Expr Simplify::visit(const Sub *op, ExprInfo *bounds) { - ExprInfo a_bounds, b_bounds; - Expr a = mutate(op->a, &a_bounds); - Expr b = mutate(op->b, &b_bounds); +Expr Simplify::visit(const Sub *op, ExprInfo *info) { + ExprInfo a_info, b_info; + Expr a = mutate(op->a, &a_info); + Expr b = mutate(op->b, &b_info); - if (bounds && no_overflow_int(op->type)) { + if (info && no_overflow_int(op->type)) { // Doesn't account for correlated a, b, so any // cancellation rule that exploits that should always // remutate to recalculate the bounds. - bounds->min_defined = a_bounds.min_defined && - b_bounds.max_defined && - sub_with_overflow(64, a_bounds.min, b_bounds.max, &(bounds->min)); - bounds->max_defined = a_bounds.max_defined && - b_bounds.min_defined && - sub_with_overflow(64, a_bounds.max, b_bounds.min, &(bounds->max)); - bounds->alignment = a_bounds.alignment - b_bounds.alignment; - bounds->trim_bounds_using_alignment(); + info->bounds = a_info.bounds - b_info.bounds; + info->alignment = a_info.alignment - b_info.alignment; + info->trim_bounds_using_alignment(); } if (may_simplify(op->type)) { @@ -446,7 +441,7 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) { rewrite((min(z, x*c0 + y) + w) / c1 - x*c2, (min(z - x*c0, y) + w) / c1, c0 == c1 * c2) || false)))) { - return mutate(rewrite.result, bounds); + return mutate(rewrite.result, info); } } // clang-format on From 7f4bb38756f01600d320b05617633b97920b542c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 25 Mar 2024 15:29:43 -0700 Subject: [PATCH 07/33] Handle bounds of narrower types in the simplifier too --- src/Simplify_Add.cpp | 3 ++- src/Simplify_Call.cpp | 5 +++++ src/Simplify_Div.cpp | 6 +++--- src/Simplify_Exprs.cpp | 8 ++++++-- src/Simplify_Internal.h | 30 ++++++++++++++++++++++++++++++ src/Simplify_Mul.cpp | 3 ++- src/Simplify_Sub.cpp | 3 ++- test/correctness/simplify.cpp | 9 ++++++++- 8 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index 4efc7e4b9fcb..e4cccf131b5e 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -8,10 +8,11 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) { Expr a = mutate(op->a, &a_info); Expr b = mutate(op->b, &b_info); - if (info && no_overflow_int(op->type)) { + if (info) { info->bounds = a_info.bounds + b_info.bounds; info->alignment = a_info.alignment + b_info.alignment; info->trim_bounds_using_alignment(); + info->cast_to(op->type); } if (may_simplify(op->type)) { diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 609a156f9aea..9dd98c3e3e8b 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -294,6 +294,11 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return mutate(unbroadcast, info); } + if (info) { + info->bounds = abs(a_info.bounds); + info->cast_to(op->type); + } + Type ta = a.type(); int64_t ia = 0; double fa = 0; diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index fecd381545cc..45b6a2ad8fb7 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -8,10 +8,11 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { Expr a = mutate(op->a, &a_info); Expr b = mutate(op->b, &b_info); - if (info && no_overflow_int(op->type)) { + if (info) { info->bounds = a_info.bounds / b_info.bounds; info->alignment = a_info.alignment / b_info.alignment; info->trim_bounds_using_alignment(); + info->cast_to(op->type); // TODO: add test case which resolves to a scalar, but only after // trimming using the alignment. @@ -37,8 +38,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { bool denominator_non_zero = (no_overflow_int(op->type) && - (b_info.bounds < 0 || - b_info.bounds > 0 || + (!b_info.bounds.contains(0) || b_info.alignment.remainder != 0)); if (may_simplify(op->type)) { diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index d8678ddc5736..b7800dd6faa5 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -8,9 +8,10 @@ namespace Internal { // Miscellaneous expression visitors that are too small to bother putting in their own files Expr Simplify::visit(const IntImm *op, ExprInfo *info) { - if (info && no_overflow_int(op->type)) { + if (info) { info->bounds = ConstantInterval::single_point(op->value); info->alignment = ModulusRemainder(0, op->value); + info->cast_to(op->type); } else { clear_bounds_info(info); } @@ -22,6 +23,7 @@ Expr Simplify::visit(const UIntImm *op, ExprInfo *info) { int64_t v = (int64_t)(op->value); info->bounds = ConstantInterval::single_point(v); info->alignment = ModulusRemainder(0, v); + info->cast_to(op->type); } else { clear_bounds_info(info); } @@ -258,7 +260,7 @@ Expr Simplify::visit(const Ramp *op, ExprInfo *info) { Expr stride = mutate(op->stride, &stride_info); const int lanes = op->lanes; - if (info && no_overflow_int(op->type)) { + if (info) { info->bounds = base_info.bounds + stride_info.bounds * ConstantInterval(0, lanes - 1); // A ramp lane is b + l * s. Expanding b into mb * x + rb and s into ms * y + rs, we get: // mb * x + rb + l * (ms * y + rs) @@ -272,6 +274,8 @@ Expr Simplify::visit(const Ramp *op, ExprInfo *info) { r = mod_imp(base_info.alignment.remainder, m); } info->alignment = {m, r}; + info->trim_bounds_using_alignment(); + info->cast_to(op->type); } // A somewhat torturous way to check if the stride is zero, diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 48d06c717794..0b87b7aa6774 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -75,6 +75,36 @@ class Simplify : public VariadicVisitor { } } + void cast_to(Type t) { + if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) { + return; + } + + // We've just done some infinite-integer operation on a bounded + // integer type, and we need to project the bounds and alignment + // back in-range. + + // Bounds: + bounds.cast_to(t); + + if (t.bits() >= 64) { + // Just preserve any power-of-two factor in the modulus. When + // alignment.modulus == 0, the value is some positive constant + // representable as any 64-bit integer type, so there's no + // wraparound. + if (alignment.modulus > 0) { + // This masks off all bits except for the lowest set one, + // giving the largest power-of-two factor of a number. + alignment.modulus &= -alignment.modulus; + alignment.remainder = mod_imp(alignment.remainder, alignment.modulus); + } + } else { + // A narrowing integer cast adds some unknown multiple of 2^bits + // TODO: Add += for ModulusRemainder + alignment = alignment + ModulusRemainder(((int64_t)1 << t.bits()), 0); + } + } + // Mix in existing knowledge about this Expr void intersect(const ExprInfo &other) { bounds = ConstantInterval::make_intersection(bounds, other.bounds); diff --git a/src/Simplify_Mul.cpp b/src/Simplify_Mul.cpp index e333de5c566b..446f420c6c91 100644 --- a/src/Simplify_Mul.cpp +++ b/src/Simplify_Mul.cpp @@ -8,10 +8,11 @@ Expr Simplify::visit(const Mul *op, ExprInfo *info) { Expr a = mutate(op->a, &a_info); Expr b = mutate(op->b, &b_info); - if (info && no_overflow_int(op->type)) { + if (info) { info->bounds = a_info.bounds * b_info.bounds; info->alignment = a_info.alignment * b_info.alignment; info->trim_bounds_using_alignment(); + info->cast_to(op->type); } if (may_simplify(op->type)) { diff --git a/src/Simplify_Sub.cpp b/src/Simplify_Sub.cpp index d566f42cad9c..79f0b6ba4344 100644 --- a/src/Simplify_Sub.cpp +++ b/src/Simplify_Sub.cpp @@ -8,13 +8,14 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { Expr a = mutate(op->a, &a_info); Expr b = mutate(op->b, &b_info); - if (info && no_overflow_int(op->type)) { + if (info) { // Doesn't account for correlated a, b, so any // cancellation rule that exploits that should always // remutate to recalculate the bounds. info->bounds = a_info.bounds - b_info.bounds; info->alignment = a_info.alignment - b_info.alignment; info->trim_bounds_using_alignment(); + info->cast_to(op->type); } if (may_simplify(op->type)) { diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 6f497531da94..c94b5b552425 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1316,6 +1316,13 @@ void check_bounds() { check(max(x * 4 + 63, y) - max(y - 3, x * 4), clamp(x * 4 - y, -63, -3) + 66); check(max(x * 4, y - 3) - max(x * 4 + 63, y), clamp(y - x * 4, 3, 63) + -66); check(max(y - 3, x * 4) - max(x * 4 + 63, y), clamp(y - x * 4, 3, 63) + -66); + + // Check we can track bounds correctly through various operations + check(ramp(cast(x) / 2 + 3, cast(1), 16) < broadcast(200, 16), const_true(16)); + check(cast(cast(x)) * 3 >= cast(0), const_true()); + check(cast(cast(x)) * 3 < cast(768), const_true()); + check(cast(abs(cast(x))) >= cast(0), const_true()); + check(cast(abs(cast(x))) - cast(128) <= cast(0), const_true()); } void check_boolean() { @@ -2214,7 +2221,7 @@ int main(int argc, char **argv) { // This expression used to cause infinite recursion. check(Broadcast::make(-16, 2) < (ramp(Cast::make(UInt(16), 7), Cast::make(UInt(16), 11), 2) - Broadcast::make(1, 2)), - Broadcast::make(-15, 2) < (ramp(make_const(UInt(16), 7), make_const(UInt(16), 11), 2))); + const_true(2)); { // Verify that integer types passed to min() and max() are coerced to match From 6434210b32a1756c2286891c06aa59b4d51bacf0 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 28 Mar 2024 11:08:45 -0700 Subject: [PATCH 08/33] Fix * operator. Add min/max/mod --- src/ConstantBounds.cpp | 2 ++ src/ConstantInterval.cpp | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp index 071fc63e5365..d67297d1b940 100644 --- a/src/ConstantBounds.cpp +++ b/src/ConstantBounds.cpp @@ -35,6 +35,8 @@ ConstantInterval constant_integer_bounds(const Expr &e, const Scope()) { // Can overflow when dividing type.min() by -1 return cast(op->type, constant_integer_bounds(op->a) / constant_integer_bounds(op->b)); + } else if (const Mod *op = e.as()) { + return cast(op->type, constant_integer_bounds(op->a) % constant_integer_bounds(op->b)); } else if (const Min *op = e.as()) { return min(constant_integer_bounds(op->a), constant_integer_bounds(op->b)); } else if (const Max *op = e.as()) { diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp index 7be859e63265..6f4459479f62 100644 --- a/src/ConstantInterval.cpp +++ b/src/ConstantInterval.cpp @@ -127,8 +127,6 @@ ConstantInterval ConstantInterval::make_intersection(const ConstantInterval &a, return result; } -// TODO: These were taken directly from the simplifier, so change the simplifier -// to use these instead of duplicating the code. void ConstantInterval::operator+=(const ConstantInterval &other) { min_defined = min_defined && other.min_defined && @@ -191,6 +189,7 @@ void ConstantInterval::operator*=(const ConstantInterval &other) { (other.max_defined && bounded && positive)) { // One side has a max, and the other side is bounded and positive // (e.g. a constant). + result.max_defined = true; result.max = saturating_mul(max, other.max); if (!result.max_defined) { result.max = 0; @@ -199,6 +198,7 @@ void ConstantInterval::operator*=(const ConstantInterval &other) { (other.min_defined && bounded && positive)) { // One side has a min, and the other side is bounded and positive // (e.g. a constant). + result.min_defined = true; min = saturating_mul(min, other.min); if (!result.min_defined) { result.min = 0; @@ -495,6 +495,20 @@ ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { return result; } +ConstantInterval min(const ConstantInterval &a, int64_t b) { + ConstantInterval result = a; + if (result.max_defined) { + result.max = std::min(a.max, b); + } else { + result.max = b; + result.max_defined = true; + } + if (result.min_defined) { + result.min = std::min(a.min, b); + } + return result; +} + ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { ConstantInterval result; result.min_defined = a.min_defined || b.min_defined; @@ -512,6 +526,20 @@ ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { return result; } +ConstantInterval max(const ConstantInterval &a, int64_t b) { + ConstantInterval result = a; + if (result.min_defined) { + result.min = std::max(a.min, b); + } else { + result.min = b; + result.min_defined = true; + } + if (result.max_defined) { + result.max = std::max(a.max, b); + } + return result; +} + ConstantInterval abs(const ConstantInterval &a) { ConstantInterval result; if (a.min_defined && a.max_defined && a.min != INT64_MIN) { From f308a8ca3971c28865174174e5052d88535390da Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 28 Mar 2024 11:09:48 -0700 Subject: [PATCH 09/33] Add cache for constant bounds queries --- src/ConstantBounds.cpp | 19 +++- src/FindIntrinsics.cpp | 243 ++++++++++++++++++----------------------- src/IROperator.cpp | 42 ++++--- src/IROperator.h | 15 ++- src/Monotonic.cpp | 1 - src/Simplify_Exprs.cpp | 7 -- src/Simplify_LT.cpp | 3 - src/Simplify_Max.cpp | 4 - src/Simplify_Stmts.cpp | 11 +- 9 files changed, 168 insertions(+), 177 deletions(-) diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp index d67297d1b940..a6d66e623e7e 100644 --- a/src/ConstantBounds.cpp +++ b/src/ConstantBounds.cpp @@ -6,10 +6,12 @@ namespace Halide { namespace Internal { -ConstantInterval constant_integer_bounds(const Expr &e, const Scope &scope) { +ConstantInterval constant_integer_bounds(const Expr &e, + const Scope &scope, + std::map *cache) { internal_assert(e.defined()); - auto ret = [&]() { + auto get_bounds = [&]() { // Compute the bounds of each IR node from the bounds of its args. Math // on ConstantInterval is in terms of infinite integers, so any op that // can overflow needs to cast the resulting interval back to the output @@ -145,9 +147,18 @@ ConstantInterval constant_integer_bounds(const Expr &e, const Scopetry_emplace(e); + if (cache_miss) { + it->second = get_bounds(); + } + ret = it->second; + } else { + ret = get_bounds(); + } if (true) { internal_assert((!ret.has_lower_bound() || e.type().can_represent(ret.min)) && diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 0008746d81aa..4142831cd2be 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -46,23 +46,6 @@ bool can_narrow(const Type &t) { t.bits() >= 8; } -Expr lossless_narrow(const Expr &x) { - return can_narrow(x.type()) ? lossless_cast(x.type().narrow(), x) : Expr(); -} - -// Remove a widening cast even if it changes the sign of the result. -Expr strip_widening_cast(const Expr &x) { - if (can_narrow(x.type())) { - Expr narrow = lossless_narrow(x); - if (narrow.defined()) { - return narrow; - } - return lossless_cast(x.type().narrow().with_code(halide_type_uint), x); - } else { - return Expr(); - } -} - Expr saturating_narrow(const Expr &a) { Type narrow = a.type().narrow(); return saturating_cast(narrow, a); @@ -78,36 +61,6 @@ bool no_overflow(Type t) { return t.is_float() || no_overflow_int(t); } -// TODO: Can I delete this now and just rely on lossless cast? - -// If there's a widening add or subtract in the first e.type().bits() / 2 - 1 -// levels down a tree of adds or subtracts, we know there's enough headroom for -// another add without overflow. For example, it is safe to add to -// (widening_add(x, y) - z) without overflow. -bool is_safe_for_add(const Expr &e, int max_depth) { - if (max_depth-- <= 0) { - return false; - } - if (const Add *add = e.as()) { - return is_safe_for_add(add->a, max_depth) || is_safe_for_add(add->b, max_depth); - } else if (const Sub *sub = e.as()) { - return is_safe_for_add(sub->a, max_depth) || is_safe_for_add(sub->b, max_depth); - } else if (const Cast *cast = e.as()) { - if (cast->type.bits() > cast->value.type().bits()) { - return true; - } else if (cast->type.bits() == cast->value.type().bits()) { - return is_safe_for_add(cast->value, max_depth); - } - } else if (Call::as_intrinsic(e, {Call::widening_add, Call::widening_sub, Call::widen_right_add, Call::widen_right_sub})) { - return true; - } - return false; -} - -bool is_safe_for_add(const Expr &e) { - return is_safe_for_add(e, e.type().bits() / 2 - 1); -} - // We want to find and remove an add of 'round' from e. This is not // the same thing as just subtracting round, we specifically want // to remove an addition of exactly round. @@ -133,103 +86,130 @@ Expr find_and_subtract(const Expr &e, const Expr &round) { return Expr(); } -Expr to_rounding_shift(const Call *c) { - if (c->is_intrinsic(Call::shift_left) || c->is_intrinsic(Call::shift_right)) { - internal_assert(c->args.size() == 2); - Expr a = c->args[0]; - Expr b = c->args[1]; +class FindIntrinsics : public IRMutator { +protected: + using IRMutator::visit; - // Helper to make the appropriate shift. - auto rounding_shift = [&](const Expr &a, const Expr &b) { - if (c->is_intrinsic(Call::shift_right)) { - return rounding_shift_right(a, b); - } else { - return rounding_shift_left(a, b); - } - }; + IRMatcher::Wild<0> x; + IRMatcher::Wild<1> y; + IRMatcher::Wild<2> z; + IRMatcher::Wild<3> w; + IRMatcher::WildConst<0> c0; + IRMatcher::WildConst<1> c1; - // The rounding offset for the shift we have. - Type round_type = a.type().with_lanes(1); - if (Call::as_intrinsic(a, {Call::widening_add})) { - round_type = round_type.narrow(); - } - Expr round; - if (c->is_intrinsic(Call::shift_right)) { - round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + std::map bounds_cache; + Scope let_var_bounds; + + Expr lossless_cast(Type t, const Expr &e) { + return Halide::Internal::lossless_cast(t, e, &bounds_cache); + } + + ConstantInterval constant_integer_bounds(const Expr &e) { + // TODO: Use the scope - add let visitors + return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache); + } + + Expr lossless_narrow(const Expr &x) { + return can_narrow(x.type()) ? lossless_cast(x.type().narrow(), x) : Expr(); + } + + // Remove a widening cast even if it changes the sign of the result. + Expr strip_widening_cast(const Expr &x) { + if (can_narrow(x.type())) { + Expr narrow = lossless_narrow(x); + if (narrow.defined()) { + return narrow; + } + return lossless_cast(x.type().narrow().with_code(halide_type_uint), x); } else { - round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + return Expr(); } - // Input expressions are simplified before running find_intrinsics, but b - // has been lifted here so we need to lower_intrinsics before simplifying - // and re-lifting. Should we move this code into the FindIntrinsics class - // to make it easier to lift round? - round = lower_intrinsics(round); - round = simplify(round); - round = find_intrinsics(round); - - // We can always handle widening adds. - if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) { - if (can_prove(lower_intrinsics(add->args[0] == round))) { - return rounding_shift(cast(add->type, add->args[1]), b); - } else if (can_prove(lower_intrinsics(add->args[1] == round))) { - return rounding_shift(cast(add->type, add->args[0]), b); + } + + Expr to_rounding_shift(const Call *c) { + if (c->is_intrinsic(Call::shift_left) || c->is_intrinsic(Call::shift_right)) { + internal_assert(c->args.size() == 2); + Expr a = c->args[0]; + Expr b = c->args[1]; + + // Helper to make the appropriate shift. + auto rounding_shift = [&](const Expr &a, const Expr &b) { + if (c->is_intrinsic(Call::shift_right)) { + return rounding_shift_right(a, b); + } else { + return rounding_shift_left(a, b); + } + }; + + // The rounding offset for the shift we have. + Type round_type = a.type().with_lanes(1); + if (Call::as_intrinsic(a, {Call::widening_add})) { + round_type = round_type.narrow(); + } + Expr round; + if (c->is_intrinsic(Call::shift_right)) { + round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + } else { + round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + } + // Input expressions are simplified before running find_intrinsics, but b + // has been lifted here so we need to lower_intrinsics before simplifying + // and re-lifting. Should we move this code into the FindIntrinsics class + // to make it easier to lift round? + round = lower_intrinsics(round); + round = simplify(round); + round = find_intrinsics(round); + + // We can always handle widening adds. + if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) { + if (can_prove(lower_intrinsics(add->args[0] == round))) { + return rounding_shift(cast(add->type, add->args[1]), b); + } else if (can_prove(lower_intrinsics(add->args[1] == round))) { + return rounding_shift(cast(add->type, add->args[0]), b); + } } - } - if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) { - if (can_prove(lower_intrinsics(add->args[1] == round))) { - return rounding_shift(cast(add->type, add->args[0]), b); + if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) { + if (can_prove(lower_intrinsics(add->args[1] == round))) { + return rounding_shift(cast(add->type, add->args[0]), b); + } } - } - // Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add - // TODO: this pattern makes me want to change the semantics of this op. - if (const Cast *cast = a.as()) { - if (cast->is_reinterpret()) { - if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) { - if (can_prove(lower_intrinsics(add->args[1] == round))) { - // We expect the first operand to be a reinterpet cast. - if (const Cast *cast_a = add->args[0].as()) { - if (cast_a->is_reinterpret()) { - return rounding_shift(cast_a->value, b); + // Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add + // TODO: this pattern makes me want to change the semantics of this op. + if (const Cast *cast = a.as()) { + if (cast->is_reinterpret()) { + if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) { + if (can_prove(lower_intrinsics(add->args[1] == round))) { + // We expect the first operand to be a reinterpet cast. + if (const Cast *cast_a = add->args[0].as()) { + if (cast_a->is_reinterpret()) { + return rounding_shift(cast_a->value, b); + } } } } } } - } - // If it wasn't a widening or saturating add, we might still - // be able to safely accept the rounding. - Expr a_less_round = find_and_subtract(a, round); - if (a_less_round.defined()) { - // We found and removed the rounding. However, we may have just changed - // behavior due to overflow. This is still safe if the type is not - // overflowing, or we can find a widening add or subtract in the tree - // of adds/subtracts. This is a common pattern, e.g. - // rounding_halving_add(a, b) = shift_round(widening_add(a, b) + 1, 1). - // TODO: This could be done with bounds inference instead of this hack - // if it supported intrinsics like widening_add and tracked bounds for - // types other than int32. - if (no_overflow(a.type()) || is_safe_for_add(a_less_round)) { - return rounding_shift(simplify(a_less_round), b); + // If it wasn't a widening or saturating add, we might still + // be able to safely accept the rounding. + Expr a_less_round = find_and_subtract(a, round); + if (a_less_round.defined()) { + // We found and removed the rounding. Verify it didn't change + // overflow behavior. + if (no_overflow(a.type()) || + a.type().can_represent(constant_integer_bounds(a_less_round) + + constant_integer_bounds(round))) { + // If we can add the rounding term back on without causing + // overflow, then it must not have overflowed originally. + return rounding_shift(simplify(a_less_round), b); + } } } - } - return Expr(); -} - -class FindIntrinsics : public IRMutator { -protected: - using IRMutator::visit; - - IRMatcher::Wild<0> x; - IRMatcher::Wild<1> y; - IRMatcher::Wild<2> z; - IRMatcher::Wild<3> w; - IRMatcher::WildConst<0> c0; - IRMatcher::WildConst<1> c1; + return Expr(); + } Expr visit(const Add *op) override { if (!find_intrinsics_for_type(op->type)) { @@ -553,8 +533,6 @@ class FindIntrinsics : public IRMutator { // Do we need to worry about this cast overflowing? ConstantInterval value_bounds = constant_integer_bounds(value); - debug(0) << "Bounds of " << Expr(op) << " are " << value_bounds.min << " " << value_bounds.min_defined << " " << value_bounds.max << " " << value_bounds.max_defined << "\n"; - bool no_overflow = (op->type.can_represent(op->value.type()) || op->type.can_represent(value_bounds)); @@ -1490,14 +1468,12 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) // if it isn't already full precision. This avoids infinite loops despite // "lowering" this to another mul_shift_right operation. ConstantInterval cq = constant_integer_bounds(q); - debug(0) << " cq = " << cq.min << " " << cq.min_defined << " " << cq.max << " " << cq.max_defined << "\n"; if (cq.is_single_point() && cq.max >= 0 && cq.max < full_q) { int missing_q = full_q - (int)cq.max; // Try to scale up the args by factors of two without overflowing int a_shift = 0, b_shift = 0; ConstantInterval ca = constant_integer_bounds(a); - debug(0) << " ca = " << ca.min << " " << ca.min_defined << " " << ca.max << " " << ca.max_defined << "\n"; do { ConstantInterval bigger = ca * ConstantInterval::single_point(2); if (a.type().can_represent(bigger) && a_shift + b_shift < missing_q) { @@ -1507,7 +1483,6 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) } } while (false); ConstantInterval cb = constant_integer_bounds(b); - debug(0) << " cb = " << cb.min << " " << cb.min_defined << " " << cb.max << " " << cb.max_defined << "\n"; do { ConstantInterval bigger = cb * ConstantInterval::single_point(2); if (b.type().can_represent(bigger) && b_shift + b_shift < missing_q) { @@ -1516,8 +1491,6 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) continue; } } while (false); - - debug(0) << "a_shift = " << a_shift << " b_shift = " << b_shift << " full_q = " << full_q << "\n"; if (a_shift + b_shift == missing_q) { return rounding_mul_shift_right(simplify(a << a_shift), simplify(b << b_shift), full_q); } @@ -1527,10 +1500,8 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) Expr wide_result = rounding_shift_right(widening_mul(a, b), q); Expr narrowed = lossless_cast(a.type(), wide_result); if (narrowed.defined()) { - debug(0) << " losslessly narrowed to " << narrowed << "\n"; return narrowed; } else { - debug(0) << " returning saturating_narrow(" << wide_result << ")\n"; return saturating_narrow(wide_result); } } diff --git a/src/IROperator.cpp b/src/IROperator.cpp index b857eb4947b8..2b0359fb3280 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -436,19 +436,19 @@ Expr const_false(int w) { return make_zero(UInt(1, w)); } -Expr lossless_cast(Type t, Expr e) { +Expr lossless_cast(Type t, Expr e, std::map *cache) { if (!e.defined() || t == e.type()) { return e; } else if (t.can_represent(e.type())) { return cast(t, std::move(e)); } else if (const Cast *c = e.as()) { if (c->type.can_represent(c->value.type())) { - return lossless_cast(t, c->value); + return lossless_cast(t, c->value, cache); } else { return Expr(); } } else if (const Broadcast *b = e.as()) { - Expr v = lossless_cast(t.element_of(), b->value); + Expr v = lossless_cast(t.element_of(), b->value, cache); if (v.defined()) { return Broadcast::make(v, b->lanes); } else { @@ -475,41 +475,51 @@ Expr lossless_cast(Type t, Expr e) { } else if (const Shuffle *shuf = e.as()) { std::vector vecs; for (const auto &vec : shuf->vectors) { - vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec)); + vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache)); if (!vecs.back().defined()) { return Expr(); } } return Shuffle::make(vecs, shuf->indices); } else if (t.is_int_or_uint()) { - // We'll just throw a cast around something, if the bounds are small - // enough. - ConstantInterval ci = constant_integer_bounds(e); + // Check the bounds. If they're small enough, we can throw narrowing + // casts around e, or subterms. + ConstantInterval ci; + if (cache) { + auto [it, cache_miss] = cache->try_emplace(e); + if (cache_miss) { + it->second = constant_integer_bounds(e, Scope::empty_scope(), cache); + } + ci = it->second; + } else { + ci = constant_integer_bounds(e); + } + if (t.can_represent(ci)) { // There are certain IR nodes where if the result is expressible // using some type, and the args are expressible using that type, // then the operation can just be done in that type. if (const Add *op = e.as()) { - Expr a = lossless_cast(t, op->a); - Expr b = lossless_cast(t, op->b); + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); if (a.defined() && b.defined()) { return a + b; } } else if (const Sub *op = e.as()) { - Expr a = lossless_cast(t, op->a); - Expr b = lossless_cast(t, op->b); + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); if (a.defined() && b.defined()) { return a - b; } } else if (const Mul *op = e.as()) { - Expr a = lossless_cast(t, op->a); - Expr b = lossless_cast(t, op->b); + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); if (a.defined() && b.defined()) { return a * b; } } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add})) { - Expr a = lossless_cast(t, op->args[0]); - Expr b = lossless_cast(t, op->args[1]); + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); if (a.defined() && b.defined()) { return a + b; } @@ -517,7 +527,7 @@ Expr lossless_cast(Type t, Expr e) { if (op->op == VectorReduce::Add || op->op == VectorReduce::Min || op->op == VectorReduce::Max) { - Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value); + Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache); if (v.defined()) { return VectorReduce::make(op->op, v, op->type.lanes()); } diff --git a/src/IROperator.h b/src/IROperator.h index a96ef6223c0d..c84a4682152f 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -8,6 +8,7 @@ */ #include +#include #include "Expr.h" #include "Tuple.h" @@ -140,10 +141,16 @@ Expr const_true(int lanes = 1); * falses, if a lanes argument is given. */ Expr const_false(int lanes = 1); -/** Attempt to cast an expression to a smaller type while provably not - * losing information. If it can't be done, return an undefined - * Expr. */ -Expr lossless_cast(Type t, Expr e); +/** Attempt to cast an expression to a smaller type while provably not losing + * information. If it can't be done, return an undefined Expr. + * + * Optionally accepts a map that gives the constant bounds of exprs already + * analyzed to avoid redoing work across many calls to lossless_cast. It is not + * safe to use this optional map in contexts where the same Expr object may + * take on a different value. For example: + * (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)). + * It is safe to use it after uniquify_variable_names has been run. */ +Expr lossless_cast(Type t, Expr e, std::map *cache = nullptr); /** Attempt to negate x without introducing new IR and without overflow. * If it can't be done, return an undefined Expr. */ diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index e09358075dae..b69746ba2826 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -528,7 +528,6 @@ ConstantInterval derivative_bounds(const Expr &e, const std::string &var, const } DerivativeBounds m(var, scope); remove_likelies(remove_promises(e)).accept(&m); - debug(0) << "Derivative bounds of " << e << " w.r.t. " << var << ": " << m.result << "\n"; return m.result; } diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index b7800dd6faa5..70a63d9d1f4c 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -215,19 +215,12 @@ Expr Simplify::visit(const Variable *op, ExprInfo *info) { *info = *b; } if (b->bounds.is_single_point()) { - if (info) { - debug(0) << "Var is single point: " << op->name << ": " << info->bounds << "\n"; - } return make_const(op->type, b->bounds.min); } } else if (info && !no_overflow_int(op->type)) { info->bounds = ConstantInterval::bounds_of_type(op->type); } - if (info) { - debug(0) << "Bounds of var: " << op->name << ": " << info->bounds << "\n"; - } - if (auto *v_info = var_info.shallow_find(op->name)) { // if replacement is defined, we should substitute it in (unless // it's a var that has been hidden by a nested scope). diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 1069602e24a1..c9ac45c349d7 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -20,12 +20,9 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { if (may_simplify(ty)) { // Prove or disprove using bounds analysis - debug(0) << "ELEPHANT: " << Expr(op) << ": " << a_info.bounds << ", " << b_info.bounds << "\n"; if (a_info.bounds < b_info.bounds) { - debug(0) << "... true\n"; return const_true(lanes); } else if (a_info.bounds >= b_info.bounds) { - debug(0) << "... false\n"; return const_false(lanes); } diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index 59bb3b6313e5..6f3ecc1999f7 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -24,10 +24,6 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { return e; }; - if (info) { - debug(0) << "Bounds of max: " << Expr(op) << ": " << a_info.bounds << ", " << b_info.bounds << ", " << info->bounds << "\n"; - } - // Early out when the bounds tells us one side or the other is smaller if (a_info.bounds <= b_info.bounds) { return strip_likely(b); diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index fac58dcb847f..57cfe74cf7e3 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -219,8 +219,15 @@ Stmt Simplify::visit(const For *op) { bool bounds_tracked = false; ExprInfo loop_var_info; - loop_var_info.bounds = ConstantInterval::make_union(min_info.bounds, - min_info.bounds + extent_info.bounds - 1); + // Deduce bounds for the loop var that are true for any code than runs + // inside the loop body. Code in the inner loop only runs if the extent is + // at least one, so we can throw a max around the extent bounds. + + loop_var_info.bounds = + ConstantInterval::make_union(min_info.bounds, + min_info.bounds + max(extent_info.bounds, 1) - 1); + + if (loop_var_info.bounds.has_upper_bound() || loop_var_info.bounds.has_lower_bound()) { bounds_tracked = true; From cffadd8c2fb26843828614e0731ec2922af56813 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 11:08:29 -0700 Subject: [PATCH 10/33] Fix ConstantInterval multiplication --- src/ConstantInterval.cpp | 135 ++++++++++++++++++++------------------- 1 file changed, 70 insertions(+), 65 deletions(-) diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp index 6f4459479f62..7d049aa9cf05 100644 --- a/src/ConstantInterval.cpp +++ b/src/ConstantInterval.cpp @@ -146,68 +146,7 @@ void ConstantInterval::operator-=(const ConstantInterval &other) { } void ConstantInterval::operator*=(const ConstantInterval &other) { - ConstantInterval result; - - // Compute a possible extreme value of the product, setting the min/max - // defined flags if it's unbounded. - auto saturating_mul = [&](int64_t a, int64_t b) -> int64_t { - int64_t c; - if (mul_with_overflow(64, a, b, &c)) { - return c; - } else if ((a > 0) == (b > 0)) { - result.max_defined = false; - return INT64_MAX; - } else { - result.min_defined = false; - return INT64_MIN; - } - }; - - bool positive = min_defined && min > 0; - bool other_positive = other.min_defined && other.min > 0; - bool bounded = min_defined && max_defined; - bool other_bounded = other.min_defined && other.max_defined; - - if (bounded && other_bounded) { - // Both are bounded - result.min_defined = result.max_defined = true; - int64_t v1 = saturating_mul(min, other.min); - int64_t v2 = saturating_mul(min, other.max); - int64_t v3 = saturating_mul(max, other.min); - int64_t v4 = saturating_mul(max, other.max); - if (result.min_defined) { - result.min = std::min(std::min(v1, v2), std::min(v3, v4)); - } else { - result.min = 0; - } - if (result.max_defined) { - result.max = std::max(std::max(v1, v2), std::max(v3, v4)); - } else { - result.max = 0; - } - } else if ((max_defined && other_bounded && other_positive) || - (other.max_defined && bounded && positive)) { - // One side has a max, and the other side is bounded and positive - // (e.g. a constant). - result.max_defined = true; - result.max = saturating_mul(max, other.max); - if (!result.max_defined) { - result.max = 0; - } - } else if ((min_defined && other_bounded && other_positive) || - (other.min_defined && bounded && positive)) { - // One side has a min, and the other side is bounded and positive - // (e.g. a constant). - result.min_defined = true; - min = saturating_mul(min, other.min); - if (!result.min_defined) { - result.min = 0; - } - } - // TODO: what about the above two cases, but for multiplication by bounded - // and negative intervals? - - *this = result; + (*this) = (*this) * other; } void ConstantInterval::operator/=(const ConstantInterval &other) { @@ -417,8 +356,74 @@ ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b) } ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result *= b; + ConstantInterval result; + + // Compute a possible extreme value of the product, either incorporating it + // into result.min / result.max, or setting the min/max defined flags if it + // overflows. + auto consider_case = [&](int64_t a, int64_t b) { + int64_t c; + if (mul_with_overflow(64, a, b, &c)) { + result.min = std::min(result.min, c); + result.max = std::max(result.max, c); + } else if ((a > 0) == (b > 0)) { + result.max_defined = false; + } else { + result.min_defined = false; + } + }; + + result.min_defined = result.max_defined = true; + result.min = INT64_MAX; + result.max = INT64_MIN; + if (a.min_defined && b.min_defined) { + consider_case(a.min, b.min); + } + if (a.min_defined && b.max_defined) { + consider_case(a.min, b.max); + } + if (a.max_defined && b.min_defined) { + consider_case(a.max, b.min); + } + if (a.max_defined && b.max_defined) { + consider_case(a.max, b.max); + } + + bool a_bounded_negative = a.min_defined && a <= 0; + bool a_bounded_positive = a.max_defined && a >= 0; + bool b_bounded_negative = b.min_defined && b <= 0; + bool b_bounded_positive = b.max_defined && b >= 0; + + if (result.min_defined) { + result.min_defined = + ((a.is_bounded() && b.is_bounded()) || + (a >= 0 && b >= 0) || + (a <= 0 && b <= 0) || + (a.min_defined && b_bounded_positive) || + (b.min_defined && a_bounded_positive) || + (a.max_defined && b_bounded_negative) || + (b.max_defined && a_bounded_negative)); + } + + if (result.max_defined) { + result.max_defined = + ((a.is_bounded() && b.is_bounded()) || + (a >= 0 && b <= 0) || + (a <= 0 && b >= 0) || + (a.max_defined && b_bounded_positive) || + (b.max_defined && a_bounded_positive) || + (a.min_defined && b_bounded_negative) || + (b.min_defined && a_bounded_negative)); + } + + if (!result.min_defined) { + result.min = 0; + } + + if (!result.max_defined) { + result.max = 0; + } + return result; } @@ -475,7 +480,7 @@ ConstantInterval operator*(const ConstantInterval &a, int64_t b) { } ConstantInterval operator%(const ConstantInterval &a, int64_t b) { - return a * ConstantInterval(b, b); + return a % ConstantInterval(b, b); } ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { From 2f14881b90dcf703781e51a095435ef0a0139e4c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 11:09:06 -0700 Subject: [PATCH 11/33] Add a simplifier rule which is apparently now necessary --- src/Simplify_Cast.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index f64c55f68640..089e46bdc62b 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -107,6 +107,17 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { // outer cast is narrower, the inner cast can be // eliminated. return mutate(Cast::make(op->type, cast->value), info); + } else if (cast && + op->type.is_int_or_uint() && + cast->type.is_int() && + cast->value.type().is_int() && + op->type.bits() >= cast->type.bits() && + cast->type.bits() >= cast->value.type().bits()) { + // Casting from a signed type always sign-extends, so widening + // partway to a signed type and the rest of the way to some other + // integer type is the same as just widening to that integer type + // directly. + return mutate(Cast::make(op->type, cast->value), info); } else if (cast && (op->type.is_int() || op->type.is_uint()) && (cast->type.is_int() || cast->type.is_uint()) && From 26efb7c89992dfd79823fc4fb00e0d6d4999f17a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 12:28:57 -0700 Subject: [PATCH 12/33] Misc cleanups and test improvements --- src/ConstantBounds.cpp | 16 +- src/ConstantInterval.cpp | 267 +++++++++++++++-------------- test/correctness/lossless_cast.cpp | 68 ++++++-- 3 files changed, 201 insertions(+), 150 deletions(-) diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp index a6d66e623e7e..25aca7c6b6d5 100644 --- a/src/ConstantBounds.cpp +++ b/src/ConstantBounds.cpp @@ -84,16 +84,16 @@ ConstantInterval constant_integer_bounds(const Expr &e, } else if (op->is_intrinsic(Call::halving_add)) { return (constant_integer_bounds(op->args[0]) + constant_integer_bounds(op->args[1])) / - ConstantInterval(2, 2); + 2; } else if (op->is_intrinsic(Call::halving_sub)) { return cast(op->type, (constant_integer_bounds(op->args[0]) - constant_integer_bounds(op->args[1])) / - ConstantInterval(2, 2)); + 2); } else if (op->is_intrinsic(Call::rounding_halving_add)) { return (constant_integer_bounds(op->args[0]) + constant_integer_bounds(op->args[1]) + - ConstantInterval(1, 1)) / - ConstantInterval(2, 2); + 1) / + 2; } else if (op->is_intrinsic(Call::saturating_add)) { return saturating_cast(op->type, (constant_integer_bounds(op->args[0]) + @@ -138,12 +138,14 @@ ConstantInterval constant_integer_bounds(const Expr &e, } else if (cb.max_defined && cb.max <= 0) { return cast(op->type, ca << (-cb)); } else { - auto rounding_term = ConstantInterval(0, 1) << max(cb - ConstantInterval(1, 1), ConstantInterval(0, 0)); + auto rounding_term = ConstantInterval(0, 1) << max(cb - 1, 0); return cast(op->type, (ca + rounding_term) >> cb); } } - // TODO: more intrinsics - // TODO: widening_shift_left is important + // If you add a new intrinsic here, also add it to the expression + // generator in test/correctness/lossless_cast.cpp + + // TODO: mul_shift_right, rounding_mul_shift_right, widening_shift_left/right, rounding_shift_left } return ConstantInterval::bounds_of_type(e.type()); diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp index 7d049aa9cf05..a29f42f5ee9c 100644 --- a/src/ConstantInterval.cpp +++ b/src/ConstantInterval.cpp @@ -2,6 +2,7 @@ #include "Error.h" #include "IROperator.h" +#include "IRPrinter.h" namespace Halide { namespace Internal { @@ -24,12 +25,14 @@ ConstantInterval ConstantInterval::single_point(int64_t x) { ConstantInterval ConstantInterval::bounded_below(int64_t min) { ConstantInterval result(min, min); result.max_defined = false; + result.max = 0; return result; } ConstantInterval ConstantInterval::bounded_above(int64_t max) { ConstantInterval result(max, max); result.min_defined = false; + result.min = 0; return result; } @@ -54,7 +57,7 @@ bool ConstantInterval::has_lower_bound() const { } bool ConstantInterval::is_bounded() const { - return has_upper_bound() && has_lower_bound(); + return max_defined && min_defined; } bool ConstantInterval::operator==(const ConstantInterval &other) const { @@ -128,21 +131,11 @@ ConstantInterval ConstantInterval::make_intersection(const ConstantInterval &a, } void ConstantInterval::operator+=(const ConstantInterval &other) { - min_defined = min_defined && - other.min_defined && - add_with_overflow(64, min, other.min, &min); - max_defined = max_defined && - other.max_defined && - add_with_overflow(64, max, other.max, &max); + (*this) = (*this) + other; } void ConstantInterval::operator-=(const ConstantInterval &other) { - min_defined = min_defined && - other.max_defined && - sub_with_overflow(64, min, other.max, &min); - max_defined = max_defined && - other.min_defined && - sub_with_overflow(64, max, other.min, &max); + (*this) = (*this) - other; } void ConstantInterval::operator*=(const ConstantInterval &other) { @@ -150,128 +143,31 @@ void ConstantInterval::operator*=(const ConstantInterval &other) { } void ConstantInterval::operator/=(const ConstantInterval &other) { - ConstantInterval result; - - result.min = INT64_MAX; - result.max = INT64_MIN; - - // Enumerate all possible values for the min and max and take the extreme values. - if (min_defined && other.min_defined && other.min != 0) { - int64_t v = div_imp(min, other.min); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - if (min_defined && other.max_defined && other.max != 0) { - int64_t v = div_imp(min, other.max); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - if (max_defined && other.max_defined && other.max != 0) { - int64_t v = div_imp(max, other.max); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - if (max_defined && other.min_defined && other.min != 0) { - int64_t v = div_imp(max, other.min); - result.min = std::min(result.min, v); - result.max = std::max(result.max, v); - } - - // Define an int64_t zero just to pacify std::min and std::max - constexpr int64_t zero = 0; - - const bool other_positive = other.min_defined && other.min > 0; - const bool other_negative = other.max_defined && other.max < 0; - if ((other_positive && !other.max_defined) || - (other_negative && !other.min_defined)) { - // Take limit as other -> +/- infinity - result.min = std::min(result.min, zero); - result.max = std::max(result.max, zero); - } - - bool bounded_numerator = min_defined && max_defined; - - result.min_defined = ((min_defined && other_positive) || - (max_defined && other_negative)); - result.max_defined = ((max_defined && other_positive) || - (min_defined && other_negative)); - - // That's as far as we can get knowing the sign of the - // denominator. For bounded numerators, we additionally know - // that div can't make anything larger in magnitude, so we can - // take the intersection with that. - if (bounded_numerator && min != INT64_MIN) { - int64_t magnitude = std::max(max, -min); - if (result.min_defined) { - result.min = std::max(result.min, -magnitude); - } else { - result.min = -magnitude; - } - if (result.max_defined) { - result.max = std::min(result.max, magnitude); - } else { - result.max = magnitude; - } - result.min_defined = result.max_defined = true; - } - - // Finally we can provide a bound if the numerator and denominator are - // non-positive or non-negative. - bool numerator_non_negative = min_defined && min >= 0; - bool denominator_non_negative = other.min_defined && other.min >= 0; - bool numerator_non_positive = max_defined && max <= 0; - bool denominator_non_positive = other.max_defined && other.max <= 0; - if ((numerator_non_negative && denominator_non_negative) || - (numerator_non_positive && denominator_non_positive)) { - if (result.min_defined) { - result.min = std::max(result.min, zero); - } else { - result.min_defined = true; - result.min = 0; - } - } - if ((numerator_non_negative && denominator_non_positive) || - (numerator_non_positive && denominator_non_negative)) { - if (result.max_defined) { - result.max = std::min(result.max, zero); - } else { - result.max_defined = true; - result.max = 0; - } - } - - // Normalize the values if it's undefined - if (!result.min_defined) { - result.min = 0; - } - if (!result.max_defined) { - result.max = 0; - } + (*this) = (*this) / other; +} - *this = result; +void ConstantInterval::operator%=(const ConstantInterval &other) { + (*this) = (*this) % other; } void ConstantInterval::operator+=(int64_t x) { - // TODO: Optimize this - *this += ConstantInterval(x, x); + (*this) = (*this) + x; } void ConstantInterval::operator-=(int64_t x) { - // TODO: Optimize this - *this -= ConstantInterval(x, x); + (*this) = (*this) - x; } void ConstantInterval::operator*=(int64_t x) { - // TODO: Optimize this - *this *= ConstantInterval(x, x); + (*this) = (*this) * x; } void ConstantInterval::operator/=(int64_t x) { - // TODO: Optimize this - *this /= ConstantInterval(x, x); + (*this) = (*this) / x; +} + +void ConstantInterval::operator%=(int64_t x) { + (*this) = (*this) % x; } bool operator<=(const ConstantInterval &a, const ConstantInterval &b) { @@ -338,20 +234,125 @@ ConstantInterval ConstantInterval::bounds_of_type(Type t) { } ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result += b; + ConstantInterval result; + result.min_defined = a.min_defined && + b.min_defined && + add_with_overflow(64, a.min, b.min, &result.min); + + result.max_defined = a.max_defined && + b.max_defined && + add_with_overflow(64, a.max, b.max, &result.max); return result; } ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result -= b; + ConstantInterval result; + result.min_defined = a.min_defined && + b.max_defined && + sub_with_overflow(64, a.min, b.max, &result.min); + result.max_defined = a.max_defined && + b.min_defined && + sub_with_overflow(64, a.max, b.min, &result.max); return result; } ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result /= b; + ConstantInterval result; + + result.min = INT64_MAX; + result.max = INT64_MIN; + + auto consider_case = [&](int64_t a, int64_t b) { + int64_t v = div_imp(a, b); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + }; + + // Enumerate all possible values for the min and max and take the extreme values. + if (a.min_defined && b.min_defined && b.min != 0) { + consider_case(a.min, b.min); + } + + if (a.min_defined && b.max_defined && b.max != 0) { + consider_case(a.min, b.max); + } + + if (a.max_defined && b.max_defined && b.max != 0) { + consider_case(a.max, b.max); + } + + if (a.max_defined && b.min_defined && b.min != 0) { + consider_case(a.max, b.min); + } + + // Define an int64_t zero just to pacify std::min and std::max + constexpr int64_t zero = 0; + + const bool b_positive = b > 0; + const bool b_negative = b < 0; + if ((b_positive && !b.max_defined) || + (b_negative && !b.min_defined)) { + // Take limit as other -> +/- infinity + result.min = std::min(result.min, zero); + result.max = std::max(result.max, zero); + } + + result.min_defined = ((a.min_defined && b_positive) || + (a.max_defined && b_negative)); + result.max_defined = ((a.max_defined && b_positive) || + (a.min_defined && b_negative)); + + // That's as far as we can get knowing the sign of the + // denominator. For bounded numerators, we additionally know + // that div can't make anything larger in magnitude, so we can + // take the intersection with that. + if (a.is_bounded() && a.min != INT64_MIN) { + int64_t magnitude = std::max(a.max, -a.min); + if (result.min_defined) { + result.min = std::max(result.min, -magnitude); + } else { + result.min = -magnitude; + } + if (result.max_defined) { + result.max = std::min(result.max, magnitude); + } else { + result.max = magnitude; + } + result.min_defined = result.max_defined = true; + } + + // Finally we can deduce the sign if the numerator and denominator are + // non-positive or non-negative. + bool a_non_negative = a >= 0; + bool b_non_negative = b >= 0; + bool a_non_positive = a <= 0; + bool b_non_positive = b <= 0; + if ((a_non_negative && b_non_negative) || + (a_non_positive && b_non_positive)) { + if (result.min_defined) { + result.min = std::max(result.min, zero); + } else { + result.min_defined = true; + result.min = 0; + } + } else if ((a_non_negative && b_non_positive) || + (a_non_positive && b_non_negative)) { + if (result.max_defined) { + result.max = std::min(result.max, zero); + } else { + result.max_defined = true; + result.max = 0; + } + } + + // Normalize the values if it's undefined + if (!result.min_defined) { + result.min = 0; + } + if (!result.max_defined) { + result.max = 0; + } + return result; } @@ -464,23 +465,23 @@ ConstantInterval operator%(const ConstantInterval &a, const ConstantInterval &b) } ConstantInterval operator+(const ConstantInterval &a, int64_t b) { - return a + ConstantInterval(b, b); + return a + ConstantInterval::single_point(b); } ConstantInterval operator-(const ConstantInterval &a, int64_t b) { - return a - ConstantInterval(b, b); + return a - ConstantInterval::single_point(b); } ConstantInterval operator/(const ConstantInterval &a, int64_t b) { - return a / ConstantInterval(b, b); + return a / ConstantInterval::single_point(b); } ConstantInterval operator*(const ConstantInterval &a, int64_t b) { - return a * ConstantInterval(b, b); + return a * ConstantInterval::single_point(b); } ConstantInterval operator%(const ConstantInterval &a, int64_t b) { - return a % ConstantInterval(b, b); + return a % ConstantInterval::single_point(b); } ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index ffd5cf008716..c692f2ac7200 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -6,9 +6,11 @@ using namespace Halide::Internal; int check_lossless_cast(const Type &t, const Expr &in, const Expr &correct) { Expr result = lossless_cast(t, in); if (!equal(result, correct)) { - std::cout << "Incorrect lossless_cast result:\nlossless_cast(" - << t << ", " << in << ") gave:\n " << result - << " but expected was:\n " << correct << "\n"; + std::cout << "Incorrect lossless_cast result:\n" + << "lossless_cast(" << t << ", " << in << ") gave:\n" + << " " << result + << " but expected was:\n" + << " " << correct << "\n"; return 1; } return 0; @@ -104,10 +106,14 @@ Expr random_expr(std::mt19937 &rng) { Expr e; int i1 = rng() % exprs.size(); int i2 = rng() % exprs.size(); + int i3 = rng() % exprs.size(); int op = rng() % 7; Expr e1 = exprs[i1]; Expr e2 = cast(e1.type(), exprs[i2]); + Expr e3 = exprs[i3]; bool may_widen = e1.type().bits() < 64; + Expr e2_narrow = exprs[i2]; + bool may_widen_right = e2_narrow.type() == e1.type().narrow(); switch (op) { case 0: if (may_widen) { @@ -132,7 +138,7 @@ Expr random_expr(std::mt19937 &rng) { e = e1 / e2; break; case 6: - switch (rng() % 10) { + switch (rng() % 14) { case 0: if (may_widen) { e = widening_add(e1, e2); @@ -169,6 +175,26 @@ Expr random_expr(std::mt19937 &rng) { case 9: e = count_trailing_zeros(e1); break; + case 10: + if (e3.type().is_uint()) { + e = rounding_mul_shift_right(e1, e2, e3); + } + break; + case 11: + if (may_widen_right) { + e = widen_right_add(e1, e2_narrow); + } + break; + case 12: + if (may_widen_right) { + e = widen_right_sub(e1, e2_narrow); + } + break; + case 13: + if (may_widen_right) { + e = widen_right_mul(e1, e2_narrow); + } + break; } } @@ -176,9 +202,9 @@ Expr random_expr(std::mt19937 &rng) { continue; } - // Stop when we get to 64 bits, but probably don't stop on a widening - // cast, because that'll just get trivially stripped. - if (e.type().bits() == 64 && (op > 1 || ((rng() & 7) == 0))) { + // Stop when we get to 64 bits, but probably don't stop on a cast, + // because that'll just get trivially stripped. + if (e.type().bits() == 64 && (e.as() == nullptr || ((rng() & 7) == 0))) { return e; } @@ -211,10 +237,15 @@ int test_one(uint32_t seed) { buf_i8.fill(rng); Expr e1 = random_expr(rng); + + // We're also going to test constant_integer_bounds here. + ConstantInterval bounds = constant_integer_bounds(e1); + Type target; std::vector target_types = {UInt(32), Int(32), UInt(16), Int(16)}; target = target_types[rng() % target_types.size()]; Expr e2 = lossless_cast(target, e1); + if (!e2.defined()) { return 0; } @@ -223,14 +254,18 @@ int test_one(uint32_t seed) { f(x) = {cast(e1), cast(e2)}; f.vectorize(x, 4, TailStrategy::RoundUp); - // std::cout << e1 << " to " << target << "\n -> " << e2 << "\n -> " << simplify(e2) << "\n"; - // std::cout << "\n\n\n--------------------\n\n\n"; Buffer out1(size), out2(size); Pipeline p(f); CheckForIntOverflow checker; + // We don't have constant-folding rules for all intrinsics, so we also need + // to feed the checker the lowered form. + checker.mutate(simplify(lower_intrinsics(e1))); + checker.mutate(simplify(lower_intrinsics(e2))); + if (checker.found_overflow) { + return 0; + } p.add_custom_lowering_pass(&checker, nullptr); p.realize({out1, out2}); - if (checker.found_overflow) { // We don't do anything in the expression generator to avoid signed // integer overflow, so just skip anything with signed integer overflow. @@ -240,6 +275,7 @@ int test_one(uint32_t seed) { for (int x = 0; x < size; x++) { if (out1(x) != out2(x)) { std::cout + << "lossless_cast failure\n" << "seed = " << seed << "\n" << "x = " << x << "\n" << "buf_u8 = " << (int)buf_u8(x) << "\n" @@ -250,6 +286,18 @@ int test_one(uint32_t seed) { << "Lossless cast: " << e2 << "\n"; return 1; } + + if (!bounds.contains(out1(x))) { + std::cout + << "constant_integer_bounds failure\n" + << "seed = " << seed << "\n" + << "x = " << x << "\n" + << "buf_u8 = " << (int)buf_u8(x) << "\n" + << "buf_i8 = " << (int)buf_i8(x) << "\n" + << "out1 = " << out1(x) << "\n" + << "Expression: " << e1 << "\n" + << "Bounds: " << bounds << "\n"; + } } return 0; From b053ec6b5fb8e51653ed7a6bf06c29b5243f359b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 12:54:51 -0700 Subject: [PATCH 13/33] Add missing files --- src/ConstantBounds.h | 25 ++++++ src/ConstantInterval.h | 168 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 src/ConstantBounds.h create mode 100644 src/ConstantInterval.h diff --git a/src/ConstantBounds.h b/src/ConstantBounds.h new file mode 100644 index 000000000000..35178bddb4a6 --- /dev/null +++ b/src/ConstantBounds.h @@ -0,0 +1,25 @@ +#ifndef HALIDE_CONSTANT_BOUNDS_H +#define HALIDE_CONSTANT_BOUNDS_H + +#include "ConstantInterval.h" +#include "Expr.h" +#include "Scope.h" + +/** \file + * Methods for computing compile-time constant int64_t upper and lower bounds of + * an expression. Cheaper than symbolic bounds inference, and useful for things + * like instruction selection. + */ + +namespace Halide { +namespace Internal { + +// TODO: comments +ConstantInterval constant_integer_bounds(const Expr &e, + const Scope &scope = Scope::empty_scope(), + std::map *cache = nullptr); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/ConstantInterval.h b/src/ConstantInterval.h new file mode 100644 index 000000000000..9939f89abf7f --- /dev/null +++ b/src/ConstantInterval.h @@ -0,0 +1,168 @@ +#ifndef HALIDE_CONSTANT_INTERVAL_H +#define HALIDE_CONSTANT_INTERVAL_H + +#include + +/** \file + * Defines the ConstantInterval class, and operators on it. + */ + +namespace Halide { + +struct Type; + +namespace Internal { + +/** A class to represent ranges of integers. Can be unbounded above or below, + * but they cannot be empty. */ +struct ConstantInterval { + /** The lower and upper bound of the interval. They are included + * in the interval. */ + int64_t min = 0, max = 0; + bool min_defined = false, max_defined = false; + + /* A default-constructed Interval is everything */ + ConstantInterval(); + + /** Construct an interval from a lower and upper bound. */ + ConstantInterval(int64_t min, int64_t max); + + /** The interval representing everything. */ + static ConstantInterval everything(); + + /** Construct an interval representing a single point. */ + static ConstantInterval single_point(int64_t x); + + /** Construct intervals bounded above or below. */ + static ConstantInterval bounded_below(int64_t min); + static ConstantInterval bounded_above(int64_t max); + + /** Is the interval the entire range */ + bool is_everything() const; + + /** Is the interval just a single value (min == max) */ + bool is_single_point() const; + + /** Is the interval a particular single value */ + bool is_single_point(int64_t x) const; + + /** Does the interval have a finite upper and lower bound */ + bool is_bounded() const; + + /** Expand the interval to include another Interval */ + void include(const ConstantInterval &i); + + /** Expand the interval to include a point */ + void include(int64_t x); + + /** Test if the interval contains a particular value */ + bool contains(int64_t x) const; + + /** Construct the smallest interval containing two intervals. */ + static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b); + + /** Construct the largest interval contained within two intervals. Throws an + * error if the interval is empty. */ + static ConstantInterval make_intersection(const ConstantInterval &a, const ConstantInterval &b); + + /** Equivalent to same_as. Exists so that the autoscheduler can + * compare two map for equality in order to + * cache computations. */ + bool operator==(const ConstantInterval &other) const; + + /** In-place versions of the arithmetic operators below. */ + // @{ + void operator+=(const ConstantInterval &other); + void operator+=(int64_t); + void operator-=(const ConstantInterval &other); + void operator-=(int64_t); + void operator*=(const ConstantInterval &other); + void operator*=(int64_t); + void operator/=(const ConstantInterval &other); + void operator/=(int64_t); + void operator%=(const ConstantInterval &other); + void operator%=(int64_t); + // @} + + /** Negate an interval. */ + ConstantInterval operator-() const; + + /** Track what happens if a constant integer interval is forced to fit into + * a concrete integer type. */ + void cast_to(const Type &t); + + /** Get constant integer bounds on a type. */ + static ConstantInterval bounds_of_type(Type); +}; + +/** Arithmetic operators on ConstantIntervals. The resulting interval contains + * all possible values of the operator applied to any two elements of the + * argument intervals. Note that these operator on unbounded integers. If you + * are applying this to concrete small integer types, you will need to manually + * cast the constant interval back to the desired type to model the effect of + * overflow. */ +// @{ +ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator+(const ConstantInterval &a, int64_t b); +ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator-(const ConstantInterval &a, int64_t b); +ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator/(const ConstantInterval &a, int64_t b); +ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator*(const ConstantInterval &a, int64_t b); +ConstantInterval operator%(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator%(const ConstantInterval &a, int64_t b); +ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval min(const ConstantInterval &a, int64_t b); +ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval max(const ConstantInterval &a, int64_t b); +ConstantInterval abs(const ConstantInterval &a); +ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator<<(const ConstantInterval &a, int64_t b); +ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator>>(const ConstantInterval &a, int64_t b); +// @} + +/** Comparison operators on ConstantIntervals. Returns whether the comparison is + * true for all values of the two intervals. */ +// @{ +bool operator<=(const ConstantInterval &a, const ConstantInterval &b); +bool operator<=(const ConstantInterval &a, int64_t b); +bool operator<=(int64_t a, const ConstantInterval &b); +bool operator<(const ConstantInterval &a, const ConstantInterval &b); +bool operator<(const ConstantInterval &a, int64_t b); +bool operator<(int64_t a, const ConstantInterval &b); + +inline bool operator>=(const ConstantInterval &a, const ConstantInterval &b) { + return b <= a; +} +inline bool operator>(const ConstantInterval &a, const ConstantInterval &b) { + return b < a; +} +inline bool operator>=(const ConstantInterval &a, int64_t b) { + return b <= a; +} +inline bool operator>(const ConstantInterval &a, int64_t b) { + return b < a; +} +inline bool operator>=(int64_t a, const ConstantInterval &b) { + return b <= a; +} +inline bool operator>(int64_t a, const ConstantInterval &b) { + return b < a; +} + +// @} +} // namespace Internal + +/** Cast operators for ConstantIntervals. These ones have to live out in + * Halide::, to avoid C++ name lookup confusion with the Halide::cast variants + * that take Exprs. */ +// @{ +Internal::ConstantInterval cast(Type t, const Internal::ConstantInterval &a); +Internal::ConstantInterval saturating_cast(Type t, const Internal::ConstantInterval &a); +// @} + +} // namespace Halide + +#endif From 413b4a6dd8dd79c39bf9224bc04da907173c5028 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 12:59:33 -0700 Subject: [PATCH 14/33] Account for more aggressive simplification in fuse test --- test/correctness/fuse.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/correctness/fuse.cpp b/test/correctness/fuse.cpp index 87ebcba3dbc4..d644e6fb741e 100644 --- a/test/correctness/fuse.cpp +++ b/test/correctness/fuse.cpp @@ -72,7 +72,7 @@ int main(int argc, char **argv) { Var xy("xy"); f.compute_root() .fuse(x, y, xy) - .vectorize(xy, 16); + .vectorize(xy, 16, TailStrategy::RoundUp); f.add_custom_lowering_pass(new CheckForMod); f.compile_jit(); From 854122f53f089268602af06413552b988ec55015 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 13:01:16 -0700 Subject: [PATCH 15/33] Remove redundant helpers --- src/ConstantBounds.cpp | 10 +++---- src/ConstantInterval.cpp | 8 ------ src/IRPrinter.cpp | 25 +++++++++++------- src/Monotonic.cpp | 56 ++++++++++------------------------------ src/Simplify.cpp | 20 +++++++------- src/Simplify_Cast.cpp | 4 +-- src/Simplify_Internal.h | 4 +-- src/Simplify_Let.cpp | 8 +++--- src/Simplify_Stmts.cpp | 5 ++-- 9 files changed, 53 insertions(+), 87 deletions(-) diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp index 25aca7c6b6d5..7ecc25ac35b3 100644 --- a/src/ConstantBounds.cpp +++ b/src/ConstantBounds.cpp @@ -162,12 +162,10 @@ ConstantInterval constant_integer_bounds(const Expr &e, ret = get_bounds(); } - if (true) { - internal_assert((!ret.has_lower_bound() || e.type().can_represent(ret.min)) && - (!ret.has_upper_bound() || e.type().can_represent(ret.max))) - << "constant_bounds returned defined bounds that are not representable in " - << "the type of the Expr passed in.\n Expr: " << e << "\n Bounds: " << ret; - } + internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && + (!ret.max_defined || e.type().can_represent(ret.max))) + << "constant_bounds returned defined bounds that are not representable in " + << "the type of the Expr passed in.\n Expr: " << e << "\n Bounds: " << ret; return ret; } diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp index a29f42f5ee9c..770f39a3207a 100644 --- a/src/ConstantInterval.cpp +++ b/src/ConstantInterval.cpp @@ -48,14 +48,6 @@ bool ConstantInterval::is_single_point(int64_t x) const { return min_defined && max_defined && min == x && max == x; } -bool ConstantInterval::has_upper_bound() const { - return max_defined; -} - -bool ConstantInterval::has_lower_bound() const { - return min_defined; -} - bool ConstantInterval::is_bounded() const { return max_defined && min_defined; } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 7719eccfc489..0c29b5c8a749 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -449,9 +449,7 @@ std::ostream &operator<<(std::ostream &out, const Closure &c) { return out; } -namespace { -template -void emit_interval(std::ostream &out, const T &in) { +std::ostream &operator<<(std::ostream &out, const Interval &in) { out << "["; if (in.has_lower_bound()) { out << in.min; @@ -465,16 +463,23 @@ void emit_interval(std::ostream &out, const T &in) { out << "inf"; } out << "]"; -} -} // namespace - -std::ostream &operator<<(std::ostream &out, const Interval &c) { - emit_interval(out, c); return out; } -std::ostream &operator<<(std::ostream &out, const ConstantInterval &c) { - emit_interval(out, c); +std::ostream &operator<<(std::ostream &out, const ConstantInterval &in) { + out << "["; + if (in.min_defined) { + out << in.min; + } else { + out << "-inf"; + } + out << ", "; + if (in.max_defined) { + out << in.max; + } else { + out << "inf"; + } + out << "]"; return out; } diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index b69746ba2826..eb3d8fd651f8 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -43,16 +43,8 @@ const int64_t *as_const_int_or_uint(const Expr &e) { return nullptr; } -bool is_constant(const ConstantInterval &a) { - return a.is_single_point(0); -} - -bool is_monotonic_increasing(const ConstantInterval &a) { - return a.has_lower_bound() && a.min >= 0; -} - -bool is_monotonic_decreasing(const ConstantInterval &a) { - return a.has_upper_bound() && a.max <= 0; +bool is_constant(const ConstantInterval &x) { + return x.is_single_point(0); } ConstantInterval to_interval(Monotonic m) { @@ -72,25 +64,15 @@ ConstantInterval to_interval(Monotonic m) { Monotonic to_monotonic(const ConstantInterval &x) { if (is_constant(x)) { return Monotonic::Constant; - } else if (is_monotonic_increasing(x)) { + } else if (x >= 0) { return Monotonic::Increasing; - } else if (is_monotonic_decreasing(x)) { + } else if (x <= 0) { return Monotonic::Decreasing; } else { return Monotonic::Unknown; } } -ConstantInterval unify(const ConstantInterval &a, const ConstantInterval &b) { - return ConstantInterval::make_union(a, b); -} - -ConstantInterval unify(const ConstantInterval &a, int64_t b) { - ConstantInterval result; - result.include(b); - return result; -} - class DerivativeBounds : public IRVisitor { const string &var; @@ -193,10 +175,10 @@ class DerivativeBounds : public IRVisitor { if (*b == 0) { result = ConstantInterval(0, 0); } else { - if (result.has_lower_bound()) { + if (result.min_defined) { result.min = div_imp(result.min, *b); } - if (result.has_upper_bound()) { + if (result.max_defined) { if (result.max != INT64_MIN) { result.max = div_imp(result.max - 1, *b) + 1; } else { @@ -223,16 +205,14 @@ class DerivativeBounds : public IRVisitor { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit(const Max *op) override { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit_eq(const Expr &a, const Expr &b) { @@ -262,17 +242,12 @@ class DerivativeBounds : public IRVisitor { a.accept(this); ConstantInterval ra = result; b.accept(this); - ConstantInterval rb = result; - result = unify(-ra, rb); + result.include(-ra); // If the result is bounded, limit it to [-1, 1]. The largest // difference possible is flipping from true to false or false // to true. - if (result.has_lower_bound()) { - result.min = std::min(std::max(result.min, -1), 1); - } - if (result.has_upper_bound()) { - result.max = std::min(std::max(result.max, -1), 1); - } + result.min = std::min(std::max(result.min, -1), 1); + result.max = std::min(std::max(result.max, -1), 1); } void visit(const LT *op) override { @@ -295,16 +270,14 @@ class DerivativeBounds : public IRVisitor { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit(const Or *op) override { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit(const Not *op) override { @@ -327,8 +300,7 @@ class DerivativeBounds : public IRVisitor { op->true_value.accept(this); ConstantInterval ra = result; op->false_value.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); // If the condition is not constant, we hit a "bump" when the condition changes value. if (!is_constant(rcond)) { diff --git a/src/Simplify.cpp b/src/Simplify.cpp index c73b9b43f4e6..2f6acea23969 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -38,8 +38,8 @@ Simplify::Simplify(bool r, const Scope *bi, const Scopemutate(lt->b, &i); - if (i.bounds.has_lower_bound()) { + if (i.bounds.min_defined) { // !(v < i) learn_lower_bound(v, i.bounds.min); } @@ -145,7 +145,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { v = lt->b.as(); if (v) { simplify->mutate(lt->a, &i); - if (i.bounds.has_upper_bound()) { + if (i.bounds.max_defined) { // !(i < v) learn_upper_bound(v, i.bounds.max); } @@ -155,7 +155,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::ExprInfo i; if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(le->b, &i); - if (i.bounds.has_lower_bound()) { + if (i.bounds.min_defined) { // !(v <= i) learn_lower_bound(v, i.bounds.min + 1); } @@ -163,7 +163,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { v = le->b.as(); if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(le->a, &i); - if (i.bounds.has_upper_bound()) { + if (i.bounds.max_defined) { // !(i <= v) learn_upper_bound(v, i.bounds.max - 1); } @@ -267,7 +267,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo i; if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(lt->b, &i); - if (i.bounds.has_upper_bound()) { + if (i.bounds.max_defined) { // v < i learn_upper_bound(v, i.bounds.max - 1); } @@ -275,7 +275,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { v = lt->b.as(); if (v && v->type.is_int() && v->type.bits() >= 32) { simplify->mutate(lt->a, &i); - if (i.bounds.has_lower_bound()) { + if (i.bounds.min_defined) { // i < v learn_lower_bound(v, i.bounds.min + 1); } @@ -285,7 +285,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo i; if (v) { simplify->mutate(le->b, &i); - if (i.bounds.has_upper_bound()) { + if (i.bounds.max_defined) { // v <= i learn_upper_bound(v, i.bounds.max); } @@ -293,7 +293,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { v = le->b.as(); if (v) { simplify->mutate(le->a, &i); - if (i.bounds.has_lower_bound()) { + if (i.bounds.min_defined) { // i <= v learn_lower_bound(v, i.bounds.min); } diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 089e46bdc62b..02cc4542399a 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -11,11 +11,11 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { // If there's overflow in a no-overflow type (e.g. due to casting // from a UInt(64) to an Int(32), then set the corresponding bound // to infinity. - if (info->bounds.has_upper_bound() && !op->type.can_represent(info->bounds.max)) { + if (info->bounds.max_defined && !op->type.can_represent(info->bounds.max)) { info->bounds.max_defined = false; info->bounds.max = 0; } - if (info->bounds.has_lower_bound() && !op->type.can_represent(info->bounds.min)) { + if (info->bounds.min_defined && !op->type.can_represent(info->bounds.min)) { info->bounds.min_defined = false; info->bounds.min = 0; } diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 0b87b7aa6774..1b0308d938a1 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -47,7 +47,7 @@ class Simplify : public VariadicVisitor { if (alignment.modulus == 0) { bounds = ConstantInterval::single_point(alignment.remainder); } else if (alignment.modulus > 1) { - if (bounds.has_lower_bound()) { + if (bounds.min_defined) { int64_t adjustment; bool no_overflow = sub_with_overflow(64, alignment.remainder, mod_imp(bounds.min, alignment.modulus), &adjustment); adjustment = mod_imp(adjustment, alignment.modulus); @@ -57,7 +57,7 @@ class Simplify : public VariadicVisitor { bounds.min = new_min; } } - if (bounds.has_upper_bound()) { + if (bounds.max_defined) { int64_t adjustment; bool no_overflow = sub_with_overflow(64, mod_imp(bounds.max, alignment.modulus), alignment.remainder, &adjustment); adjustment = mod_imp(adjustment, alignment.modulus); diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 481f37be87a4..40be942ac805 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -205,8 +205,8 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { // Remutate new_value to get updated bounds ExprInfo new_value_info; f.new_value = mutate(f.new_value, &new_value_info); - if (new_value_info.bounds.has_lower_bound() || - new_value_info.bounds.has_upper_bound() || + if (new_value_info.bounds.min_defined || + new_value_info.bounds.max_defined || new_value_info.alignment.modulus != 1) { // There is some useful information bounds_and_alignment_info.push(f.new_name, new_value_info); @@ -215,8 +215,8 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { } if (no_overflow_scalar_int(f.value.type())) { - if (value_info.bounds.has_lower_bound() || - value_info.bounds.has_upper_bound() || + if (value_info.bounds.min_defined || + value_info.bounds.max_defined || value_info.alignment.modulus != 1) { bounds_and_alignment_info.push(op->name, value_info); f.value_bounds_tracked = true; diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 57cfe74cf7e3..f63c7d35601f 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -227,9 +227,8 @@ Stmt Simplify::visit(const For *op) { ConstantInterval::make_union(min_info.bounds, min_info.bounds + max(extent_info.bounds, 1) - 1); - - if (loop_var_info.bounds.has_upper_bound() || - loop_var_info.bounds.has_lower_bound()) { + if (loop_var_info.bounds.max_defined || + loop_var_info.bounds.min_defined) { bounds_tracked = true; bounds_and_alignment_info.push(op->name, loop_var_info); } From 4a293b119eba388f63e1316bcdef9d7f9ffc7c27 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 13:08:33 -0700 Subject: [PATCH 16/33] Add missing comment --- src/ConstantBounds.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/ConstantBounds.h b/src/ConstantBounds.h index 35178bddb4a6..26ab114455bb 100644 --- a/src/ConstantBounds.h +++ b/src/ConstantBounds.h @@ -14,7 +14,17 @@ namespace Halide { namespace Internal { -// TODO: comments +/** Deduce constant integer bounds on an expression. This can be useful to + * decide if, for example, the expression can be cast to another type, be + * negated, be incremented, etc without risking overflow. + * + * Also optionally accepts a scope containing the integer bounds of any + * variables that may be referenced, and a cache of constant integer bounds on + * known Exprs, which this function will update. The cache is helpful to + * short-circuit large numbers of redundant queries, but it should not be used + * in contexts where the same Expr object may take on different values within a + * single Expr (i.e. before uniquify_variable_names). + */ ConstantInterval constant_integer_bounds(const Expr &e, const Scope &scope = Scope::empty_scope(), std::map *cache = nullptr); From 0856319ed7366f67526db381a8dd45b7c0b123ef Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 13:13:03 -0700 Subject: [PATCH 17/33] clear_bounds_info -> clear_expr_info --- src/Simplify_Call.cpp | 4 ++-- src/Simplify_Cast.cpp | 2 +- src/Simplify_Div.cpp | 2 +- src/Simplify_Exprs.cpp | 8 ++++---- src/Simplify_Internal.h | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 9dd98c3e3e8b..99ee0c16ae6a 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -145,7 +145,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // LLVM shl and shr instructions produce poison for // shifts >= typesize, so we will follow suit in our simplifier. if (ub >= (uint64_t)(t.bits())) { - clear_bounds_info(info); + clear_expr_info(info); return make_signed_integer_overflow(t); } if (a.type().is_uint() || ub < ((uint64_t)t.bits() - 1)) { @@ -778,7 +778,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // just fall thru and take the general case. debug(2) << "Simplifier: unhandled PureExtern: " << op->name << "\n"; } else if (op->is_intrinsic(Call::signed_integer_overflow)) { - clear_bounds_info(info); + clear_expr_info(info); } else if (op->is_intrinsic(Call::concat_bits) && op->args.size() == 1) { return mutate(op->args[0], info); } diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 02cc4542399a..3a60f5aa464c 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -37,7 +37,7 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { int64_t i = 0; uint64_t u = 0; if (Call::as_intrinsic(value, {Call::signed_integer_overflow})) { - clear_bounds_info(info); + clear_expr_info(info); return make_signed_integer_overflow(op->type); } else if (value.type() == op->type) { return value; diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index 45b6a2ad8fb7..c06337bfb566 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -30,7 +30,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { // a known-wrong value. (Note that no_overflow_int() should // only be true for signed integers.) internal_assert(op->type.is_int()); - clear_bounds_info(info); + clear_expr_info(info); return make_signed_integer_overflow(op->type); } } diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 70a63d9d1f4c..02f19ae13a6a 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -13,7 +13,7 @@ Expr Simplify::visit(const IntImm *op, ExprInfo *info) { info->alignment = ModulusRemainder(0, op->value); info->cast_to(op->type); } else { - clear_bounds_info(info); + clear_expr_info(info); } return op; } @@ -25,18 +25,18 @@ Expr Simplify::visit(const UIntImm *op, ExprInfo *info) { info->alignment = ModulusRemainder(0, v); info->cast_to(op->type); } else { - clear_bounds_info(info); + clear_expr_info(info); } return op; } Expr Simplify::visit(const FloatImm *op, ExprInfo *info) { - clear_bounds_info(info); + clear_expr_info(info); return op; } Expr Simplify::visit(const StringImm *op, ExprInfo *info) { - clear_bounds_info(info); + clear_expr_info(info); return op; } diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 1b0308d938a1..0b7e7c7c7049 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -114,7 +114,7 @@ class Simplify : public VariadicVisitor { }; HALIDE_ALWAYS_INLINE - void clear_bounds_info(ExprInfo *b) { + void clear_expr_info(ExprInfo *b) { if (b) { *b = ExprInfo{}; } From 16a706da4606684521fc4ce291e4bacc0fc32a6e Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 13:30:50 -0700 Subject: [PATCH 18/33] Remove bad TODO I can't think of a single case that could cause this --- src/Simplify_Div.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index c06337bfb566..a43ce9b93048 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -14,9 +14,6 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { info->trim_bounds_using_alignment(); info->cast_to(op->type); - // TODO: add test case which resolves to a scalar, but only after - // trimming using the alignment. - // Bounded numerator divided by constantish // denominator can sometimes collapse things to a // constant at this point From ecfae44e12c51284b06a0f9ca9bf8d83e91e824e Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 13:32:27 -0700 Subject: [PATCH 19/33] It's too late to change the semantics of fixed point intrinsics --- src/FindIntrinsics.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 4142831cd2be..54358c0ef76f 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -176,7 +176,6 @@ class FindIntrinsics : public IRMutator { } // Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add - // TODO: this pattern makes me want to change the semantics of this op. if (const Cast *cast = a.as()) { if (cast->is_reinterpret()) { if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) { From 66c56f1705bc887700ba9157deec5121a1d573db Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 1 Apr 2024 13:35:01 -0700 Subject: [PATCH 20/33] Fix some UB --- src/IRMatch.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/IRMatch.h b/src/IRMatch.h index cf9312db0075..e71d56b0fa55 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -2108,7 +2108,8 @@ struct WidenOp { HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { Expr e = a.make(state, {}); - return cast(e.type().widen(), std::move(e)); + Type w = e.type().widen(); + return cast(w, std::move(e)); } constexpr static bool foldable = false; From 0fb8d3883ef78cfc1ef1502fe9ccd1f7777da37b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 2 Apr 2024 12:10:14 -0700 Subject: [PATCH 21/33] Stronger assert in Simplify_Div --- src/Simplify_Div.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index a43ce9b93048..c399470eb3dc 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -26,7 +26,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { // we're better off returning an overflow condition than // a known-wrong value. (Note that no_overflow_int() should // only be true for signed integers.) - internal_assert(op->type.is_int()); + internal_assert(no_overflow_int(op->type)); clear_expr_info(info); return make_signed_integer_overflow(op->type); } From c6065ff59f8f097c06971e78f719ba2051af979e Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 2 Apr 2024 12:10:41 -0700 Subject: [PATCH 22/33] Delete bad rewrite rules --- src/FindIntrinsics.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 54358c0ef76f..8b8d104565bb 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -781,11 +781,6 @@ class FindIntrinsics : public IRMutator { // We only care about integers, this should be trivially true. is_x_same_int_or_uint) || - // widening_add(x + widen(y), widen(z)) -> widening_add(x, widening_add(y, z)) - rewrite(widening_add(widen_right_add(x, y), widen(z)), - widening_add(x, widening_add(y, z))) || - rewrite(widening_add(widen(z), widen_right_add(x, y)), - widening_add(x, widening_add(y, z))) || // Saturating patterns. rewrite(saturating_cast(op->type, widening_add(x, y)), From 6bcc66a5fb3576eeabb0b50bb52b7a1ee6c3cae6 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 2 Apr 2024 12:11:26 -0700 Subject: [PATCH 23/33] Fix bad test when lowering mul_shift_right b_shift + b_shift < missing_q --- src/FindIntrinsics.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 8b8d104565bb..27ef367ab329 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1469,7 +1469,7 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) int a_shift = 0, b_shift = 0; ConstantInterval ca = constant_integer_bounds(a); do { - ConstantInterval bigger = ca * ConstantInterval::single_point(2); + ConstantInterval bigger = ca * 2; if (a.type().can_represent(bigger) && a_shift + b_shift < missing_q) { ca = bigger; a_shift++; @@ -1478,8 +1478,8 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) } while (false); ConstantInterval cb = constant_integer_bounds(b); do { - ConstantInterval bigger = cb * ConstantInterval::single_point(2); - if (b.type().can_represent(bigger) && b_shift + b_shift < missing_q) { + ConstantInterval bigger = cb * 2; + if (b.type().can_represent(bigger) && a_shift + b_shift < missing_q) { cb = bigger; b_shift++; continue; From c652667f15bb46a5f06e426b2f530410bff66828 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 2 Apr 2024 12:11:42 -0700 Subject: [PATCH 24/33] Avoid UB in lowering of rounding_shift_right/left --- src/FindIntrinsics.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 27ef367ab329..0c46379bdf77 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -781,7 +781,6 @@ class FindIntrinsics : public IRMutator { // We only care about integers, this should be trivially true. is_x_same_int_or_uint) || - // Saturating patterns. rewrite(saturating_cast(op->type, widening_add(x, y)), saturating_add(x, y), @@ -1250,10 +1249,11 @@ Expr lower_widening_shift_right(const Expr &a, const Expr &b) { } Expr lower_rounding_shift_left(const Expr &a, const Expr &b) { - // Shift left, then add one to the result if bits were dropped - // (because b < 0) and the most significant dropped bit was a one. + // Shift left, then add one to the result if bits were dropped (because b < 0) + // and the most significant dropped bit was a one. We must take care not + // to introduce UB in the shifts, even if the result would be masked off. Expr b_negative = select(b < 0, make_one(a.type()), make_zero(a.type())); - return simplify((a << b) + (b_negative & (a << (b + 1)))); + return simplify((a << b) + (b_negative & (a << saturating_add(b, make_one(b.type()))))); } Expr lower_rounding_shift_right(const Expr &a, const Expr &b) { @@ -1265,10 +1265,11 @@ Expr lower_rounding_shift_right(const Expr &a, const Expr &b) { Expr round = simplify(cast(a.type(), (1 << shift) - 1)); return rounding_halving_add(a, round) >> shift; } - // Shift right, then add one to the result if bits were dropped - // (because b > 0) and the most significant dropped bit was a one. + // Shift right, then add one to the result if bits were dropped (because b > 0) + // and the most significant dropped bit was a one. We must take care not to + // introduce UB in the shifts, even if the result would be masked off. Expr b_positive = select(b > 0, make_one(a.type()), make_zero(a.type())); - return simplify((a >> b) + (b_positive & (a >> (b - 1)))); + return simplify((a >> b) + (b_positive & (a >> saturating_sub(b, make_one(b.type()))))); } Expr lower_saturating_add(const Expr &a, const Expr &b) { From 1737a52dfa6bd6da8dae0b1e9b756fc54e93ef0d Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 2 Apr 2024 12:34:37 -0700 Subject: [PATCH 25/33] Add shifts to the lossless cast fuzzer This required a more careful signed-integer-overflow detection routine --- test/correctness/lossless_cast.cpp | 129 ++++++++++++++++++++++------- 1 file changed, 98 insertions(+), 31 deletions(-) diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index c692f2ac7200..9a31597e1aad 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -110,7 +110,7 @@ Expr random_expr(std::mt19937 &rng) { int op = rng() % 7; Expr e1 = exprs[i1]; Expr e2 = cast(e1.type(), exprs[i2]); - Expr e3 = exprs[i3]; + Expr e3 = cast(e1.type().with_code(halide_type_uint), exprs[i3]); bool may_widen = e1.type().bits() < 64; Expr e2_narrow = exprs[i2]; bool may_widen_right = e2_narrow.type() == e1.type().narrow(); @@ -138,7 +138,7 @@ Expr random_expr(std::mt19937 &rng) { e = e1 / e2; break; case 6: - switch (rng() % 14) { + switch (rng() % 15) { case 0: if (may_widen) { e = widening_add(e1, e2); @@ -176,21 +176,26 @@ Expr random_expr(std::mt19937 &rng) { e = count_trailing_zeros(e1); break; case 10: - if (e3.type().is_uint()) { + if (may_widen) { e = rounding_mul_shift_right(e1, e2, e3); } break; case 11: + if (may_widen) { + e = mul_shift_right(e1, e2, e3); + } + break; + case 12: if (may_widen_right) { e = widen_right_add(e1, e2_narrow); } break; - case 12: + case 13: if (may_widen_right) { e = widen_right_sub(e1, e2_narrow); } break; - case 13: + case 14: if (may_widen_right) { e = widen_right_mul(e1, e2_narrow); } @@ -212,21 +217,92 @@ Expr random_expr(std::mt19937 &rng) { } } -class CheckForIntOverflow : public IRMutator { - using IRMutator::visit; +bool might_have_ub(Expr e) { + class MightOverflow : public IRVisitor { + std::map cache; + + using IRVisitor::visit; - Expr visit(const Call *op) override { - if (op->is_intrinsic(Call::signed_integer_overflow)) { - found_overflow = true; - return make_zero(op->type); - } else { - return IRMutator::visit(op); + bool no_overflow_int(const Type &t) { + return t.is_int() && t.bits() >= 32; + } + + ConstantInterval bounds(const Expr &e) { + return constant_integer_bounds(e, Scope::empty_scope(), &cache); + } + + void visit(const Add *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->a) + bounds(op->b))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Sub *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->a) - bounds(op->b))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Mul *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->a) * bounds(op->b))) { + found = true; + } else { + IRVisitor::visit(op); + } } - } -public: - bool found_overflow = false; -}; + void visit(const Div *op) override { + if (no_overflow_int(op->type) && + (bounds(op->a) / bounds(op->b)).contains(-1)) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Cast *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->value))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Call *op) override { + if (op->is_intrinsic({Call::shift_left, + Call::shift_right, + Call::rounding_shift_left, + Call::rounding_shift_right, + Call::widening_shift_left, + Call::widening_shift_right, + Call::mul_shift_right, + Call::rounding_mul_shift_right})) { + auto shift_bounds = bounds(op->args.back()); + if (!(shift_bounds > -op->type.bits() && shift_bounds < op->type.bits())) { + found = true; + } + } else if (op->is_intrinsic({Call::signed_integer_overflow})) { + found = true; + } + IRVisitor::visit(op); + } + + public: + bool found = false; + } checker; + + e.accept(&checker); + + return checker.found; +} bool found_error = false; @@ -238,6 +314,10 @@ int test_one(uint32_t seed) { Expr e1 = random_expr(rng); + if (might_have_ub(e1)) { + return 0; + } + // We're also going to test constant_integer_bounds here. ConstantInterval bounds = constant_integer_bounds(e1); @@ -256,21 +336,7 @@ int test_one(uint32_t seed) { Buffer out1(size), out2(size); Pipeline p(f); - CheckForIntOverflow checker; - // We don't have constant-folding rules for all intrinsics, so we also need - // to feed the checker the lowered form. - checker.mutate(simplify(lower_intrinsics(e1))); - checker.mutate(simplify(lower_intrinsics(e2))); - if (checker.found_overflow) { - return 0; - } - p.add_custom_lowering_pass(&checker, nullptr); p.realize({out1, out2}); - if (checker.found_overflow) { - // We don't do anything in the expression generator to avoid signed - // integer overflow, so just skip anything with signed integer overflow. - return 0; - } for (int x = 0; x < size; x++) { if (out1(x) != out2(x)) { @@ -297,6 +363,7 @@ int test_one(uint32_t seed) { << "out1 = " << out1(x) << "\n" << "Expression: " << e1 << "\n" << "Bounds: " << bounds << "\n"; + return 1; } } From ddab1cf8184a55953b52bd941f2b49c841ff4f37 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 5 Apr 2024 09:30:59 -0700 Subject: [PATCH 26/33] Fix bug in lossless_negate --- src/IROperator.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 2b0359fb3280..39f68d9c3f87 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -543,6 +543,12 @@ Expr lossless_cast(Type t, Expr e, std::map Expr lossless_negate(const Expr &x) { if (const Mul *m = x.as()) { + // Check the terms can't multiply to produce the most negative value. + if (x.type().is_int() && + !x.type().can_represent(-constant_integer_bounds(x))) { + return Expr(); + } + Expr b = lossless_negate(m->b); if (b.defined()) { return Mul::make(m->a, b); @@ -569,8 +575,7 @@ Expr lossless_negate(const Expr &x) { } else if (const Cast *c = x.as()) { Expr value = lossless_negate(c->value); if (value.defined()) { - // This works for constants, but not other things that - // could possibly be negated. + // This logic is only sound if we know the cast can't overflow. value = lossless_cast(c->type, value); if (value.defined()) { return value; From a0f1d233273ec3da4a3499256326ea241d202432 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 2 Jun 2024 14:36:57 -0700 Subject: [PATCH 27/33] Add constant interval test --- test/correctness/constant_interval.cpp | 173 +++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 test/correctness/constant_interval.cpp diff --git a/test/correctness/constant_interval.cpp b/test/correctness/constant_interval.cpp new file mode 100644 index 000000000000..0f421d4b8b5a --- /dev/null +++ b/test/correctness/constant_interval.cpp @@ -0,0 +1,173 @@ +#include "Halide.h" + +using namespace Halide; +using namespace Halide::Internal; + +std::mt19937 rng; + +int64_t sample(const ConstantInterval &i) { + int64_t upper = i.max_defined ? i.max : 1024; + int64_t lower = i.min_defined ? i.min : -1024; + return lower + (rng() % (upper - lower + 1)); +} + +ConstantInterval random_interval() { + int64_t a = (rng() % 512) - 256; + int64_t b = (rng() % 512) - 256; + ConstantInterval result; + if (rng() & 1) { + result.max_defined = true; + result.max = std::max(a, b); + } + if (rng() & 1) { + result.min_defined = true; + result.min = std::min(a, b); + } + return result; +} + +int main(int argc, char **argv) { + for (int i = 0; i < 1000; i++) { + std::vector> values; + for (int j = 0; j < 10; j++) { + values.emplace_back(random_interval(), 0); + values.back().second = sample(values.back().first); + } + + for (int j = 0; j < 1000; j++) { + auto a = values[rng() % values.size()]; + auto b = values[rng() % values.size()]; + decltype(a) c; + + auto check = [&](const char *op) { + if (!c.first.contains(c.second)) { + std::cout << "Error for operator " << op << ":\n" + << "a: " << a.second << " in " << a.first << "\n" + << "b: " << b.second << " in " << b.first << "\n" + << "c: " << c.second << " not in " << c.first << "\n"; + exit(1); + } + }; + + auto check_scalar = [&](const char *op) { + if (!c.first.contains(c.second)) { + std::cout << "Error for operator " << op << ":\n" + << "a: " << a.second << " in " << a.first << "\n" + << "b: " << b.second << "\n" + << "c: " << c.second << " not in " << c.first << "\n"; + exit(1); + } + }; + + // Arithmetic + c.first = a.first + b.first; + c.second = a.second + b.second; + check("+"); + + c.first = a.first - b.first; + c.second = a.second - b.second; + check("-"); + + c.first = a.first * b.first; + c.second = a.second * b.second; + check("*"); + + c.first = a.first / b.first; + c.second = div_imp(a.second, b.second); + check("/"); + + c.first = min(a.first, b.first); + c.second = std::min(a.second, b.second); + check("min"); + + c.first = max(a.first, b.first); + c.second = std::max(a.second, b.second); + check("max"); + + c.first = a.first % b.first; + c.second = mod_imp(a.second, b.second); + check("%"); + + // Arithmetic with constant RHS + c.first = a.first + b.second; + c.second = a.second + b.second; + check_scalar("+"); + + c.first = a.first - b.second; + c.second = a.second - b.second; + check_scalar("-"); + + c.first = a.first * b.second; + c.second = a.second * b.second; + check_scalar("*"); + + c.first = a.first / b.second; + c.second = div_imp(a.second, b.second); + check_scalar("/"); + + c.first = min(a.first, b.second); + c.second = std::min(a.second, b.second); + check_scalar("min"); + + c.first = max(a.first, b.second); + c.second = std::max(a.second, b.second); + check_scalar("max"); + + c.first = a.first % b.second; + c.second = mod_imp(a.second, b.second); + check_scalar("%"); + + // Some unary operators + c.first = -a.first; + c.second = -a.second; + check("unary -"); + + c.first = cast(UInt(8), a.first); + c.second = (int64_t)(uint8_t)(a.second); + check("cast to uint8"); + + c.first = cast(Int(8), a.first); + c.second = (int64_t)(int8_t)(a.second); + check("cast to uint8"); + + // Comparison + _halide_user_assert(!(a.first < b.first) || a.second < b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.first <= b.first) || a.second <= b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.first > b.first) || a.second > b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.first >= b.first) || a.second >= b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + // Comparison against constants + _halide_user_assert(!(a.first < b.second) || a.second < b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.first <= b.second) || a.second <= b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.first > b.second) || a.second > b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.first >= b.second) || a.second >= b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.second < b.first) || a.second < b.second) + << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.second <= b.first) || a.second <= b.second) + << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.second > b.first) || a.second > b.second) + << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.second >= b.first) || a.second >= b.second) + << a.second << " " << b.first << " " << b.second; + } + } + return 0; +} From ac5b13df24063b37d256c7ad10ff36da6cb6be7c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 3 Jun 2024 11:39:39 -0700 Subject: [PATCH 28/33] Rework find_mpy_ops to handle more structures --- src/HexagonOptimize.cpp | 121 ++++++++++++++++++++++++---------------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 8b01b741a39b..13b2b5d24559 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -382,6 +382,7 @@ typedef pair MulExpr; // the number of lanes in Broadcast or indices in a Shuffle // to match the ty lanes before using lossless_cast on it. Expr unbroadcast_lossless_cast(Type ty, Expr x) { + internal_assert(x.defined()); if (x.type().is_vector()) { if (const Broadcast *bc = x.as()) { if (ty.is_scalar()) { @@ -410,57 +411,78 @@ Expr unbroadcast_lossless_cast(Type ty, Expr x) { // multiplies in 'mpys', added to 'rest'. // Difference in mpys.size() - return indicates the number of // expressions where we pretend the op to be multiplied by 1. -int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count, +int find_mpy_ops(const Expr &op, Type result_ty, Type a_ty, Type b_ty, int max_mpy_count, vector &mpys, Expr &rest) { + auto add_to_rest = [&](const Expr &a) { + if (rest.defined()) { + // Just widen to the result type. We run find_intrinsics on rest + // after calling this, to find things like widen_right_add in this + // summation. + rest = Add::make(rest, cast(result_ty, a)); + } else { + rest = cast(result_ty, a); + } + }; + if ((int)mpys.size() >= max_mpy_count) { - rest = rest.defined() ? Add::make(rest, op) : op; + add_to_rest(op); return 0; } - // If the add is also widening, remove the cast. - Expr stripped = op; - int mpy_bits = std::max(a_ty.bits(), b_ty.bits()) * 2; - if (op.type().bits() == mpy_bits * 2) { - if (const Cast *cast = op.as()) { - if (cast->value.type().bits() == mpy_bits) { - stripped = cast->value; - } - } - } - - Expr maybe_mul = as_mul(stripped); - if (maybe_mul.defined()) { - const Mul *mul = maybe_mul.as(); - Expr a = unbroadcast_lossless_cast(a_ty, mul->a); - Expr b = unbroadcast_lossless_cast(b_ty, mul->b); + auto handle_mul = [&](const Expr &arg0, const Expr &arg1) -> bool { + Expr a = unbroadcast_lossless_cast(a_ty, arg0); + Expr b = unbroadcast_lossless_cast(b_ty, arg1); if (a.defined() && b.defined()) { mpys.emplace_back(a, b); - return 1; - } else { + return true; + } else if (a_ty != b_ty) { // Try to commute the op. - a = unbroadcast_lossless_cast(a_ty, mul->b); - b = unbroadcast_lossless_cast(b_ty, mul->a); + a = unbroadcast_lossless_cast(a_ty, arg1); + b = unbroadcast_lossless_cast(b_ty, arg0); if (a.defined() && b.defined()) { mpys.emplace_back(a, b); - return 1; + return true; } } - } else if (const Add *add = stripped.as()) { - int mpy_count = 0; - mpy_count += find_mpy_ops(add->a, a_ty, b_ty, max_mpy_count, mpys, rest); - mpy_count += find_mpy_ops(add->b, a_ty, b_ty, max_mpy_count, mpys, rest); - return mpy_count; - } else if (const Call *add = Call::as_intrinsic(stripped, {Call::widening_add})) { - int mpy_count = 0; - mpy_count += find_mpy_ops(cast(op.type(), add->args[0]), a_ty, b_ty, max_mpy_count, mpys, rest); - mpy_count += find_mpy_ops(cast(op.type(), add->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest); - return mpy_count; - } else if (const Call *wadd = Call::as_intrinsic(stripped, {Call::widen_right_add})) { - int mpy_count = 0; - mpy_count += find_mpy_ops(wadd->args[0], a_ty, b_ty, max_mpy_count, mpys, rest); - mpy_count += find_mpy_ops(cast(op.type(), wadd->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest); - return mpy_count; + return false; + }; + + if (const Mul *mul = op.as()) { + bool no_overflow = mul->type.can_represent(constant_integer_bounds(mul->a) * + constant_integer_bounds(mul->b)); + if (no_overflow && handle_mul(mul->a, mul->b)) { + return 1; + } + } else if (const Call *mul = Call::as_intrinsic(op, {Call::widening_mul, Call::widen_right_mul})) { + bool no_overflow = (mul->is_intrinsic(Call::widening_mul) || + mul->type.can_represent(constant_integer_bounds(mul->args[0]) * + constant_integer_bounds(mul->args[1]))); + if (no_overflow && handle_mul(mul->args[0], mul->args[1])) { + return 1; + } + } else if (const Add *add = op.as()) { + bool no_overflow = (add->type == result_ty || + add->type.can_represent(constant_integer_bounds(add->a) + + constant_integer_bounds(add->b))); + if (no_overflow) { + return (find_mpy_ops(add->a, result_ty, a_ty, b_ty, max_mpy_count, mpys, rest) + + find_mpy_ops(add->b, result_ty, a_ty, b_ty, max_mpy_count, mpys, rest)); + } + } else if (const Call *add = Call::as_intrinsic(op, {Call::widening_add, Call::widen_right_add})) { + bool no_overflow = (add->type == result_ty || + add->is_intrinsic(Call::widening_add) || + add->type.can_represent(constant_integer_bounds(add->args[0]) + + constant_integer_bounds(add->args[1]))); + if (no_overflow) { + return (find_mpy_ops(add->args[0], result_ty, a_ty, b_ty, max_mpy_count, mpys, rest) + + find_mpy_ops(add->args[1], result_ty, a_ty, b_ty, max_mpy_count, mpys, rest)); + } + } else if (const Cast *cast = op.as()) { + bool cast_is_lossless = cast->type.can_represent(constant_integer_bounds(cast->value)); + if (cast_is_lossless) { + return find_mpy_ops(cast->value, result_ty, a_ty, b_ty, max_mpy_count, mpys, rest); + } } // Attempt to pretend this op is multiplied by 1. @@ -472,7 +494,7 @@ int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count, } else if (as_b.defined()) { mpys.emplace_back(make_one(a_ty), as_b); } else { - rest = rest.defined() ? Add::make(rest, op) : op; + add_to_rest(op); } return 0; } @@ -555,10 +577,10 @@ class OptimizePatterns : public IRMutator { // match a subset of the expressions that vector*vector // matches. if (op->type.is_uint()) { - mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), UInt(8), 4, mpys, rest); suffix = ".vub.ub"; } else { - mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), Int(8), 4, mpys, rest); suffix = ".vub.b"; } @@ -589,7 +611,7 @@ class OptimizePatterns : public IRMutator { new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern); } if (rest.defined()) { - new_expr = Add::make(new_expr, rest); + new_expr = Add::make(new_expr, find_intrinsics(rest)); } return mutate(new_expr); } @@ -599,10 +621,10 @@ class OptimizePatterns : public IRMutator { mpys.clear(); rest = Expr(); if (op->type.is_uint()) { - mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest); suffix = ".vub.vub"; } else { - mpy_count = find_mpy_ops(op, Int(8, lanes), Int(8, lanes), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, Int(8, lanes), Int(8, lanes), 4, mpys, rest); suffix = ".vb.vb"; } @@ -632,7 +654,7 @@ class OptimizePatterns : public IRMutator { new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern); } if (rest.defined()) { - new_expr = Add::make(new_expr, rest); + new_expr = Add::make(new_expr, find_intrinsics(rest)); } return mutate(new_expr); } @@ -651,11 +673,11 @@ class OptimizePatterns : public IRMutator { // Try to find vector*scalar multiplies. if (op->type.bits() == 16) { - mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 2, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), Int(8), 2, mpys, rest); vmpa_suffix = ".vub.vub.b.b"; vdmpy_suffix = ".vub.b"; } else if (op->type.bits() == 32) { - mpy_count = find_mpy_ops(op, Int(16, lanes), Int(8), 2, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, Int(16, lanes), Int(8), 2, mpys, rest); vmpa_suffix = ".vh.vh.b.b"; vdmpy_suffix = ".vh.b"; } @@ -683,7 +705,7 @@ class OptimizePatterns : public IRMutator { new_expr = halide_hexagon_add_2mpy(op->type, vmpa_suffix, mpys[0].first, mpys[1].first, mpys[0].second, mpys[1].second); } if (rest.defined()) { - new_expr = Add::make(new_expr, rest); + new_expr = Add::make(new_expr, find_intrinsics(rest)); } return mutate(new_expr); } @@ -2272,6 +2294,9 @@ Stmt scatter_gather_generator(Stmt s) { } Stmt optimize_hexagon_instructions(Stmt s, const Target &t) { + debug(4) << "Hexagon: lowering before find_intrinsics\n" + << s << "\n"; + // We need to redo intrinsic matching due to simplification that has // happened after the end of target independent lowering. s = find_intrinsics(s); From c8f7e8f36969737e210904c84aad25c63baf29fc Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 3 Jun 2024 11:39:47 -0700 Subject: [PATCH 29/33] Fix bugs in lossless_cast --- src/IROperator.cpp | 77 ++++++++++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index cb3397b8652a..a9ce3fdc59b3 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -444,33 +444,23 @@ Expr lossless_cast(Type t, Expr e, std::map } else if (const Cast *c = e.as()) { if (c->type.can_represent(c->value.type())) { return lossless_cast(t, c->value, cache); - } else { - return Expr(); } } else if (const Broadcast *b = e.as()) { Expr v = lossless_cast(t.element_of(), b->value, cache); if (v.defined()) { return Broadcast::make(v, b->lanes); - } else { - return Expr(); } } else if (const IntImm *i = e.as()) { if (t.can_represent(i->value)) { return make_const(t, i->value); - } else { - return Expr(); } } else if (const UIntImm *i = e.as()) { if (t.can_represent(i->value)) { return make_const(t, i->value); - } else { - return Expr(); } } else if (const FloatImm *f = e.as()) { if (t.can_represent(f->value)) { return make_const(t, f->value); - } else { - return Expr(); } } else if (const Shuffle *shuf = e.as()) { std::vector vecs; @@ -484,16 +474,7 @@ Expr lossless_cast(Type t, Expr e, std::map } else if (t.is_int_or_uint()) { // Check the bounds. If they're small enough, we can throw narrowing // casts around e, or subterms. - ConstantInterval ci; - if (cache) { - auto [it, cache_miss] = cache->try_emplace(e); - if (cache_miss) { - it->second = constant_integer_bounds(e, Scope::empty_scope(), cache); - } - ci = it->second; - } else { - ci = constant_integer_bounds(e); - } + ConstantInterval ci = constant_integer_bounds(e, Scope::empty_scope(), cache); if (t.can_represent(ci)) { // There are certain IR nodes where if the result is expressible @@ -503,25 +484,62 @@ Expr lossless_cast(Type t, Expr e, std::map Expr a = lossless_cast(t, op->a, cache); Expr b = lossless_cast(t, op->b, cache); if (a.defined() && b.defined()) { - return a + b; + return Add::make(a, b); } } else if (const Sub *op = e.as()) { Expr a = lossless_cast(t, op->a, cache); Expr b = lossless_cast(t, op->b, cache); if (a.defined() && b.defined()) { - return a - b; + return Sub::make(a, b); } } else if (const Mul *op = e.as()) { Expr a = lossless_cast(t, op->a, cache); Expr b = lossless_cast(t, op->b, cache); if (a.defined() && b.defined()) { - return a * b; + return Mul::make(a, b); + } + } else if (const Min *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + debug(0) << a << " " << b << "\n"; + return Min::make(a, b); + } + } else if (const Max *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Max::make(a, b); + } + } else if (const Mod *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Mod::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + return Add::make(a, b); } - } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add})) { + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) { Expr a = lossless_cast(t, op->args[0], cache); Expr b = lossless_cast(t, op->args[1], cache); if (a.defined() && b.defined()) { - return a + b; + return Sub::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + return Mul::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + return a << b; } } else if (const VectorReduce *op = e.as()) { if (op->op == VectorReduce::Add || @@ -534,7 +552,14 @@ Expr lossless_cast(Type t, Expr e, std::map } } - return cast(t, e); + // At this point we know the expression fits in the target type, but + // what we really want is for the expression to be computed in the + // target type. So we can add a cast to the target type if we want + // here, but it only makes sense to do it if the expression type has + // the same or fewer bits than the target type. + if (e.type().bits() <= t.bits()) { + return cast(t, e); + } } } From 9570818c7e1c0a8daf8fcd3197e411842ac3579a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 3 Jun 2024 12:09:56 -0700 Subject: [PATCH 30/33] Fix mul_shift_right expansion --- src/FindIntrinsics.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 065498714b84..6f2407fdd4c5 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1472,23 +1472,25 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) // Try to scale up the args by factors of two without overflowing int a_shift = 0, b_shift = 0; ConstantInterval ca = constant_integer_bounds(a); - do { + while (true) { ConstantInterval bigger = ca * 2; if (a.type().can_represent(bigger) && a_shift + b_shift < missing_q) { ca = bigger; a_shift++; - continue; + } else { + break; } - } while (false); + } ConstantInterval cb = constant_integer_bounds(b); - do { + while (true) { ConstantInterval bigger = cb * 2; if (b.type().can_represent(bigger) && a_shift + b_shift < missing_q) { cb = bigger; b_shift++; - continue; + } else { + break; } - } while (false); + } if (a_shift + b_shift == missing_q) { return rounding_mul_shift_right(simplify(a << a_shift), simplify(b << b_shift), full_q); } From 7414ee66d7eff79d877eeba682120967f44981e6 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 3 Jun 2024 13:37:04 -0700 Subject: [PATCH 31/33] Delete commented-out code --- src/FindIntrinsics.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 6f2407fdd4c5..b72122460706 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -552,7 +552,6 @@ class FindIntrinsics : public IRMutator { auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y)); - // auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2); if ( // Saturating patterns rewrite(max(min(widening_add(x, y), upper), lower), From c33dbfbb08225661279c31745fd33ac5b73f2011 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 4 Jun 2024 16:01:15 -0700 Subject: [PATCH 32/33] Don't introduce out-of-range shifts in lossless_cast --- src/IROperator.cpp | 12 +++++++-- test/correctness/lossless_cast.cpp | 39 ++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index a9ce3fdc59b3..05dc329fe7ef 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -535,11 +535,19 @@ Expr lossless_cast(Type t, Expr e, std::map if (a.defined() && b.defined()) { return Mul::make(a, b); } - } else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left})) { + } else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left, + Call::shift_right, Call::widening_shift_right})) { Expr a = lossless_cast(t, op->args[0], cache); Expr b = lossless_cast(t, op->args[1], cache); if (a.defined() && b.defined()) { - return a << b; + ConstantInterval cb = constant_integer_bounds(b, Scope::empty_scope(), cache); + if (cb > -t.bits() && cb < t.bits()) { + if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { + return a << b; + } else if (op->is_intrinsic({Call::shift_right, Call::widening_shift_right})) { + return a >> b; + } + } } } else if (const VectorReduce *op = e.as()) { if (op->op == VectorReduce::Add || diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index 58d0c9c6ddf3..f9fb492a8e7a 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -231,6 +231,25 @@ Expr random_expr(std::mt19937 &rng) { } } +bool definitely_has_ub(Expr e) { + e = simplify(e); + + class HasOverflow : public IRVisitor { + void visit(const Call *op) override { + if (op->is_intrinsic({Call::signed_integer_overflow})) { + found = true; + } + IRVisitor::visit(op); + } + + public: + bool found = false; + } has_overflow; + e.accept(&has_overflow); + + return has_overflow.found; +} + bool might_have_ub(Expr e) { class MightOverflow : public IRVisitor { std::map cache; @@ -328,7 +347,7 @@ int test_one(uint32_t seed) { Expr e1 = random_expr(rng); - if (might_have_ub(e1)) { + if (might_have_ub(e1) || might_have_ub(simplify(e1))) { return 0; } @@ -344,12 +363,26 @@ int test_one(uint32_t seed) { return 0; } + if (definitely_has_ub(e2)) { + std::cout << "lossless_cast introduced ub:\n" + << "seed = " << seed << "\n" + << "e1 = " << e1 << "\n" + << "e2 = " << e2 << "\n" + << "simplify(e1) = " << simplify(e1) << "\n" + << "simplify(e2) = " << simplify(e2) << "\n"; + return 1; + } + Func f; f(x) = {cast(e1), cast(e2)}; f.vectorize(x, 4, TailStrategy::RoundUp); Buffer out1(size), out2(size); Pipeline p(f); + + // Check for signed integer overflow + // Module m = p.compile_to_module({}, "test"); + p.realize({out1, out2}); for (int x = 0; x < size; x++) { @@ -397,7 +430,9 @@ int fuzz_test(uint32_t root_seed) { std::cout << "Fuzz testing with root seed " << root_seed << "\n"; for (int i = 0; i < 1000; i++) { - if (test_one(seed_generator())) { + auto s = seed_generator(); + std::cout << s << "\n"; + if (test_one(s)) { return 1; } } From 0b561c7dcbedc8cd4d7e1285f6894643d7500087 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 10 Jun 2024 09:58:11 -0700 Subject: [PATCH 33/33] Some constant folding only happens after lowering intrinsics in codegen --- test/correctness/lossless_cast.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index f9fb492a8e7a..22d3506d7859 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -346,8 +346,11 @@ int test_one(uint32_t seed) { buf_i8.fill(rng); Expr e1 = random_expr(rng); + Expr simplified = simplify(e1); - if (might_have_ub(e1) || might_have_ub(simplify(e1))) { + if (might_have_ub(e1) || + might_have_ub(simplified) || + might_have_ub(lower_intrinsics(simplified))) { return 0; }