Skip to content

Commit

Permalink
Add support for rendering type relation errors in new style
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Jan 18, 2019
1 parent 230bf3e commit e1ec636
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 28 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/error_reporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class ErrorReporter {
*/
[[noreturn]] void RenderErrors(const Module& module, bool use_color = true);

inline bool AnyErrors() {
return errors_.size() != 0;
}

private:
std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>

#include "base.h"
#include "error.h"
#include "../attrs.h"

namespace tvm {
Expand Down Expand Up @@ -295,9 +296,14 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;

TVM_DLL virtual void ReportError(const Error& err) = 0;

// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}

// Not sure if best design, if not we should recreate a reporter for each relation.
mutable NodeRef location;

static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
};
Expand Down
14 changes: 10 additions & 4 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) {

Type ConcreteBroadcast(const TensorType& t1,
const TensorType& t2,
DataType output_dtype) {
DataType output_dtype,
const TypeReporter& reporter) {
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
Expand All @@ -70,9 +71,13 @@ Type ConcreteBroadcast(const TensorType& t1,
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
} else {
LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2;
reporter->ReportError(
RELAY_ERROR(
"Incompatible broadcast type "
<< t1 << " and " << t2));
}
}

size_t max_ndim = std::max(ndim1, ndim2);
auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape;
for (; i <= max_ndim; ++i) {
Expand All @@ -92,7 +97,8 @@ bool BroadcastRel(const Array<Type>& types,
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype));
reporter->Assign(types[2],
ConcreteBroadcast(t0, t1, t0->dtype, reporter));
return true;
}
}
Expand All @@ -109,7 +115,7 @@ bool BroadcastCompRel(const Array<Type>& types,
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool()));
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool(), reporter));
return true;
}
}
Expand Down
27 changes: 17 additions & 10 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ struct ResolvedTypeInfo {
class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
public:
// constructors
TypeInferencer() {
}

explicit TypeInferencer(Module mod, GlobalVar current_func)
: mod_(mod), current_func_(current_func), err_reporter() {
: mod_(mod), current_func_(current_func),
err_reporter(), solver_(current_func, &this->err_reporter) {
}

// inference the type of expr.
Expand Down Expand Up @@ -123,7 +123,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type Unify(const Type& t1, const Type& t2, const Expr& expr) {
// TODO(tqchen, jroesch): propagate span to solver
try {
return solver_.Unify(t1, t2);
return solver_.Unify(t1, t2, expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
Expand Down Expand Up @@ -224,7 +224,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
auto attrs = make_node<TupleGetItemAttrs>();
attrs->index = op->index;
solver_.AddConstraint(TypeRelationNode::make(
tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)));
tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef<TupleGetItem>(op));
return rtype;
}

Expand Down Expand Up @@ -267,7 +267,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// The result will be the return type of the operator.
Type PrimitiveCall(const FuncTypeNode* op,
Array<Type> arg_types,
const Attrs& attrs) {
const Attrs& attrs,
const NodeRef& loc) {
if (op->type_params.size() != arg_types.size() + 1) return Type();
if (op->type_constraints.size() != 1) return Type();
const TypeRelationNode* rel = op->type_constraints[0].as<TypeRelationNode>();
Expand All @@ -280,7 +281,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
arg_types.push_back(rtype);
// we can do simple replacement here
solver_.AddConstraint(TypeRelationNode::make(
rel->func, arg_types, arg_types.size() - 1, attrs));
rel->func, arg_types, arg_types.size() - 1, attrs), loc);
return rtype;
}

Expand Down Expand Up @@ -388,9 +389,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
for (auto cs : fn_ty->type_constraints) {
if (auto tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call));
} else {
solver_.AddConstraint(cs);
solver_.AddConstraint(cs, GetRef<Call>(call));
}
}

Expand All @@ -406,7 +408,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (const OpNode* opnode = call->op.as<OpNode>()) {
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
arg_types,
call->attrs);
call->attrs,
GetRef<Call>(call));
if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
return rtype;
Expand Down Expand Up @@ -584,6 +587,10 @@ Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Solve the constraints.
solver_.Solve();

if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}

