Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strip asserts right at the end of lowering #8094

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ SOURCE_FILES = \
StorageFlattening.cpp \
StorageFolding.cpp \
StrictifyFloat.cpp \
StripAsserts.cpp \
Substitute.cpp \
Target.cpp \
Tracing.cpp \
Expand Down Expand Up @@ -785,6 +786,7 @@ HEADER_FILES = \
StorageFlattening.h \
StorageFolding.h \
StrictifyFloat.h \
StripAsserts.h \
Substitute.h \
Target.h \
Tracing.h \
Expand Down
39 changes: 13 additions & 26 deletions src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ Stmt add_image_checks_inner(Stmt s,
const FuncValueBounds &fb,
bool will_inject_host_copies) {

bool no_asserts = t.has_feature(Target::NoAsserts);
bool no_bounds_query = t.has_feature(Target::NoBoundsQuery);

// First hunt for all the referenced buffers
Expand Down Expand Up @@ -618,12 +617,9 @@ Stmt add_image_checks_inner(Stmt s,
replace_with_constrained[name] = constrained_var;
}

Expr error = 0;
if (!no_asserts) {
error = Call::make(Int(32), "halide_error_constraint_violated",
{name, var, constrained_var_str, constrained_var},
Call::Extern);
}
Expr error = Call::make(Int(32), "halide_error_constraint_violated",
{name, var, constrained_var_str, constrained_var},
Call::Extern);

// Check the var passed in equals the constrained version (when not in inference mode)
asserts_constrained.push_back(AssertStmt::make(var == constrained_var, error));
Expand Down Expand Up @@ -679,14 +675,12 @@ Stmt add_image_checks_inner(Stmt s,
}
};

if (!no_asserts) {
// Inject the code that checks the host pointers.
prepend_stmts(&asserts_host_non_null);
prepend_stmts(&asserts_host_alignment);
prepend_stmts(&asserts_device_not_dirty);
prepend_stmts(&dims_no_overflow_asserts);
prepend_lets(&lets_overflow);
}
// Inject the code that checks the host pointers.
prepend_stmts(&asserts_host_non_null);
prepend_stmts(&asserts_host_alignment);
prepend_stmts(&asserts_device_not_dirty);
prepend_stmts(&dims_no_overflow_asserts);
prepend_lets(&lets_overflow);

