Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report useful error to user if the promise_clamp all fails to losslessly cast. #8238

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1587,10 +1587,31 @@ Tuple mux(const Expr &id, const std::vector<Tuple> &values) {
return Tuple{result};
}

namespace {
void cast_bounds_for_promise_clamped(const Expr &value, const Expr &min, const Expr &max, Expr &casted_min, Expr &casted_max, const char *call_name) {
{
Expr n_min_val = lossless_cast(value.type(), min);
if (min.defined()) {
user_assert(n_min_val.defined())
<< call_name << " min argument (type " << min.node_type() << " " << min.type() << ") could not be cast losslessly to " << value.type();
}
casted_min = n_min_val.defined() ? n_min_val : value.type().min();
}
{
Expr n_max_val = lossless_cast(value.type(), max);
if (max.defined()) {
user_assert(n_max_val.defined())
<< call_name << " max argument (type " << max.node_type() << " " << max.type() << ") could not be cast losslessly to " << value.type();
}
casted_max = n_max_val.defined() ? n_max_val : value.type().max();
}
}
} // namespace

Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
user_assert(value.defined()) << "unsafe_promise_clamped with undefined value.\n";
Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
Expr n_min_val, n_max_val;
cast_bounds_for_promise_clamped(value, min, max, n_min_val, n_max_val, "unsafe_promise_clamped");

// Min and max are allowed to be undefined with the meaning of no bound on that side.
return Call::make(value.type(),
Expand All @@ -1602,8 +1623,8 @@ Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max)
namespace Internal {
Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
internal_assert(value.defined()) << "promise_clamped with undefined value.\n";
Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
Expr n_min_val, n_max_val;
cast_bounds_for_promise_clamped(value, min, max, n_min_val, n_max_val, "promise_clamped");

// Min and max are allowed to be undefined with the meaning of no bound on that side.
return Call::make(value.type(),
Expand Down
62 changes: 61 additions & 1 deletion src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "Associativity.h"
#include "Closure.h"
#include "ConstantInterval.h"
#include "Expr.h"
#include "IROperator.h"
#include "Interval.h"
#include "Module.h"
Expand Down Expand Up @@ -48,7 +49,6 @@ ostream &operator<<(ostream &out, const Type &type) {
}
return out;
}

ostream &operator<<(ostream &stream, const Expr &ir) {
if (!ir.defined()) {
stream << "(undefined)";
Expand Down Expand Up @@ -270,6 +270,66 @@ void IRPrinter::test() {
std::cout << "IRPrinter test passed\n";
}

std::ostream &operator<<(std::ostream &stream, IRNodeType type) {
#define CASE(e) \
case IRNodeType::e: \
stream << #e; \
break;
switch (type) {
CASE(IntImm)
CASE(UIntImm)
CASE(FloatImm)
CASE(StringImm)
CASE(Broadcast)
CASE(Cast)
CASE(Reinterpret)
CASE(Variable)
CASE(Add)
CASE(Sub)
CASE(Mod)
CASE(Mul)
CASE(Div)
CASE(Min)
CASE(Max)
CASE(EQ)
CASE(NE)
CASE(LT)
CASE(LE)
CASE(GT)
CASE(GE)
CASE(And)
CASE(Or)
CASE(Not)
CASE(Select)
CASE(Load)
CASE(Ramp)
CASE(Call)
CASE(Let)
CASE(Shuffle)
CASE(VectorReduce)
// Stmts
CASE(LetStmt)
CASE(AssertStmt)
CASE(ProducerConsumer)
CASE(For)
CASE(Acquire)
CASE(Store)
CASE(Provide)
CASE(Allocate)
CASE(Free)
CASE(Realize)
CASE(Block)
CASE(Fork)
CASE(IfThenElse)
CASE(Evaluate)
CASE(Prefetch)
CASE(Atomic)
CASE(HoistedStorage)
}
#undef CASE
return stream;
}

ostream &operator<<(ostream &stream, const AssociativePattern &p) {
stream << "{\n";
for (size_t i = 0; i < p.ops.size(); ++i) {
Expand Down
5 changes: 5 additions & 0 deletions src/IRPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class Closure;
struct Interval;
struct ConstantInterval;
struct ModulusRemainder;
enum class IRNodeType;

/** Emit a halide node type on an output stream (such as std::cout) in
* human-readable form */
std::ostream &operator<<(std::ostream &stream, IRNodeType);

/** Emit a halide associative pattern on an output stream (such as std::cout)
* in a human-readable form */
Expand Down
Loading