From e44b4dbc931750eb5acfae842a0490500f9d96cc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 6 Jan 2019 17:17:25 -0800 Subject: [PATCH 01/13] Add descriptive whole module errors to Relay This PR adds descriptive errors to Relay. It is now possible to annotate and visulize errors annotated at certain expressions in a relay::Module. Extract module changes from other branch Fix me Get first version working Improve color support Update error rendering, and comment about rang license Remove uncessary type checking from InferType Update more error reporting to use new machinery Update more error reporting Hacking Negate checks Remove debugging Address MK's comments Fix CPP lint Disable linting for rang Document new relay::Module APIs Add another test case Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update src/relay/ir/module.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Apply suggestions from code review Co-Authored-By: jroesch Fix comment in type_infer.cc Remove copy of relay::Module Fix lambda capture Refactor to remove error reporter from module Update include/tvm/relay/module.h Co-Authored-By: jroesch Fix small issue Fix doc comment Fix doc comment typo Fix build error Fix implicit ctor not working Fix macro Remove duplicate macro Try and fix issue with stringstream Fix old type inference invariant WIP repairing from rebase Rebase fixup Another fix --- .gitmodules | 3 + 3rdparty/rang | 1 + CMakeLists.txt | 1 + include/tvm/relay/error.h | 29 +++- include/tvm/relay/error_reporter.h | 104 ++++++++++++++ include/tvm/relay/module.h | 26 ++++ src/relay/ir/error_reporter.cc | 126 +++++++++++++++++ src/relay/ir/module.cc | 21 +++ src/relay/pass/type_infer.cc | 155 +++++++++++++++------ src/relay/util/CPPLINT.cfg | 1 + tests/python/relay/test_error_reporting.py | 27 ++++ 11 files changed, 448 insertions(+), 46 deletions(-) create mode 160000 3rdparty/rang create mode 100644 include/tvm/relay/error_reporter.h create mode 100644 src/relay/ir/error_reporter.cc create mode 100644 src/relay/util/CPPLINT.cfg create mode 100644 tests/python/relay/test_error_reporting.py diff --git a/.gitmodules b/.gitmodules index 8011ec12d24b..984326434c3f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "dlpack"] path = 3rdparty/dlpack url = https://github.com/dmlc/dlpack +[submodule "3rdparty/rang"] + path = 3rdparty/rang + url = https://github.com/agauniyal/rang diff --git a/3rdparty/rang b/3rdparty/rang new file mode 160000 index 000000000000..cabe04d6d6b0 --- /dev/null +++ b/3rdparty/rang @@ -0,0 +1 @@ +Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762 diff --git a/CMakeLists.txt b/CMakeLists.txt index 8765a3346069..23dd58a2cd26 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) include_directories("include") include_directories("3rdparty/dlpack/include") include_directories("3rdparty/dmlc-core/include") +include_directories("3rdparty/rang/include") include_directories("3rdparty/compiler-rt") # initial variables diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 1c2b90611bbd..d367ac27f0b1 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -7,13 +7,40 @@ #define TVM_RELAY_ERROR_H_ #include +#include +#include #include "./base.h" namespace tvm { namespace relay { +/*! \brief A wrapper around std::stringstream. + * + * This is designed to avoid platform specific + * issues compiling and using std::stringstream + * for error reporting. + */ +struct StringStream { + std::stringstream ss; + + template + StringStream& operator<<(const T& t) { + ss << t; + return *this; + } + + std::string str() const { + return ss.str(); + } +}; + +#define RELAY_ERROR(msg) (StringStream() << msg) + struct Error : public dmlc::Error { - explicit Error(const std::string &msg) : dmlc::Error(msg) {} + Span sp; + explicit Error(const std::string &msg) : dmlc::Error(msg), sp() {} + Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) + Error(const StringStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) }; struct InternalError : public Error { diff --git a/include/tvm/relay/error_reporter.h b/include/tvm/relay/error_reporter.h new file mode 100644 index 000000000000..e34d24b8b221 --- /dev/null +++ b/include/tvm/relay/error_reporter.h @@ -0,0 +1,104 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error_reporter.h + * \brief The set of errors raised by Relay. + */ +#ifndef TVM_RELAY_ERROR_REPORTER_H_ +#define TVM_RELAY_ERROR_REPORTER_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief An abstraction around how errors are stored and reported. + * Designed to be opaque to users, so we can support a robust and simpler + * error reporting mode, as well as a more complex mode. + * + * The first mode is the most accurate: we report a Relay error at a specific + * Span, and then render the error message directly against a textual representation + * of the program, highlighting the exact lines in which it occurs. This mode is not + * implemented in this PR and will not work. + * + * The second mode is a general-purpose mode, which attempts to annotate the program's + * textual format with errors. + * + * The final mode represents the old mode, if we report an error that has no span or + * expression, we will default to throwing an exception with a textual representation + * of the error and no indication of where it occured in the original program. + * + * The latter mode is not ideal, and the goal of the new error reporting machinery is + * to avoid ever reporting errors in this style. + */ +class ErrorReporter { + public: + ErrorReporter() : errors_(), node_to_error_() {} + + /*! \brief Report a tvm::relay::Error. + * + * This API is useful for reporting spanned errors. + * + * \param err The error to report. + */ + void Report(const Error& err) { + if (!err.sp.defined()) { + throw err; + } + + this->errors_.push_back(err); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error message to report. + */ + inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { + this->ReportAt(global, node, Error(err)); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error to report. + */ + void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); + + /*! \brief Render all reported errors and exit the program. + * + * This function should be used after executing a pass to render reported errors. + * + * It will build an error message from the set of errors, depending on the error + * reporting strategy. + * + * \param module The module to report errors on. + * \param use_color Controls whether to colorize the output. + */ + [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); + + private: + std::vector errors_; + std::unordered_map, NodeHash, NodeEqual> node_to_error_; + std::unordered_map node_to_gv_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ERROR_REPORTER_H_ diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8d302c09d959..8585323d1628 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -43,11 +43,15 @@ class ModuleNode : public RelayNode { /*! \brief A map from ids to all global functions. */ tvm::Map functions; + /*! \brief The entry function (i.e. "main"). */ + GlobalVar entry_func; + ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); v->Visit("global_var_map_", &global_var_map_); + v->Visit("entry_func", &entry_func); } TVM_DLL static Module make(tvm::Map global_funcs); @@ -111,6 +115,27 @@ class ModuleNode : public RelayNode { */ void Update(const Module& other); + /*! + * \brief Get the entry point of the module. + * + * \returns The entry point function, (i.e. main). + */ + Expr EntryPoint(); + + /*! \brief Construct a module from a standalone expression. + * + * Allows one to optionally pass a global function map as + * well. + * + * \param expr The expression to set as the entry point to the module. + * \param global_funcs The global function map. + * + * \returns A module with expr set as the entry point. + */ + static Module FromExpr( + const Expr& expr, + const tvm::Map& global_funcs = {}); + static constexpr const char* _type_key = "relay.Module"; TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); @@ -132,6 +157,7 @@ struct Module : public NodeRef { using ContainerType = ModuleNode; }; + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/error_reporter.cc b/src/relay/ir/error_reporter.cc new file mode 100644 index 000000000000..b50180699960 --- /dev/null +++ b/src/relay/ir/error_reporter.cc @@ -0,0 +1,126 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error_reporter.h + * \brief The set of errors raised by Relay. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +template +using NodeMap = std::unordered_map; + +[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) { + // First we pick an error reporting strategy for each error. + // TODO(@jroesch): Spanned errors are currently not supported. + for (auto err : this->errors_) { + CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; + } + + NodeMap> error_maps; + + // Set control mode in order to produce colors; + if (use_color) { + rang::setControlMode(rang::control::Force); + } + + for (auto pair : this->node_to_gv_) { + auto node = pair.first; + auto global = Downcast(pair.second); + + auto has_errs = this->node_to_error_.find(node); + + CHECK(has_errs != this->node_to_error_.end()); + + const auto& error_indicies = has_errs->second; + + std::stringstream err_msg; + + err_msg << rang::fg::red; + for (auto index : error_indicies) { + err_msg << this->errors_[index].what() << "; "; + } + err_msg << rang::fg::reset; + + // Setup error map. + auto it = error_maps.find(global); + if (it != error_maps.end()) { + it->second.insert({ node, err_msg.str() }); + } else { + error_maps.insert({ global, { { node, err_msg.str() }}}); + } + } + + // Now we will construct the fully-annotated program to display to + // the user. + std::stringstream annotated_prog; + + // First we output a header for the errors. + annotated_prog << + rang::style::bold << std::endl << + "Error(s) have occurred. We have annotated the program with them:" + << std::endl << std::endl << rang::style::reset; + + // For each global function which contains errors, we will + // construct an annotated function. + for (auto pair : error_maps) { + auto global = pair.first; + auto err_map = pair.second; + auto func = module->Lookup(global); + + // We output the name of the function before displaying + // the annotated program. + annotated_prog << + rang::style::bold << + "In `" << global->name_hint << "`: " << + std::endl << + rang::style::reset; + + // We then call into the Relay printer to generate the program. + // + // The annotation callback will annotate the error messages + // contained in the map. + annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { + auto it = err_map.find(expr); + if (it != err_map.end()) { + return it->second; + } else { + return std::string(""); + } + }); + } + + auto msg = annotated_prog.str(); + + if (use_color) { + rang::setControlMode(rang::control::Auto); + } + + // Finally we report the error, currently we do so to LOG(FATAL), + // it may be good to instead report it to std::cout. + LOG(FATAL) << annotated_prog.str() << std::endl; + + exit(1); +} + +void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { + size_t index_to_insert = this->errors_.size(); + this->errors_.push_back(err); + auto it = this->node_to_error_.find(node); + if (it != this->node_to_error_.end()) { + it->second.push_back(index_to_insert); + } else { + this->node_to_error_.insert({ node, { index_to_insert }}); + } + this->node_to_gv_.insert({ node, global }); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index cbb0b7768004..02318658e427 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map global_funcs) { << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } + + n->entry_func = GlobalVarNode::make("main"); return Module(n); } @@ -96,6 +98,25 @@ void ModuleNode::Update(const Module& mod) { } } +Expr ModuleNode::EntryPoint() { + return this->Lookup(this->entry_func); +} + +Module ModuleNode::FromExpr( + const Expr& expr, + const tvm::Map& global_funcs) { + auto mod = ModuleNode::make({}); + auto func_node = expr.as(); + Function func; + if (func_node) { + func = GetRef(func_node); + } else { + func = FunctionNode::make({}, expr, Type(), {}, {}); + } + mod->Add(mod->entry_func, func); + return mod; +} + TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index af4cc6607a44..4e9d27f92fca 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -21,6 +21,7 @@ */ #include +#include #include #include #include "type_solver.h" @@ -84,8 +85,8 @@ class TypeInferencer : private ExprFunctor { // constructors TypeInferencer() { } - explicit TypeInferencer(Module mod) - : mod_(mod) { + explicit TypeInferencer(Module mod, GlobalVar current_func) + : mod_(mod), current_func_(current_func), err_reporter() { } // inference the type of expr. @@ -96,6 +97,13 @@ class TypeInferencer : private ExprFunctor { class Resolver; // internal environment Module mod_; + + // The current function being type checked. + GlobalVar current_func_; + + // The error reporter. + ErrorReporter err_reporter; + // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -109,18 +117,21 @@ class TypeInferencer : private ExprFunctor { // relation function TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; - // Unify two types - Type Unify(const Type& t1, const Type& t2, const Span& span) { + + // Perform unification on two types and report the error at the expression + // or the span of the expression. + Type Unify(const Type& t1, const Type& t2, const Expr& expr) { // TODO(tqchen, jroesch): propagate span to solver try { return solver_.Unify(t1, t2); } catch (const dmlc::Error &e) { - LOG(FATAL) - << "Error unifying `" + this->ReportFatalError( + expr, + RELAY_ERROR("Error unifying `" << t1 << "` and `" << t2 - << "`: " << e.what(); + << "`: " << e.what())); return Type(); } } @@ -151,7 +162,7 @@ class TypeInferencer : private ExprFunctor { } // Lazily get type for expr - // will call visit to deduce it if it is not in the type_map_ + // expression, we will populate it now, and return the result. Type GetType(const Expr &expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { @@ -163,7 +174,19 @@ class TypeInferencer : private ExprFunctor { return ret; } - // Visitor logics + [[noreturn]] void ReportFatalError(const Expr& expr, const std::stringstream& err) { + CHECK(this->current_func_.defined()); + this->err_reporter.ReportAt(this->current_func_, expr, Error(err)); + this->err_reporter.RenderErrors(this->mod_); + } + + [[noreturn]] void ReportFatalError(const Expr& expr, const Error& err) { + CHECK(this->current_func_.defined()); + this->err_reporter.ReportAt(this->current_func_, expr, err); + this->err_reporter.RenderErrors(this->mod_); + } + + // Visitor Logic Type VisitExpr_(const VarNode* op) final { if (op->type_annotation.defined()) { return op->type_annotation; @@ -174,8 +197,13 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); - CHECK(mod_.defined()) - << "Cannot do type inference without a global variable"; + if (!mod_.defined()) { + this->ReportFatalError( + GetRef(op), + RELAY_ERROR( + "Cannot do type inference on global variables " \ + "without a module")); + } Expr e = mod_->Lookup(var); return e->checked_type(); } @@ -210,37 +238,39 @@ class TypeInferencer : private ExprFunctor { return op->op_type; } - Type VisitExpr_(const LetNode* op) final { + Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion - bool is_functional_literal = op->value.as() != nullptr; + bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + type_map_[let->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } - Type vtype = GetType(op->value); - if (op->var->type_annotation.defined()) { - vtype = Unify(vtype, op->var->type_annotation, op->span); + Type vtype = GetType(let->value); + if (let->var->type_annotation.defined()) { + vtype = Unify(vtype, let->var->type_annotation, GetRef(let)); } - CHECK(is_functional_literal || !type_map_.count(op->var)); + CHECK(is_functional_literal || !type_map_.count(let->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = vtype; - return GetType(op->body); + type_map_[let->var].checked_type = vtype; + return GetType(let->body); } - Type VisitExpr_(const IfNode* op) final { + Type VisitExpr_(const IfNode* ite) final { // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. - Type cond_type = this->GetType(op->cond); + Type cond_type = this->GetType(ite->cond); this->Unify(cond_type, TensorTypeNode::Scalar(tvm::Bool()), - op->cond->span); - Type checked_true = this->GetType(op->true_branch); - Type checked_false = this->GetType(op->false_branch); - return this->Unify(checked_true, checked_false, op->span); + ite->cond); + Type checked_true = this->GetType(ite->true_branch); + Type checked_false = this->GetType(ite->false_branch); + return this->Unify(checked_true, checked_false, GetRef(ite)); } - // Handle special case basic primitive operator, - // if successful return the return type + // This code is special-cased for primitive operators, + // which are registered in the style defined in src/relay/op/*. + // + // The result will be the return type of the operator. Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs) { @@ -304,16 +334,19 @@ class TypeInferencer : private ExprFunctor { auto* fn_ty_node = ftype.as(); auto* inc_ty_node = ftype.as(); - CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; + if (fn_ty_node == nullptr && inc_ty_node == nullptr) { + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("only expressions with function types can be called, found " + << ftype)); + } // incomplete type => it must be a function taking the arg types // with an unknown return type if (inc_ty_node != nullptr) { Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); - Type unified = this->Unify(ftype, func_type, call->span); + Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); } @@ -323,10 +356,16 @@ class TypeInferencer : private ExprFunctor { type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); } } - CHECK(type_args.size() == fn_ty_node->type_params.size()) - << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() - << "but got " << type_args.size(); + + if (type_args.size() != fn_ty_node->type_params.size()) { + this->ReportFatalError(GetRef(call), + RELAY_ERROR("Incorrect number of type args in " + << call->span << ": " + << "Expected " + << fn_ty_node->type_params.size() + << "but got " << type_args.size())); + } + FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); AddTypeArgs(GetRef(call), type_args); @@ -336,14 +375,20 @@ class TypeInferencer : private ExprFunctor { if (type_arity != number_of_args) { if (type_arity < number_of_args) { - LOG(FATAL) << "the function is provided too many arguments " << call->span; + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args)); } else { - LOG(FATAL) << "the function is provided too few arguments" << call->span; + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args)); } } for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]); } for (auto cs : fn_ty->type_constraints) { @@ -385,7 +430,7 @@ class TypeInferencer : private ExprFunctor { } Type rtype = GetType(f->body); if (f->ret_type.defined()) { - rtype = this->Unify(f->ret_type, rtype, f->span); + rtype = this->Unify(f->ret_type, rtype, GetRef(f)); } auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); @@ -445,6 +490,9 @@ class TypeInferencer::Resolver : public ExprMutator { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); + + // TODO(@jroesch): it would be nice if we would report resolution + // errors directly on the program. CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -549,10 +597,27 @@ Expr TypeInferencer::Infer(Expr expr) { } -Expr InferType(const Expr& expr, const Module& mod) { - auto e = TypeInferencer(mod).Infer(expr); - CHECK(WellFormed(e)); - return e; +Expr InferType(const Expr& expr, const Module& mod_ref) { + if (!mod_ref.defined()) { + Module mod = ModuleNode::FromExpr(expr); + // NB(@jroesch): By adding the expression to the module we will + // type check it anyway; afterwards we can just recover type + // from the type-checked function to avoid doing unnecessary work. + + Function e = mod->Lookup(mod->entry_func); + + // FromExpr wraps a naked expression as a function, we will unbox + // it here. + if (auto func = expr.as()) { + return e; + } else { + return e->body; + } + } else { + auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); + CHECK(WellFormed(e)); + return e; + } } Function InferType(const Function& func, @@ -561,7 +626,7 @@ Function InferType(const Function& func, Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); - Expr func_ret = TypeInferencer(mod).Infer(func_copy); + Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); mod->Remove(var); CHECK(WellFormed(func_ret)); return Downcast(func_ret); diff --git a/src/relay/util/CPPLINT.cfg b/src/relay/util/CPPLINT.cfg new file mode 100644 index 000000000000..610884f9397e --- /dev/null +++ b/src/relay/util/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=rang.h diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py new file mode 100644 index 000000000000..481c19482da0 --- /dev/null +++ b/tests/python/relay/test_error_reporting.py @@ -0,0 +1,27 @@ +import tvm +from tvm import relay + +def check_type_err(expr, msg): + try: + expr = relay.ir_pass.infer_type(expr) + assert False + except tvm.TVMError as err: + assert msg in str(err) + +def test_too_many_args(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x) + y = relay.var('y', shape=(10, 10)) + check_type_err( + f(x, y), + "the function is provided too many arguments expected 1, found 2;") + +def test_too_few_args(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(10, 10)) + f = relay.Function([x, y], x) + check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;") + +if __name__ == "__main__": + test_too_many_args() + test_too_few_args() From 27c21cf3d174bc54172d98149eff0d285a3b4c30 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 16 Jan 2019 18:23:05 -0800 Subject: [PATCH 02/13] Fix warning --- src/relay/pass/type_infer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 4e9d27f92fca..8424cbcbec95 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -608,7 +608,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { // FromExpr wraps a naked expression as a function, we will unbox // it here. - if (auto func = expr.as()) { + if (expr.as()) { return e; } else { return e->body; From 91853fe265852d58140257cb383efbd4e4570616 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 17 Jan 2019 20:04:52 -0800 Subject: [PATCH 03/13] Address code review --- include/tvm/relay/error.h | 8 ++++---- src/relay/ir/error_reporter.cc | 2 -- src/relay/ir/module.cc | 2 +- src/relay/pass/type_infer.cc | 14 ++++---------- src/relay/util/CPPLINT.cfg | 1 - 5 files changed, 9 insertions(+), 18 deletions(-) delete mode 100644 src/relay/util/CPPLINT.cfg diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index d367ac27f0b1..7f8712a75283 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -20,11 +20,11 @@ namespace relay { * issues compiling and using std::stringstream * for error reporting. */ -struct StringStream { +struct RelayErrorStream { std::stringstream ss; template - StringStream& operator<<(const T& t) { + RelayErrorStream& operator<<(const T& t) { ss << t; return *this; } @@ -34,13 +34,13 @@ struct StringStream { } }; -#define RELAY_ERROR(msg) (StringStream() << msg) +#define RELAY_ERROR(msg) (RelayErrorStream() << msg) struct Error : public dmlc::Error { Span sp; explicit Error(const std::string &msg) : dmlc::Error(msg), sp() {} Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) - Error(const StringStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) + Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) }; struct InternalError : public Error { diff --git a/src/relay/ir/error_reporter.cc b/src/relay/ir/error_reporter.cc index b50180699960..2082e67eb055 100644 --- a/src/relay/ir/error_reporter.cc +++ b/src/relay/ir/error_reporter.cc @@ -106,8 +106,6 @@ using NodeMap = std::unordered_map; // Finally we report the error, currently we do so to LOG(FATAL), // it may be good to instead report it to std::cout. LOG(FATAL) << annotated_prog.str() << std::endl; - - exit(1); } void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 02318658e427..cfdf0ea5a5d7 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -105,7 +105,7 @@ Expr ModuleNode::EntryPoint() { Module ModuleNode::FromExpr( const Expr& expr, const tvm::Map& global_funcs) { - auto mod = ModuleNode::make({}); + auto mod = ModuleNode::make(global_funcs); auto func_node = expr.as(); Function func; if (func_node) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 8424cbcbec95..ef87b04442c4 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -174,12 +174,6 @@ class TypeInferencer : private ExprFunctor { return ret; } - [[noreturn]] void ReportFatalError(const Expr& expr, const std::stringstream& err) { - CHECK(this->current_func_.defined()); - this->err_reporter.ReportAt(this->current_func_, expr, Error(err)); - this->err_reporter.RenderErrors(this->mod_); - } - [[noreturn]] void ReportFatalError(const Expr& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); @@ -604,19 +598,19 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { // type check it anyway; afterwards we can just recover type // from the type-checked function to avoid doing unnecessary work. - Function e = mod->Lookup(mod->entry_func); + Function func = mod->Lookup(mod->entry_func); // FromExpr wraps a naked expression as a function, we will unbox // it here. if (expr.as()) { - return e; + return func; } else { - return e->body; + return func->body; } } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); - return e; + return func; } } diff --git a/src/relay/util/CPPLINT.cfg b/src/relay/util/CPPLINT.cfg deleted file mode 100644 index 610884f9397e..000000000000 --- a/src/relay/util/CPPLINT.cfg +++ /dev/null @@ -1 +0,0 @@ -exclude_files=rang.h From fa6eb72b08b4a968f75866dbf70b95d1bc669082 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 17 Jan 2019 20:11:22 -0800 Subject: [PATCH 04/13] More code review feedback --- include/tvm/relay/error.h | 2 +- src/relay/ir/error_reporter.cc | 3 +++ src/relay/pass/type_infer.cc | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 7f8712a75283..2e62efc715e3 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -38,7 +38,7 @@ struct RelayErrorStream { struct Error : public dmlc::Error { Span sp; - explicit Error(const std::string &msg) : dmlc::Error(msg), sp() {} + explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {} Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) }; diff --git a/src/relay/ir/error_reporter.cc b/src/relay/ir/error_reporter.cc index 2082e67eb055..acd8ff46e886 100644 --- a/src/relay/ir/error_reporter.cc +++ b/src/relay/ir/error_reporter.cc @@ -106,6 +106,9 @@ using NodeMap = std::unordered_map; // Finally we report the error, currently we do so to LOG(FATAL), // it may be good to instead report it to std::cout. LOG(FATAL) << annotated_prog.str() << std::endl; + + // NB(@jroesch): this is to ensure that the function does not return. + exit(1); } void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ef87b04442c4..9031712995e6 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -610,7 +610,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); - return func; + return e; } } From 2110b985e35ca40b0fefca1e4740173acc972b69 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 18 Jan 2019 15:53:09 -0800 Subject: [PATCH 05/13] Add support for rendering type relation errors in new style --- include/tvm/relay/error_reporter.h | 4 ++ include/tvm/relay/type.h | 6 +++ src/relay/op/type_relations.cc | 14 +++++-- src/relay/pass/type_infer.cc | 27 +++++++----- src/relay/pass/type_solver.cc | 48 ++++++++++++++++++---- src/relay/pass/type_solver.h | 17 ++++++-- tests/python/relay/test_error_reporting.py | 11 ++++- 7 files changed, 99 insertions(+), 28 deletions(-) diff --git a/include/tvm/relay/error_reporter.h b/include/tvm/relay/error_reporter.h index e34d24b8b221..fc0deef43c99 100644 --- a/include/tvm/relay/error_reporter.h +++ b/include/tvm/relay/error_reporter.h @@ -92,6 +92,10 @@ class ErrorReporter { */ [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); + inline bool AnyErrors() { + return errors_.size() != 0; + } + private: std::vector errors_; std::unordered_map, NodeHash, NodeEqual> node_to_error_; diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 69a8a4fb0bd7..140b5e8aab75 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -12,6 +12,7 @@ #include #include "base.h" +#include "error.h" #include "../attrs.h" namespace tvm { @@ -295,9 +296,14 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; + TVM_DLL virtual void ReportError(const Error& err) = 0; + // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} + // Not sure if best design, if not we should recreate a reporter for each relation. + mutable NodeRef location; + static constexpr const char* _type_key = "relay.TypeReporter"; TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); }; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 467c0fcde860..1f335cf837fc 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -55,7 +55,8 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, - DataType output_dtype) { + DataType output_dtype, + const TypeReporter& reporter) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); @@ -70,9 +71,13 @@ Type ConcreteBroadcast(const TensorType& t1, } else if (EqualConstInt(s2, 1)) { oshape.push_back(s1); } else { - LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2; + reporter->ReportError( + RELAY_ERROR( + "Incompatible broadcast type " + << t1 << " and " << t2)); } } + size_t max_ndim = std::max(ndim1, ndim2); auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape; for (; i <= max_ndim; ++i) { @@ -92,7 +97,8 @@ bool BroadcastRel(const Array& types, if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype)); + reporter->Assign(types[2], + ConcreteBroadcast(t0, t1, t0->dtype, reporter)); return true; } } @@ -109,7 +115,7 @@ bool BroadcastCompRel(const Array& types, if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool())); + reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool(), reporter)); return true; } } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 9031712995e6..15466de966c0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -83,10 +83,10 @@ struct ResolvedTypeInfo { class TypeInferencer : private ExprFunctor { public: // constructors - TypeInferencer() { - } + explicit TypeInferencer(Module mod, GlobalVar current_func) - : mod_(mod), current_func_(current_func), err_reporter() { + : mod_(mod), current_func_(current_func), + err_reporter(), solver_(current_func, &this->err_reporter) { } // inference the type of expr. @@ -123,7 +123,7 @@ class TypeInferencer : private ExprFunctor { Type Unify(const Type& t1, const Type& t2, const Expr& expr) { // TODO(tqchen, jroesch): propagate span to solver try { - return solver_.Unify(t1, t2); + return solver_.Unify(t1, t2, expr); } catch (const dmlc::Error &e) { this->ReportFatalError( expr, @@ -224,7 +224,7 @@ class TypeInferencer : private ExprFunctor { auto attrs = make_node(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( - tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs))); + tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); return rtype; } @@ -267,7 +267,8 @@ class TypeInferencer : private ExprFunctor { // The result will be the return type of the operator. Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, - const Attrs& attrs) { + const Attrs& attrs, + const NodeRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); const TypeRelationNode* rel = op->type_constraints[0].as(); @@ -280,7 +281,7 @@ class TypeInferencer : private ExprFunctor { arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelationNode::make( - rel->func, arg_types, arg_types.size() - 1, attrs)); + rel->func, arg_types, arg_types.size() - 1, attrs), loc); return rtype; } @@ -388,9 +389,10 @@ class TypeInferencer : private ExprFunctor { for (auto cs : fn_ty->type_constraints) { if (auto tr = cs.as()) { solver_.AddConstraint( - TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs)); + TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs), + GetRef(call)); } else { - solver_.AddConstraint(cs); + solver_.AddConstraint(cs, GetRef(call)); } } @@ -406,7 +408,8 @@ class TypeInferencer : private ExprFunctor { if (const OpNode* opnode = call->op.as()) { Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, - call->attrs); + call->attrs, + GetRef(call)); if (rtype.defined()) { AddTypeArgs(GetRef(call), arg_types); return rtype; @@ -584,6 +587,10 @@ Expr TypeInferencer::Infer(Expr expr) { // Step 1: Solve the constraints. solver_.Solve(); + if (err_reporter.AnyErrors()) { + err_reporter.RenderErrors(mod_); + } + // Step 2: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); CHECK(WellFormed(resolved_expr)); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index caea3755b8f9..0cdcab904ac3 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -16,7 +16,7 @@ class TypeSolver::Reporter : public TypeReporterNode { : solver_(solver) {} void Assign(const Type& dst, const Type& src) final { - solver_->Unify(dst, src); + solver_->Unify(dst, src, location); } bool Assert(const IndexExpr& cond) final { @@ -35,6 +35,27 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } + void ReportError(const Error& err) final { + std::cout + << "Current Function: " + << solver_->current_func + << std::endl; + + std::cout + << "Location: " + << location + << std::endl; + + std::cout + << err.what() + << std::endl; + + solver_->err_reporter_->ReportAt( + solver_->current_func, + location, + err); + } + private: TypeSolver* solver_; }; @@ -329,8 +350,10 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver() - : reporter_(make_node(this)) { +TypeSolver::TypeSolver(const GlobalVar ¤t_func, ErrorReporter* err_reporter) + : reporter_(make_node(this)), + current_func(current_func), + err_reporter_(err_reporter) { } // destructor @@ -351,16 +374,19 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { } // Add equality constraint -Type TypeSolver::Unify(const Type& dst, const Type& src) { +Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) { + // NB(@jroesch): we should probably pass location into the unifier to do better + // error reporting as well. Unifier unifier(this); return unifier.Unify(dst, src); } // Add type constraint to the solver. -void TypeSolver::AddConstraint(const TypeConstraint& constraint) { +void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { if (auto *op = constraint.as()) { // create a new relation node. RelationNode* rnode = arena_.make(); + rnode->location = loc; rnode->rel = GetRef(op); rel_nodes_.push_back(rnode); // populate the type information. @@ -404,6 +430,10 @@ bool TypeSolver::Solve() { args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); CHECK_LE(args.size(), rel->args.size()); } + + CHECK(rnode->location.defined()) << "undefined location"; + + reporter_->location = rnode->location; // call the function bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); // mark inqueue as false after the function call @@ -420,13 +450,13 @@ bool TypeSolver::Solve() { return num_resolved_rels_ == rel_nodes_.size(); } - // Expose type solver only for debugging purposes. TVM_REGISTER_API("relay._ir_pass._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; - auto solver = std::make_shared(); + ErrorReporter err_reporter; + auto solver = std::make_shared(GlobalVarNode::make("test"), &err_reporter); auto mod = [solver](std::string name) -> PackedFunc { if (name == "Solve") { @@ -435,7 +465,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") }); } else if (name == "Unify") { return TypedPackedFunc([solver](Type lhs, Type rhs) { - return solver->Unify(lhs, rhs); + return solver->Unify(lhs, rhs, NodeRef()); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { @@ -443,7 +473,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") }); } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { - return solver->AddConstraint(c); + return solver->AddConstraint(c, NodeRef()); }); } else { return PackedFunc(); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index b4635fdec331..02e8bb786c61 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -6,8 +6,10 @@ #ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_ #define TVM_RELAY_PASS_TYPE_SOLVER_H_ +#include #include #include +#include #include #include #include "../../common/arena.h" @@ -40,13 +42,13 @@ using common::LinkedList; */ class TypeSolver { public: - TypeSolver(); + TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter); ~TypeSolver(); /*! * \brief Add a type constraint to the solver. * \param constraint The constraint to be added. */ - void AddConstraint(const TypeConstraint& constraint); + void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation); /*! * \brief Resolve type to the solution type in the solver. * \param type The type to be resolved. @@ -63,7 +65,7 @@ class TypeSolver { * \param lhs The left operand. * \param rhs The right operand */ - Type Unify(const Type& lhs, const Type& rhs); + Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); private: class OccursChecker; @@ -112,6 +114,7 @@ class TypeSolver { return root; } }; + /*! \brief relation node */ struct RelationNode { /*! \brief Whether the relation is in the queue to be solved */ @@ -122,7 +125,10 @@ class TypeSolver { TypeRelation rel; /*! \brief list types to this relation */ LinkedList type_list; + /*! \brief The location this type relation originated from. */ + NodeRef location; }; + /*! \brief List of all allocated type nodes */ std::vector type_nodes_; /*! \brief List of all allocated relation nodes */ @@ -137,6 +143,11 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; + /*! \brief The global representing the current function. */ + GlobalVar current_func; + /*! \brief Error reporting. */ + ErrorReporter* err_reporter_; + /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index 481c19482da0..8d5fbe451d9e 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -22,6 +22,13 @@ def test_too_few_args(): f = relay.Function([x, y], x) check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;") +def test_rel_fail(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(11, 10)) + f = relay.Function([x, y], x + y) + check_type_err(f(x, y), "the function is provided too few arguments expected 2, found 1;") + if __name__ == "__main__": - test_too_many_args() - test_too_few_args() + # test_too_many_args() + # test_too_few_args() + test_rel_fail() From f178806955a134d849b65743bcd538728e3cf3d2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 18 Jan 2019 15:59:45 -0800 Subject: [PATCH 06/13] Fix up error reporting --- src/relay/pass/type_solver.cc | 14 -------------- tests/python/relay/test_error_reporting.py | 6 +++--- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 0cdcab904ac3..8dc8c20435bb 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -36,20 +36,6 @@ class TypeSolver::Reporter : public TypeReporterNode { } void ReportError(const Error& err) final { - std::cout - << "Current Function: " - << solver_->current_func - << std::endl; - - std::cout - << "Location: " - << location - << std::endl; - - std::cout - << err.what() - << std::endl; - solver_->err_reporter_->ReportAt( solver_->current_func, location, diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index 8d5fbe451d9e..1720af21afea 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -26,9 +26,9 @@ def test_rel_fail(): x = relay.var('x', shape=(10, 10)) y = relay.var('y', shape=(11, 10)) f = relay.Function([x, y], x + y) - check_type_err(f(x, y), "the function is provided too few arguments expected 2, found 1;") + check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);") if __name__ == "__main__": - # test_too_many_args() - # test_too_few_args() + test_too_many_args() + test_too_few_args() test_rel_fail() From 8d7171812b788cbbf2af56fd1323600713b1e0e1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 20 Jan 2019 16:25:05 -0800 Subject: [PATCH 07/13] Refactor to allow type relations to throw exceptions --- include/tvm/relay/error.h | 98 ++++++++++++++++- include/tvm/relay/error_reporter.h | 108 ------------------- include/tvm/relay/type.h | 3 - src/relay/ir/{error_reporter.cc => error.cc} | 6 +- src/relay/op/type_relations.cc | 5 +- src/relay/pass/type_infer.cc | 1 - src/relay/pass/type_solver.cc | 44 +++++--- src/relay/pass/type_solver.h | 9 +- 8 files changed, 139 insertions(+), 135 deletions(-) delete mode 100644 include/tvm/relay/error_reporter.h rename src/relay/ir/{error_reporter.cc => error.cc} (97%) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 2e62efc715e3..892d0349581c 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -14,6 +14,13 @@ namespace tvm { namespace relay { +#define RELAY_ERROR(msg) (RelayErrorStream() << msg) + +// Forward declaration for error reporting. +struct Error; +class GlobalVar; +struct Module; + /*! \brief A wrapper around std::stringstream. * * This is designed to avoid platform specific @@ -32,9 +39,9 @@ struct RelayErrorStream { std::string str() const { return ss.str(); } -}; -#define RELAY_ERROR(msg) (RelayErrorStream() << msg) + [[noreturn]] void Raise() const; +}; struct Error : public dmlc::Error { Span sp; @@ -55,6 +62,93 @@ struct TypecheckerError : public Error { explicit TypecheckerError(const std::string &msg) : Error(msg) {} }; +/*! \brief An abstraction around how errors are stored and reported. + * Designed to be opaque to users, so we can support a robust and simpler + * error reporting mode, as well as a more complex mode. + * + * The first mode is the most accurate: we report a Relay error at a specific + * Span, and then render the error message directly against a textual representation + * of the program, highlighting the exact lines in which it occurs. This mode is not + * implemented in this PR and will not work. + * + * The second mode is a general-purpose mode, which attempts to annotate the program's + * textual format with errors. + * + * The final mode represents the old mode, if we report an error that has no span or + * expression, we will default to throwing an exception with a textual representation + * of the error and no indication of where it occured in the original program. + * + * The latter mode is not ideal, and the goal of the new error reporting machinery is + * to avoid ever reporting errors in this style. + */ +class ErrorReporter { + public: + ErrorReporter() : errors_(), node_to_error_() {} + + /*! \brief Report a tvm::relay::Error. + * + * This API is useful for reporting spanned errors. + * + * \param err The error to report. + */ + void Report(const Error& err) { + if (!err.sp.defined()) { + throw err; + } + + this->errors_.push_back(err); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error message to report. + */ + inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { + this->ReportAt(global, node, Error(err)); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error to report. + */ + void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); + + /*! \brief Render all reported errors and exit the program. + * + * This function should be used after executing a pass to render reported errors. + * + * It will build an error message from the set of errors, depending on the error + * reporting strategy. + * + * \param module The module to report errors on. + * \param use_color Controls whether to colorize the output. + */ + [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); + + inline bool AnyErrors() { + return errors_.size() != 0; + } + + private: + std::vector errors_; + std::unordered_map, NodeHash, NodeEqual> node_to_error_; + std::unordered_map node_to_gv_; +}; + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/error_reporter.h b/include/tvm/relay/error_reporter.h deleted file mode 100644 index fc0deef43c99..000000000000 --- a/include/tvm/relay/error_reporter.h +++ /dev/null @@ -1,108 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file error_reporter.h - * \brief The set of errors raised by Relay. - */ -#ifndef TVM_RELAY_ERROR_REPORTER_H_ -#define TVM_RELAY_ERROR_REPORTER_H_ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -/*! \brief An abstraction around how errors are stored and reported. - * Designed to be opaque to users, so we can support a robust and simpler - * error reporting mode, as well as a more complex mode. - * - * The first mode is the most accurate: we report a Relay error at a specific - * Span, and then render the error message directly against a textual representation - * of the program, highlighting the exact lines in which it occurs. This mode is not - * implemented in this PR and will not work. - * - * The second mode is a general-purpose mode, which attempts to annotate the program's - * textual format with errors. - * - * The final mode represents the old mode, if we report an error that has no span or - * expression, we will default to throwing an exception with a textual representation - * of the error and no indication of where it occured in the original program. - * - * The latter mode is not ideal, and the goal of the new error reporting machinery is - * to avoid ever reporting errors in this style. - */ -class ErrorReporter { - public: - ErrorReporter() : errors_(), node_to_error_() {} - - /*! \brief Report a tvm::relay::Error. - * - * This API is useful for reporting spanned errors. - * - * \param err The error to report. - */ - void Report(const Error& err) { - if (!err.sp.defined()) { - throw err; - } - - this->errors_.push_back(err); - } - - /*! \brief Report an error against a program, using the full program - * error reporting strategy. - * - * This error reporting method requires the global function in which - * to report an error, the expression to report the error on, - * and the error object. - * - * \param global The global function in which the expression is contained. - * \param node The expression or type to report the error at. - * \param err The error message to report. - */ - inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { - this->ReportAt(global, node, Error(err)); - } - - /*! \brief Report an error against a program, using the full program - * error reporting strategy. - * - * This error reporting method requires the global function in which - * to report an error, the expression to report the error on, - * and the error object. - * - * \param global The global function in which the expression is contained. - * \param node The expression or type to report the error at. - * \param err The error to report. - */ - void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); - - /*! \brief Render all reported errors and exit the program. - * - * This function should be used after executing a pass to render reported errors. - * - * It will build an error message from the set of errors, depending on the error - * reporting strategy. - * - * \param module The module to report errors on. - * \param use_color Controls whether to colorize the output. - */ - [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); - - inline bool AnyErrors() { - return errors_.size() != 0; - } - - private: - std::vector errors_; - std::unordered_map, NodeHash, NodeEqual> node_to_error_; - std::unordered_map node_to_gv_; -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_ERROR_REPORTER_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 140b5e8aab75..2ee65a2121bc 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -12,7 +12,6 @@ #include #include "base.h" -#include "error.h" #include "../attrs.h" namespace tvm { @@ -296,8 +295,6 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; - TVM_DLL virtual void ReportError(const Error& err) = 0; - // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} diff --git a/src/relay/ir/error_reporter.cc b/src/relay/ir/error.cc similarity index 97% rename from src/relay/ir/error_reporter.cc rename to src/relay/ir/error.cc index acd8ff46e886..d56321286338 100644 --- a/src/relay/ir/error_reporter.cc +++ b/src/relay/ir/error.cc @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include @@ -14,6 +14,10 @@ namespace tvm { namespace relay { +void RelayErrorStream::Raise() const { + throw Error(*this); +} + template using NodeMap = std::unordered_map; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 1f335cf837fc..09ec610de5a5 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -71,10 +71,9 @@ Type ConcreteBroadcast(const TensorType& t1, } else if (EqualConstInt(s2, 1)) { oshape.push_back(s1); } else { - reporter->ReportError( - RELAY_ERROR( + RELAY_ERROR( "Incompatible broadcast type " - << t1 << " and " << t2)); + << t1 << " and " << t2).Raise(); } } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 15466de966c0..20b6e9ac730c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -21,7 +21,6 @@ */ #include -#include #include #include #include "type_solver.h" diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 8dc8c20435bb..db81eadc23ab 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -35,13 +35,6 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - void ReportError(const Error& err) final { - solver_->err_reporter_->ReportAt( - solver_->current_func, - location, - err); - } - private: TypeSolver* solver_; }; @@ -367,6 +360,13 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) { return unifier.Unify(dst, src); } +void TypeSolver::ReportError(const Error& err, const NodeRef& location) { + this->err_reporter_->ReportAt( + this->current_func, + location, + err); + } + // Add type constraint to the solver. void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { if (auto *op = constraint.as()) { @@ -418,19 +418,31 @@ bool TypeSolver::Solve() { } CHECK(rnode->location.defined()) << "undefined location"; - reporter_->location = rnode->location; - // call the function - bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); - // mark inqueue as false after the function call - // so that rnode itself won't get enqueued again. - rnode->inqueue = false; - if (resolved) { - ++num_resolved_rels_; + try { + // Call the Type Relation's function. + bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); + + if (resolved) { + ++num_resolved_rels_; + } + + rnode->resolved = resolved; + } catch (const Error& err) { + this->ReportError(err, rnode->location); + rnode->resolved = false; + } catch (const dmlc::Error& err) { + rnode->resolved = false; + std::cout << err.what() << std::endl; + exit(1); } - rnode->resolved = resolved; + + // Mark inqueue as false after the function call + // so that rnode itself won't get enqueued again. + rnode->inqueue = false; } + // This criterion is not necessarily right for all the possible cases // TODO(tqchen): We should also count the number of in-complete types. return num_resolved_rels_ == rel_nodes_.size(); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 02e8bb786c61..84ea43d42b1f 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include "../../common/arena.h" @@ -67,6 +67,13 @@ class TypeSolver { */ Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); + /*! + * \brief Report an error at the provided location. + * \param err The error to report. + * \param loc The location to report the error. + */ + void ReportError(const Error& err, const NodeRef& location); + private: class OccursChecker; class Unifier; From d96316a93518315eaa0776ede03944f160f04988 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 20 Jan 2019 16:30:21 -0800 Subject: [PATCH 08/13] Fix CI failure --- src/relay/pass/type_solver.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index db81eadc23ab..e2f812cebea3 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -463,7 +463,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") }); } else if (name == "Unify") { return TypedPackedFunc([solver](Type lhs, Type rhs) { - return solver->Unify(lhs, rhs, NodeRef()); + return solver->Unify(lhs, rhs, lhs); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { @@ -471,7 +471,9 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") }); } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { - return solver->AddConstraint(c, NodeRef()); + Expr e = VarNode::make("dummy_var", + IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + return solver->AddConstraint(c, e); }); } else { return PackedFunc(); From 8fcc15e46ba060ecb987e0a7172fc06c745b03bc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 20 Jan 2019 16:47:34 -0800 Subject: [PATCH 09/13] Clean up a few impl. details and add docs --- include/tvm/relay/type.h | 1 - src/relay/op/type_relations.cc | 7 +++---- src/relay/pass/type_solver.cc | 6 +++++- src/relay/pass/type_solver.h | 4 +++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 2ee65a2121bc..e5f1b50764ef 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -298,7 +298,6 @@ class TypeReporterNode : public Node { // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} - // Not sure if best design, if not we should recreate a reporter for each relation. mutable NodeRef location; static constexpr const char* _type_key = "relay.TypeReporter"; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 09ec610de5a5..2618054a663d 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -55,8 +55,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, - DataType output_dtype, - const TypeReporter& reporter) { + DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); @@ -97,7 +96,7 @@ bool BroadcastRel(const Array& types, if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); reporter->Assign(types[2], - ConcreteBroadcast(t0, t1, t0->dtype, reporter)); + ConcreteBroadcast(t0, t1, t0->dtype)); return true; } } @@ -114,7 +113,7 @@ bool BroadcastCompRel(const Array& types, if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool(), reporter)); + reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool())); return true; } } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index e2f812cebea3..502dae4aa1e0 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -417,7 +417,11 @@ bool TypeSolver::Solve() { CHECK_LE(args.size(), rel->args.size()); } - CHECK(rnode->location.defined()) << "undefined location"; + CHECK(rnode->location.defined()) + << "undefined location, should be set when constructing relation node"; + + // We need to set this in order to understand where unification + // errors generated by the error reporting are coming from. reporter_->location = rnode->location; try { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 84ea43d42b1f..b56d45c3b685 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -47,6 +47,7 @@ class TypeSolver { /*! * \brief Add a type constraint to the solver. * \param constraint The constraint to be added. + * \param location The location at which the constraint was incurred. */ void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation); /*! @@ -64,13 +65,14 @@ class TypeSolver { * \brief Unify lhs and rhs. * \param lhs The left operand. * \param rhs The right operand + * \param location The location at which the unification problem arose. */ Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); /*! * \brief Report an error at the provided location. * \param err The error to report. - * \param loc The location to report the error. + * \param loc The location at which to report the error. */ void ReportError(const Error& err, const NodeRef& location); From 52616a59d27c35fc992d3e4ba186204ab59616d6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 20 Jan 2019 17:12:04 -0800 Subject: [PATCH 10/13] Fix circular dep issue --- include/tvm/relay/error.h | 20 ++++---------------- include/tvm/relay/pass.h | 2 +- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 892d0349581c..4bfd153cca49 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -10,16 +10,16 @@ #include #include #include "./base.h" +#include "./expr.h" +#include "./module.h" namespace tvm { namespace relay { #define RELAY_ERROR(msg) (RelayErrorStream() << msg) -// Forward declaration for error reporting. +// Forward declaratio for RelayErrorStream. struct Error; -class GlobalVar; -struct Module; /*! \brief A wrapper around std::stringstream. * @@ -40,7 +40,7 @@ struct RelayErrorStream { return ss.str(); } - [[noreturn]] void Raise() const; + void Raise() const; }; struct Error : public dmlc::Error { @@ -50,18 +50,6 @@ struct Error : public dmlc::Error { Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) }; -struct InternalError : public Error { - explicit InternalError(const std::string &msg) : Error(msg) {} -}; - -struct FatalTypeError : public Error { - explicit FatalTypeError(const std::string &s) : Error(s) {} -}; - -struct TypecheckerError : public Error { - explicit TypecheckerError(const std::string &msg) : Error(msg) {} -}; - /*! \brief An abstraction around how errors are stored and reported. * Designed to be opaque to users, so we can support a robust and simpler * error reporting mode, as well as a more complex mode. diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8527ab7a2cb5..38f6a805f131 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include #include +#include #include #include From 07d1674ce9e3029321a4d5a9e0879adcfcf67127 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 23 Jan 2019 14:32:59 -0800 Subject: [PATCH 11/13] Fix error reporting for dmlc::Error --- src/relay/pass/type_solver.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 502dae4aa1e0..63bd012a813e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -438,8 +438,11 @@ bool TypeSolver::Solve() { rnode->resolved = false; } catch (const dmlc::Error& err) { rnode->resolved = false; - std::cout << err.what() << std::endl; - exit(1); + this->ReportError( + RELAY_ERROR( + "an internal invariant was violdated while" \ + "typechecking your program" << + err.what()), rnode->location); } // Mark inqueue as false after the function call From c98bb22a33a6bcb0f6a9fbfc40196a4f2a6aab5b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 24 Jan 2019 16:07:10 -0800 Subject: [PATCH 12/13] Address final comments --- include/tvm/relay/error.h | 2 +- include/tvm/relay/type.h | 8 ++++++-- src/relay/ir/error.cc | 5 +---- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/type_solver.cc | 9 ++++++++- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 4bfd153cca49..0451a9826cde 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -125,7 +125,7 @@ class ErrorReporter { * \param module The module to report errors on. * \param use_color Controls whether to colorize the output. */ - [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); + void RenderErrors(const Module& module, bool use_color = true); inline bool AnyErrors() { return errors_.size() != 0; diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e5f1b50764ef..f3bcf2c0a1d9 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -295,11 +295,15 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; + /*! + * \brief Set the location at which to report unification errors. + * \param ref The program node to report the error. + */ + TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} - mutable NodeRef location; - static constexpr const char* _type_key = "relay.TypeReporter"; TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); }; diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index d56321286338..24f8d1c49b6b 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -21,7 +21,7 @@ void RelayErrorStream::Raise() const { template using NodeMap = std::unordered_map; -[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) { +void ErrorReporter::RenderErrors(const Module& module, bool use_color) { // First we pick an error reporting strategy for each error. // TODO(@jroesch): Spanned errors are currently not supported. for (auto err : this->errors_) { @@ -110,9 +110,6 @@ using NodeMap = std::unordered_map; // Finally we report the error, currently we do so to LOG(FATAL), // it may be good to instead report it to std::cout. LOG(FATAL) << annotated_prog.str() << std::endl; - - // NB(@jroesch): this is to ensure that the function does not return. - exit(1); } void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 20b6e9ac730c..3135715f7691 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -173,7 +173,7 @@ class TypeInferencer : private ExprFunctor { return ret; } - [[noreturn]] void ReportFatalError(const Expr& expr, const Error& err) { + void ReportFatalError(const Expr& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 63bd012a813e..dafcaf56015a 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -35,7 +35,14 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } + TVM_DLL void SetLocation(const NodeRef& ref) final { + location = ref; + } + private: + /*! \brief The location to report unification errors at. */ + mutable NodeRef location; + TypeSolver* solver_; }; @@ -422,7 +429,7 @@ bool TypeSolver::Solve() { // We need to set this in order to understand where unification // errors generated by the error reporting are coming from. - reporter_->location = rnode->location; + reporter_->SetLocation(rnode->location); try { // Call the Type Relation's function. From 1263614389c0f49ce93198d0eb228c180c7b0185 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 24 Jan 2019 20:30:36 -0800 Subject: [PATCH 13/13] Remove EntryPoint --- include/tvm/relay/module.h | 7 ------- src/relay/ir/module.cc | 4 ---- 2 files changed, 11 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8585323d1628..45ccfe3a8089 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -115,13 +115,6 @@ class ModuleNode : public RelayNode { */ void Update(const Module& other); - /*! - * \brief Get the entry point of the module. - * - * \returns The entry point function, (i.e. main). - */ - Expr EntryPoint(); - /*! \brief Construct a module from a standalone expression. * * Allows one to optionally pass a global function map as diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index cfdf0ea5a5d7..9ba5efecec80 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -98,10 +98,6 @@ void ModuleNode::Update(const Module& mod) { } } -Expr ModuleNode::EntryPoint() { - return this->Lookup(this->entry_func); -} - Module ModuleNode::FromExpr( const Expr& expr, const tvm::Map& global_funcs) {