diff --git a/Makefile b/Makefile index b73b1632a0eb..72c05619e3ea 100644 --- a/Makefile +++ b/Makefile @@ -603,6 +603,7 @@ SOURCE_FILES = \ StorageFlattening.cpp \ StorageFolding.cpp \ StrictifyFloat.cpp \ + StripAsserts.cpp \ Substitute.cpp \ Target.cpp \ Tracing.cpp \ @@ -785,6 +786,7 @@ HEADER_FILES = \ StorageFlattening.h \ StorageFolding.h \ StrictifyFloat.h \ + StripAsserts.h \ Substitute.h \ Target.h \ Tracing.h \ diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index dfe9ae88c85f..77d8015f32b9 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -162,7 +162,6 @@ Stmt add_image_checks_inner(Stmt s, const FuncValueBounds &fb, bool will_inject_host_copies) { - bool no_asserts = t.has_feature(Target::NoAsserts); bool no_bounds_query = t.has_feature(Target::NoBoundsQuery); // First hunt for all the referenced buffers @@ -618,12 +617,9 @@ Stmt add_image_checks_inner(Stmt s, replace_with_constrained[name] = constrained_var; } - Expr error = 0; - if (!no_asserts) { - error = Call::make(Int(32), "halide_error_constraint_violated", - {name, var, constrained_var_str, constrained_var}, - Call::Extern); - } + Expr error = Call::make(Int(32), "halide_error_constraint_violated", + {name, var, constrained_var_str, constrained_var}, + Call::Extern); // Check the var passed in equals the constrained version (when not in inference mode) asserts_constrained.push_back(AssertStmt::make(var == constrained_var, error)); @@ -679,14 +675,12 @@ Stmt add_image_checks_inner(Stmt s, } }; - if (!no_asserts) { - // Inject the code that checks the host pointers. - prepend_stmts(&asserts_host_non_null); - prepend_stmts(&asserts_host_alignment); - prepend_stmts(&asserts_device_not_dirty); - prepend_stmts(&dims_no_overflow_asserts); - prepend_lets(&lets_overflow); - } + // Inject the code that checks the host pointers. + prepend_stmts(&asserts_host_non_null); + prepend_stmts(&asserts_host_alignment); + prepend_stmts(&asserts_device_not_dirty); + prepend_stmts(&dims_no_overflow_asserts); + prepend_lets(&lets_overflow); // Replace uses of the var with the constrained versions in the // rest of the program. We also need to respect the existence of @@ -698,15 +692,10 @@ Stmt add_image_checks_inner(Stmt s, // all in reverse order compared to execution, as we incrementally // prepending code. - // Inject the code that checks the constraints are correct. We - // need these regardless of how NoAsserts is set, because they are - // what gets Halide to actually exploit the constraint. + // Inject the code that checks the constraints are correct. prepend_stmts(&asserts_constrained); - - if (!no_asserts) { - prepend_stmts(&asserts_required); - prepend_stmts(&asserts_type_checks); - } + prepend_stmts(&asserts_required); + prepend_stmts(&asserts_type_checks); // Inject the code that returns early for inference mode. if (!no_bounds_query) { @@ -714,9 +703,7 @@ Stmt add_image_checks_inner(Stmt s, prepend_stmts(&buffer_rewrites); } - if (!no_asserts) { - prepend_stmts(&asserts_proposed); - } + prepend_stmts(&asserts_proposed); // Inject the code that defines the proposed sizes. prepend_lets(&lets_proposed); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cca681661c35..557574f284c4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -156,6 +156,7 @@ set(HEADER_FILES StorageFlattening.h StorageFolding.h StrictifyFloat.h + StripAsserts.h Substitute.h Target.h Tracing.h @@ -340,6 +341,7 @@ set(SOURCE_FILES StorageFlattening.cpp StorageFolding.cpp StrictifyFloat.cpp + StripAsserts.cpp Substitute.cpp Target.cpp Tracing.cpp diff --git a/src/Lower.cpp b/src/Lower.cpp index ba0918831fc8..560e0353c7a4 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -68,6 +68,7 @@ #include "StorageFlattening.h" #include "StorageFolding.h" #include "StrictifyFloat.h" +#include "StripAsserts.h" #include "Substitute.h" #include "Tracing.h" #include "TrimNoOps.h" @@ -427,6 +428,12 @@ void lower_impl(const vector &output_funcs, s = hoist_prefetches(s); log("Lowering after hoisting prefetches:", s); + if (t.has_feature(Target::NoAsserts)) { + debug(1) << "Stripping asserts...\n"; + s = strip_asserts(s); + log("Lowering after stripping asserts:", s); + } + debug(1) << "Lowering after final simplification:\n" << s << "\n\n"; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index c575cd47477d..aa45841253b7 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1368,11 +1368,7 @@ class InjectFunctionRealization : public IRMutator { // This is also the point at which we inject explicit bounds // for this realization. - if (target.has_feature(Target::NoAsserts)) { - return s; - } else { - return inject_explicit_bounds(s, func); - } + return inject_explicit_bounds(s, func); } Stmt build_realize_function_from_group(Stmt s, int func_index) { diff --git a/src/StripAsserts.cpp b/src/StripAsserts.cpp new file mode 100644 index 000000000000..9d9c667f4db1 --- /dev/null +++ b/src/StripAsserts.cpp @@ -0,0 +1,121 @@ +#include "StripAsserts.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "IRVisitor.h" +#include + +namespace Halide { +namespace Internal { + +namespace { + +bool may_discard(const Expr &e) { + class MayDiscard : public IRVisitor { + using IRVisitor::visit; + + void visit(const Call *op) override { + // Extern calls that are side-effecty in the sense that you can't + // move them around in the IR, but we're free to discard because + // they're just getters. + static const std::set discardable{ + Call::buffer_get_dimensions, + Call::buffer_get_min, + Call::buffer_get_extent, + Call::buffer_get_stride, + Call::buffer_get_max, + Call::buffer_get_host, + Call::buffer_get_device, + Call::buffer_get_device_interface, + Call::buffer_get_shape, + Call::buffer_get_host_dirty, + Call::buffer_get_device_dirty, + Call::buffer_get_type}; + + if (!(op->is_pure() || + discardable.count(op->name))) { + result = false; + } + } + + public: + bool result = true; + } d; + e.accept(&d); + + return d.result; +} + +class StripAsserts : public IRMutator { + using IRMutator::visit; + + // We're going to track which symbols are used so that we can strip lets we + // don't need after removing the asserts. + std::set used; + + // Drop all assert stmts. Assumes that you don't want any side-effects from + // the condition. + Stmt visit(const AssertStmt *op) override { + return Evaluate::make(0); + } + + Expr visit(const Variable *op) override { + used.insert(op->name); + return op; + } + + Expr visit(const Load *op) override { + used.insert(op->name); + return IRMutator::visit(op); + } + + Stmt visit(const Store *op) override { + used.insert(op->name); + return IRMutator::visit(op); + } + + // Also dead-code eliminate any let stmts wrapped around asserts + Stmt visit(const LetStmt *op) override { + Stmt body = mutate(op->body); + if (is_no_op(body)) { + if (may_discard(op->value)) { + return body; + } else { + // We visit the value just to keep the used variable set + // accurate. + mutate(op->value); + return Evaluate::make(op->value); + } + } else if (body.same_as(op->body)) { + mutate(op->value); + return op; + } else if (may_discard(op->value) && !used.count(op->name)) { + return body; + } else { + mutate(op->value); + return LetStmt::make(op->name, op->value, body); + } + } + + Stmt visit(const Block *op) override { + Stmt first = mutate(op->first); + Stmt rest = mutate(op->rest); + if (first.same_as(op->first) && rest.same_as(op->rest)) { + return op; + } else if (is_no_op(rest)) { + return first; + } else if (is_no_op(first)) { + return rest; + } else { + return Block::make(first, rest); + } + } +}; + +} // namespace + +Stmt strip_asserts(const Stmt &s) { + return StripAsserts().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/StripAsserts.h b/src/StripAsserts.h new file mode 100644 index 000000000000..48b22b3a5218 --- /dev/null +++ b/src/StripAsserts.h @@ -0,0 +1,18 @@ +#ifndef HALIDE_STRIP_ASSERTS_H +#define HALIDE_STRIP_ASSERTS_H + +/** \file + * Defines the lowering pass that strips asserts when NoAsserts is set. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +Stmt strip_asserts(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif