Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework the simplifier to use ConstantInterval for bounds #8222

Merged
merged 21 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
2741cde
Fix saturating add matching in associativity checking
abadams May 14, 2024
375c73b
Fix give-up case in ModulusRemainder
abadams May 14, 2024
3c52ab6
Update the simplifier to use ConstantInterval
abadams May 14, 2024
456da0c
Merge remote-tracking branch 'origin/main' into abadams/constant_inte…
abadams May 14, 2024
743621e
Merge remote-tracking branch 'origin/abadams/fix_modulus_remainder_ov…
abadams May 14, 2024
8ea558b
Merge remote-tracking branch 'origin/abadams/fix_saturating_add_assoc…
abadams May 14, 2024
c2ea452
Move the simplify fuzzer back to a correctness test
abadams May 15, 2024
40e8b07
Make debug_indent not static
abadams May 15, 2024
19e7ec7
Track expr info on non-overflowing casts to int
abadams May 15, 2024
12c3830
Delete commented-out code
abadams May 15, 2024
4b245f1
clang-tidy
abadams May 15, 2024
b1ee0a3
Delete unused member
abadams May 15, 2024
40f3cd1
Merge remote-tracking branch 'origin/main' into abadams/constant_inte…
abadams May 15, 2024
459e8bd
Fix cmakelists for the fuzzer removal
abadams May 22, 2024
d78dcec
Handle contradictions more gracefully in learn_true
abadams May 23, 2024
2f14594
Merge remote-tracking branch 'origin/main' into abadams/constant_inte…
abadams May 24, 2024
0f51b87
Better comments
abadams May 27, 2024
405c8ce
Address review comments
abadams Jun 1, 2024
e42dfbe
Merge remote-tracking branch 'origin/main' into abadams/constant_inte…
abadams Jun 1, 2024
883a9da
Fix failure to pop loop var info
abadams Jun 1, 2024
b413bba
Merge remote-tracking branch 'origin/main' into abadams/constant_inte…
abadams Jun 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 33 additions & 37 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,29 @@ using std::pair;
using std::string;
using std::vector;

#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS)
int Simplify::debug_indent = 0;
#endif

Simplify::Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai)
: remove_dead_code(r) {

// Only respect the constant bounds from the containing scope.
for (auto iter = bi->cbegin(); iter != bi->cend(); ++iter) {
ExprInfo bounds;
ExprInfo info;
if (const int64_t *i_min = as_const_int(iter.value().min)) {
bounds.min_defined = true;
bounds.min = *i_min;
info.bounds.min_defined = true;
info.bounds.min = *i_min;
}
if (const int64_t *i_max = as_const_int(iter.value().max)) {
bounds.max_defined = true;
bounds.max = *i_max;
info.bounds.max_defined = true;
info.bounds.max = *i_max;
}

if (const auto *a = ai->find(iter.name())) {
bounds.alignment = *a;
info.alignment = *a;
}

if (bounds.min_defined || bounds.max_defined || bounds.alignment.modulus != 1) {
bounds_and_alignment_info.push(iter.name(), bounds);
if (info.bounds.min_defined ||
info.bounds.max_defined ||
info.alignment.modulus != 1) {
bounds_and_alignment_info.push(iter.name(), info);
}
}

Expand All @@ -48,20 +46,20 @@ Simplify::Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemaind
// Already handled
continue;
}
ExprInfo bounds;
bounds.alignment = iter.value();
bounds_and_alignment_info.push(iter.name(), bounds);
ExprInfo info;
info.alignment = iter.value();
bounds_and_alignment_info.push(iter.name(), info);
}
}

std::pair<std::vector<Expr>, bool> Simplify::mutate_with_changes(const std::vector<Expr> &old_exprs, ExprInfo *bounds) {
std::pair<std::vector<Expr>, bool> Simplify::mutate_with_changes(const std::vector<Expr> &old_exprs) {
vector<Expr> new_exprs(old_exprs.size());
bool changed = false;

// Mutate the args
for (size_t i = 0; i < old_exprs.size(); i++) {
const Expr &old_e = old_exprs[i];
Expr new_e = mutate(old_e, bounds);
Expr new_e = mutate(old_e, nullptr);
if (!new_e.same_as(old_e)) {
changed = true;
}
Expand Down Expand Up @@ -135,35 +133,35 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) {
Simplify::ExprInfo i;
if (v) {
simplify->mutate(lt->b, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// !(v < i)
learn_lower_bound(v, i.min);
learn_lower_bound(v, i.bounds.min);
}
}
v = lt->b.as<Variable>();
if (v) {
simplify->mutate(lt->a, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// !(i < v)
learn_upper_bound(v, i.max);
learn_upper_bound(v, i.bounds.max);
}
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
Simplify::ExprInfo i;
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(le->b, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// !(v <= i)
learn_lower_bound(v, i.min + 1);
learn_lower_bound(v, i.bounds.min + 1);
}
}
v = le->b.as<Variable>();
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(le->a, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// !(i <= v)
learn_upper_bound(v, i.max - 1);
learn_upper_bound(v, i.bounds.max - 1);
}
}
} else if (const Call *c = Call::as_tag(fact)) {
Expand All @@ -185,8 +183,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) {

void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) {
ExprInfo b;
b.max_defined = true;
b.max = val;
b.bounds = ConstantInterval::bounded_above(val);
if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) {
b.intersect(*info);
}
Expand All @@ -196,8 +193,7 @@ void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) {

void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) {
ExprInfo b;
b.min_defined = true;
b.min = val;
b.bounds = ConstantInterval::bounded_below(val);
if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) {
b.intersect(*info);
}
Expand Down Expand Up @@ -267,35 +263,35 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) {
Simplify::ExprInfo i;
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(lt->b, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// v < i
learn_upper_bound(v, i.max - 1);
learn_upper_bound(v, i.bounds.max - 1);
}
}
v = lt->b.as<Variable>();
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(lt->a, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// i < v
learn_lower_bound(v, i.min + 1);
learn_lower_bound(v, i.bounds.min + 1);
}
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
Simplify::ExprInfo i;
if (v) {
simplify->mutate(le->b, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// v <= i
learn_upper_bound(v, i.max);
learn_upper_bound(v, i.bounds.max);
}
}
v = le->b.as<Variable>();
if (v) {
simplify->mutate(le->a, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// i <= v
learn_lower_bound(v, i.min);
learn_lower_bound(v, i.bounds.min);
}
}
} else if (const Call *c = Call::as_tag(fact)) {
Expand Down
6 changes: 4 additions & 2 deletions src/Simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ namespace Internal {
* Exprs that should be assumed to be true.
*/
// @{
Stmt simplify(const Stmt &, bool remove_dead_code = true,
Stmt simplify(const Stmt &,
bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
Expr simplify(const Expr &, bool remove_dead_code = true,
Expr simplify(const Expr &,
bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
Expand Down
28 changes: 12 additions & 16 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@
namespace Halide {
namespace Internal {

Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
ExprInfo a_bounds, b_bounds;
Expr a = mutate(op->a, &a_bounds);
Expr b = mutate(op->b, &b_bounds);

if (bounds && no_overflow_int(op->type)) {
bounds->min_defined = a_bounds.min_defined &&
b_bounds.min_defined &&
add_with_overflow(64, a_bounds.min, b_bounds.min, &(bounds->min));
bounds->max_defined = a_bounds.max_defined &&
b_bounds.max_defined &&
add_with_overflow(64, a_bounds.max, b_bounds.max, &(bounds->max));
bounds->alignment = a_bounds.alignment + b_bounds.alignment;
bounds->trim_bounds_using_alignment();
Expr Simplify::visit(const Add *op, ExprInfo *info) {
ExprInfo a_info, b_info;
Expr a = mutate(op->a, &a_info);
Expr b = mutate(op->b, &b_info);

if (info) {
info->bounds = a_info.bounds + b_info.bounds;
info->alignment = a_info.alignment + b_info.alignment;
info->trim_bounds_using_alignment();
info->cast_to(op->type);
}

if (may_simplify(op->type)) {

// Order commutative operations by node type
if (should_commute(a, b)) {
std::swap(a, b);
std::swap(a_bounds, b_bounds);
std::swap(a_info, b_info);
}

auto rewrite = IRMatcher::rewriter(IRMatcher::add(a, b), op->type);
Expand Down Expand Up @@ -194,7 +190,7 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
rewrite(x + (y + (c0 - x)/c1)*c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) ||

false)))) {
return mutate(rewrite.result, bounds);
return mutate(rewrite.result, info);
}
// clang-format on
}
Expand Down
4 changes: 2 additions & 2 deletions src/Simplify_And.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace Halide {
namespace Internal {

Expr Simplify::visit(const And *op, ExprInfo *bounds) {
Expr Simplify::visit(const And *op, ExprInfo *info) {
if (falsehoods.count(op)) {
return const_false(op->type.lanes());
}
Expand Down Expand Up @@ -109,7 +109,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) {
rewrite(x <= y && x <= z, x <= min(y, z)) ||
rewrite(y <= x && z <= x, max(y, z) <= x)) {

return mutate(rewrite.result, bounds);
return mutate(rewrite.result, info);
}

if (a.same_as(op->a) &&
Expand Down
Loading
Loading