Skip to content

Commit

Permalink
Fix horrifying bug in lossless_cast of a subtract (#8155)
Browse files Browse the repository at this point in the history
* Fix horrifying bug in lossless_cast of a subtract

* Use constant integer intervals to analyze safety for lossless_cast

TODO:

- Dedup the constant integer code with the same code in the simplifier.
- Move constant interval arithmetic operations out of the class.
- Make the ConstantInterval part of the return type of lossless_cast
(and turn it into an inner helper) so that it isn't constantly
recomputed.

* Fix ARM and HVX instruction selection

Also added more TODOs

* Using constant_integer_bounds to strengthen FindIntrinsics

In particular, we can do better instruction selection for pmulhrsw

* Move new classes to new files

Also fix up Monotonic.cpp

* Make the simplifier use ConstantInterval

* Handle bounds of narrower types in the simplifier too

* Fix * operator. Add min/max/mod

* Add cache for constant bounds queries

* Fix ConstantInterval multiplication

* Add a simplifier rule which is apparently now necessary

* Misc cleanups and test improvements

* Add missing files

* Account for more aggressive simplification in fuse test

* Remove redundant helpers

* Add missing comment

* clear_bounds_info -> clear_expr_info

* Remove bad TODO

I can't think of a single case that could cause this

* It's too late to change the semantics of fixed point intrinsics

* Fix some UB

* Stronger assert in Simplify_Div

* Delete bad rewrite rules

* Fix bad test when lowering mul_shift_right

b_shift + b_shift < missing_q

* Avoid UB in lowering of rounding_shift_right/left

* Add shifts to the lossless cast fuzzer

This required a more careful signed-integer-overflow detection routine

* Fix bug in lossless_negate

* Add constant interval test

* Rework find_mpy_ops to handle more structures

* Fix bugs in lossless_cast

* Fix mul_shift_right expansion

* Delete commented-out code

* Don't introduce out-of-range shifts in lossless_cast

* Some constant folding only happens after lowering intrinsics in codegen

---------

Co-authored-by: Steven Johnson <[email protected]>
  • Loading branch information
abadams and steven-johnson authored Jun 26, 2024
1 parent 9b703f3 commit cab27d8
Show file tree
Hide file tree
Showing 12 changed files with 514 additions and 392 deletions.
64 changes: 28 additions & 36 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,50 +1212,42 @@ void CodeGen_ARM::visit(const Add *op) {
Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc");
Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc");

// clang-format off
Expr ma_i8 = widening_mul(a_i8, ac_i8);
Expr mb_i8 = widening_mul(b_i8, bc_i8);
Expr mc_i8 = widening_mul(c_i8, cc_i8);
Expr md_i8 = widening_mul(d_i8, dc_i8);

Expr ma_u8 = widening_mul(a_u8, ac_u8);
Expr mb_u8 = widening_mul(b_u8, bc_u8);
Expr mc_u8 = widening_mul(c_u8, cc_u8);
Expr md_u8 = widening_mul(d_u8, dc_u8);

static const Pattern patterns[] = {
// If we had better normalization, we could drastically reduce the number of patterns here.
// Signed variants.
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
// Signed variants (associative).
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{(init_i32 + widening_add(ma_i8, mb_i8)) + widening_add(mc_i8, md_i8), "dot_product"},
{init_i32 + (widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8)), "dot_product"},
{widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8), "dot_product"},

// Unsigned variants.
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
// Unsigned variants (associative).
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{(init_u32 + widening_add(ma_u8, mb_u8)) + widening_add(mc_u8, md_u8), "dot_product"},
{init_u32 + (widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8)), "dot_product"},
{widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8), "dot_product"},
};
// clang-format on

std::map<std::string, Expr> matches;
for (const Pattern &p : patterns) {
if (expr_match(p.pattern, op, matches)) {
Expr init = matches["init"];
Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]});
// Coefficients can be 1 if not in the pattern.
Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes()));
// This hideous code pattern implements fetching a
// default value if the map doesn't contain a key.
Expr _ac = matches.try_emplace("ac", one).first->second;
Expr _bc = matches.try_emplace("bc", one).first->second;
Expr _cc = matches.try_emplace("cc", one).first->second;
Expr _dc = matches.try_emplace("dc", one).first->second;
Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc});
Expr init;
auto it = matches.find("init");
if (it == matches.end()) {
init = make_zero(op->type);
} else {
init = it->second;
}
Expr values = Shuffle::make_interleave({matches["a"], matches["b"],
matches["c"], matches["d"]});
Expr coeffs = Shuffle::make_interleave({matches["ac"], matches["bc"],
matches["cc"], matches["dc"]});
value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs});
if (value) {
return;
Expand Down
11 changes: 8 additions & 3 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ void CodeGen_X86::visit(const Cast *op) {
};

// clang-format off
static const Pattern patterns[] = {
// This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't
static Pattern patterns[] = {
// This isn't rounding_mul_shift_right(i16, i16, 15) because it doesn't
// saturate the result.
{"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))},

Expand Down Expand Up @@ -736,7 +736,12 @@ void CodeGen_X86::visit(const Call *op) {
// Handle edge case of possible overflow.
// See https://github.com/halide/Halide/pull/7129/files#r1008331426
// On AVX512 (and with enough lanes) we can use a mask register.
if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
ConstantInterval ca = constant_integer_bounds(a);
ConstantInterval cb = constant_integer_bounds(b);
if (!ca.contains(-32768) || !cb.contains(-32768)) {
// Overflow isn't possible
pmulhrs.accept(this);
} else if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs);
expr.accept(this);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const IntImm *IntImm::make(Type t, int64_t value) {
internal_assert(t.is_int() && t.is_scalar())
<< "IntImm must be a scalar Int\n";
internal_assert(t.bits() >= 1 && t.bits() <= 64)
<< "IntImm must have between 1 and 64 bits\n";
<< "IntImm must have between 1 and 64 bits: " << t << "\n";

// Normalize the value by dropping the high bits.
// Since left-shift of negative value is UB in C++, cast to uint64 first;
Expand All @@ -28,7 +28,7 @@ const UIntImm *UIntImm::make(Type t, uint64_t value) {
internal_assert(t.is_uint() && t.is_scalar())
<< "UIntImm must be a scalar UInt\n";
internal_assert(t.bits() >= 1 && t.bits() <= 64)
<< "UIntImm must have between 1 and 64 bits\n";
<< "UIntImm must have between 1 and 64 bits " << t << "\n";

// Normalize the value by dropping the high bits
value <<= (64 - t.bits());
Expand Down
Loading

0 comments on commit cab27d8

Please sign in to comment.