From 04258a57c4f8e7f46796e3ee5148a5101a92f7b4 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 13 Feb 2024 13:33:54 -0800 Subject: [PATCH] Strip asserts right at the end of lowering The simplifier exploits asserts to make simplification. When compiling with NoAsserts, certain assertions aren't ever introduced, which means that the simplifier can't exploit certain things that we know to be true. Mostly this has a negative effect on code size. E.g. tail cases get generated even though they are actually dead code. This PR keeps all the assertions right until the end of lowering, when it strips them in a dedicated pass. This reduces object file size for a large production blob of Halide code by ~10%, without measurably affecting runtime. --- Makefile | 2 + src/AddImageChecks.cpp | 39 ++++-------- src/CMakeLists.txt | 2 + src/Lower.cpp | 7 +++ src/ScheduleFunctions.cpp | 6 +- src/StripAsserts.cpp | 121 ++++++++++++++++++++++++++++++++++++++ src/StripAsserts.h | 18 ++++++ 7 files changed, 164 insertions(+), 31 deletions(-) create mode 100644 src/StripAsserts.cpp create mode 100644 src/StripAsserts.h diff --git a/Makefile b/Makefile index b73b1632a0eb..72c05619e3ea 100644 --- a/Makefile +++ b/Makefile @@ -603,6 +603,7 @@ SOURCE_FILES = \ StorageFlattening.cpp \ StorageFolding.cpp \ StrictifyFloat.cpp \ + StripAsserts.cpp \ Substitute.cpp \ Target.cpp \ Tracing.cpp \ @@ -785,6 +786,7 @@ HEADER_FILES = \ StorageFlattening.h \ StorageFolding.h \ StrictifyFloat.h \ + StripAsserts.h \ Substitute.h \ Target.h \ Tracing.h \ diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index dfe9ae88c85f..77d8015f32b9 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -162,7 +162,6 @@ Stmt add_image_checks_inner(Stmt s, const FuncValueBounds &fb, bool will_inject_host_copies) { - bool no_asserts = t.has_feature(Target::NoAsserts); bool no_bounds_query = t.has_feature(Target::NoBoundsQuery); // First hunt for all the referenced buffers @@ -618,12 +617,9 @@ Stmt add_image_checks_inner(Stmt s, replace_with_constrained[name] = constrained_var; } - Expr error = 0; - if (!no_asserts) { - error = Call::make(Int(32), "halide_error_constraint_violated", - {name, var, constrained_var_str, constrained_var}, - Call::Extern); - } + Expr error = Call::make(Int(32), "halide_error_constraint_violated", + {name, var, constrained_var_str, constrained_var}, + Call::Extern); // Check the var passed in equals the constrained version (when not in inference mode) asserts_constrained.push_back(AssertStmt::make(var == constrained_var, error)); @@ -679,14 +675,12 @@ Stmt add_image_checks_inner(Stmt s, } }; - if (!no_asserts) { - // Inject the code that checks the host pointers. - prepend_stmts(&asserts_host_non_null); - prepend_stmts(&asserts_host_alignment); - prepend_stmts(&asserts_device_not_dirty); - prepend_stmts(&dims_no_overflow_asserts); - prepend_lets(&lets_overflow); - } + // Inject the code that checks the host pointers. + prepend_stmts(&asserts_host_non_null); + prepend_stmts(&asserts_host_alignment); + prepend_stmts(&asserts_device_not_dirty); + prepend_stmts(&dims_no_overflow_asserts); + prepend_lets(&lets_overflow); // Replace uses of the var with the constrained versions in the // rest of the program. We also need to respect the existence of @@ -698,15 +692,10 @@ Stmt add_image_checks_inner(Stmt s, // all in reverse order compared to execution, as we incrementally // prepending code. - // Inject the code that checks the constraints are correct. We - // need these regardless of how NoAsserts is set, because they are - // what gets Halide to actually exploit the constraint. + // Inject the code that checks the constraints are correct. prepend_stmts(&asserts_constrained); - - if (!no_asserts) { - prepend_stmts(&asserts_required); - prepend_stmts(&asserts_type_checks); - } + prepend_stmts(&asserts_required); + prepend_stmts(&asserts_type_checks); // Inject the code that returns early for inference mode. if (!no_bounds_query) { @@ -714,9 +703,7 @@ Stmt add_image_checks_inner(Stmt s, prepend_stmts(&buffer_rewrites); } - if (!no_asserts) { - prepend_stmts(&asserts_proposed); - } + prepend_stmts(&asserts_proposed); // Inject the code that defines the proposed sizes. prepend_lets(&lets_proposed); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cca681661c35..557574f284c4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -156,6 +156,7 @@ set(HEADER_FILES StorageFlattening.h StorageFolding.h StrictifyFloat.h + StripAsserts.h Substitute.h Target.h Tracing.h @@ -340,6 +341,7 @@ set(SOURCE_FILES StorageFlattening.cpp StorageFolding.cpp StrictifyFloat.cpp + StripAsserts.cpp Substitute.cpp Target.cpp Tracing.cpp diff --git a/src/Lower.cpp b/src/Lower.cpp index ba0918831fc8..560e0353c7a4 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -68,6 +68,7 @@ #include "StorageFlattening.h" #include "StorageFolding.h" #include "StrictifyFloat.h" +#include "StripAsserts.h" #include "Substitute.h" #include "Tracing.h" #include "TrimNoOps.h" @@ -427,6 +428,12 @@ void lower_impl(const vector &output_funcs, s = hoist_prefetches(s); log("Lowering after hoisting prefetches:", s); + if (t.has_feature(Target::NoAsserts)) { + debug(1) << "Stripping asserts...\n"; + s = strip_asserts(s); + log("Lowering after stripping asserts:", s); + } + debug(1) << "Lowering after final simplification:\n" << s << "\n\n"; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index c575cd47477d..aa45841253b7 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1368,11 +1368,7 @@ class InjectFunctionRealization : public IRMutator { // This is also the point at which we inject explicit bounds // for this realization. - if (target.has_feature(Target::NoAsserts)) { - return s; - } else { - return inject_explicit_bounds(s, func); - } + return inject_explicit_bounds(s, func); } Stmt build_realize_function_from_group(Stmt s, int func_index) { diff --git a/src/StripAsserts.cpp b/src/StripAsserts.cpp new file mode 100644 index 000000000000..9d9c667f4db1 --- /dev/null +++ b/src/StripAsserts.cpp @@ -0,0 +1,121 @@ +#include "StripAsserts.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "IRVisitor.h" +#include + +namespace Halide { +namespace Internal { + +namespace { + +bool may_discard(const Expr &e) { + class MayDiscard : public IRVisitor { + using IRVisitor::visit; + + void visit(const Call *op) override { + // Extern calls that are side-effecty in the sense that you can't + // move them around in the IR, but we're free to discard because + // they're just getters. + static const std::set discardable{ + Call::buffer_get_dimensions, + Call::buffer_get_min, + Call::buffer_get_extent, + Call::buffer_get_stride, + Call::buffer_get_max, + Call::buffer_get_host, + Call::buffer_get_device, + Call::buffer_get_device_interface, + Call::buffer_get_shape, + Call::buffer_get_host_dirty, + Call::buffer_get_device_dirty, + Call::buffer_get_type}; + + if (!(op->is_pure() || + discardable.count(op->name))) { + result = false; + } + } + + public: + bool result = true; + } d; + e.accept(&d); + + return d.result; +} + +class StripAsserts : public IRMutator { + using IRMutator::visit; + + // We're going to track which symbols are used so that we can strip lets we + // don't need after removing the asserts. + std::set used; + + // Drop all assert stmts. Assumes that you don't want any side-effects from + // the condition. + Stmt visit(const AssertStmt *op) override { + return Evaluate::make(0); + } + + Expr visit(const Variable *op) override { + used.insert(op->name); + return op; + } + + Expr visit(const Load *op) override { + used.insert(op->name); + return IRMutator::visit(op); + } + + Stmt visit(const Store *op) override { + used.insert(op->name); + return IRMutator::visit(op); + } + + // Also dead-code eliminate any let stmts wrapped around asserts + Stmt visit(const LetStmt *op) override { + Stmt body = mutate(op->body); + if (is_no_op(body)) { + if (may_discard(op->value)) { + return body; + } else { + // We visit the value just to keep the used variable set + // accurate. + mutate(op->value); + return Evaluate::make(op->value); + } + } else if (body.same_as(op->body)) { + mutate(op->value); + return op; + } else if (may_discard(op->value) && !used.count(op->name)) { + return body; + } else { + mutate(op->value); + return LetStmt::make(op->name, op->value, body); + } + } + + Stmt visit(const Block *op) override { + Stmt first = mutate(op->first); + Stmt rest = mutate(op->rest); + if (first.same_as(op->first) && rest.same_as(op->rest)) { + return op; + } else if (is_no_op(rest)) { + return first; + } else if (is_no_op(first)) { + return rest; + } else { + return Block::make(first, rest); + } + } +}; + +} // namespace + +Stmt strip_asserts(const Stmt &s) { + return StripAsserts().mutate(s); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/StripAsserts.h b/src/StripAsserts.h new file mode 100644 index 000000000000..48b22b3a5218 --- /dev/null +++ b/src/StripAsserts.h @@ -0,0 +1,18 @@ +#ifndef HALIDE_STRIP_ASSERTS_H +#define HALIDE_STRIP_ASSERTS_H + +/** \file + * Defines the lowering pass that strips asserts when NoAsserts is set. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +Stmt strip_asserts(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif