diff --git a/Makefile b/Makefile index 17e8a80e1ca4..440b307a920e 100644 --- a/Makefile +++ b/Makefile @@ -477,6 +477,8 @@ SOURCE_FILES = \ CodeGen_WebGPU_Dev.cpp \ CodeGen_X86.cpp \ CompilerLogger.cpp \ + ConstantBounds.cpp \ + ConstantInterval.cpp \ CPlusPlusMangle.cpp \ CSE.cpp \ Debug.cpp \ @@ -671,6 +673,8 @@ HEADER_FILES = \ CompilerLogger.h \ ConciseCasts.h \ CPlusPlusMangle.h \ + ConstantBounds.h \ + ConstantInterval.h \ CSE.h \ Debug.h \ DebugArguments.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 557574f284c4..2f410244d2b0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -45,6 +45,8 @@ set(HEADER_FILES CompilerLogger.h ConciseCasts.h CPlusPlusMangle.h + ConstantBounds.h + ConstantInterval.h CSE.h Debug.h DebugArguments.h @@ -219,6 +221,8 @@ set(SOURCE_FILES CodeGen_X86.cpp CompilerLogger.cpp CPlusPlusMangle.cpp + ConstantBounds.cpp + ConstantInterval.cpp CSE.cpp Debug.cpp DebugArguments.cpp diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp new file mode 100644 index 000000000000..11d1a42133a9 --- /dev/null +++ b/src/ConstantBounds.cpp @@ -0,0 +1,170 @@ +#include "ConstantBounds.h" +#include "IR.h" +#include "IROperator.h" +#include "IRPrinter.h" + +namespace Halide { +namespace Internal { + +namespace { +ConstantInterval bounds_helper(const Expr &e, + Scope &scope, + std::map *cache) { + internal_assert(e.defined()); + + auto recurse = [&](const Expr &e) { + return bounds_helper(e, scope, cache); + }; + + auto get_infinite_bounds = [&]() { + // Compute the bounds of each IR node from the bounds of its args. Math + // on ConstantInterval is in terms of infinite integers. + if (const UIntImm *op = e.as()) { + if (Int(64).can_represent(op->value)) { + return ConstantInterval::single_point((int64_t)(op->value)); + } + } else if (const IntImm *op = e.as()) { + return ConstantInterval::single_point(op->value); + } else if (const Variable *op = e.as()) { + if (const auto *in = scope.find(op->name)) { + return *in; + } + } else if (const Add *op = e.as()) { + return recurse(op->a) + recurse(op->b); + } else if (const Sub *op = e.as()) { + return recurse(op->a) - recurse(op->b); + } else if (const Mul *op = e.as()) { + return recurse(op->a) * recurse(op->b); + } else if (const Div *op = e.as
()) { + return recurse(op->a) / recurse(op->b); + } else if (const Mod *op = e.as()) { + return recurse(op->a) % recurse(op->b); + } else if (const Min *op = e.as()) { + return min(recurse(op->a), recurse(op->b)); + } else if (const Max *op = e.as()) { + return max(recurse(op->a), recurse(op->b)); + } else if (const Cast *op = e.as()) { + return recurse(op->value); + } else if (const Broadcast *op = e.as()) { + return recurse(op->value); + } else if (const VectorReduce *op = e.as()) { + int f = op->value.type().lanes() / op->type.lanes(); + ConstantInterval factor(f, f); + ConstantInterval arg_bounds = recurse(op->value); + switch (op->op) { + case VectorReduce::Add: + return arg_bounds * factor; + case VectorReduce::SaturatingAdd: + return saturating_cast(op->type, arg_bounds * factor); + case VectorReduce::Min: + case VectorReduce::Max: + case VectorReduce::And: + case VectorReduce::Or: + return arg_bounds; + default:; + } + } else if (const Shuffle *op = e.as()) { + ConstantInterval arg_bounds = recurse(op->vectors[0]); + for (size_t i = 1; i < op->vectors.size(); i++) { + arg_bounds.include(recurse(op->vectors[i])); + } + return arg_bounds; + } else if (const Let *op = e.as()) { + ScopedBinding bind(scope, op->name, recurse(op->value)); + return recurse(op->body); + } else if (const Call *op = e.as()) { + ConstantInterval result; + if (op->is_intrinsic(Call::abs)) { + return abs(recurse(op->args[0])); + } else if (op->is_intrinsic(Call::absd)) { + return abs(recurse(op->args[0]) - recurse(op->args[1])); + } else if (op->is_intrinsic(Call::count_leading_zeros) || + op->is_intrinsic(Call::count_trailing_zeros)) { + // Conservatively just say it's the potential number of zeros in the type. + return ConstantInterval(0, op->args[0].type().bits()); + } else if (op->is_intrinsic(Call::halving_add)) { + return (recurse(op->args[0]) + recurse(op->args[1])) / 2; + } else if (op->is_intrinsic(Call::halving_sub)) { + return (recurse(op->args[0]) - recurse(op->args[1])) / 2; + } else if (op->is_intrinsic(Call::rounding_halving_add)) { + return (recurse(op->args[0]) + recurse(op->args[1]) + 1) / 2; + } else if (op->is_intrinsic(Call::saturating_add)) { + return saturating_cast(op->type, + (recurse(op->args[0]) + + recurse(op->args[1]))); + } else if (op->is_intrinsic(Call::saturating_sub)) { + return saturating_cast(op->type, + (recurse(op->args[0]) - + recurse(op->args[1]))); + } else if (op->is_intrinsic({Call::widening_add, Call::widen_right_add})) { + return recurse(op->args[0]) + recurse(op->args[1]); + } else if (op->is_intrinsic({Call::widening_sub, Call::widen_right_sub})) { + return recurse(op->args[0]) - recurse(op->args[1]); + } else if (op->is_intrinsic({Call::widening_mul, Call::widen_right_mul})) { + return recurse(op->args[0]) * recurse(op->args[1]); + } else if (op->is_intrinsic({Call::shift_right, Call::widening_shift_right})) { + return recurse(op->args[0]) >> recurse(op->args[1]); + } else if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { + return recurse(op->args[0]) << recurse(op->args[1]); + } else if (op->is_intrinsic({Call::rounding_shift_right, Call::rounding_shift_left})) { + ConstantInterval ca = recurse(op->args[0]); + ConstantInterval cb = recurse(op->args[1]); + if (op->is_intrinsic(Call::rounding_shift_left)) { + cb = -cb; + } + ConstantInterval rounding_term = 1 << (cb - 1); + // Note if cb is <= 0, rounding_term is zero. + return (ca + rounding_term) >> cb; + } else if (op->is_intrinsic(Call::mul_shift_right)) { + ConstantInterval ca = recurse(op->args[0]); + ConstantInterval cb = recurse(op->args[1]); + ConstantInterval cq = recurse(op->args[2]); + return (ca * cb) >> cq; + } else if (op->is_intrinsic(Call::rounding_mul_shift_right)) { + ConstantInterval ca = recurse(op->args[0]); + ConstantInterval cb = recurse(op->args[1]); + ConstantInterval cq = recurse(op->args[2]); + ConstantInterval rounding_term = 1 << (cq - 1); + return (ca * cb + rounding_term) >> cq; + } + // If you add a new intrinsic here, also add it to the expression + // generator in test/correctness/lossless_cast.cpp + } + + return ConstantInterval::bounds_of_type(e.type()); + }; + + auto get_typed_bounds = [&]() { + return cast(e.type(), get_infinite_bounds()); + }; + + ConstantInterval ret; + if (cache) { + auto [it, cache_miss] = cache->try_emplace(e); + if (cache_miss) { + it->second = get_typed_bounds(); + } + ret = it->second; + } else { + ret = get_typed_bounds(); + } + + internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) && + (!ret.max_defined || e.type().can_represent(ret.max))) + << "constant_bounds returned defined bounds that are not representable in " + << "the type of the Expr passed in.\n Expr: " << e << "\n Bounds: " << ret; + + return ret; +} +} // namespace + +ConstantInterval constant_integer_bounds(const Expr &e, + const Scope &scope, + std::map *cache) { + Scope sub_scope; + sub_scope.set_containing_scope(&scope); + return bounds_helper(e, sub_scope, cache); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/ConstantBounds.h b/src/ConstantBounds.h new file mode 100644 index 000000000000..26ab114455bb --- /dev/null +++ b/src/ConstantBounds.h @@ -0,0 +1,35 @@ +#ifndef HALIDE_CONSTANT_BOUNDS_H +#define HALIDE_CONSTANT_BOUNDS_H + +#include "ConstantInterval.h" +#include "Expr.h" +#include "Scope.h" + +/** \file + * Methods for computing compile-time constant int64_t upper and lower bounds of + * an expression. Cheaper than symbolic bounds inference, and useful for things + * like instruction selection. + */ + +namespace Halide { +namespace Internal { + +/** Deduce constant integer bounds on an expression. This can be useful to + * decide if, for example, the expression can be cast to another type, be + * negated, be incremented, etc without risking overflow. + * + * Also optionally accepts a scope containing the integer bounds of any + * variables that may be referenced, and a cache of constant integer bounds on + * known Exprs, which this function will update. The cache is helpful to + * short-circuit large numbers of redundant queries, but it should not be used + * in contexts where the same Expr object may take on different values within a + * single Expr (i.e. before uniquify_variable_names). + */ +ConstantInterval constant_integer_bounds(const Expr &e, + const Scope &scope = Scope::empty_scope(), + std::map *cache = nullptr); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/ConstantInterval.cpp b/src/ConstantInterval.cpp new file mode 100644 index 000000000000..8bdeec69325c --- /dev/null +++ b/src/ConstantInterval.cpp @@ -0,0 +1,680 @@ +#include "ConstantInterval.h" + +#include "Error.h" +#include "IROperator.h" +#include "IRPrinter.h" + +namespace Halide { +namespace Internal { + +ConstantInterval::ConstantInterval() = default; + +ConstantInterval::ConstantInterval(int64_t min, int64_t max) + : min(min), max(max), min_defined(true), max_defined(true) { + internal_assert(min <= max); +} + +ConstantInterval ConstantInterval::everything() { + return ConstantInterval(); +} + +ConstantInterval ConstantInterval::single_point(int64_t x) { + return ConstantInterval(x, x); +} + +ConstantInterval ConstantInterval::bounded_below(int64_t min) { + ConstantInterval result(min, min); + result.max_defined = false; + result.max = 0; + return result; +} + +ConstantInterval ConstantInterval::bounded_above(int64_t max) { + ConstantInterval result(max, max); + result.min_defined = false; + result.min = 0; + return result; +} + +bool ConstantInterval::is_everything() const { + return !min_defined && !max_defined; +} + +bool ConstantInterval::is_single_point() const { + return min_defined && max_defined && min == max; +} + +bool ConstantInterval::is_single_point(int64_t x) const { + return min_defined && max_defined && min == x && max == x; +} + +bool ConstantInterval::is_bounded() const { + return max_defined && min_defined; +} + +bool ConstantInterval::operator==(const ConstantInterval &other) const { + if (min_defined != other.min_defined || max_defined != other.max_defined) { + return false; + } + return (!min_defined || min == other.min) && (!max_defined || max == other.max); +} + +void ConstantInterval::include(const ConstantInterval &i) { + if (max_defined && i.max_defined) { + max = std::max(max, i.max); + } else { + max_defined = false; + } + if (min_defined && i.min_defined) { + min = std::min(min, i.min); + } else { + min_defined = false; + } +} + +void ConstantInterval::include(int64_t x) { + if (max_defined) { + max = std::max(max, x); + } + if (min_defined) { + min = std::min(min, x); + } +} + +bool ConstantInterval::contains(int64_t x) const { + const bool too_small = min_defined && x < min; + const bool too_large = max_defined && x > max; + return !(too_small || too_large); +} + +bool ConstantInterval::contains(int32_t x) const { + return contains((int64_t)x); +} + +bool ConstantInterval::contains(uint64_t x) const { + if (x <= (uint64_t)std::numeric_limits::max()) { + // Representable as an int64_t, so just defer to that method. + return contains((int64_t)x); + } else { + // This uint64_t is not representable as an int64_t, which means it's + // greater than 2^32 - 1. Given that we can't represent that as a bound, + // the best we can do is checking if the interval is unbounded above. + return !max_defined; + } +} + +ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result = a; + result.include(b); + return result; +} + +ConstantInterval ConstantInterval::make_intersection(const ConstantInterval &a, + const ConstantInterval &b) { + ConstantInterval result; + if (a.min_defined) { + if (b.min_defined) { + result.min = std::max(a.min, b.min); + } else { + result.min = a.min; + } + result.min_defined = true; + } else { + result.min_defined = b.min_defined; + result.min = b.min; + } + if (a.max_defined) { + if (b.max_defined) { + result.max = std::min(a.max, b.max); + } else { + result.max = a.max; + } + result.max_defined = true; + } else { + result.max_defined = b.max_defined; + result.max = b.max; + } + // Our class invariant is that whenever they're both defined, min <= + // max. Intersection is the only method that could break that, and it + // happens when the intersected intervals do not overlap. + internal_assert(!result.is_bounded() || result.min <= result.max) + << "Empty ConstantInterval constructed in make_intersection"; + return result; +} + +void ConstantInterval::operator+=(const ConstantInterval &other) { + (*this) = (*this) + other; +} + +void ConstantInterval::operator-=(const ConstantInterval &other) { + (*this) = (*this) - other; +} + +void ConstantInterval::operator*=(const ConstantInterval &other) { + (*this) = (*this) * other; +} + +void ConstantInterval::operator/=(const ConstantInterval &other) { + (*this) = (*this) / other; +} + +void ConstantInterval::operator%=(const ConstantInterval &other) { + (*this) = (*this) % other; +} + +void ConstantInterval::operator+=(int64_t x) { + (*this) = (*this) + x; +} + +void ConstantInterval::operator-=(int64_t x) { + (*this) = (*this) - x; +} + +void ConstantInterval::operator*=(int64_t x) { + (*this) = (*this) * x; +} + +void ConstantInterval::operator/=(int64_t x) { + (*this) = (*this) / x; +} + +void ConstantInterval::operator%=(int64_t x) { + (*this) = (*this) % x; +} + +bool operator<=(const ConstantInterval &a, const ConstantInterval &b) { + return a.max_defined && b.min_defined && a.max <= b.min; +} +bool operator<(const ConstantInterval &a, const ConstantInterval &b) { + return a.max_defined && b.min_defined && a.max < b.min; +} + +bool operator<=(const ConstantInterval &a, int64_t b) { + return a.max_defined && a.max <= b; +} +bool operator<(const ConstantInterval &a, int64_t b) { + return a.max_defined && a.max < b; +} + +bool operator<=(int64_t a, const ConstantInterval &b) { + return b.min_defined && a <= b.min; +} +bool operator<(int64_t a, const ConstantInterval &b) { + return b.min_defined && a < b.min; +} + +void ConstantInterval::cast_to(const Type &t) { + if (!t.can_represent(*this)) { + // We have potential overflow or underflow, return the entire bounds of + // the type. + ConstantInterval type_bounds; + if (t.is_int()) { + if (t.bits() <= 64) { + type_bounds.min_defined = type_bounds.max_defined = true; + type_bounds.min = ((int64_t)(-1)) << (t.bits() - 1); + type_bounds.max = ~type_bounds.min; + } + } else if (t.is_uint()) { + type_bounds.min_defined = true; + type_bounds.min = 0; + if (t.bits() < 64) { + type_bounds.max_defined = true; + type_bounds.max = (((int64_t)(1)) << t.bits()) - 1; + } + } + // If it's not int or uint, we're setting this to a default-constructed + // ConstantInterval, which is everything. + *this = type_bounds; + } +} + +ConstantInterval ConstantInterval::operator-() const { + ConstantInterval result; + if (min_defined && min != INT64_MIN) { + result.max_defined = true; + result.max = -min; + } + if (max_defined) { + result.min_defined = true; + result.min = -max; + } + return result; +} + +ConstantInterval ConstantInterval::bounds_of_type(Type t) { + return cast(t, ConstantInterval::everything()); +} + +ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + result.min_defined = a.min_defined && + b.min_defined && + add_with_overflow(64, a.min, b.min, &result.min); + + result.max_defined = a.max_defined && + b.max_defined && + add_with_overflow(64, a.max, b.max, &result.max); + return result; +} + +ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + result.min_defined = a.min_defined && + b.max_defined && + sub_with_overflow(64, a.min, b.max, &result.min); + result.max_defined = a.max_defined && + b.min_defined && + sub_with_overflow(64, a.max, b.min, &result.max); + return result; +} + +ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + + result.min = INT64_MAX; + result.max = INT64_MIN; + + auto consider_case = [&](int64_t a, int64_t b) { + int64_t v = div_imp(a, b); + result.min = std::min(result.min, v); + result.max = std::max(result.max, v); + }; + + // Enumerate all possible values for the min and max and take the extreme values. + if (a.min_defined && b.min_defined && b.min != 0) { + consider_case(a.min, b.min); + } + + if (a.min_defined && b.max_defined && b.max != 0) { + consider_case(a.min, b.max); + } + + if (a.max_defined && b.max_defined && b.max != 0) { + consider_case(a.max, b.max); + } + + if (a.max_defined && b.min_defined && b.min != 0) { + consider_case(a.max, b.min); + } + + // Define an int64_t zero just to pacify std::min and std::max + constexpr int64_t zero = 0; + + const bool b_positive = b > 0; + const bool b_negative = b < 0; + if ((b_positive && !b.max_defined) || + (b_negative && !b.min_defined)) { + // Take limit as other -> +/- infinity + result.min = std::min(result.min, zero); + result.max = std::max(result.max, zero); + } + + result.min_defined = ((a.min_defined && b_positive) || + (a.max_defined && b_negative)); + result.max_defined = ((a.max_defined && b_positive) || + (a.min_defined && b_negative)); + + // That's as far as we can get knowing the sign of the + // denominator. For bounded numerators, we additionally know + // that div can't make anything larger in magnitude, so we can + // take the intersection with that. + if (a.is_bounded() && a.min != INT64_MIN) { + int64_t magnitude = std::max(a.max, -a.min); + if (result.min_defined) { + result.min = std::max(result.min, -magnitude); + } else { + result.min = -magnitude; + } + if (result.max_defined) { + result.max = std::min(result.max, magnitude); + } else { + result.max = magnitude; + } + result.min_defined = result.max_defined = true; + } + + // Finally we can deduce the sign if the numerator and denominator are + // non-positive or non-negative. + bool a_non_negative = a >= 0; + bool b_non_negative = b >= 0; + bool a_non_positive = a <= 0; + bool b_non_positive = b <= 0; + if ((a_non_negative && b_non_negative) || + (a_non_positive && b_non_positive)) { + if (result.min_defined) { + result.min = std::max(result.min, zero); + } else { + result.min_defined = true; + result.min = 0; + } + } else if ((a_non_negative && b_non_positive) || + (a_non_positive && b_non_negative)) { + if (result.max_defined) { + result.max = std::min(result.max, zero); + } else { + result.max_defined = true; + result.max = 0; + } + } + + // Normalize the values if it's undefined + if (!result.min_defined) { + result.min = 0; + } + if (!result.max_defined) { + result.max = 0; + } + + // Check the class invariant as a sanity check. + internal_assert(!result.is_bounded() || (result.min <= result.max)); + + return result; +} + +ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + + // Compute a possible extreme value of the product, either incorporating it + // into result.min / result.max, or setting the min/max defined flags if it + // overflows. + auto consider_case = [&](int64_t a, int64_t b) { + int64_t c; + if (mul_with_overflow(64, a, b, &c)) { + result.min = std::min(result.min, c); + result.max = std::max(result.max, c); + } else if ((a > 0) == (b > 0)) { + result.max_defined = false; + } else { + result.min_defined = false; + } + }; + + result.min_defined = result.max_defined = true; + result.min = INT64_MAX; + result.max = INT64_MIN; + if (a.min_defined && b.min_defined) { + consider_case(a.min, b.min); + } + if (a.min_defined && b.max_defined) { + consider_case(a.min, b.max); + } + if (a.max_defined && b.min_defined) { + consider_case(a.max, b.min); + } + if (a.max_defined && b.max_defined) { + consider_case(a.max, b.max); + } + + const bool a_bounded_negative = a.min_defined && a <= 0; + const bool a_bounded_positive = a.max_defined && a >= 0; + const bool b_bounded_negative = b.min_defined && b <= 0; + const bool b_bounded_positive = b.max_defined && b >= 0; + + if (result.min_defined) { + result.min_defined = + ((a.is_bounded() && b.is_bounded()) || + (a >= 0 && b >= 0) || + (a <= 0 && b <= 0) || + (a.min_defined && b_bounded_positive) || + (b.min_defined && a_bounded_positive) || + (a.max_defined && b_bounded_negative) || + (b.max_defined && a_bounded_negative)); + } + + if (result.max_defined) { + result.max_defined = + ((a.is_bounded() && b.is_bounded()) || + (a >= 0 && b <= 0) || + (a <= 0 && b >= 0) || + (a.max_defined && b_bounded_positive) || + (b.max_defined && a_bounded_positive) || + (a.min_defined && b_bounded_negative) || + (b.min_defined && a_bounded_negative)); + } + + if (!result.min_defined) { + result.min = 0; + } + + if (!result.max_defined) { + result.max = 0; + } + + // Check the class invariant as a sanity check. + internal_assert(!result.is_bounded() || (result.min <= result.max)); + + return result; +} + +ConstantInterval operator%(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + + // Maybe the mod won't actually do anything + if (a >= 0 && a < abs(b)) { + return a; + } + + // The result is at least zero. + result.min_defined = true; + result.min = 0; + + // Mod by produces a result between 0 + // and max(0, abs(modulus) - 1). However, if b is unbounded in + // either direction, abs(modulus) could be arbitrarily + // large. + if (b.is_bounded() && b.max != INT64_MIN) { + result.max_defined = true; + result.max = 0; // When b == 0 + result.max = std::max(result.max, b.max - 1); // When b > 0 + result.max = std::max(result.max, ~b.min); // When b < 0 + // Note that ~b.min is equal to (-1 - b.min). It's written as ~b.min to + // make it clear that it can't overflow. + } + + // If a is positive, mod can't make it larger + if (a.is_bounded() && a.min >= 0) { + if (result.max_defined) { + result.max = std::min(result.max, a.max); + } else { + result.max_defined = true; + result.max = a.max; + } + } + + // Check the class invariant as a sanity check. + internal_assert(!result.is_bounded() || (result.min <= result.max)); + + return result; +} + +ConstantInterval operator+(const ConstantInterval &a, int64_t b) { + return a + ConstantInterval::single_point(b); +} + +ConstantInterval operator-(const ConstantInterval &a, int64_t b) { + return a - ConstantInterval::single_point(b); +} + +ConstantInterval operator/(const ConstantInterval &a, int64_t b) { + return a / ConstantInterval::single_point(b); +} + +ConstantInterval operator*(const ConstantInterval &a, int64_t b) { + return a * ConstantInterval::single_point(b); +} + +ConstantInterval operator%(const ConstantInterval &a, int64_t b) { + return a % ConstantInterval::single_point(b); +} + +ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + result.max_defined = a.max_defined || b.max_defined; + result.min_defined = a.min_defined && b.min_defined; + if (a.max_defined && b.max_defined) { + result.max = std::min(a.max, b.max); + } else if (a.max_defined) { + result.max = a.max; + } else if (b.max_defined) { + result.max = b.max; + } + if (a.min_defined && b.min_defined) { + result.min = std::min(a.min, b.min); + } + return result; +} + +ConstantInterval min(const ConstantInterval &a, int64_t b) { + ConstantInterval result = a; + if (result.max_defined) { + result.max = std::min(a.max, b); + } else { + result.max = b; + result.max_defined = true; + } + if (result.min_defined) { + result.min = std::min(a.min, b); + } + return result; +} + +ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b) { + ConstantInterval result; + result.min_defined = a.min_defined || b.min_defined; + result.max_defined = a.max_defined && b.max_defined; + if (a.min_defined && b.min_defined) { + result.min = std::max(a.min, b.min); + } else if (a.min_defined) { + result.min = a.min; + } else if (b.min_defined) { + result.min = b.min; + } + if (a.max_defined && b.max_defined) { + result.max = std::max(a.max, b.max); + } + return result; +} + +ConstantInterval max(const ConstantInterval &a, int64_t b) { + ConstantInterval result = a; + if (result.min_defined) { + result.min = std::max(a.min, b); + } else { + result.min = b; + result.min_defined = true; + } + if (result.max_defined) { + result.max = std::max(a.max, b); + } + return result; +} + +ConstantInterval abs(const ConstantInterval &a) { + ConstantInterval result; + if (a.min_defined && a.max_defined && a.min != INT64_MIN) { + result.max_defined = true; + result.max = std::max(-a.min, a.max); + } + result.min_defined = true; + if (a.min_defined && a.min > 0) { + result.min = a.min; + } else if (a.max_defined && a.max < 0 && a.max != INT64_MIN) { + result.min = -a.max; + } else { + result.min = 0; + } + + return result; +} + +ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b) { + // In infinite integers (with no overflow): + + // a << b == a * 2^b + + // This can't be used directly, because if b is negative then 2^b is not an + // integer. Instead, we'll break b into a difference of two positive values: + // b = b_pos - b_neg + // So + // a * 2^b + // = a * 2^(b_pos - b_neg) + // = (a * 2^b_pos) / 2^b_neg + + // From there we can use the * and / operators. + + ConstantInterval b_pos = max(b, 0), b_neg = max(-b, 0); + + // At this point, we have sliced the interval b into two parts. E.g. + // if b = [10, 12], b_pos = [10, 12] and b_neg = [0, 0] + // if b = [-4, 8], b_pos = [0, 8] and b_neg = [0, 4] + // if b = [-10, -3], b_pos = [0, 0] and b_neg = [3, 10] + // if b = [-3, inf], b_pos = [0, inf] and b_neg = [0, 3] + // In all cases, note that b_pos - b_neg = b by our definition of operator- + // for ConstantIntervals above (ignoring corner cases, for which b_pos - + // b_neg safely over-approximates the bounds of b). + + auto two_to_the = [](const ConstantInterval &i) { + const int64_t one = 1; + ConstantInterval r; + // We should know i is positive at this point. + internal_assert(i.min_defined && i.min >= 0); + r.min_defined = true; + if (i.min >= 63) { + // It's at least a value too large for us to represent, which is not + // the same as min_defined = false. + r.min = INT64_MAX; + } else { + r.min = one << i.min; + } + if (i.max < 63) { + r.max_defined = true; + r.max = one << i.max; + } + return r; + }; + + return (a * two_to_the(b_pos)) / two_to_the(b_neg); +} + +ConstantInterval operator<<(const ConstantInterval &a, int64_t b) { + return a << ConstantInterval::single_point(b); +} + +ConstantInterval operator<<(int64_t a, const ConstantInterval &b) { + return ConstantInterval::single_point(a) << b; +} + +ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b) { + return a << (-b); +} + +ConstantInterval operator>>(const ConstantInterval &a, int64_t b) { + return a >> ConstantInterval::single_point(b); +} + +ConstantInterval operator>>(int64_t a, const ConstantInterval &b) { + return ConstantInterval::single_point(a) >> b; +} + +} // namespace Internal + +using namespace Internal; + +ConstantInterval cast(Type t, const ConstantInterval &a) { + ConstantInterval result = a; + result.cast_to(t); + return result; +} + +ConstantInterval saturating_cast(Type t, const ConstantInterval &a) { + ConstantInterval b = ConstantInterval::bounds_of_type(t); + if (a >= b) { + return ConstantInterval::single_point(b.max); + } else if (a <= b) { + return ConstantInterval::single_point(b.min); + } else { + return ConstantInterval::make_intersection(a, b); + } +} + +} // namespace Halide diff --git a/src/ConstantInterval.h b/src/ConstantInterval.h new file mode 100644 index 000000000000..daa6f0f4dbe0 --- /dev/null +++ b/src/ConstantInterval.h @@ -0,0 +1,176 @@ +#ifndef HALIDE_CONSTANT_INTERVAL_H +#define HALIDE_CONSTANT_INTERVAL_H + +#include + +/** \file + * Defines the ConstantInterval class, and operators on it. + */ + +namespace Halide { + +struct Type; + +namespace Internal { + +/** A class to represent ranges of integers. Can be unbounded above or below, + * but they cannot be empty. */ +struct ConstantInterval { + /** The lower and upper bound of the interval. They are included + * in the interval. */ + int64_t min = 0, max = 0; + bool min_defined = false, max_defined = false; + + /* A default-constructed Interval is everything */ + ConstantInterval(); + + /** Construct an interval from a lower and upper bound. */ + ConstantInterval(int64_t min, int64_t max); + + /** The interval representing everything. */ + static ConstantInterval everything(); + + /** Construct an interval representing a single point. */ + static ConstantInterval single_point(int64_t x); + + /** Construct intervals bounded above or below. */ + static ConstantInterval bounded_below(int64_t min); + static ConstantInterval bounded_above(int64_t max); + + /** Is the interval the entire range */ + bool is_everything() const; + + /** Is the interval just a single value (min == max) */ + bool is_single_point() const; + + /** Is the interval a particular single value */ + bool is_single_point(int64_t x) const; + + /** Does the interval have a finite upper and lower bound */ + bool is_bounded() const; + + /** Expand the interval to include another Interval */ + void include(const ConstantInterval &i); + + /** Expand the interval to include a point */ + void include(int64_t x); + + /** Test if the interval contains a particular value */ + bool contains(int32_t x) const; + + /** Test if the interval contains a particular value */ + bool contains(int64_t x) const; + + /** Test if the interval contains a particular unsigned value */ + bool contains(uint64_t x) const; + + /** Construct the smallest interval containing two intervals. */ + static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b); + + /** Construct the largest interval contained within two intervals. Throws an + * error if the interval is empty. */ + static ConstantInterval make_intersection(const ConstantInterval &a, const ConstantInterval &b); + + /** Equivalent to same_as. Exists so that the autoscheduler can + * compare two map for equality in order to + * cache computations. */ + bool operator==(const ConstantInterval &other) const; + + /** In-place versions of the arithmetic operators below. */ + // @{ + void operator+=(const ConstantInterval &other); + void operator+=(int64_t); + void operator-=(const ConstantInterval &other); + void operator-=(int64_t); + void operator*=(const ConstantInterval &other); + void operator*=(int64_t); + void operator/=(const ConstantInterval &other); + void operator/=(int64_t); + void operator%=(const ConstantInterval &other); + void operator%=(int64_t); + // @} + + /** Negate an interval. */ + ConstantInterval operator-() const; + + /** Track what happens if a constant integer interval is forced to fit into + * a concrete integer type. */ + void cast_to(const Type &t); + + /** Get constant integer bounds on a type. */ + static ConstantInterval bounds_of_type(Type); +}; + +/** Arithmetic operators on ConstantIntervals. The resulting interval contains + * all possible values of the operator applied to any two elements of the + * argument intervals. Note that these operator on unbounded integers. If you + * are applying this to concrete small integer types, you will need to manually + * cast the constant interval back to the desired type to model the effect of + * overflow. */ +// @{ +ConstantInterval operator+(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator+(const ConstantInterval &a, int64_t b); +ConstantInterval operator-(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator-(const ConstantInterval &a, int64_t b); +ConstantInterval operator/(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator/(const ConstantInterval &a, int64_t b); +ConstantInterval operator*(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator*(const ConstantInterval &a, int64_t b); +ConstantInterval operator%(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator%(const ConstantInterval &a, int64_t b); +ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval min(const ConstantInterval &a, int64_t b); +ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval max(const ConstantInterval &a, int64_t b); +ConstantInterval abs(const ConstantInterval &a); +ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator<<(const ConstantInterval &a, int64_t b); +ConstantInterval operator<<(int64_t a, const ConstantInterval &b); +ConstantInterval operator>>(const ConstantInterval &a, const ConstantInterval &b); +ConstantInterval operator>>(const ConstantInterval &a, int64_t b); +ConstantInterval operator>>(int64_t a, const ConstantInterval &b); +// @} + +/** Comparison operators on ConstantIntervals. Returns whether the comparison is + * true for all values of the two intervals. */ +// @{ +bool operator<=(const ConstantInterval &a, const ConstantInterval &b); +bool operator<=(const ConstantInterval &a, int64_t b); +bool operator<=(int64_t a, const ConstantInterval &b); +bool operator<(const ConstantInterval &a, const ConstantInterval &b); +bool operator<(const ConstantInterval &a, int64_t b); +bool operator<(int64_t a, const ConstantInterval &b); + +inline bool operator>=(const ConstantInterval &a, const ConstantInterval &b) { + return b <= a; +} +inline bool operator>(const ConstantInterval &a, const ConstantInterval &b) { + return b < a; +} +inline bool operator>=(const ConstantInterval &a, int64_t b) { + return b <= a; +} +inline bool operator>(const ConstantInterval &a, int64_t b) { + return b < a; +} +inline bool operator>=(int64_t a, const ConstantInterval &b) { + return b <= a; +} +inline bool operator>(int64_t a, const ConstantInterval &b) { + return b < a; +} + +// @} +} // namespace Internal + +/** Cast operators for ConstantIntervals. These ones have to live out in + * Halide::, to avoid C++ name lookup confusion with the Halide::cast variants + * that take Exprs. */ +// @{ +Internal::ConstantInterval cast(Type t, const Internal::ConstantInterval &a); +Internal::ConstantInterval saturating_cast(Type t, const Internal::ConstantInterval &a); +// @} + +} // namespace Halide + +#endif diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index d7b053981ac8..793234c8b3ff 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -935,9 +935,12 @@ class FindIntrinsics : public IRMutator { Expr b_narrow = lossless_narrow(op->args[1]); if (a_narrow.defined() && b_narrow.defined()) { Expr result; - if (op->is_intrinsic(Call::rounding_shift_right) && can_prove(b_narrow > 0)) { + if (op->is_intrinsic(Call::rounding_shift_right) && + can_prove(b_narrow > 0 && b_narrow < a_narrow.type().bits())) { result = rounding_shift_right(a_narrow, b_narrow); - } else if (op->is_intrinsic(Call::rounding_shift_left) && can_prove(b_narrow < 0)) { + } else if (op->is_intrinsic(Call::rounding_shift_left) && + b_narrow.type().is_int() && + can_prove(b_narrow < 0 && b_narrow > -a_narrow.type().bits())) { result = rounding_shift_left(a_narrow, b_narrow); } else { return op; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 0f318f777561..3492c9e828c3 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -502,7 +502,7 @@ Expr lossless_cast(Type t, Expr e) { Expr a = lossless_cast(t.narrow(), sub->a); Expr b = lossless_cast(t.narrow(), sub->b); if (a.defined() && b.defined()) { - return cast(t, a) + cast(t, b); + return cast(t, a) - cast(t, b); } else { return Expr(); } @@ -567,7 +567,8 @@ Expr lossless_cast(Type t, Expr e) { } Expr lossless_negate(const Expr &x) { - if (const Mul *m = x.as()) { + if (false /* const Mul *m = x.as() */) { // disabled pending #8155 + /* Expr b = lossless_negate(m->b); if (b.defined()) { return Mul::make(m->a, b); @@ -576,6 +577,7 @@ Expr lossless_negate(const Expr &x) { if (a.defined()) { return Mul::make(a, m->b); } + */ } else if (const Call *m = Call::as_intrinsic(x, {Call::widening_mul})) { Expr b = lossless_negate(m->args[1]); if (b.defined()) { diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index a186be1874d7..0c29b5c8a749 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -6,8 +6,11 @@ #include "AssociativeOpsTable.h" #include "Associativity.h" #include "Closure.h" +#include "ConstantInterval.h" #include "IROperator.h" +#include "Interval.h" #include "Module.h" +#include "ModulusRemainder.h" #include "Target.h" #include "Util.h" @@ -446,6 +449,45 @@ std::ostream &operator<<(std::ostream &out, const Closure &c) { return out; } +std::ostream &operator<<(std::ostream &out, const Interval &in) { + out << "["; + if (in.has_lower_bound()) { + out << in.min; + } else { + out << "-inf"; + } + out << ", "; + if (in.has_upper_bound()) { + out << in.max; + } else { + out << "inf"; + } + out << "]"; + return out; +} + +std::ostream &operator<<(std::ostream &out, const ConstantInterval &in) { + out << "["; + if (in.min_defined) { + out << in.min; + } else { + out << "-inf"; + } + out << ", "; + if (in.max_defined) { + out << in.max; + } else { + out << "inf"; + } + out << "]"; + return out; +} + +std::ostream &operator<<(std::ostream &out, const ModulusRemainder &c) { + out << "(mod: " << c.modulus << " rem: " << c.remainder << ")"; + return out; +} + IRPrinter::IRPrinter(ostream &s) : stream(s) { s.setf(std::ios::fixed, std::ios::floatfield); diff --git a/src/IRPrinter.h b/src/IRPrinter.h index 849e50b816f4..6addbbd7c771 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -58,6 +58,9 @@ namespace Internal { struct AssociativePattern; struct AssociativeOp; class Closure; +struct Interval; +struct ConstantInterval; +struct ModulusRemainder; /** Emit a halide associative pattern on an output stream (such as std::cout) * in a human-readable form */ @@ -90,9 +93,18 @@ std::ostream &operator<<(std::ostream &stream, const LinkageType &); /** Emit a halide dimension type in human-readable format */ std::ostream &operator<<(std::ostream &stream, const DimType &); -/** Emit a Closure in human-readable format */ +/** Emit a Closure in human-readable form */ std::ostream &operator<<(std::ostream &out, const Closure &c); +/** Emit an Interval in human-readable form */ +std::ostream &operator<<(std::ostream &out, const Interval &c); + +/** Emit a ConstantInterval in human-readable form */ +std::ostream &operator<<(std::ostream &out, const ConstantInterval &c); + +/** Emit a ModulusRemainder in human-readable form */ +std::ostream &operator<<(std::ostream &out, const ModulusRemainder &c); + struct Indentation { int indent; }; diff --git a/src/Interval.cpp b/src/Interval.cpp index 10550f7ed48b..7d0cc41d44b9 100644 --- a/src/Interval.cpp +++ b/src/Interval.cpp @@ -3,6 +3,8 @@ #include "IRMatch.h" #include "IROperator.h" +using namespace Halide::Internal; + namespace Halide { namespace Internal { @@ -157,91 +159,5 @@ Expr Interval::neg_inf_noinline() { return Interval::neg_inf_expr; } -ConstantInterval::ConstantInterval() = default; - -ConstantInterval::ConstantInterval(int64_t min, int64_t max) - : min(min), max(max), min_defined(true), max_defined(true) { - internal_assert(min <= max); -} - -ConstantInterval ConstantInterval::everything() { - return ConstantInterval(); -} - -ConstantInterval ConstantInterval::single_point(int64_t x) { - return ConstantInterval(x, x); -} - -ConstantInterval ConstantInterval::bounded_below(int64_t min) { - ConstantInterval result(min, min); - result.max_defined = false; - return result; -} - -ConstantInterval ConstantInterval::bounded_above(int64_t max) { - ConstantInterval result(max, max); - result.min_defined = false; - return result; -} - -bool ConstantInterval::is_everything() const { - return !min_defined && !max_defined; -} - -bool ConstantInterval::is_single_point() const { - return min_defined && max_defined && min == max; -} - -bool ConstantInterval::is_single_point(int64_t x) const { - return min_defined && max_defined && min == x && max == x; -} - -bool ConstantInterval::has_upper_bound() const { - return max_defined; -} - -bool ConstantInterval::has_lower_bound() const { - return min_defined; -} - -bool ConstantInterval::is_bounded() const { - return has_upper_bound() && has_lower_bound(); -} - -bool ConstantInterval::operator==(const ConstantInterval &other) const { - if (min_defined != other.min_defined || max_defined != other.max_defined) { - return false; - } - return (!min_defined || min == other.min) && (!max_defined || max == other.max); -} - -void ConstantInterval::include(const ConstantInterval &i) { - if (max_defined && i.max_defined) { - max = std::max(max, i.max); - } else { - max_defined = false; - } - if (min_defined && i.min_defined) { - min = std::min(min, i.min); - } else { - min_defined = false; - } -} - -void ConstantInterval::include(int64_t x) { - if (max_defined) { - max = std::max(max, x); - } - if (min_defined) { - min = std::min(min, x); - } -} - -ConstantInterval ConstantInterval::make_union(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result = a; - result.include(b); - return result; -} - } // namespace Internal } // namespace Halide diff --git a/src/Interval.h b/src/Interval.h index 1d90d4a29b55..ccd27341f167 100644 --- a/src/Interval.h +++ b/src/Interval.h @@ -87,7 +87,7 @@ struct Interval { /** Construct the smallest interval containing two intervals. */ static Interval make_union(const Interval &a, const Interval &b); - /** Construct the largest interval contained within two intervals. */ + /** Construct the largest interval contained within two other intervals. */ static Interval make_intersection(const Interval &a, const Interval &b); /** An eagerly-simplifying max of two Exprs that respects infinities. */ @@ -110,63 +110,6 @@ struct Interval { static Expr neg_inf_noinline(); }; -/** A class to represent ranges of integers. Can be unbounded above or below, but - * they cannot be empty. */ -struct ConstantInterval { - /** The lower and upper bound of the interval. They are included - * in the interval. */ - int64_t min = 0, max = 0; - bool min_defined = false, max_defined = false; - - /* A default-constructed Interval is everything */ - ConstantInterval(); - - /** Construct an interval from a lower and upper bound. */ - ConstantInterval(int64_t min, int64_t max); - - /** The interval representing everything. */ - static ConstantInterval everything(); - - /** Construct an interval representing a single point. */ - static ConstantInterval single_point(int64_t x); - - /** Construct intervals bounded above or below. */ - static ConstantInterval bounded_below(int64_t min); - static ConstantInterval bounded_above(int64_t max); - - /** Is the interval the entire range */ - bool is_everything() const; - - /** Is the interval just a single value (min == max) */ - bool is_single_point() const; - - /** Is the interval a particular single value */ - bool is_single_point(int64_t x) const; - - /** Does the interval have a finite least upper bound */ - bool has_upper_bound() const; - - /** Does the interval have a finite greatest lower bound */ - bool has_lower_bound() const; - - /** Does the interval have a finite upper and lower bound */ - bool is_bounded() const; - - /** Expand the interval to include another Interval */ - void include(const ConstantInterval &i); - - /** Expand the interval to include a point */ - void include(int64_t x); - - /** Construct the smallest interval containing two intervals. */ - static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b); - - /** Equivalent to same_as. Exists so that the autoscheduler can - * compare two map for equality in order to - * cache computations. */ - bool operator==(const ConstantInterval &other) const; -}; - } // namespace Internal } // namespace Halide diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index fee151f00a22..1450faade800 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -1,9 +1,11 @@ #include "Monotonic.h" -#include "Bounds.h" +#include "ConstantBounds.h" #include "IROperator.h" +#include "IRPrinter.h" #include "IRVisitor.h" #include "Scope.h" #include "Simplify.h" +#include "SimplifyCorrelatedDifferences.h" #include "Substitute.h" namespace Halide { @@ -42,24 +44,8 @@ const int64_t *as_const_int_or_uint(const Expr &e) { return nullptr; } -bool is_constant(const ConstantInterval &a) { - return a.is_single_point(0); -} - -bool may_be_negative(const ConstantInterval &a) { - return !a.has_lower_bound() || a.min < 0; -} - -bool may_be_positive(const ConstantInterval &a) { - return !a.has_upper_bound() || a.max > 0; -} - -bool is_monotonic_increasing(const ConstantInterval &a) { - return !may_be_negative(a); -} - -bool is_monotonic_decreasing(const ConstantInterval &a) { - return !may_be_positive(a); +bool is_constant(const ConstantInterval &x) { + return x.is_single_point(0); } ConstantInterval to_interval(Monotonic m) { @@ -79,162 +65,20 @@ ConstantInterval to_interval(Monotonic m) { Monotonic to_monotonic(const ConstantInterval &x) { if (is_constant(x)) { return Monotonic::Constant; - } else if (is_monotonic_increasing(x)) { + } else if (x >= 0) { return Monotonic::Increasing; - } else if (is_monotonic_decreasing(x)) { + } else if (x <= 0) { return Monotonic::Decreasing; } else { return Monotonic::Unknown; } } -ConstantInterval unify(const ConstantInterval &a, const ConstantInterval &b) { - return ConstantInterval::make_union(a, b); -} - -ConstantInterval unify(const ConstantInterval &a, int64_t b) { - ConstantInterval result; - result.include(b); - return result; -} - -// Helpers for doing arithmetic on ConstantIntervals that avoid generating -// expressions of pos_inf/neg_inf. -ConstantInterval add(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result; - result.min_defined = a.has_lower_bound() && b.has_lower_bound(); - result.max_defined = a.has_upper_bound() && b.has_upper_bound(); - if (result.has_lower_bound()) { - result.min_defined = add_with_overflow(64, a.min, b.min, &result.min); - } - if (result.has_upper_bound()) { - result.max_defined = add_with_overflow(64, a.max, b.max, &result.max); - } - return result; -} - -ConstantInterval add(const ConstantInterval &a, int64_t b) { - return add(a, ConstantInterval(b, b)); -} - -ConstantInterval negate(const ConstantInterval &r) { - ConstantInterval result; - result.min_defined = r.has_upper_bound(); - if (result.min_defined) { - result.min_defined = sub_with_overflow(64, 0, r.max, &result.min); - } - result.max_defined = r.has_lower_bound(); - if (result.max_defined) { - result.max_defined = sub_with_overflow(64, 0, r.min, &result.max); - } - return result; -} - -ConstantInterval sub(const ConstantInterval &a, const ConstantInterval &b) { - ConstantInterval result; - result.min_defined = a.has_lower_bound() && b.has_lower_bound(); - result.max_defined = a.has_upper_bound() && b.has_upper_bound(); - if (result.has_lower_bound()) { - result.min_defined = sub_with_overflow(64, a.min, b.max, &result.min); - } - if (result.has_upper_bound()) { - result.max_defined = sub_with_overflow(64, a.max, b.min, &result.max); - } - return result; -} - -ConstantInterval sub(const ConstantInterval &a, int64_t b) { - return sub(a, ConstantInterval(b, b)); -} - -ConstantInterval multiply(const ConstantInterval &a, int64_t b) { - ConstantInterval result(a); - if (b < 0) { - result = negate(result); - b = -b; - } - if (result.has_lower_bound()) { - result.min *= b; - } - if (result.has_upper_bound()) { - result.max *= b; - } - return result; -} - -ConstantInterval multiply(const ConstantInterval &a, const Expr &b) { - if (const int64_t *bi = as_const_int_or_uint(b)) { - return multiply(a, *bi); - } - return ConstantInterval::everything(); -} - -ConstantInterval multiply(const ConstantInterval &a, const ConstantInterval &b) { - int64_t bounds[4]; - int64_t *bounds_begin = &bounds[0]; - int64_t *bounds_end = &bounds[0]; - bool no_overflow = true; - if (a.has_lower_bound() && b.has_lower_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.min, b.min, bounds_end++); - } - if (a.has_lower_bound() && b.has_upper_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.min, b.max, bounds_end++); - } - if (a.has_upper_bound() && b.has_lower_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.max, b.min, bounds_end++); - } - if (a.has_upper_bound() && b.has_upper_bound()) { - no_overflow = no_overflow && mul_with_overflow(64, a.max, b.max, bounds_end++); - } - if (no_overflow && (bounds_begin != bounds_end)) { - ConstantInterval result = { - *std::min_element(bounds_begin, bounds_end), - *std::max_element(bounds_begin, bounds_end), - }; - // There *must* be a better way than this... Even - // cutting half the cases with swapping isn't that much help. - if (!a.has_lower_bound()) { - if (may_be_negative(b)) result.max_defined = false; // NOLINT - if (may_be_positive(b)) result.min_defined = false; // NOLINT - } - if (!a.has_upper_bound()) { - if (may_be_negative(b)) result.min_defined = false; // NOLINT - if (may_be_positive(b)) result.max_defined = false; // NOLINT - } - if (!b.has_lower_bound()) { - if (may_be_negative(a)) result.max_defined = false; // NOLINT - if (may_be_positive(a)) result.min_defined = false; // NOLINT - } - if (!b.has_upper_bound()) { - if (may_be_negative(a)) result.min_defined = false; // NOLINT - if (may_be_positive(a)) result.max_defined = false; // NOLINT - } - return result; - } else { - return ConstantInterval::everything(); - } -} - -ConstantInterval divide(const ConstantInterval &a, int64_t b) { - ConstantInterval result(a); - if (b < 0) { - result = negate(result); - b = -b; - } - if (result.has_lower_bound()) { - result.min = div_imp(result.min, b); - } - if (result.has_upper_bound()) { - result.max = div_imp(result.max - 1, b) + 1; - } - return result; -} - class DerivativeBounds : public IRVisitor { const string &var; - Scope scope; - Scope bounds; + // Bounds on the derivatives and values of variables in scope. + Scope derivative_bounds, value_bounds; void visit(const IntImm *) override { result = ConstantInterval::single_point(0); @@ -280,7 +124,7 @@ class DerivativeBounds : public IRVisitor { void visit(const Variable *op) override { if (op->name == var) { result = ConstantInterval::single_point(1); - } else if (const auto *r = scope.find(op->name)) { + } else if (const auto *r = derivative_bounds.find(op->name)) { result = *r; } else { result = ConstantInterval::single_point(0); @@ -291,16 +135,14 @@ class DerivativeBounds : public IRVisitor { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = add(ra, rb); + result += ra; } void visit(const Sub *op) override { - op->a.accept(this); - ConstantInterval ra = result; op->b.accept(this); ConstantInterval rb = result; - result = sub(ra, rb); + op->a.accept(this); + result -= rb; } void visit(const Mul *op) override { @@ -313,9 +155,9 @@ class DerivativeBounds : public IRVisitor { // This is essentially the product rule: a*rb + b*ra // but only implemented for the case where a or b is constant. if (const int64_t *b = as_const_int_or_uint(op->b)) { - result = multiply(ra, *b); + result = ra * (*b); } else if (const int64_t *a = as_const_int_or_uint(op->a)) { - result = multiply(rb, *a); + result = rb * (*a); } else { result = ConstantInterval::everything(); } @@ -326,20 +168,37 @@ class DerivativeBounds : public IRVisitor { void visit(const Div *op) override { if (op->type.is_scalar()) { - op->a.accept(this); - ConstantInterval ra = result; - if (const int64_t *b = as_const_int_or_uint(op->b)) { - result = divide(ra, *b); - } else { - result = ConstantInterval::everything(); + op->a.accept(this); + // We don't just want to divide by b. For the min we want to + // take floor division, and for the max we want to use ceil + // division. + if (*b == 0) { + result = ConstantInterval::single_point(0); + } else { + if (result.min_defined) { + result.min = div_imp(result.min, *b); + } + if (result.max_defined) { + if (result.max != INT64_MIN) { + result.max = div_imp(result.max - 1, *b) + 1; + } else { + result.max_defined = false; + result.max = 0; + } + } + if (*b < 0) { + result = -result; + } + } + return; } - } else { - result = ConstantInterval::everything(); } + result = ConstantInterval::everything(); } void visit(const Mod *op) override { + // TODO: It's possible to get tighter bounds here. What if neither arg uses the var! result = ConstantInterval::everything(); } @@ -347,16 +206,14 @@ class DerivativeBounds : public IRVisitor { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit(const Max *op) override { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit_eq(const Expr &a, const Expr &b) { @@ -386,17 +243,12 @@ class DerivativeBounds : public IRVisitor { a.accept(this); ConstantInterval ra = result; b.accept(this); - ConstantInterval rb = result; - result = unify(negate(ra), rb); + result.include(-ra); // If the result is bounded, limit it to [-1, 1]. The largest // difference possible is flipping from true to false or false // to true. - if (result.has_lower_bound()) { - result.min = std::min(std::max(result.min, -1), 1); - } - if (result.has_upper_bound()) { - result.max = std::min(std::max(result.max, -1), 1); - } + result.min = std::min(std::max(result.min, -1), 1); + result.max = std::min(std::max(result.max, -1), 1); } void visit(const LT *op) override { @@ -419,50 +271,52 @@ class DerivativeBounds : public IRVisitor { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit(const Or *op) override { op->a.accept(this); ConstantInterval ra = result; op->b.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); } void visit(const Not *op) override { op->a.accept(this); - result = negate(result); + result = -result; } void visit(const Select *op) override { - // The result is the unified bounds, added to the "bump" that happens when switching from true to false. + // The result is the unified bounds, added to the "bump" that happens + // when switching from true to false. if (op->type.is_scalar()) { op->condition.accept(this); ConstantInterval rcond = result; + // rcond is: + // [ 0 0] if the condition does not depend on the variable + // [-1, 0] if it changes once from true to false + // [ 0 1] if it changes once from false to true + // [-1, 1] if it could change in either direction op->true_value.accept(this); ConstantInterval ra = result; op->false_value.accept(this); - ConstantInterval rb = result; - result = unify(ra, rb); + result.include(ra); // If the condition is not constant, we hit a "bump" when the condition changes value. if (!is_constant(rcond)) { - // TODO: How to handle unsigned values? - Expr delta = simplify(op->true_value - op->false_value); - - Interval delta_bounds = find_constant_bounds(delta, bounds); - // TODO: Maybe we can do something with one-sided intervals? - if (delta_bounds.is_bounded()) { - ConstantInterval delta_low = multiply(rcond, delta_bounds.min); - ConstantInterval delta_high = multiply(rcond, delta_bounds.max); - result = add(result, ConstantInterval::make_union(delta_low, delta_high)); - } else { - // The bump is unbounded. - result = ConstantInterval::everything(); - } + // It's very important to have stripped likelies here, or the + // simplification might not cancel things that it should. This + // happens below in the top-level derivative_bounds call. + Expr bump = simplify(op->true_value - op->false_value); + + // This is of dubious value, because + // bound_correlated_differences really assumes you've solved for + // a variable that you're trying to cancel first. TODO: try + // removing this. + bump = bound_correlated_differences(bump); + ConstantInterval bump_bounds = constant_integer_bounds(bump, value_bounds); + result += rcond * bump_bounds; } } else { result = ConstantInterval::everything(); @@ -493,10 +347,9 @@ class DerivativeBounds : public IRVisitor { return; } - if (op->is_intrinsic(Call::unsafe_promise_clamped) || - op->is_intrinsic(Call::promise_clamped) || - op->is_intrinsic(Call::saturating_cast)) { + if (op->is_intrinsic(Call::saturating_cast)) { op->args[0].accept(this); + result.include(0); return; } @@ -526,14 +379,16 @@ class DerivativeBounds : public IRVisitor { void visit(const Let *op) override { op->value.accept(this); - ScopedBinding bounds_binding(bounds, op->name, find_constant_bounds(op->value, bounds)); - + // As above, this is of dubious value. TODO: Try removing it. + Expr v = bound_correlated_differences(op->value); + ScopedBinding vb_binding(value_bounds, op->name, + constant_integer_bounds(v, value_bounds)); if (is_constant(result)) { // No point pushing it if it's constant w.r.t the var, // because unknown variables are treated as constant. op->body.accept(this); } else { - ScopedBinding scope_binding(scope, op->name, result); + ScopedBinding db_binding(derivative_bounds, op->name, result); op->body.accept(this); } } @@ -554,7 +409,7 @@ class DerivativeBounds : public IRVisitor { switch (op->op) { case VectorReduce::Add: case VectorReduce::SaturatingAdd: - result = multiply(result, op->value.type().lanes() / op->type.lanes()); + result *= op->value.type().lanes() / op->type.lanes(); break; case VectorReduce::Min: case VectorReduce::Max: @@ -643,7 +498,7 @@ class DerivativeBounds : public IRVisitor { DerivativeBounds(const std::string &v, const Scope &parent) : var(v), result(ConstantInterval::everything()) { - scope.set_containing_scope(&parent); + derivative_bounds.set_containing_scope(&parent); } }; diff --git a/src/Monotonic.h b/src/Monotonic.h index 3d7946a13ed7..c8ba66195961 100644 --- a/src/Monotonic.h +++ b/src/Monotonic.h @@ -8,13 +8,14 @@ #include #include -#include "Interval.h" +#include "ConstantInterval.h" #include "Scope.h" namespace Halide { namespace Internal { -/** Find the bounds of the derivative of an expression. */ +/** Find the bounds of the derivative of an expression. The scope gives the + * bounds on the derivatives of any variables found. */ ConstantInterval derivative_bounds(const Expr &e, const std::string &var, const Scope &scope = Scope::empty_scope()); diff --git a/src/Type.cpp b/src/Type.cpp index 64414fa04eca..1cd95e0a6b01 100644 --- a/src/Type.cpp +++ b/src/Type.cpp @@ -1,3 +1,4 @@ +#include "ConstantBounds.h" #include "IR.h" #include #include @@ -126,6 +127,10 @@ bool Type::can_represent(Type other) const { } } +bool Type::can_represent(const Internal::ConstantInterval &in) const { + return in.is_bounded() && can_represent(in.min) && can_represent(in.max); +} + bool Type::can_represent(int64_t x) const { if (is_int()) { return x >= min_int(bits()) && x <= max_int(bits()); diff --git a/src/Type.h b/src/Type.h index c8a397b3f0a7..1ed52da16c55 100644 --- a/src/Type.h +++ b/src/Type.h @@ -269,6 +269,10 @@ struct halide_handle_traits { namespace Halide { +namespace Internal { +struct ConstantInterval; +} + struct Expr; /** Types in the halide type system. They can be ints, unsigned ints, @@ -504,6 +508,10 @@ struct Type { /** Can this type represent all values of another type? */ bool can_represent(Type other) const; + /** Can this type represent exactly all integer values of some constant + * integer range? */ + bool can_represent(const Internal::ConstantInterval &in) const; + /** Can this type represent a particular constant? */ // @{ bool can_represent(double x) const; diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 604ceda468f5..ae4a6776ac72 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -58,6 +58,7 @@ tests(GROUPS correctness computed_index.cpp concat.cpp constant_expr.cpp + constant_interval.cpp constant_type.cpp constraints.cpp convolution_multiple_kernels.cpp diff --git a/test/correctness/constant_interval.cpp b/test/correctness/constant_interval.cpp new file mode 100644 index 000000000000..ba6fa73fbdcb --- /dev/null +++ b/test/correctness/constant_interval.cpp @@ -0,0 +1,187 @@ +#include "Halide.h" + +using namespace Halide; +using namespace Halide::Internal; + +std::mt19937 rng; + +int64_t sample(const ConstantInterval &i) { + int64_t upper = i.max_defined ? i.max : 1024; + int64_t lower = i.min_defined ? i.min : -1024; + return lower + (rng() % (upper - lower + 1)); +} + +ConstantInterval random_interval() { + int64_t a = (rng() % 512) - 256; + int64_t b = (rng() % 512) - 256; + ConstantInterval result; + if (rng() & 1) { + result.max_defined = true; + result.max = std::max(a, b); + } + if (rng() & 1) { + result.min_defined = true; + result.min = std::min(a, b); + } + return result; +} + +int main(int argc, char **argv) { + for (int i = 0; i < 1000; i++) { + std::vector> values; + for (int j = 0; j < 10; j++) { + values.emplace_back(random_interval(), 0); + values.back().second = sample(values.back().first); + } + + for (int j = 0; j < 1000; j++) { + auto a = values[rng() % values.size()]; + auto b = values[rng() % values.size()]; + decltype(a) c; + + auto check = [&](const char *op) { + if (!c.first.contains(c.second)) { + std::cout << "Error for operator " << op << ":\n" + << "a: " << a.second << " in " << a.first << "\n" + << "b: " << b.second << " in " << b.first << "\n" + << "c: " << c.second << " not in " << c.first << "\n"; + exit(1); + } + }; + + auto check_scalar = [&](const char *op) { + if (!c.first.contains(c.second)) { + std::cout << "Error for operator " << op << ":\n" + << "a: " << a.second << " in " << a.first << "\n" + << "b: " << b.second << "\n" + << "c: " << c.second << " not in " << c.first << "\n"; + exit(1); + } + }; + + // Arithmetic + if (!add_would_overflow(64, a.second, b.second)) { + c.first = a.first + b.first; + c.second = a.second + b.second; + check("+"); + } + + if (!sub_would_overflow(64, a.second, b.second)) { + c.first = a.first - b.first; + c.second = a.second - b.second; + check("-"); + } + + if (!mul_would_overflow(64, a.second, b.second)) { + c.first = a.first * b.first; + c.second = a.second * b.second; + check("*"); + } + + c.first = a.first / b.first; + c.second = div_imp(a.second, b.second); + check("/"); + + c.first = min(a.first, b.first); + c.second = std::min(a.second, b.second); + check("min"); + + c.first = max(a.first, b.first); + c.second = std::max(a.second, b.second); + check("max"); + + c.first = a.first % b.first; + c.second = mod_imp(a.second, b.second); + check("%"); + + // Arithmetic with constant RHS + if (!add_would_overflow(64, a.second, b.second)) { + c.first = a.first + b.second; + c.second = a.second + b.second; + check_scalar("+"); + } + + if (!sub_would_overflow(64, a.second, b.second)) { + c.first = a.first - b.second; + c.second = a.second - b.second; + check_scalar("-"); + } + + if (!mul_would_overflow(64, a.second, b.second)) { + c.first = a.first * b.second; + c.second = a.second * b.second; + check_scalar("*"); + } + + c.first = a.first / b.second; + c.second = div_imp(a.second, b.second); + check_scalar("/"); + + c.first = min(a.first, b.second); + c.second = std::min(a.second, b.second); + check_scalar("min"); + + c.first = max(a.first, b.second); + c.second = std::max(a.second, b.second); + check_scalar("max"); + + c.first = a.first % b.second; + c.second = mod_imp(a.second, b.second); + check_scalar("%"); + + // Some unary operators + c.first = -a.first; + c.second = -a.second; + check("unary -"); + + c.first = cast(UInt(8), a.first); + c.second = (int64_t)(uint8_t)(a.second); + check("cast to uint8"); + + c.first = cast(Int(8), a.first); + c.second = (int64_t)(int8_t)(a.second); + check("cast to uint8"); + + // Comparison + _halide_user_assert(!(a.first < b.first) || a.second < b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.first <= b.first) || a.second <= b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.first > b.first) || a.second > b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.first >= b.first) || a.second >= b.second) + << a.first << " " << a.second << " " << b.first << " " << b.second; + + // Comparison against constants + _halide_user_assert(!(a.first < b.second) || a.second < b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.first <= b.second) || a.second <= b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.first > b.second) || a.second > b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.first >= b.second) || a.second >= b.second) + << a.first << " " << a.second << " " << b.second; + + _halide_user_assert(!(a.second < b.first) || a.second < b.second) + << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.second <= b.first) || a.second <= b.second) + << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.second > b.first) || a.second > b.second) + << a.second << " " << b.first << " " << b.second; + + _halide_user_assert(!(a.second >= b.first) || a.second >= b.second) + << a.second << " " << b.first << " " << b.second; + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index abdbaa9502c3..692abc0db7d4 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -6,9 +6,11 @@ using namespace Halide::Internal; int check_lossless_cast(const Type &t, const Expr &in, const Expr &correct) { Expr result = lossless_cast(t, in); if (!equal(result, correct)) { - std::cout << "Incorrect lossless_cast result:\nlossless_cast(" - << t << ", " << in << ") gave: " << result - << " but expected was: " << correct << "\n"; + std::cout << "Incorrect lossless_cast result:\n" + << "lossless_cast(" << t << ", " << in << ") gave:\n" + << " " << result + << " but expected was:\n" + << " " << correct << "\n"; return 1; } return 0; @@ -19,9 +21,11 @@ int lossless_cast_test() { Type u8 = UInt(8); Type u16 = UInt(16); Type u32 = UInt(32); + // Type u64 = UInt(64); Type i8 = Int(8); Type i16 = Int(16); Type i32 = Int(32); + Type i64 = Int(64); Type u8x = UInt(8, 4); Type u16x = UInt(16, 4); Type u32x = UInt(32, 4); @@ -52,14 +56,374 @@ int lossless_cast_test() { e = VectorReduce::make(VectorReduce::Add, cast(u32x, var_u8x), 1); res |= check_lossless_cast(u16, e, VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1)); - return res; + e = cast(u32, var_u8) - 16; + res |= check_lossless_cast(u16, e, Expr()); + + e = cast(u32, var_u8) + 16; + res |= check_lossless_cast(u16, e, cast(u16, var_u8) + 16); + + e = 16 - cast(u32, var_u8); + res |= check_lossless_cast(u16, e, Expr()); + + e = 16 + cast(u32, var_u8); + res |= check_lossless_cast(u16, e, 16 + cast(u16, var_u8)); + + // Check one where the target type is unsigned but there's a signed addition + // (that can't overflow) + e = cast(i64, cast(u16, var_u8) + cast(i32, 17)); + res |= check_lossless_cast(u32, e, cast(u32, cast(u16, var_u8)) + cast(u32, 17)); + + // Check one where the target type is unsigned but there's a signed subtract + // (that can overflow). It's not safe to enter the i16 sub + e = cast(i64, cast(i16, 10) - cast(i16, 17)); + res |= check_lossless_cast(u32, e, Expr()); + + e = cast(i64, 1024) * cast(i64, 1024) * cast(i64, 1024); + res |= check_lossless_cast(i32, e, (cast(i32, 1024) * 1024) * 1024); + + if (res) { + std::cout << "Ignoring bugs in lossless_cast for now. Will be fixed in #8155\n"; + } + return 0; + // return res; +} + +constexpr int size = 1024; +Buffer buf_u8(size, "buf_u8"); +Buffer buf_i8(size, "buf_i8"); +Var x{"x"}; + +Expr random_expr(std::mt19937 &rng) { + std::vector exprs; + // Add some atoms + exprs.push_back(cast((uint8_t)rng())); + exprs.push_back(cast((int8_t)rng())); + exprs.push_back(cast((uint8_t)rng())); + exprs.push_back(cast((int8_t)rng())); + exprs.push_back(buf_u8(x)); + exprs.push_back(buf_i8(x)); + + // Make random combinations of them + while (true) { + Expr e; + int i1 = rng() % exprs.size(); + int i2 = rng() % exprs.size(); + int i3 = rng() % exprs.size(); + int op = rng() % 8; + Expr e1 = exprs[i1]; + Expr e2 = cast(e1.type(), exprs[i2]); + Expr e3 = cast(e1.type().with_code(halide_type_uint), exprs[i3]); + bool may_widen = e1.type().bits() < 64; + Expr e2_narrow = exprs[i2]; + bool may_widen_right = e2_narrow.type() == e1.type().narrow(); + switch (op) { + case 0: + if (may_widen) { + e = cast(e1.type().widen(), e1); + } + break; + case 1: + if (may_widen) { + e = cast(Int(e1.type().bits() * 2), e1); + } + break; + case 2: + e = e1 + e2; + break; + case 3: + e = e1 - e2; + break; + case 4: + e = e1 * e2; + break; + case 5: + e = e1 / e2; + break; + case 6: + // Introduce some lets + e = common_subexpression_elimination(e1); + break; + case 7: + switch (rng() % 19) { + case 0: + if (may_widen) { + e = widening_add(e1, e2); + } + break; + case 1: + if (may_widen) { + e = widening_sub(e1, e2); + } + break; + case 2: + if (may_widen) { + e = widening_mul(e1, e2); + } + break; + case 3: + e = halving_add(e1, e2); + break; + case 4: + e = rounding_halving_add(e1, e2); + break; + case 5: + e = halving_sub(e1, e2); + break; + case 6: + e = saturating_add(e1, e2); + break; + case 7: + e = saturating_sub(e1, e2); + break; + case 8: + e = count_leading_zeros(e1); + break; + case 9: + e = count_trailing_zeros(e1); + break; + case 10: + if (may_widen) { + e = rounding_mul_shift_right(e1, e2, e3); + } + break; + case 11: + if (may_widen) { + e = mul_shift_right(e1, e2, e3); + } + break; + case 12: + if (may_widen_right) { + e = widen_right_add(e1, e2_narrow); + } + break; + case 13: + if (may_widen_right) { + e = widen_right_sub(e1, e2_narrow); + } + break; + case 14: + if (may_widen_right) { + e = widen_right_mul(e1, e2_narrow); + } + break; + case 15: + e = e1 << e2; + break; + case 16: + e = e1 >> e2; + break; + case 17: + e = rounding_shift_right(e1, e2); + break; + case 18: + e = rounding_shift_left(e1, e2); + break; + } + } + + if (!e.defined()) { + continue; + } + + // Stop when we get to 64 bits, but probably don't stop on a cast, + // because that'll just get trivially stripped. + if (e.type().bits() == 64 && (e.as() == nullptr || ((rng() & 7) == 0))) { + return e; + } + + exprs.push_back(e); + } +} + +bool might_have_ub(Expr e) { + class MightOverflow : public IRVisitor { + std::map cache; + + using IRVisitor::visit; + + bool no_overflow_int(const Type &t) { + return t.is_int() && t.bits() >= 32; + } + + ConstantInterval bounds(const Expr &e) { + return constant_integer_bounds(e, Scope::empty_scope(), &cache); + } + + void visit(const Add *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->a) + bounds(op->b))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Sub *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->a) - bounds(op->b))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Mul *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->a) * bounds(op->b))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Div *op) override { + if (no_overflow_int(op->type) && + (bounds(op->a) / bounds(op->b)).contains(-1)) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Cast *op) override { + if (no_overflow_int(op->type) && + !op->type.can_represent(bounds(op->value))) { + found = true; + } else { + IRVisitor::visit(op); + } + } + + void visit(const Call *op) override { + if (op->is_intrinsic({Call::shift_left, + Call::shift_right, + Call::rounding_shift_left, + Call::rounding_shift_right, + Call::widening_shift_left, + Call::widening_shift_right, + Call::mul_shift_right, + Call::rounding_mul_shift_right})) { + auto shift_bounds = bounds(op->args.back()); + if (!(shift_bounds > -op->type.bits() && shift_bounds < op->type.bits())) { + found = true; + } + } else if (op->is_intrinsic({Call::signed_integer_overflow})) { + found = true; + } + IRVisitor::visit(op); + } + + public: + bool found = false; + } checker; + + e.accept(&checker); + + return checker.found; +} + +bool found_error = false; + +int test_one(uint32_t seed) { + std::mt19937 rng{seed}; + + buf_u8.fill(rng); + buf_i8.fill(rng); + + Expr e1 = random_expr(rng); + + if (might_have_ub(e1)) { + return 0; + } + + // We're also going to test constant_integer_bounds here. + ConstantInterval bounds = constant_integer_bounds(e1); + + Type target; + std::vector target_types = {UInt(32), Int(32), UInt(16), Int(16)}; + target = target_types[rng() % target_types.size()]; + Expr e2 = lossless_cast(target, e1); + + if (!e2.defined()) { + return 0; + } + + Func f; + f(x) = {cast(e1), cast(e2)}; + f.vectorize(x, 4, TailStrategy::RoundUp); + + Buffer out1(size), out2(size); + Pipeline p(f); + p.realize({out1, out2}); + + for (int x = 0; x < size; x++) { + if (out1(x) != out2(x)) { + std::cout + << "lossless_cast failure\n" + << "seed = " << seed << "\n" + << "x = " << x << "\n" + << "buf_u8 = " << (int)buf_u8(x) << "\n" + << "buf_i8 = " << (int)buf_i8(x) << "\n" + << "out1 = " << out1(x) << "\n" + << "out2 = " << out2(x) << "\n" + << "Original: " << e1 << "\n" + << "Lossless cast: " << e2 << "\n" + << "Ignoring bug for now. Will be fixed in #8155\n"; + // If lossless_cast has failed on this Expr, it's possible the test + // below will fail as well. + return 0; + // return 1; + } + } + + for (int x = 0; x < size; x++) { + if ((e1.type().is_int() && !bounds.contains(out1(x))) || + (e1.type().is_uint() && !bounds.contains((uint64_t)out1(x)))) { + Expr simplified = simplify(e1); + std::cout + << "constant_integer_bounds failure\n" + << "seed = " << seed << "\n" + << "x = " << x << "\n" + << "buf_u8 = " << (int)buf_u8(x) << "\n" + << "buf_i8 = " << (int)buf_i8(x) << "\n" + << "out1 = " << out1(x) << "\n" + << "Expression: " << e1 << "\n" + << "Bounds: " << bounds << "\n" + << "Simplified: " << simplified << "\n" + // If it's still out-of-bounds when the expression is + // simplified, that'll be easier to debug. + << "Bounds: " << constant_integer_bounds(simplified) << "\n"; + return 1; + } + } + + return 0; } -int main() { +int fuzz_test(uint32_t root_seed) { + std::mt19937 seed_generator(root_seed); + + std::cout << "Fuzz testing with root seed " << root_seed << "\n"; + for (int i = 0; i < 1000; i++) { + if (test_one(seed_generator())) { + return 1; + } + } + return 0; +} + +int main(int argc, char **argv) { + if (argc == 2) { + return test_one(atoi(argv[1])); + } if (lossless_cast_test()) { - printf("lossless_cast test failed!\n"); + std::cout << "lossless_cast test failed!\n"; + return 1; + } + if (fuzz_test(time(NULL))) { + std::cout << "lossless_cast fuzz test failed!\n"; return 1; } - printf("Success!\n"); + std::cout << "Success!\n"; return 0; } diff --git a/test/fuzz/bounds.cpp b/test/fuzz/bounds.cpp index df99dcd83b03..d109f4994bbc 100644 --- a/test/fuzz/bounds.cpp +++ b/test/fuzz/bounds.cpp @@ -283,11 +283,6 @@ Expr c(Variable::make(global_var_type, fuzz_var(2))); Expr d(Variable::make(global_var_type, fuzz_var(3))); Expr e(Variable::make(global_var_type, fuzz_var(4))); -std::ostream &operator<<(std::ostream &stream, const Interval &interval) { - stream << "[" << interval.min << ", " << interval.max << "]"; - return stream; -} - Interval random_interval(FuzzedDataProvider &fdp, Type t) { Interval interval;