// Step 2: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr));
Expand Down
48 changes: 39 additions & 9 deletions src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TypeSolver::Reporter : public TypeReporterNode {
: solver_(solver) {}

void Assign(const Type& dst, const Type& src) final {
solver_->Unify(dst, src);
solver_->Unify(dst, src, location);
}

bool Assert(const IndexExpr& cond) final {
Expand All @@ -35,6 +35,27 @@ class TypeSolver::Reporter : public TypeReporterNode {
return true;
}

void ReportError(const Error& err) final {
std::cout
<< "Current Function: "
<< solver_->current_func
<< std::endl;

std::cout
<< "Location: "
<< location
<< std::endl;

std::cout
<< err.what()
<< std::endl;

solver_->err_reporter_->ReportAt(
solver_->current_func,
location,
err);
}

private:
TypeSolver* solver_;
};
Expand Down Expand Up @@ -329,8 +350,10 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
};

// constructor
TypeSolver::TypeSolver()
: reporter_(make_node<Reporter>(this)) {
TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)),
current_func(current_func),
err_reporter_(err_reporter) {
}

// destructor
Expand All @@ -351,16 +374,19 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
}

// Add equality constraint
Type TypeSolver::Unify(const Type& dst, const Type& src) {
Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
// NB(@jroesch): we should probably pass location into the unifier to do better
// error reporting as well.
Unifier unifier(this);
return unifier.Unify(dst, src);
}

// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = arena_.make<RelationNode>();
rnode->location = loc;
rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode);
// populate the type information.
Expand Down Expand Up @@ -404,6 +430,10 @@ bool TypeSolver::Solve() {
args.push_back(Resolve(tlink->value->FindRoot()->resolved_type));
CHECK_LE(args.size(), rel->args.size());
}

CHECK(rnode->location.defined()) << "undefined location";

reporter_->location = rnode->location;
// call the function
bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
// mark inqueue as false after the function call
Expand All @@ -420,13 +450,13 @@ bool TypeSolver::Solve() {
return num_resolved_rels_ == rel_nodes_.size();
}


// Expose type solver only for debugging purposes.
TVM_REGISTER_API("relay._ir_pass._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto solver = std::make_shared<TypeSolver>();
ErrorReporter err_reporter;
auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), &err_reporter);

auto mod = [solver](std::string name) -> PackedFunc {
if (name == "Solve") {
Expand All @@ -435,15 +465,15 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
});
} else if (name == "Unify") {
return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) {
return solver->Unify(lhs, rhs);
return solver->Unify(lhs, rhs, NodeRef());
});
} else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) {
return solver->Resolve(t);
});
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
return solver->AddConstraint(c);
return solver->AddConstraint(c, NodeRef());
});
} else {
return PackedFunc();
Expand Down
17 changes: 14 additions & 3 deletions src/relay/pass/type_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_
#define TVM_RELAY_PASS_TYPE_SOLVER_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/error_reporter.h>
#include <vector>
#include <queue>
#include "../../common/arena.h"
Expand Down Expand Up @@ -40,13 +42,13 @@ using common::LinkedList;
*/
class TypeSolver {
public:
TypeSolver();
TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter);
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
* \param constraint The constraint to be added.
*/
void AddConstraint(const TypeConstraint& constraint);
void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation);
/*!
* \brief Resolve type to the solution type in the solver.
* \param type The type to be resolved.
Expand All @@ -63,7 +65,7 @@ class TypeSolver {
* \param lhs The left operand.
* \param rhs The right operand
*/
Type Unify(const Type& lhs, const Type& rhs);
Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);

private:
class OccursChecker;
Expand Down Expand Up @@ -112,6 +114,7 @@ class TypeSolver {
return root;
}
};

/*! \brief relation node */
struct RelationNode {
/*! \brief Whether the relation is in the queue to be solved */
Expand All @@ -122,7 +125,10 @@ class TypeSolver {
TypeRelation rel;
/*! \brief list types to this relation */
LinkedList<TypeNode*> type_list;
/*! \brief The location this type relation originated from. */
NodeRef location;
};

/*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */
Expand All @@ -137,6 +143,11 @@ class TypeSolver {
common::Arena arena_;
/*! \brief Reporter that reports back to self */
TypeReporter reporter_;
/*! \brief The global representing the current function. */
GlobalVar current_func;
/*! \brief Error reporting. */
ErrorReporter* err_reporter_;

/*!
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
Expand Down
11 changes: 9 additions & 2 deletions tests/python/relay/test_error_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def test_too_few_args():
f = relay.Function([x, y], x)
check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;")

def test_rel_fail():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(11, 10))
f = relay.Function([x, y], x + y)
check_type_err(f(x, y), "the function is provided too few arguments expected 2, found 1;")

if __name__ == "__main__":
test_too_many_args()
test_too_few_args()
# test_too_many_args()
# test_too_few_args()
test_rel_fail()

0 comments on commit e1ec636

Please sign in to comment.