Skip to content

Commit

Permalink
Report useful error to user if the promise_clamp all fails to lossles…
Browse files Browse the repository at this point in the history
…sly cast. (#8238)

Co-authored-by: Steven Johnson <[email protected]>
  • Loading branch information
mcourteaux and steven-johnson authored Jun 4, 2024
1 parent 775bfbf commit 46e866d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
29 changes: 25 additions & 4 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1587,10 +1587,31 @@ Tuple mux(const Expr &id, const std::vector<Tuple> &values) {
return Tuple{result};
}

namespace {
void cast_bounds_for_promise_clamped(const Expr &value, const Expr &min, const Expr &max, Expr &casted_min, Expr &casted_max, const char *call_name) {
{
Expr n_min_val = lossless_cast(value.type(), min);
if (min.defined()) {
user_assert(n_min_val.defined())
<< call_name << " min argument (type " << min.node_type() << " " << min.type() << ") could not be cast losslessly to " << value.type();
}
casted_min = n_min_val.defined() ? n_min_val : value.type().min();
}
{
Expr n_max_val = lossless_cast(value.type(), max);
if (max.defined()) {
user_assert(n_max_val.defined())
<< call_name << " max argument (type " << max.node_type() << " " << max.type() << ") could not be cast losslessly to " << value.type();
}
casted_max = n_max_val.defined() ? n_max_val : value.type().max();
}
}
} // namespace

Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
user_assert(value.defined()) << "unsafe_promise_clamped with undefined value.\n";
Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
Expr n_min_val, n_max_val;
cast_bounds_for_promise_clamped(value, min, max, n_min_val, n_max_val, "unsafe_promise_clamped");

// Min and max are allowed to be undefined with the meaning of no bound on that side.
return Call::make(value.type(),
Expand All @@ -1602,8 +1623,8 @@ Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max)
namespace Internal {
Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
internal_assert(value.defined()) << "promise_clamped with undefined value.\n";
Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
Expr n_min_val, n_max_val;
cast_bounds_for_promise_clamped(value, min, max, n_min_val, n_max_val, "promise_clamped");

// Min and max are allowed to be undefined with the meaning of no bound on that side.
return Call::make(value.type(),
Expand Down
62 changes: 61 additions & 1 deletion src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "Associativity.h"
#include "Closure.h"
#include "ConstantInterval.h"
#include "Expr.h"
#include "IROperator.h"
#include "Interval.h"
#include "Module.h"
Expand Down Expand Up @@ -48,7 +49,6 @@ ostream &operator<<(ostream &out, const Type &type) {
}
return out;
}

ostream &operator<<(ostream &stream, const Expr &ir) {
if (!ir.defined()) {
stream << "(undefined)";
Expand Down Expand Up @@ -270,6 +270,66 @@ void IRPrinter::test() {
std::cout << "IRPrinter test passed\n";
}

std::ostream &operator<<(std::ostream &stream, IRNodeType type) {
#define CASE(e) \
case IRNodeType::e: \
stream << #e; \
break;
switch (type) {
CASE(IntImm)
CASE(UIntImm)
CASE(FloatImm)
CASE(StringImm)
CASE(Broadcast)
CASE(Cast)
CASE(Reinterpret)
CASE(Variable)
CASE(Add)
CASE(Sub)
CASE(Mod)
CASE(Mul)
CASE(Div)
CASE(Min)
CASE(Max)
CASE(EQ)
CASE(NE)
CASE(LT)
CASE(LE)
CASE(GT)
CASE(GE)
CASE(And)
CASE(Or)
CASE(Not)
CASE(Select)
CASE(Load)
CASE(Ramp)
CASE(Call)
CASE(Let)
CASE(Shuffle)
CASE(VectorReduce)
// Stmts
CASE(LetStmt)
CASE(AssertStmt)
CASE(ProducerConsumer)
CASE(For)
CASE(Acquire)
CASE(Store)
CASE(Provide)
CASE(Allocate)
CASE(Free)
CASE(Realize)
CASE(Block)
CASE(Fork)
CASE(IfThenElse)
CASE(Evaluate)
CASE(Prefetch)
CASE(Atomic)
CASE(HoistedStorage)
}
#undef CASE
return stream;
}

ostream &operator<<(ostream &stream, const AssociativePattern &p) {
stream << "{\n";
for (size_t i = 0; i < p.ops.size(); ++i) {
Expand Down
5 changes: 5 additions & 0 deletions src/IRPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class Closure;
struct Interval;
struct ConstantInterval;
struct ModulusRemainder;
enum class IRNodeType;

/** Emit a halide node type on an output stream (such as std::cout) in
* human-readable form */
std::ostream &operator<<(std::ostream &stream, IRNodeType);

/** Emit a halide associative pattern on an output stream (such as std::cout)
* in a human-readable form */
Expand Down

0 comments on commit 46e866d

Please sign in to comment.