// Replace uses of the var with the constrained versions in the
// rest of the program. We also need to respect the existence of
Expand All @@ -698,25 +692,18 @@ Stmt add_image_checks_inner(Stmt s,
// all in reverse order compared to execution, as we incrementally
// prepending code.

// Inject the code that checks the constraints are correct. We
// need these regardless of how NoAsserts is set, because they are
// what gets Halide to actually exploit the constraint.
// Inject the code that checks the constraints are correct.
prepend_stmts(&asserts_constrained);

if (!no_asserts) {
prepend_stmts(&asserts_required);
prepend_stmts(&asserts_type_checks);
}
prepend_stmts(&asserts_required);
prepend_stmts(&asserts_type_checks);

// Inject the code that returns early for inference mode.
if (!no_bounds_query) {
s = IfThenElse::make(!maybe_return_condition, s);
prepend_stmts(&buffer_rewrites);
}

if (!no_asserts) {
prepend_stmts(&asserts_proposed);
}
prepend_stmts(&asserts_proposed);

// Inject the code that defines the proposed sizes.
prepend_lets(&lets_proposed);
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ set(HEADER_FILES
StorageFlattening.h
StorageFolding.h
StrictifyFloat.h
StripAsserts.h
Substitute.h
Target.h
Tracing.h
Expand Down Expand Up @@ -340,6 +341,7 @@ set(SOURCE_FILES
StorageFlattening.cpp
StorageFolding.cpp
StrictifyFloat.cpp
StripAsserts.cpp
Substitute.cpp
Target.cpp
Tracing.cpp
Expand Down
7 changes: 7 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include "StorageFlattening.h"
#include "StorageFolding.h"
#include "StrictifyFloat.h"
#include "StripAsserts.h"
#include "Substitute.h"
#include "Tracing.h"
#include "TrimNoOps.h"
Expand Down Expand Up @@ -427,6 +428,12 @@ void lower_impl(const vector<Function> &output_funcs,
s = hoist_prefetches(s);
log("Lowering after hoisting prefetches:", s);

if (t.has_feature(Target::NoAsserts)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding comment from the PR about why we do this at the end

debug(1) << "Stripping asserts...\n";
s = strip_asserts(s);
log("Lowering after stripping asserts:", s);
}

debug(1) << "Lowering after final simplification:\n"
<< s << "\n\n";

Expand Down
6 changes: 1 addition & 5 deletions src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1368,11 +1368,7 @@ class InjectFunctionRealization : public IRMutator {

// This is also the point at which we inject explicit bounds
// for this realization.
if (target.has_feature(Target::NoAsserts)) {
return s;
} else {
return inject_explicit_bounds(s, func);
}
return inject_explicit_bounds(s, func);
}

Stmt build_realize_function_from_group(Stmt s, int func_index) {
Expand Down
121 changes: 121 additions & 0 deletions src/StripAsserts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "StripAsserts.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include <set>

namespace Halide {
namespace Internal {

namespace {

bool may_discard(const Expr &e) {
class MayDiscard : public IRVisitor {
using IRVisitor::visit;

void visit(const Call *op) override {
// Extern calls that are side-effecty in the sense that you can't
// move them around in the IR, but we're free to discard because
// they're just getters.
static const std::set<std::string> discardable{
Call::buffer_get_dimensions,
Call::buffer_get_min,
Call::buffer_get_extent,
Call::buffer_get_stride,
Call::buffer_get_max,
Call::buffer_get_host,
Call::buffer_get_device,
Call::buffer_get_device_interface,
Call::buffer_get_shape,
Call::buffer_get_host_dirty,
Call::buffer_get_device_dirty,
Call::buffer_get_type};

if (!(op->is_pure() ||
discardable.count(op->name))) {
result = false;
}
}

public:
bool result = true;
} d;
e.accept(&d);

return d.result;
}

class StripAsserts : public IRMutator {
using IRMutator::visit;

// We're going to track which symbols are used so that we can strip lets we
// don't need after removing the asserts.
std::set<std::string> used;

// Drop all assert stmts. Assumes that you don't want any side-effects from
// the condition.
Stmt visit(const AssertStmt *op) override {
return Evaluate::make(0);
}

Expr visit(const Variable *op) override {
used.insert(op->name);
return op;
}

Expr visit(const Load *op) override {
used.insert(op->name);
return IRMutator::visit(op);
}

Stmt visit(const Store *op) override {
used.insert(op->name);
return IRMutator::visit(op);
}

// Also dead-code eliminate any let stmts wrapped around asserts
Stmt visit(const LetStmt *op) override {
Stmt body = mutate(op->body);
if (is_no_op(body)) {
if (may_discard(op->value)) {
return body;
} else {
// We visit the value just to keep the used variable set
// accurate.
mutate(op->value);
return Evaluate::make(op->value);
}
} else if (body.same_as(op->body)) {
mutate(op->value);
return op;
} else if (may_discard(op->value) && !used.count(op->name)) {
return body;
} else {
mutate(op->value);
return LetStmt::make(op->name, op->value, body);
}
}

Stmt visit(const Block *op) override {
Stmt first = mutate(op->first);
Stmt rest = mutate(op->rest);
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return op;
} else if (is_no_op(rest)) {
return first;
} else if (is_no_op(first)) {
return rest;
} else {
return Block::make(first, rest);
}
}
};

} // namespace

Stmt strip_asserts(const Stmt &s) {
return StripAsserts().mutate(s);
}

} // namespace Internal
} // namespace Halide
18 changes: 18 additions & 0 deletions src/StripAsserts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef HALIDE_STRIP_ASSERTS_H
#define HALIDE_STRIP_ASSERTS_H

/** \file
* Defines the lowering pass that strips asserts when NoAsserts is set.
*/

#include "Expr.h"

namespace Halide {
namespace Internal {

Stmt strip_asserts(const Stmt &s);

} // namespace Internal
} // namespace Halide

#endif
Loading