diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index d0538d6ccca8..0e4de6baa050 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1212,50 +1212,42 @@ void CodeGen_ARM::visit(const Add *op) { Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc"); Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc"); - // clang-format off + Expr ma_i8 = widening_mul(a_i8, ac_i8); + Expr mb_i8 = widening_mul(b_i8, bc_i8); + Expr mc_i8 = widening_mul(c_i8, cc_i8); + Expr md_i8 = widening_mul(d_i8, dc_i8); + + Expr ma_u8 = widening_mul(a_u8, ac_u8); + Expr mb_u8 = widening_mul(b_u8, bc_u8); + Expr mc_u8 = widening_mul(c_u8, cc_u8); + Expr md_u8 = widening_mul(d_u8, dc_u8); + static const Pattern patterns[] = { - // If we had better normalization, we could drastically reduce the number of patterns here. // Signed variants. - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"}, - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)}, - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, - {init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, - {init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, - // Signed variants (associative). - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"}, - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)}, - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, - {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, - {init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, + {(init_i32 + widening_add(ma_i8, mb_i8)) + widening_add(mc_i8, md_i8), "dot_product"}, + {init_i32 + (widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8)), "dot_product"}, + {widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8), "dot_product"}, + // Unsigned variants. - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"}, - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)}, - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, - {init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, - {init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, - // Unsigned variants (associative). - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"}, - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)}, - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, - {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, - {init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, + {(init_u32 + widening_add(ma_u8, mb_u8)) + widening_add(mc_u8, md_u8), "dot_product"}, + {init_u32 + (widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8)), "dot_product"}, + {widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8), "dot_product"}, }; - // clang-format on std::map matches; for (const Pattern &p : patterns) { if (expr_match(p.pattern, op, matches)) { - Expr init = matches["init"]; - Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]}); - // Coefficients can be 1 if not in the pattern. - Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes())); - // This hideous code pattern implements fetching a - // default value if the map doesn't contain a key. - Expr _ac = matches.try_emplace("ac", one).first->second; - Expr _bc = matches.try_emplace("bc", one).first->second; - Expr _cc = matches.try_emplace("cc", one).first->second; - Expr _dc = matches.try_emplace("dc", one).first->second; - Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc}); + Expr init; + auto it = matches.find("init"); + if (it == matches.end()) { + init = make_zero(op->type); + } else { + init = it->second; + } + Expr values = Shuffle::make_interleave({matches["a"], matches["b"], + matches["c"], matches["d"]}); + Expr coeffs = Shuffle::make_interleave({matches["ac"], matches["bc"], + matches["cc"], matches["dc"]}); value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs}); if (value) { return; diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 5dd6a17e02d2..7a8d1c720098 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -538,8 +538,8 @@ void CodeGen_X86::visit(const Cast *op) { }; // clang-format off - static const Pattern patterns[] = { - // This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't + static Pattern patterns[] = { + // This isn't rounding_mul_shift_right(i16, i16, 15) because it doesn't // saturate the result. {"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))}, @@ -736,7 +736,12 @@ void CodeGen_X86::visit(const Call *op) { // Handle edge case of possible overflow. // See https://github.com/halide/Halide/pull/7129/files#r1008331426 // On AVX512 (and with enough lanes) we can use a mask register. - if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) { + ConstantInterval ca = constant_integer_bounds(a); + ConstantInterval cb = constant_integer_bounds(b); + if (!ca.contains(-32768) || !cb.contains(-32768)) { + // Overflow isn't possible + pmulhrs.accept(this); + } else if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) { Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs); expr.accept(this); } else { diff --git a/src/Expr.cpp b/src/Expr.cpp index c3a7deb483aa..d73bd72660fa 100644 --- a/src/Expr.cpp +++ b/src/Expr.cpp @@ -8,7 +8,7 @@ const IntImm *IntImm::make(Type t, int64_t value) { internal_assert(t.is_int() && t.is_scalar()) << "IntImm must be a scalar Int\n"; internal_assert(t.bits() >= 1 && t.bits() <= 64) - << "IntImm must have between 1 and 64 bits\n"; + << "IntImm must have between 1 and 64 bits: " << t << "\n"; // Normalize the value by dropping the high bits. // Since left-shift of negative value is UB in C++, cast to uint64 first; @@ -28,7 +28,7 @@ const UIntImm *UIntImm::make(Type t, uint64_t value) { internal_assert(t.is_uint() && t.is_scalar()) << "UIntImm must be a scalar UInt\n"; internal_assert(t.bits() >= 1 && t.bits() <= 64) - << "UIntImm must have between 1 and 64 bits\n"; + << "UIntImm must have between 1 and 64 bits " << t << "\n"; // Normalize the value by dropping the high bits value <<= (64 - t.bits()); diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 793234c8b3ff..b72122460706 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -2,6 +2,7 @@ #include "CSE.h" #include "CodeGen_Internal.h" #include "ConciseCasts.h" +#include "ConstantBounds.h" #include "IRMatch.h" #include "IRMutator.h" #include "Simplify.h" @@ -45,23 +46,6 @@ bool can_narrow(const Type &t) { t.bits() >= 8; } -Expr lossless_narrow(const Expr &x) { - return can_narrow(x.type()) ? lossless_cast(x.type().narrow(), x) : Expr(); -} - -// Remove a widening cast even if it changes the sign of the result. -Expr strip_widening_cast(const Expr &x) { - if (can_narrow(x.type())) { - Expr narrow = lossless_narrow(x); - if (narrow.defined()) { - return narrow; - } - return lossless_cast(x.type().narrow().with_code(halide_type_uint), x); - } else { - return Expr(); - } -} - Expr saturating_narrow(const Expr &a) { Type narrow = a.type().narrow(); return saturating_cast(narrow, a); @@ -77,34 +61,6 @@ bool no_overflow(Type t) { return t.is_float() || no_overflow_int(t); } -// If there's a widening add or subtract in the first e.type().bits() / 2 - 1 -// levels down a tree of adds or subtracts, we know there's enough headroom for -// another add without overflow. For example, it is safe to add to -// (widening_add(x, y) - z) without overflow. -bool is_safe_for_add(const Expr &e, int max_depth) { - if (max_depth-- <= 0) { - return false; - } - if (const Add *add = e.as()) { - return is_safe_for_add(add->a, max_depth) || is_safe_for_add(add->b, max_depth); - } else if (const Sub *sub = e.as()) { - return is_safe_for_add(sub->a, max_depth) || is_safe_for_add(sub->b, max_depth); - } else if (const Cast *cast = e.as()) { - if (cast->type.bits() > cast->value.type().bits()) { - return true; - } else if (cast->type.bits() == cast->value.type().bits()) { - return is_safe_for_add(cast->value, max_depth); - } - } else if (Call::as_intrinsic(e, {Call::widening_add, Call::widening_sub, Call::widen_right_add, Call::widen_right_sub})) { - return true; - } - return false; -} - -bool is_safe_for_add(const Expr &e) { - return is_safe_for_add(e, e.type().bits() / 2 - 1); -} - // We want to find and remove an add of 'round' from e. This is not // the same thing as just subtracting round, we specifically want // to remove an addition of exactly round. @@ -130,103 +86,129 @@ Expr find_and_subtract(const Expr &e, const Expr &round) { return Expr(); } -Expr to_rounding_shift(const Call *c) { - if (c->is_intrinsic(Call::shift_left) || c->is_intrinsic(Call::shift_right)) { - internal_assert(c->args.size() == 2); - Expr a = c->args[0]; - Expr b = c->args[1]; +class FindIntrinsics : public IRMutator { +protected: + using IRMutator::visit; - // Helper to make the appropriate shift. - auto rounding_shift = [&](const Expr &a, const Expr &b) { - if (c->is_intrinsic(Call::shift_right)) { - return rounding_shift_right(a, b); - } else { - return rounding_shift_left(a, b); - } - }; + IRMatcher::Wild<0> x; + IRMatcher::Wild<1> y; + IRMatcher::Wild<2> z; + IRMatcher::Wild<3> w; + IRMatcher::WildConst<0> c0; + IRMatcher::WildConst<1> c1; - // The rounding offset for the shift we have. - Type round_type = a.type().with_lanes(1); - if (Call::as_intrinsic(a, {Call::widening_add})) { - round_type = round_type.narrow(); - } - Expr round; - if (c->is_intrinsic(Call::shift_right)) { - round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + std::map bounds_cache; + Scope let_var_bounds; + + Expr lossless_cast(Type t, const Expr &e) { + return Halide::Internal::lossless_cast(t, e, &bounds_cache); + } + + ConstantInterval constant_integer_bounds(const Expr &e) { + // TODO: Use the scope - add let visitors + return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache); + } + + Expr lossless_narrow(const Expr &x) { + return can_narrow(x.type()) ? lossless_cast(x.type().narrow(), x) : Expr(); + } + + // Remove a widening cast even if it changes the sign of the result. + Expr strip_widening_cast(const Expr &x) { + if (can_narrow(x.type())) { + Expr narrow = lossless_narrow(x); + if (narrow.defined()) { + return narrow; + } + return lossless_cast(x.type().narrow().with_code(halide_type_uint), x); } else { - round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + return Expr(); } - // Input expressions are simplified before running find_intrinsics, but b - // has been lifted here so we need to lower_intrinsics before simplifying - // and re-lifting. Should we move this code into the FindIntrinsics class - // to make it easier to lift round? - round = lower_intrinsics(round); - round = simplify(round); - round = find_intrinsics(round); - - // We can always handle widening adds. - if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) { - if (can_prove(lower_intrinsics(add->args[0] == round))) { - return rounding_shift(cast(add->type, add->args[1]), b); - } else if (can_prove(lower_intrinsics(add->args[1] == round))) { - return rounding_shift(cast(add->type, add->args[0]), b); + } + + Expr to_rounding_shift(const Call *c) { + if (c->is_intrinsic(Call::shift_left) || c->is_intrinsic(Call::shift_right)) { + internal_assert(c->args.size() == 2); + Expr a = c->args[0]; + Expr b = c->args[1]; + + // Helper to make the appropriate shift. + auto rounding_shift = [&](const Expr &a, const Expr &b) { + if (c->is_intrinsic(Call::shift_right)) { + return rounding_shift_right(a, b); + } else { + return rounding_shift_left(a, b); + } + }; + + // The rounding offset for the shift we have. + Type round_type = a.type().with_lanes(1); + if (Call::as_intrinsic(a, {Call::widening_add})) { + round_type = round_type.narrow(); + } + Expr round; + if (c->is_intrinsic(Call::shift_right)) { + round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + } else { + round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; + } + // Input expressions are simplified before running find_intrinsics, but b + // has been lifted here so we need to lower_intrinsics before simplifying + // and re-lifting. Should we move this code into the FindIntrinsics class + // to make it easier to lift round? + round = lower_intrinsics(round); + round = simplify(round); + round = find_intrinsics(round); + + // We can always handle widening adds. + if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) { + if (can_prove(lower_intrinsics(add->args[0] == round))) { + return rounding_shift(cast(add->type, add->args[1]), b); + } else if (can_prove(lower_intrinsics(add->args[1] == round))) { + return rounding_shift(cast(add->type, add->args[0]), b); + } } - } - if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) { - if (can_prove(lower_intrinsics(add->args[1] == round))) { - return rounding_shift(cast(add->type, add->args[0]), b); + if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) { + if (can_prove(lower_intrinsics(add->args[1] == round))) { + return rounding_shift(cast(add->type, add->args[0]), b); + } } - } - // Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add - // TODO: this pattern makes me want to change the semantics of this op. - if (const Cast *cast = a.as()) { - if (cast->is_reinterpret()) { - if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) { - if (can_prove(lower_intrinsics(add->args[1] == round))) { - // We expect the first operand to be a reinterpet cast. - if (const Cast *cast_a = add->args[0].as()) { - if (cast_a->is_reinterpret()) { - return rounding_shift(cast_a->value, b); + // Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add + if (const Cast *cast = a.as()) { + if (cast->is_reinterpret()) { + if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) { + if (can_prove(lower_intrinsics(add->args[1] == round))) { + // We expect the first operand to be a reinterpet cast. + if (const Cast *cast_a = add->args[0].as()) { + if (cast_a->is_reinterpret()) { + return rounding_shift(cast_a->value, b); + } } } } } } - } - // If it wasn't a widening or saturating add, we might still - // be able to safely accept the rounding. - Expr a_less_round = find_and_subtract(a, round); - if (a_less_round.defined()) { - // We found and removed the rounding. However, we may have just changed - // behavior due to overflow. This is still safe if the type is not - // overflowing, or we can find a widening add or subtract in the tree - // of adds/subtracts. This is a common pattern, e.g. - // rounding_halving_add(a, b) = shift_round(widening_add(a, b) + 1, 1). - // TODO: This could be done with bounds inference instead of this hack - // if it supported intrinsics like widening_add and tracked bounds for - // types other than int32. - if (no_overflow(a.type()) || is_safe_for_add(a_less_round)) { - return rounding_shift(simplify(a_less_round), b); + // If it wasn't a widening or saturating add, we might still + // be able to safely accept the rounding. + Expr a_less_round = find_and_subtract(a, round); + if (a_less_round.defined()) { + // We found and removed the rounding. Verify it didn't change + // overflow behavior. + if (no_overflow(a.type()) || + a.type().can_represent(constant_integer_bounds(a_less_round) + + constant_integer_bounds(round))) { + // If we can add the rounding term back on without causing + // overflow, then it must not have overflowed originally. + return rounding_shift(simplify(a_less_round), b); + } } } - } - - return Expr(); -} -class FindIntrinsics : public IRMutator { -protected: - using IRMutator::visit; - - IRMatcher::Wild<0> x; - IRMatcher::Wild<1> y; - IRMatcher::Wild<2> z; - IRMatcher::Wild<3> w; - IRMatcher::WildConst<0> c0; - IRMatcher::WildConst<1> c1; + return Expr(); + } Expr visit(const Add *op) override { if (!find_intrinsics_for_type(op->type)) { @@ -548,6 +530,11 @@ class FindIntrinsics : public IRMutator { } } + // Do we need to worry about this cast overflowing? + ConstantInterval value_bounds = constant_integer_bounds(value); + bool no_overflow = (op->type.can_represent(op->value.type()) || + op->type.can_represent(value_bounds)); + if (op->type.is_int() || op->type.is_uint()) { Expr lower = cast(value.type(), op->type.min()); Expr upper = cast(value.type(), op->type.max()); @@ -565,7 +552,6 @@ class FindIntrinsics : public IRMutator { auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y)); - auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2); if ( // Saturating patterns rewrite(max(min(widening_add(x, y), upper), lower), @@ -667,32 +653,16 @@ class FindIntrinsics : public IRMutator { rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && x_y_same_sign && c0 >= bits - 1) || - rewrite(shift_right(widening_mul(x, y), c0), - mul_shift_right(x, y, cast(unsigned_type, c0)), - is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) || - - rewrite(rounding_shift_right(widening_mul(x, y), c0), - rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), - is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) || - - // We can also match on smaller shifts if one of the args is - // narrow. We don't do this for signed (yet), because the - // saturation issue is tricky. - rewrite(shift_right(widening_mul(x, cast(op->type, y)), c0), - mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || - - rewrite(rounding_shift_right(widening_mul(x, cast(op->type, y)), c0), - rounding_mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + // We can also match whenever the cast can't overflow, so + // questions of saturation are irrelevant. + (no_overflow && + (rewrite(shift_right(widening_mul(x, y), c0), + mul_shift_right(x, y, cast(unsigned_type, c0)), + is_x_same_int_or_uint && x_y_same_sign && c0 >= 0) || - rewrite(shift_right(widening_mul(cast(op->type, y), x), c0), - mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || - - rewrite(rounding_shift_right(widening_mul(cast(op->type, y), x), c0), - rounding_mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), - is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + rewrite(rounding_shift_right(widening_mul(x, y), c0), + rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), + is_x_same_int_or_uint && x_y_same_sign && c0 >= 0))) || // Halving subtract patterns rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1), @@ -908,13 +878,16 @@ class FindIntrinsics : public IRMutator { } // TODO: do we want versions of widen_right_add here? - if (op->is_intrinsic(Call::shift_right) || op->is_intrinsic(Call::shift_left)) { + if (op->is_intrinsic(Call::shift_right) || + op->is_intrinsic(Call::shift_left)) { // Try to turn this into a widening shift. internal_assert(op->args.size() == 2); Expr a_narrow = lossless_narrow(op->args[0]); Expr b_narrow = lossless_narrow(op->args[1]); if (a_narrow.defined() && b_narrow.defined()) { - Expr result = op->is_intrinsic(Call::shift_left) ? widening_shift_left(a_narrow, b_narrow) : widening_shift_right(a_narrow, b_narrow); + Expr result = op->is_intrinsic(Call::shift_left) ? + widening_shift_left(a_narrow, b_narrow) : + widening_shift_right(a_narrow, b_narrow); if (result.type() != op->type) { result = Cast::make(op->type, result); } @@ -928,7 +901,8 @@ class FindIntrinsics : public IRMutator { } } - if (op->is_intrinsic(Call::rounding_shift_left) || op->is_intrinsic(Call::rounding_shift_right)) { + if (op->is_intrinsic(Call::rounding_shift_left) || + op->is_intrinsic(Call::rounding_shift_right)) { // Try to turn this into a widening shift. internal_assert(op->args.size() == 2); Expr a_narrow = lossless_narrow(op->args[0]); @@ -1490,27 +1464,45 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) // one of the operands and the denominator by a constant. We only do this // if it isn't already full precision. This avoids infinite loops despite // "lowering" this to another mul_shift_right operation. - if (can_prove(q < full_q)) { - Expr missing_q = full_q - q; - internal_assert(missing_q.type().bits() == b.type().bits()); - Expr new_b = simplify(b << missing_q); - if (is_const(new_b) && can_prove(new_b >> missing_q == b)) { - return rounding_mul_shift_right(a, new_b, full_q); + ConstantInterval cq = constant_integer_bounds(q); + if (cq.is_single_point() && cq.max >= 0 && cq.max < full_q) { + int missing_q = full_q - (int)cq.max; + + // Try to scale up the args by factors of two without overflowing + int a_shift = 0, b_shift = 0; + ConstantInterval ca = constant_integer_bounds(a); + while (true) { + ConstantInterval bigger = ca * 2; + if (a.type().can_represent(bigger) && a_shift + b_shift < missing_q) { + ca = bigger; + a_shift++; + } else { + break; + } } - Expr new_a = simplify(a << missing_q); - if (is_const(new_a) && can_prove(new_a >> missing_q == a)) { - return rounding_mul_shift_right(new_a, b, full_q); + ConstantInterval cb = constant_integer_bounds(b); + while (true) { + ConstantInterval bigger = cb * 2; + if (b.type().can_represent(bigger) && a_shift + b_shift < missing_q) { + cb = bigger; + b_shift++; + } else { + break; + } + } + if (a_shift + b_shift == missing_q) { + return rounding_mul_shift_right(simplify(a << a_shift), simplify(b << b_shift), full_q); } } // If all else fails, just widen, shift, and narrow. - Expr result = rounding_shift_right(widening_mul(a, b), q); - if (!can_prove(q >= a.type().bits())) { - result = saturating_narrow(result); + Expr wide_result = rounding_shift_right(widening_mul(a, b), q); + Expr narrowed = lossless_cast(a.type(), wide_result); + if (narrowed.defined()) { + return narrowed; } else { - result = narrow(result); + return saturating_narrow(wide_result); } - return result; } Expr lower_intrinsic(const Call *op) { diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 6834d4abe7f3..13b2b5d24559 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -382,6 +382,7 @@ typedef pair MulExpr; // the number of lanes in Broadcast or indices in a Shuffle // to match the ty lanes before using lossless_cast on it. Expr unbroadcast_lossless_cast(Type ty, Expr x) { + internal_assert(x.defined()); if (x.type().is_vector()) { if (const Broadcast *bc = x.as()) { if (ty.is_scalar()) { @@ -410,56 +411,78 @@ Expr unbroadcast_lossless_cast(Type ty, Expr x) { // multiplies in 'mpys', added to 'rest'. // Difference in mpys.size() - return indicates the number of // expressions where we pretend the op to be multiplied by 1. -int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count, +int find_mpy_ops(const Expr &op, Type result_ty, Type a_ty, Type b_ty, int max_mpy_count, vector &mpys, Expr &rest) { - if ((int)mpys.size() >= max_mpy_count) { - rest = rest.defined() ? Add::make(rest, op) : op; - return 0; - } - // If the add is also widening, remove the cast. - int mpy_bits = std::max(a_ty.bits(), b_ty.bits()) * 2; - Expr maybe_mul = op; - if (op.type().bits() == mpy_bits * 2) { - if (const Cast *cast = op.as()) { - if (cast->value.type().bits() == mpy_bits) { - maybe_mul = cast->value; - } + auto add_to_rest = [&](const Expr &a) { + if (rest.defined()) { + // Just widen to the result type. We run find_intrinsics on rest + // after calling this, to find things like widen_right_add in this + // summation. + rest = Add::make(rest, cast(result_ty, a)); + } else { + rest = cast(result_ty, a); } + }; + + if ((int)mpys.size() >= max_mpy_count) { + add_to_rest(op); + return 0; } - maybe_mul = as_mul(maybe_mul); - if (maybe_mul.defined()) { - const Mul *mul = maybe_mul.as(); - Expr a = unbroadcast_lossless_cast(a_ty, mul->a); - Expr b = unbroadcast_lossless_cast(b_ty, mul->b); + auto handle_mul = [&](const Expr &arg0, const Expr &arg1) -> bool { + Expr a = unbroadcast_lossless_cast(a_ty, arg0); + Expr b = unbroadcast_lossless_cast(b_ty, arg1); if (a.defined() && b.defined()) { mpys.emplace_back(a, b); - return 1; - } else { + return true; + } else if (a_ty != b_ty) { // Try to commute the op. - a = unbroadcast_lossless_cast(a_ty, mul->b); - b = unbroadcast_lossless_cast(b_ty, mul->a); + a = unbroadcast_lossless_cast(a_ty, arg1); + b = unbroadcast_lossless_cast(b_ty, arg0); if (a.defined() && b.defined()) { mpys.emplace_back(a, b); - return 1; + return true; } } + return false; + }; + + if (const Mul *mul = op.as()) { + bool no_overflow = mul->type.can_represent(constant_integer_bounds(mul->a) * + constant_integer_bounds(mul->b)); + if (no_overflow && handle_mul(mul->a, mul->b)) { + return 1; + } + } else if (const Call *mul = Call::as_intrinsic(op, {Call::widening_mul, Call::widen_right_mul})) { + bool no_overflow = (mul->is_intrinsic(Call::widening_mul) || + mul->type.can_represent(constant_integer_bounds(mul->args[0]) * + constant_integer_bounds(mul->args[1]))); + if (no_overflow && handle_mul(mul->args[0], mul->args[1])) { + return 1; + } } else if (const Add *add = op.as()) { - int mpy_count = 0; - mpy_count += find_mpy_ops(add->a, a_ty, b_ty, max_mpy_count, mpys, rest); - mpy_count += find_mpy_ops(add->b, a_ty, b_ty, max_mpy_count, mpys, rest); - return mpy_count; - } else if (const Call *add = Call::as_intrinsic(op, {Call::widening_add})) { - int mpy_count = 0; - mpy_count += find_mpy_ops(cast(op.type(), add->args[0]), a_ty, b_ty, max_mpy_count, mpys, rest); - mpy_count += find_mpy_ops(cast(op.type(), add->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest); - return mpy_count; - } else if (const Call *wadd = Call::as_intrinsic(op, {Call::widen_right_add})) { - int mpy_count = 0; - mpy_count += find_mpy_ops(wadd->args[0], a_ty, b_ty, max_mpy_count, mpys, rest); - mpy_count += find_mpy_ops(cast(op.type(), wadd->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest); - return mpy_count; + bool no_overflow = (add->type == result_ty || + add->type.can_represent(constant_integer_bounds(add->a) + + constant_integer_bounds(add->b))); + if (no_overflow) { + return (find_mpy_ops(add->a, result_ty, a_ty, b_ty, max_mpy_count, mpys, rest) + + find_mpy_ops(add->b, result_ty, a_ty, b_ty, max_mpy_count, mpys, rest)); + } + } else if (const Call *add = Call::as_intrinsic(op, {Call::widening_add, Call::widen_right_add})) { + bool no_overflow = (add->type == result_ty || + add->is_intrinsic(Call::widening_add) || + add->type.can_represent(constant_integer_bounds(add->args[0]) + + constant_integer_bounds(add->args[1]))); + if (no_overflow) { + return (find_mpy_ops(add->args[0], result_ty, a_ty, b_ty, max_mpy_count, mpys, rest) + + find_mpy_ops(add->args[1], result_ty, a_ty, b_ty, max_mpy_count, mpys, rest)); + } + } else if (const Cast *cast = op.as()) { + bool cast_is_lossless = cast->type.can_represent(constant_integer_bounds(cast->value)); + if (cast_is_lossless) { + return find_mpy_ops(cast->value, result_ty, a_ty, b_ty, max_mpy_count, mpys, rest); + } } // Attempt to pretend this op is multiplied by 1. @@ -471,7 +494,7 @@ int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count, } else if (as_b.defined()) { mpys.emplace_back(make_one(a_ty), as_b); } else { - rest = rest.defined() ? Add::make(rest, op) : op; + add_to_rest(op); } return 0; } @@ -554,10 +577,10 @@ class OptimizePatterns : public IRMutator { // match a subset of the expressions that vector*vector // matches. if (op->type.is_uint()) { - mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), UInt(8), 4, mpys, rest); suffix = ".vub.ub"; } else { - mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), Int(8), 4, mpys, rest); suffix = ".vub.b"; } @@ -588,7 +611,7 @@ class OptimizePatterns : public IRMutator { new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern); } if (rest.defined()) { - new_expr = Add::make(new_expr, rest); + new_expr = Add::make(new_expr, find_intrinsics(rest)); } return mutate(new_expr); } @@ -598,10 +621,10 @@ class OptimizePatterns : public IRMutator { mpys.clear(); rest = Expr(); if (op->type.is_uint()) { - mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest); suffix = ".vub.vub"; } else { - mpy_count = find_mpy_ops(op, Int(8, lanes), Int(8, lanes), 4, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, Int(8, lanes), Int(8, lanes), 4, mpys, rest); suffix = ".vb.vb"; } @@ -631,7 +654,7 @@ class OptimizePatterns : public IRMutator { new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern); } if (rest.defined()) { - new_expr = Add::make(new_expr, rest); + new_expr = Add::make(new_expr, find_intrinsics(rest)); } return mutate(new_expr); } @@ -650,11 +673,11 @@ class OptimizePatterns : public IRMutator { // Try to find vector*scalar multiplies. if (op->type.bits() == 16) { - mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 2, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, UInt(8, lanes), Int(8), 2, mpys, rest); vmpa_suffix = ".vub.vub.b.b"; vdmpy_suffix = ".vub.b"; } else if (op->type.bits() == 32) { - mpy_count = find_mpy_ops(op, Int(16, lanes), Int(8), 2, mpys, rest); + mpy_count = find_mpy_ops(op, op->type, Int(16, lanes), Int(8), 2, mpys, rest); vmpa_suffix = ".vh.vh.b.b"; vdmpy_suffix = ".vh.b"; } @@ -682,7 +705,7 @@ class OptimizePatterns : public IRMutator { new_expr = halide_hexagon_add_2mpy(op->type, vmpa_suffix, mpys[0].first, mpys[1].first, mpys[0].second, mpys[1].second); } if (rest.defined()) { - new_expr = Add::make(new_expr, rest); + new_expr = Add::make(new_expr, find_intrinsics(rest)); } return mutate(new_expr); } @@ -2271,6 +2294,9 @@ Stmt scatter_gather_generator(Stmt s) { } Stmt optimize_hexagon_instructions(Stmt s, const Target &t) { + debug(4) << "Hexagon: lowering before find_intrinsics\n" + << s << "\n"; + // We need to redo intrinsic matching due to simplification that has // happened after the end of target independent lowering. s = find_intrinsics(s); diff --git a/src/IRMatch.h b/src/IRMatch.h index 4f6dfb13c145..da5a300cfb01 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -2068,6 +2068,53 @@ HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp +struct WidenOp { + struct pattern_tag {}; + A a; + + constexpr static uint32_t binds = bindings::mask; + + constexpr static IRNodeType min_node_type = IRNodeType::Cast; + constexpr static IRNodeType max_node_type = IRNodeType::Cast; + constexpr static bool canonical = A::canonical; + + template + HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { + if (e.node_type != Cast::_node_type) { + return false; + } + const Cast &op = (const Cast &)e; + return (e.type == op.value.type().widen() && + a.template match(*op.value.get(), state)); + } + template + HALIDE_ALWAYS_INLINE bool match(const WidenOp &op, MatcherState &state) const noexcept { + return a.template match(unwrap(op.a), state); + } + + HALIDE_ALWAYS_INLINE + Expr make(MatcherState &state, halide_type_t type_hint) const { + Expr e = a.make(state, {}); + Type w = e.type().widen(); + return cast(w, std::move(e)); + } + + constexpr static bool foldable = false; +}; + +template +std::ostream &operator<<(std::ostream &s, const WidenOp &op) { + s << "widen(" << op.a << ")"; + return s; +} + +template +HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp { + assert_is_lvalue_if_expr(); + return {pattern_arg(a)}; +} + template struct SliceOp { struct pattern_tag {}; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 2011fdfa06bf..9be9fb55396f 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -6,12 +6,14 @@ #include #include "CSE.h" +#include "ConstantBounds.h" #include "Debug.h" #include "Func.h" #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" #include "IRPrinter.h" +#include "Interval.h" #include "Util.h" #include "Var.h" @@ -434,141 +436,152 @@ Expr const_false(int w) { return make_zero(UInt(1, w)); } -Expr lossless_cast(Type t, Expr e) { +Expr lossless_cast(Type t, Expr e, std::map *cache) { if (!e.defined() || t == e.type()) { return e; } else if (t.can_represent(e.type())) { return cast(t, std::move(e)); - } - - if (const Cast *c = e.as()) { + } else if (const Cast *c = e.as()) { if (c->type.can_represent(c->value.type())) { - // We can recurse into widening casts. - return lossless_cast(t, c->value); - } else { - return Expr(); + return lossless_cast(t, c->value, cache); } - } - - if (const Broadcast *b = e.as()) { - Expr v = lossless_cast(t.element_of(), b->value); + } else if (const Broadcast *b = e.as()) { + Expr v = lossless_cast(t.element_of(), b->value, cache); if (v.defined()) { return Broadcast::make(v, b->lanes); - } else { - return Expr(); } - } - - if (const IntImm *i = e.as()) { + } else if (const IntImm *i = e.as()) { if (t.can_represent(i->value)) { return make_const(t, i->value); - } else { - return Expr(); } - } - - if (const UIntImm *i = e.as()) { + } else if (const UIntImm *i = e.as()) { if (t.can_represent(i->value)) { return make_const(t, i->value); - } else { - return Expr(); } - } - - if (const FloatImm *f = e.as()) { + } else if (const FloatImm *f = e.as()) { if (t.can_represent(f->value)) { return make_const(t, f->value); - } else { - return Expr(); - } - } - - if (t.is_int_or_uint() && t.bits() >= 16) { - if (const Add *add = e.as()) { - // If we can losslessly narrow the args even more - // aggressively, we're good. - // E.g. lossless_cast(uint16, (uint32)(some_u8) + 37) - // = (uint16)(some_u8) + 37 - Expr a = lossless_cast(t.narrow(), add->a); - Expr b = lossless_cast(t.narrow(), add->b); - if (a.defined() && b.defined()) { - return cast(t, a) + cast(t, b); - } else { - return Expr(); - } } - - if (const Sub *sub = e.as()) { - Expr a = lossless_cast(t.narrow(), sub->a); - Expr b = lossless_cast(t.narrow(), sub->b); - if (a.defined() && b.defined()) { - return cast(t, a) - cast(t, b); - } else { - return Expr(); - } - } - - if (const Mul *mul = e.as()) { - Expr a = lossless_cast(t.narrow(), mul->a); - Expr b = lossless_cast(t.narrow(), mul->b); - if (a.defined() && b.defined()) { - return cast(t, a) * cast(t, b); - } else { + } else if (const Shuffle *shuf = e.as()) { + std::vector vecs; + for (const auto &vec : shuf->vectors) { + vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache)); + if (!vecs.back().defined()) { return Expr(); } } - - if (const VectorReduce *reduce = e.as()) { - const int factor = reduce->value.type().lanes() / reduce->type.lanes(); - switch (reduce->op) { - case VectorReduce::Add: - // A horizontal add requires one extra bit per factor - // of two in the reduction factor. E.g. a reduction of - // 8 vector lanes down to 2 requires 2 extra bits in - // the output. We only deal with power-of-two types - // though, so just make sure the reduction factor - // isn't so large that it will more than double the - // number of bits required. - if (factor < (1 << (t.bits() / 2))) { - Type narrower = reduce->value.type().with_bits(t.bits() / 2); - Expr val = lossless_cast(narrower, reduce->value); - if (val.defined()) { - val = cast(narrower.with_bits(t.bits()), val); - return VectorReduce::make(reduce->op, val, reduce->type.lanes()); + return Shuffle::make(vecs, shuf->indices); + } else if (t.is_int_or_uint()) { + // Check the bounds. If they're small enough, we can throw narrowing + // casts around e, or subterms. + ConstantInterval ci = constant_integer_bounds(e, Scope::empty_scope(), cache); + + if (t.can_represent(ci)) { + // There are certain IR nodes where if the result is expressible + // using some type, and the args are expressible using that type, + // then the operation can just be done in that type. + if (const Add *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Add::make(a, b); + } + } else if (const Sub *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Sub::make(a, b); + } + } else if (const Mul *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Mul::make(a, b); + } + } else if (const Min *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + debug(0) << a << " " << b << "\n"; + return Min::make(a, b); + } + } else if (const Max *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Max::make(a, b); + } + } else if (const Mod *op = e.as()) { + Expr a = lossless_cast(t, op->a, cache); + Expr b = lossless_cast(t, op->b, cache); + if (a.defined() && b.defined()) { + return Mod::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + return Add::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + return Sub::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + return Mul::make(a, b); + } + } else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left, + Call::shift_right, Call::widening_shift_right})) { + Expr a = lossless_cast(t, op->args[0], cache); + Expr b = lossless_cast(t, op->args[1], cache); + if (a.defined() && b.defined()) { + ConstantInterval cb = constant_integer_bounds(b, Scope::empty_scope(), cache); + if (cb > -t.bits() && cb < t.bits()) { + if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { + return a << b; + } else if (op->is_intrinsic({Call::shift_right, Call::widening_shift_right})) { + return a >> b; + } } } - break; - case VectorReduce::Max: - case VectorReduce::Min: { - Expr val = lossless_cast(t, reduce->value); - if (val.defined()) { - return VectorReduce::make(reduce->op, val, reduce->type.lanes()); + } else if (const VectorReduce *op = e.as()) { + if (op->op == VectorReduce::Add || + op->op == VectorReduce::Min || + op->op == VectorReduce::Max) { + Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache); + if (v.defined()) { + return VectorReduce::make(op->op, v, op->type.lanes()); + } } - break; } - default: - break; - } - } - } - if (const Shuffle *shuf = e.as()) { - std::vector vecs; - for (const auto &vec : shuf->vectors) { - vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec)); - if (!vecs.back().defined()) { - return Expr(); + // At this point we know the expression fits in the target type, but + // what we really want is for the expression to be computed in the + // target type. So we can add a cast to the target type if we want + // here, but it only makes sense to do it if the expression type has + // the same or fewer bits than the target type. + if (e.type().bits() <= t.bits()) { + return cast(t, e); } } - return Shuffle::make(vecs, shuf->indices); } return Expr(); } Expr lossless_negate(const Expr &x) { - if (false /* const Mul *m = x.as() */) { // disabled pending #8155 - /* + if (const Mul *m = x.as()) { + // Check the terms can't multiply to produce the most negative value. + if (x.type().is_int() && + !x.type().can_represent(-constant_integer_bounds(x))) { + return Expr(); + } + Expr b = lossless_negate(m->b); if (b.defined()) { return Mul::make(m->a, b); @@ -577,7 +590,7 @@ Expr lossless_negate(const Expr &x) { if (a.defined()) { return Mul::make(a, m->b); } - */ + } else if (const Call *m = Call::as_intrinsic(x, {Call::widening_mul})) { Expr b = lossless_negate(m->args[1]); if (b.defined()) { @@ -596,8 +609,7 @@ Expr lossless_negate(const Expr &x) { } else if (const Cast *c = x.as()) { Expr value = lossless_negate(c->value); if (value.defined()) { - // This works for constants, but not other things that - // could possibly be negated. + // This logic is only sound if we know the cast can't overflow. value = lossless_cast(c->type, value); if (value.defined()) { return value; diff --git a/src/IROperator.h b/src/IROperator.h index a96ef6223c0d..c84a4682152f 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -8,6 +8,7 @@ */ #include +#include #include "Expr.h" #include "Tuple.h" @@ -140,10 +141,16 @@ Expr const_true(int lanes = 1); * falses, if a lanes argument is given. */ Expr const_false(int lanes = 1); -/** Attempt to cast an expression to a smaller type while provably not - * losing information. If it can't be done, return an undefined - * Expr. */ -Expr lossless_cast(Type t, Expr e); +/** Attempt to cast an expression to a smaller type while provably not losing + * information. If it can't be done, return an undefined Expr. + * + * Optionally accepts a map that gives the constant bounds of exprs already + * analyzed to avoid redoing work across many calls to lossless_cast. It is not + * safe to use this optional map in contexts where the same Expr object may + * take on a different value. For example: + * (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)). + * It is safe to use it after uniquify_variable_names has been run. */ +Expr lossless_cast(Type t, Expr e, std::map *cache = nullptr); /** Attempt to negate x without introducing new IR and without overflow. * If it can't be done, return an undefined Expr. */ diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index 692abc0db7d4..22d3506d7859 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -81,11 +81,7 @@ int lossless_cast_test() { e = cast(i64, 1024) * cast(i64, 1024) * cast(i64, 1024); res |= check_lossless_cast(i32, e, (cast(i32, 1024) * 1024) * 1024); - if (res) { - std::cout << "Ignoring bugs in lossless_cast for now. Will be fixed in #8155\n"; - } - return 0; - // return res; + return res; } constexpr int size = 1024; @@ -235,6 +231,25 @@ Expr random_expr(std::mt19937 &rng) { } } +bool definitely_has_ub(Expr e) { + e = simplify(e); + + class HasOverflow : public IRVisitor { + void visit(const Call *op) override { + if (op->is_intrinsic({Call::signed_integer_overflow})) { + found = true; + } + IRVisitor::visit(op); + } + + public: + bool found = false; + } has_overflow; + e.accept(&has_overflow); + + return has_overflow.found; +} + bool might_have_ub(Expr e) { class MightOverflow : public IRVisitor { std::map cache; @@ -331,8 +346,11 @@ int test_one(uint32_t seed) { buf_i8.fill(rng); Expr e1 = random_expr(rng); + Expr simplified = simplify(e1); - if (might_have_ub(e1)) { + if (might_have_ub(e1) || + might_have_ub(simplified) || + might_have_ub(lower_intrinsics(simplified))) { return 0; } @@ -348,12 +366,26 @@ int test_one(uint32_t seed) { return 0; } + if (definitely_has_ub(e2)) { + std::cout << "lossless_cast introduced ub:\n" + << "seed = " << seed << "\n" + << "e1 = " << e1 << "\n" + << "e2 = " << e2 << "\n" + << "simplify(e1) = " << simplify(e1) << "\n" + << "simplify(e2) = " << simplify(e2) << "\n"; + return 1; + } + Func f; f(x) = {cast(e1), cast(e2)}; f.vectorize(x, 4, TailStrategy::RoundUp); Buffer out1(size), out2(size); Pipeline p(f); + + // Check for signed integer overflow + // Module m = p.compile_to_module({}, "test"); + p.realize({out1, out2}); for (int x = 0; x < size; x++) { @@ -367,12 +399,8 @@ int test_one(uint32_t seed) { << "out1 = " << out1(x) << "\n" << "out2 = " << out2(x) << "\n" << "Original: " << e1 << "\n" - << "Lossless cast: " << e2 << "\n" - << "Ignoring bug for now. Will be fixed in #8155\n"; - // If lossless_cast has failed on this Expr, it's possible the test - // below will fail as well. - return 0; - // return 1; + << "Lossless cast: " << e2 << "\n"; + return 1; } } @@ -405,7 +433,9 @@ int fuzz_test(uint32_t root_seed) { std::cout << "Fuzz testing with root seed " << root_seed << "\n"; for (int i = 0; i < 1000; i++) { - if (test_one(seed_generator())) { + auto s = seed_generator(); + std::cout << s << "\n"; + if (test_one(s)) { return 1; } } diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index 3ebf5071569e..50eca65d34f4 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -561,7 +561,7 @@ class SimdOpCheckARM : public SimdOpCheckTest { // use the forms with an accumulator check(arm32 ? "vpadal.s8" : "sadalp", 16, sum_(i16(in_i8(f * x + r)))); check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(i16(in_u8(f * x + r)))); - check(arm32 ? "vpadal.u8" : "uadalp*", 16, sum_(u16(in_u8(f * x + r)))); + check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(u16(in_u8(f * x + r)))); check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i16(f * x + r)))); check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u16(f * x + r)))); @@ -595,17 +595,10 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(u32(in_u8(f * x + r)))); if (!arm32) { - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12); - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4)); - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) + i32(i8_4) * 12); - check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) + i32(i8_3) * 9 + i32(i8_4) * 12); - check("sdot", 8, i32_1 + i32(i8_1) + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12); - - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12); - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4)); - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) + u32(u8_4) * 12); - check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) + u32(u8_3) * 9 + u32(u8_4) * 12); - check("udot", 8, u32_1 + u32(u8_1) + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12); + check("udot", 8, u32(u8_1) * 200 + u32(u8_2) * 201 + u32(u8_3) * 202 + u32(u8_4) * 203); + // For signed, mapping the pattern above to sdot + // is a wash, because we can add more products + // of i8s together before they overflow an i16. } } else { check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r)))); @@ -621,15 +614,15 @@ class SimdOpCheckARM : public SimdOpCheckTest { // signed, because the intermediate type is u16 if (target.has_feature(Target::ARMDotProd)) { check(arm32 ? "vpadal.s16" : "sdot", 8, sum_(i32(in_i8(f * x + r)))); - check(arm32 ? "vpadal.u16" : "udot", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.s16" : "udot", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpadal.u16" : "udot", 8, sum_(u32(in_u8(f * x + r)))); } else { check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i8(f * x + r)))); - check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(u32(in_u8(f * x + r)))); } check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_i16(f * x + r)))); - check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(i64(in_u16(f * x + r)))); + check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_u16(f * x + r)))); check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(u64(in_u16(f * x + r)))); } diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 8286bc68f9e6..4a81dfbdf926 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -253,6 +253,17 @@ class SimdOpCheckX86 : public SimdOpCheckTest { for (int w = 2; w <= 4; w++) { check("pmulhrsw", 4 * w, i16((i32(i16_1) * i32(i16_2) + 16384) >> 15)); check("pmulhrsw", 4 * w, i16_sat((i32(i16_1) * i32(i16_2) + 16384) >> 15)); + // Should be able to use the non-saturating form of pmulhrsw, + // because the second arg can't be -32768, so the i16_sat + // doesn't actually need to saturate. + check("pmulhrsw", 4 * w, i16_sat((i32(i16_1) * i32(i16_2 / 2) + 16384) >> 15)); + + // Should be able to use pmulhrsw despite the shift being too + // small, because there are enough bits of headroom to shift + // left one of the args: + check("pmulhrsw", 4 * w, i16_sat((i32(i16_1) * i32(i16_2 / 2) + 8192) >> 14)); + check("pmulhrsw", 4 * w, i16((i32(i16_1) * i32(i16_2 / 3) + 8192) >> 14)); + check("pabsb", 8 * w, abs(i8_1)); check("pabsw", 4 * w, abs(i16_1)); check("pabsd", 2 * w, abs(i32_1)); diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index be1421d5e11c..264471294380 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1313,6 +1313,13 @@ void check_bounds() { check(max(x * 4 + 63, y) - max(y - 3, x * 4), clamp(x * 4 - y, -63, -3) + 66); check(max(x * 4, y - 3) - max(x * 4 + 63, y), clamp(y - x * 4, 3, 63) + -66); check(max(y - 3, x * 4) - max(x * 4 + 63, y), clamp(y - x * 4, 3, 63) + -66); + + // Check we can track bounds correctly through various operations + check(ramp(cast(x) / 2 + 3, cast(1), 16) < broadcast(200, 16), const_true(16)); + check(cast(cast(x)) * 3 >= cast(0), const_true()); + check(cast(cast(x)) * 3 < cast(768), const_true()); + check(cast(abs(cast(x))) >= cast(0), const_true()); + check(cast(abs(cast(x))) - cast(128) <= cast(0), const_true()); } void check_boolean() {