From 02631f67780a9175ac202e5564e25bc6d93393c2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 25 Jan 2019 10:17:31 -0800 Subject: [PATCH] [Relay] Add generic & informative Relay error reporting (#2408) --- .gitmodules | 3 + 3rdparty/rang | 1 + CMakeLists.txt | 1 + include/tvm/relay/error.h | 127 +++++++++++++-- include/tvm/relay/module.h | 19 +++ include/tvm/relay/pass.h | 2 +- include/tvm/relay/type.h | 6 + src/relay/ir/error.cc | 128 +++++++++++++++ src/relay/ir/module.cc | 17 ++ src/relay/op/type_relations.cc | 8 +- src/relay/pass/type_infer.cc | 173 ++++++++++++++------- src/relay/pass/type_solver.cc | 78 ++++++++-- src/relay/pass/type_solver.h | 26 +++- tests/python/relay/test_error_reporting.py | 34 ++++ 14 files changed, 537 insertions(+), 86 deletions(-) create mode 160000 3rdparty/rang create mode 100644 src/relay/ir/error.cc create mode 100644 tests/python/relay/test_error_reporting.py diff --git a/.gitmodules b/.gitmodules index 8011ec12d24b..984326434c3f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "dlpack"] path = 3rdparty/dlpack url = https://github.com/dmlc/dlpack +[submodule "3rdparty/rang"] + path = 3rdparty/rang + url = https://github.com/agauniyal/rang diff --git a/3rdparty/rang b/3rdparty/rang new file mode 160000 index 000000000000..cabe04d6d6b0 --- /dev/null +++ b/3rdparty/rang @@ -0,0 +1 @@ +Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762 diff --git a/CMakeLists.txt b/CMakeLists.txt index 8765a3346069..23dd58a2cd26 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) include_directories("include") include_directories("3rdparty/dlpack/include") include_directories("3rdparty/dmlc-core/include") +include_directories("3rdparty/rang/include") include_directories("3rdparty/compiler-rt") # initial variables diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 1c2b90611bbd..0451a9826cde 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -7,25 +7,134 @@ #define TVM_RELAY_ERROR_H_ #include +#include +#include #include "./base.h" +#include "./expr.h" +#include "./module.h" namespace tvm { namespace relay { -struct Error : public dmlc::Error { - explicit Error(const std::string &msg) : dmlc::Error(msg) {} -}; +#define RELAY_ERROR(msg) (RelayErrorStream() << msg) + +// Forward declaratio for RelayErrorStream. +struct Error; + +/*! \brief A wrapper around std::stringstream. + * + * This is designed to avoid platform specific + * issues compiling and using std::stringstream + * for error reporting. + */ +struct RelayErrorStream { + std::stringstream ss; + + template + RelayErrorStream& operator<<(const T& t) { + ss << t; + return *this; + } -struct InternalError : public Error { - explicit InternalError(const std::string &msg) : Error(msg) {} + std::string str() const { + return ss.str(); + } + + void Raise() const; }; -struct FatalTypeError : public Error { - explicit FatalTypeError(const std::string &s) : Error(s) {} +struct Error : public dmlc::Error { + Span sp; + explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {} + Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) + Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) }; -struct TypecheckerError : public Error { - explicit TypecheckerError(const std::string &msg) : Error(msg) {} +/*! \brief An abstraction around how errors are stored and reported. + * Designed to be opaque to users, so we can support a robust and simpler + * error reporting mode, as well as a more complex mode. + * + * The first mode is the most accurate: we report a Relay error at a specific + * Span, and then render the error message directly against a textual representation + * of the program, highlighting the exact lines in which it occurs. This mode is not + * implemented in this PR and will not work. + * + * The second mode is a general-purpose mode, which attempts to annotate the program's + * textual format with errors. + * + * The final mode represents the old mode, if we report an error that has no span or + * expression, we will default to throwing an exception with a textual representation + * of the error and no indication of where it occured in the original program. + * + * The latter mode is not ideal, and the goal of the new error reporting machinery is + * to avoid ever reporting errors in this style. + */ +class ErrorReporter { + public: + ErrorReporter() : errors_(), node_to_error_() {} + + /*! \brief Report a tvm::relay::Error. + * + * This API is useful for reporting spanned errors. + * + * \param err The error to report. + */ + void Report(const Error& err) { + if (!err.sp.defined()) { + throw err; + } + + this->errors_.push_back(err); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error message to report. + */ + inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { + this->ReportAt(global, node, Error(err)); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error to report. + */ + void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); + + /*! \brief Render all reported errors and exit the program. + * + * This function should be used after executing a pass to render reported errors. + * + * It will build an error message from the set of errors, depending on the error + * reporting strategy. + * + * \param module The module to report errors on. + * \param use_color Controls whether to colorize the output. + */ + void RenderErrors(const Module& module, bool use_color = true); + + inline bool AnyErrors() { + return errors_.size() != 0; + } + + private: + std::vector errors_; + std::unordered_map, NodeHash, NodeEqual> node_to_error_; + std::unordered_map node_to_gv_; }; } // namespace relay diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8d302c09d959..45ccfe3a8089 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -43,11 +43,15 @@ class ModuleNode : public RelayNode { /*! \brief A map from ids to all global functions. */ tvm::Map functions; + /*! \brief The entry function (i.e. "main"). */ + GlobalVar entry_func; + ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); v->Visit("global_var_map_", &global_var_map_); + v->Visit("entry_func", &entry_func); } TVM_DLL static Module make(tvm::Map global_funcs); @@ -111,6 +115,20 @@ class ModuleNode : public RelayNode { */ void Update(const Module& other); + /*! \brief Construct a module from a standalone expression. + * + * Allows one to optionally pass a global function map as + * well. + * + * \param expr The expression to set as the entry point to the module. + * \param global_funcs The global function map. + * + * \returns A module with expr set as the entry point. + */ + static Module FromExpr( + const Expr& expr, + const tvm::Map& global_funcs = {}); + static constexpr const char* _type_key = "relay.Module"; TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); @@ -132,6 +150,7 @@ struct Module : public NodeRef { using ContainerType = ModuleNode; }; + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8527ab7a2cb5..38f6a805f131 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include #include +#include #include #include diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 69a8a4fb0bd7..f3bcf2c0a1d9 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -295,6 +295,12 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; + /*! + * \brief Set the location at which to report unification errors. + * \param ref The program node to report the error. + */ + TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc new file mode 100644 index 000000000000..24f8d1c49b6b --- /dev/null +++ b/src/relay/ir/error.cc @@ -0,0 +1,128 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error_reporter.h + * \brief The set of errors raised by Relay. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +void RelayErrorStream::Raise() const { + throw Error(*this); +} + +template +using NodeMap = std::unordered_map; + +void ErrorReporter::RenderErrors(const Module& module, bool use_color) { + // First we pick an error reporting strategy for each error. + // TODO(@jroesch): Spanned errors are currently not supported. + for (auto err : this->errors_) { + CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; + } + + NodeMap> error_maps; + + // Set control mode in order to produce colors; + if (use_color) { + rang::setControlMode(rang::control::Force); + } + + for (auto pair : this->node_to_gv_) { + auto node = pair.first; + auto global = Downcast(pair.second); + + auto has_errs = this->node_to_error_.find(node); + + CHECK(has_errs != this->node_to_error_.end()); + + const auto& error_indicies = has_errs->second; + + std::stringstream err_msg; + + err_msg << rang::fg::red; + for (auto index : error_indicies) { + err_msg << this->errors_[index].what() << "; "; + } + err_msg << rang::fg::reset; + + // Setup error map. + auto it = error_maps.find(global); + if (it != error_maps.end()) { + it->second.insert({ node, err_msg.str() }); + } else { + error_maps.insert({ global, { { node, err_msg.str() }}}); + } + } + + // Now we will construct the fully-annotated program to display to + // the user. + std::stringstream annotated_prog; + + // First we output a header for the errors. + annotated_prog << + rang::style::bold << std::endl << + "Error(s) have occurred. We have annotated the program with them:" + << std::endl << std::endl << rang::style::reset; + + // For each global function which contains errors, we will + // construct an annotated function. + for (auto pair : error_maps) { + auto global = pair.first; + auto err_map = pair.second; + auto func = module->Lookup(global); + + // We output the name of the function before displaying + // the annotated program. + annotated_prog << + rang::style::bold << + "In `" << global->name_hint << "`: " << + std::endl << + rang::style::reset; + + // We then call into the Relay printer to generate the program. + // + // The annotation callback will annotate the error messages + // contained in the map. + annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { + auto it = err_map.find(expr); + if (it != err_map.end()) { + return it->second; + } else { + return std::string(""); + } + }); + } + + auto msg = annotated_prog.str(); + + if (use_color) { + rang::setControlMode(rang::control::Auto); + } + + // Finally we report the error, currently we do so to LOG(FATAL), + // it may be good to instead report it to std::cout. + LOG(FATAL) << annotated_prog.str() << std::endl; +} + +void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { + size_t index_to_insert = this->errors_.size(); + this->errors_.push_back(err); + auto it = this->node_to_error_.find(node); + if (it != this->node_to_error_.end()) { + it->second.push_back(index_to_insert); + } else { + this->node_to_error_.insert({ node, { index_to_insert }}); + } + this->node_to_gv_.insert({ node, global }); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index cbb0b7768004..9ba5efecec80 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map global_funcs) { << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } + + n->entry_func = GlobalVarNode::make("main"); return Module(n); } @@ -96,6 +98,21 @@ void ModuleNode::Update(const Module& mod) { } } +Module ModuleNode::FromExpr( + const Expr& expr, + const tvm::Map& global_funcs) { + auto mod = ModuleNode::make(global_funcs); + auto func_node = expr.as(); + Function func; + if (func_node) { + func = GetRef(func_node); + } else { + func = FunctionNode::make({}, expr, Type(), {}, {}); + } + mod->Add(mod->entry_func, func); + return mod; +} + TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 467c0fcde860..2618054a663d 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -70,9 +70,12 @@ Type ConcreteBroadcast(const TensorType& t1, } else if (EqualConstInt(s2, 1)) { oshape.push_back(s1); } else { - LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2; + RELAY_ERROR( + "Incompatible broadcast type " + << t1 << " and " << t2).Raise(); } } + size_t max_ndim = std::max(ndim1, ndim2); auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape; for (; i <= max_ndim; ++i) { @@ -92,7 +95,8 @@ bool BroadcastRel(const Array& types, if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype)); + reporter->Assign(types[2], + ConcreteBroadcast(t0, t1, t0->dtype)); return true; } } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index af4cc6607a44..3135715f7691 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -82,10 +82,10 @@ struct ResolvedTypeInfo { class TypeInferencer : private ExprFunctor { public: // constructors - TypeInferencer() { - } - explicit TypeInferencer(Module mod) - : mod_(mod) { + + explicit TypeInferencer(Module mod, GlobalVar current_func) + : mod_(mod), current_func_(current_func), + err_reporter(), solver_(current_func, &this->err_reporter) { } // inference the type of expr. @@ -96,6 +96,13 @@ class TypeInferencer : private ExprFunctor { class Resolver; // internal environment Module mod_; + + // The current function being type checked. + GlobalVar current_func_; + + // The error reporter. + ErrorReporter err_reporter; + // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -109,18 +116,21 @@ class TypeInferencer : private ExprFunctor { // relation function TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; - // Unify two types - Type Unify(const Type& t1, const Type& t2, const Span& span) { + + // Perform unification on two types and report the error at the expression + // or the span of the expression. + Type Unify(const Type& t1, const Type& t2, const Expr& expr) { // TODO(tqchen, jroesch): propagate span to solver try { - return solver_.Unify(t1, t2); + return solver_.Unify(t1, t2, expr); } catch (const dmlc::Error &e) { - LOG(FATAL) - << "Error unifying `" + this->ReportFatalError( + expr, + RELAY_ERROR("Error unifying `" << t1 << "` and `" << t2 - << "`: " << e.what(); + << "`: " << e.what())); return Type(); } } @@ -151,7 +161,7 @@ class TypeInferencer : private ExprFunctor { } // Lazily get type for expr - // will call visit to deduce it if it is not in the type_map_ + // expression, we will populate it now, and return the result. Type GetType(const Expr &expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { @@ -163,7 +173,13 @@ class TypeInferencer : private ExprFunctor { return ret; } - // Visitor logics + void ReportFatalError(const Expr& expr, const Error& err) { + CHECK(this->current_func_.defined()); + this->err_reporter.ReportAt(this->current_func_, expr, err); + this->err_reporter.RenderErrors(this->mod_); + } + + // Visitor Logic Type VisitExpr_(const VarNode* op) final { if (op->type_annotation.defined()) { return op->type_annotation; @@ -174,8 +190,13 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); - CHECK(mod_.defined()) - << "Cannot do type inference without a global variable"; + if (!mod_.defined()) { + this->ReportFatalError( + GetRef(op), + RELAY_ERROR( + "Cannot do type inference on global variables " \ + "without a module")); + } Expr e = mod_->Lookup(var); return e->checked_type(); } @@ -202,7 +223,7 @@ class TypeInferencer : private ExprFunctor { auto attrs = make_node(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( - tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs))); + tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); return rtype; } @@ -210,40 +231,43 @@ class TypeInferencer : private ExprFunctor { return op->op_type; } - Type VisitExpr_(const LetNode* op) final { + Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion - bool is_functional_literal = op->value.as() != nullptr; + bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + type_map_[let->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } - Type vtype = GetType(op->value); - if (op->var->type_annotation.defined()) { - vtype = Unify(vtype, op->var->type_annotation, op->span); + Type vtype = GetType(let->value); + if (let->var->type_annotation.defined()) { + vtype = Unify(vtype, let->var->type_annotation, GetRef(let)); } - CHECK(is_functional_literal || !type_map_.count(op->var)); + CHECK(is_functional_literal || !type_map_.count(let->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = vtype; - return GetType(op->body); + type_map_[let->var].checked_type = vtype; + return GetType(let->body); } - Type VisitExpr_(const IfNode* op) final { + Type VisitExpr_(const IfNode* ite) final { // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. - Type cond_type = this->GetType(op->cond); + Type cond_type = this->GetType(ite->cond); this->Unify(cond_type, TensorTypeNode::Scalar(tvm::Bool()), - op->cond->span); - Type checked_true = this->GetType(op->true_branch); - Type checked_false = this->GetType(op->false_branch); - return this->Unify(checked_true, checked_false, op->span); + ite->cond); + Type checked_true = this->GetType(ite->true_branch); + Type checked_false = this->GetType(ite->false_branch); + return this->Unify(checked_true, checked_false, GetRef(ite)); } - // Handle special case basic primitive operator, - // if successful return the return type + // This code is special-cased for primitive operators, + // which are registered in the style defined in src/relay/op/*. + // + // The result will be the return type of the operator. Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, - const Attrs& attrs) { + const Attrs& attrs, + const NodeRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); const TypeRelationNode* rel = op->type_constraints[0].as(); @@ -256,7 +280,7 @@ class TypeInferencer : private ExprFunctor { arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelationNode::make( - rel->func, arg_types, arg_types.size() - 1, attrs)); + rel->func, arg_types, arg_types.size() - 1, attrs), loc); return rtype; } @@ -304,16 +328,19 @@ class TypeInferencer : private ExprFunctor { auto* fn_ty_node = ftype.as(); auto* inc_ty_node = ftype.as(); - CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; + if (fn_ty_node == nullptr && inc_ty_node == nullptr) { + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("only expressions with function types can be called, found " + << ftype)); + } // incomplete type => it must be a function taking the arg types // with an unknown return type if (inc_ty_node != nullptr) { Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); - Type unified = this->Unify(ftype, func_type, call->span); + Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); } @@ -323,10 +350,16 @@ class TypeInferencer : private ExprFunctor { type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); } } - CHECK(type_args.size() == fn_ty_node->type_params.size()) - << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() - << "but got " << type_args.size(); + + if (type_args.size() != fn_ty_node->type_params.size()) { + this->ReportFatalError(GetRef(call), + RELAY_ERROR("Incorrect number of type args in " + << call->span << ": " + << "Expected " + << fn_ty_node->type_params.size() + << "but got " << type_args.size())); + } + FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); AddTypeArgs(GetRef(call), type_args); @@ -336,22 +369,29 @@ class TypeInferencer : private ExprFunctor { if (type_arity != number_of_args) { if (type_arity < number_of_args) { - LOG(FATAL) << "the function is provided too many arguments " << call->span; + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args)); } else { - LOG(FATAL) << "the function is provided too few arguments" << call->span; + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args)); } } for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]); } for (auto cs : fn_ty->type_constraints) { if (auto tr = cs.as()) { solver_.AddConstraint( - TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs)); + TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs), + GetRef(call)); } else { - solver_.AddConstraint(cs); + solver_.AddConstraint(cs, GetRef(call)); } } @@ -367,7 +407,8 @@ class TypeInferencer : private ExprFunctor { if (const OpNode* opnode = call->op.as()) { Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, - call->attrs); + call->attrs, + GetRef(call)); if (rtype.defined()) { AddTypeArgs(GetRef(call), arg_types); return rtype; @@ -385,7 +426,7 @@ class TypeInferencer : private ExprFunctor { } Type rtype = GetType(f->body); if (f->ret_type.defined()) { - rtype = this->Unify(f->ret_type, rtype, f->span); + rtype = this->Unify(f->ret_type, rtype, GetRef(f)); } auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); @@ -445,6 +486,9 @@ class TypeInferencer::Resolver : public ExprMutator { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); + + // TODO(@jroesch): it would be nice if we would report resolution + // errors directly on the program. CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -542,6 +586,10 @@ Expr TypeInferencer::Infer(Expr expr) { // Step 1: Solve the constraints. solver_.Solve(); + if (err_reporter.AnyErrors()) { + err_reporter.RenderErrors(mod_); + } + // Step 2: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); CHECK(WellFormed(resolved_expr)); @@ -549,10 +597,27 @@ Expr TypeInferencer::Infer(Expr expr) { } -Expr InferType(const Expr& expr, const Module& mod) { - auto e = TypeInferencer(mod).Infer(expr); - CHECK(WellFormed(e)); - return e; +Expr InferType(const Expr& expr, const Module& mod_ref) { + if (!mod_ref.defined()) { + Module mod = ModuleNode::FromExpr(expr); + // NB(@jroesch): By adding the expression to the module we will + // type check it anyway; afterwards we can just recover type + // from the type-checked function to avoid doing unnecessary work. + + Function func = mod->Lookup(mod->entry_func); + + // FromExpr wraps a naked expression as a function, we will unbox + // it here. + if (expr.as()) { + return func; + } else { + return func->body; + } + } else { + auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); + CHECK(WellFormed(e)); + return e; + } } Function InferType(const Function& func, @@ -561,7 +626,7 @@ Function InferType(const Function& func, Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); - Expr func_ret = TypeInferencer(mod).Infer(func_copy); + Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); mod->Remove(var); CHECK(WellFormed(func_ret)); return Downcast(func_ret); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index caea3755b8f9..dafcaf56015a 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -16,7 +16,7 @@ class TypeSolver::Reporter : public TypeReporterNode { : solver_(solver) {} void Assign(const Type& dst, const Type& src) final { - solver_->Unify(dst, src); + solver_->Unify(dst, src, location); } bool Assert(const IndexExpr& cond) final { @@ -35,7 +35,14 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } + TVM_DLL void SetLocation(const NodeRef& ref) final { + location = ref; + } + private: + /*! \brief The location to report unification errors at. */ + mutable NodeRef location; + TypeSolver* solver_; }; @@ -329,8 +336,10 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver() - : reporter_(make_node(this)) { +TypeSolver::TypeSolver(const GlobalVar ¤t_func, ErrorReporter* err_reporter) + : reporter_(make_node(this)), + current_func(current_func), + err_reporter_(err_reporter) { } // destructor @@ -351,16 +360,26 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { } // Add equality constraint -Type TypeSolver::Unify(const Type& dst, const Type& src) { +Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) { + // NB(@jroesch): we should probably pass location into the unifier to do better + // error reporting as well. Unifier unifier(this); return unifier.Unify(dst, src); } +void TypeSolver::ReportError(const Error& err, const NodeRef& location) { + this->err_reporter_->ReportAt( + this->current_func, + location, + err); + } + // Add type constraint to the solver. -void TypeSolver::AddConstraint(const TypeConstraint& constraint) { +void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { if (auto *op = constraint.as()) { // create a new relation node. RelationNode* rnode = arena_.make(); + rnode->location = loc; rnode->rel = GetRef(op); rel_nodes_.push_back(rnode); // populate the type information. @@ -404,29 +423,52 @@ bool TypeSolver::Solve() { args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); CHECK_LE(args.size(), rel->args.size()); } - // call the function - bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); - // mark inqueue as false after the function call - // so that rnode itself won't get enqueued again. - rnode->inqueue = false; - if (resolved) { - ++num_resolved_rels_; + CHECK(rnode->location.defined()) + << "undefined location, should be set when constructing relation node"; + + // We need to set this in order to understand where unification + // errors generated by the error reporting are coming from. + reporter_->SetLocation(rnode->location); + + try { + // Call the Type Relation's function. + bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); + + if (resolved) { + ++num_resolved_rels_; + } + + rnode->resolved = resolved; + } catch (const Error& err) { + this->ReportError(err, rnode->location); + rnode->resolved = false; + } catch (const dmlc::Error& err) { + rnode->resolved = false; + this->ReportError( + RELAY_ERROR( + "an internal invariant was violdated while" \ + "typechecking your program" << + err.what()), rnode->location); } - rnode->resolved = resolved; + + // Mark inqueue as false after the function call + // so that rnode itself won't get enqueued again. + rnode->inqueue = false; } + // This criterion is not necessarily right for all the possible cases // TODO(tqchen): We should also count the number of in-complete types. return num_resolved_rels_ == rel_nodes_.size(); } - // Expose type solver only for debugging purposes. TVM_REGISTER_API("relay._ir_pass._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; - auto solver = std::make_shared(); + ErrorReporter err_reporter; + auto solver = std::make_shared(GlobalVarNode::make("test"), &err_reporter); auto mod = [solver](std::string name) -> PackedFunc { if (name == "Solve") { @@ -435,7 +477,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") }); } else if (name == "Unify") { return TypedPackedFunc([solver](Type lhs, Type rhs) { - return solver->Unify(lhs, rhs); + return solver->Unify(lhs, rhs, lhs); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { @@ -443,7 +485,9 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") }); } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { - return solver->AddConstraint(c); + Expr e = VarNode::make("dummy_var", + IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + return solver->AddConstraint(c, e); }); } else { return PackedFunc(); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index b4635fdec331..b56d45c3b685 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -6,8 +6,10 @@ #ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_ #define TVM_RELAY_PASS_TYPE_SOLVER_H_ +#include #include #include +#include #include #include #include "../../common/arena.h" @@ -40,13 +42,14 @@ using common::LinkedList; */ class TypeSolver { public: - TypeSolver(); + TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter); ~TypeSolver(); /*! * \brief Add a type constraint to the solver. * \param constraint The constraint to be added. + * \param location The location at which the constraint was incurred. */ - void AddConstraint(const TypeConstraint& constraint); + void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation); /*! * \brief Resolve type to the solution type in the solver. * \param type The type to be resolved. @@ -62,8 +65,16 @@ class TypeSolver { * \brief Unify lhs and rhs. * \param lhs The left operand. * \param rhs The right operand + * \param location The location at which the unification problem arose. */ - Type Unify(const Type& lhs, const Type& rhs); + Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); + + /*! + * \brief Report an error at the provided location. + * \param err The error to report. + * \param loc The location at which to report the error. + */ + void ReportError(const Error& err, const NodeRef& location); private: class OccursChecker; @@ -112,6 +123,7 @@ class TypeSolver { return root; } }; + /*! \brief relation node */ struct RelationNode { /*! \brief Whether the relation is in the queue to be solved */ @@ -122,7 +134,10 @@ class TypeSolver { TypeRelation rel; /*! \brief list types to this relation */ LinkedList type_list; + /*! \brief The location this type relation originated from. */ + NodeRef location; }; + /*! \brief List of all allocated type nodes */ std::vector type_nodes_; /*! \brief List of all allocated relation nodes */ @@ -137,6 +152,11 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; + /*! \brief The global representing the current function. */ + GlobalVar current_func; + /*! \brief Error reporting. */ + ErrorReporter* err_reporter_; + /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py new file mode 100644 index 000000000000..1720af21afea --- /dev/null +++ b/tests/python/relay/test_error_reporting.py @@ -0,0 +1,34 @@ +import tvm +from tvm import relay + +def check_type_err(expr, msg): + try: + expr = relay.ir_pass.infer_type(expr) + assert False + except tvm.TVMError as err: + assert msg in str(err) + +def test_too_many_args(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x) + y = relay.var('y', shape=(10, 10)) + check_type_err( + f(x, y), + "the function is provided too many arguments expected 1, found 2;") + +def test_too_few_args(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(10, 10)) + f = relay.Function([x, y], x) + check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;") + +def test_rel_fail(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(11, 10)) + f = relay.Function([x, y], x + y) + check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);") + +if __name__ == "__main__": + test_too_many_args() + test_too_few_args() + test_rel_fail()