From 7a370ea22887e16d1d07380f12f225d9c169d07f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 6 Jan 2019 17:17:25 -0800 Subject: [PATCH] Add descriptive whole module errors to Relay This PR adds descriptive errors to Relay. It is now possible to annotate and visulize errors annotated at certain expressions in a relay::Module. Extract module changes from other branch Fix me Get first version working Improve color support Update error rendering, and comment about rang license Remove uncessary type checking from InferType Update more error reporting to use new machinery Update more error reporting Hacking Negate checks Remove debugging Address MK's comments Fix CPP lint Disable linting for rang Document new relay::Module APIs Add another test case Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update include/tvm/relay/module.h Co-Authored-By: jroesch Update src/relay/ir/module.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Update include/tvm/relay/error.h Co-Authored-By: jroesch Apply suggestions from code review Co-Authored-By: jroesch Fix comment in type_infer.cc Remove copy of relay::Module Fix lambda capture Refactor to remove error reporter from module Update include/tvm/relay/module.h Co-Authored-By: jroesch Fix small issue Fix doc comment Fix doc comment typo Fix build error Fix implicit ctor not working Fix macro Remove duplicate macro Try and fix issue with stringstream Fix old type inference invariant WIP repairing from rebase Rebase fixup Another fix --- .gitmodules | 3 + 3rdparty/rang | 1 + CMakeLists.txt | 1 + include/tvm/relay/error.h | 29 +++- include/tvm/relay/error_reporter.h | 104 ++++++++++++++ include/tvm/relay/module.h | 26 ++++ src/relay/ir/error_reporter.cc | 126 +++++++++++++++++ src/relay/ir/module.cc | 21 +++ src/relay/pass/type_infer.cc | 155 +++++++++++++++------ src/relay/util/CPPLINT.cfg | 1 + tests/python/relay/test_error_reporting.py | 27 ++++ 11 files changed, 448 insertions(+), 46 deletions(-) create mode 160000 3rdparty/rang create mode 100644 include/tvm/relay/error_reporter.h create mode 100644 src/relay/ir/error_reporter.cc create mode 100644 src/relay/util/CPPLINT.cfg create mode 100644 tests/python/relay/test_error_reporting.py diff --git a/.gitmodules b/.gitmodules index 8011ec12d24b5..984326434c3fb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "dlpack"] path = 3rdparty/dlpack url = https://github.com/dmlc/dlpack +[submodule "3rdparty/rang"] + path = 3rdparty/rang + url = https://github.com/agauniyal/rang diff --git a/3rdparty/rang b/3rdparty/rang new file mode 160000 index 0000000000000..cabe04d6d6b05 --- /dev/null +++ b/3rdparty/rang @@ -0,0 +1 @@ +Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762 diff --git a/CMakeLists.txt b/CMakeLists.txt index 363b2056a87ad..922cda3beff6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) include_directories("include") include_directories("3rdparty/dlpack/include") include_directories("3rdparty/dmlc-core/include") +include_directories("3rdparty/rang/include") include_directories("3rdparty/compiler-rt") # initial variables diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 1c2b90611bbd3..d367ac27f0b14 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -7,13 +7,40 @@ #define TVM_RELAY_ERROR_H_ #include +#include +#include #include "./base.h" namespace tvm { namespace relay { +/*! \brief A wrapper around std::stringstream. + * + * This is designed to avoid platform specific + * issues compiling and using std::stringstream + * for error reporting. + */ +struct StringStream { + std::stringstream ss; + + template + StringStream& operator<<(const T& t) { + ss << t; + return *this; + } + + std::string str() const { + return ss.str(); + } +}; + +#define RELAY_ERROR(msg) (StringStream() << msg) + struct Error : public dmlc::Error { - explicit Error(const std::string &msg) : dmlc::Error(msg) {} + Span sp; + explicit Error(const std::string &msg) : dmlc::Error(msg), sp() {} + Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) + Error(const StringStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*) }; struct InternalError : public Error { diff --git a/include/tvm/relay/error_reporter.h b/include/tvm/relay/error_reporter.h new file mode 100644 index 0000000000000..e34d24b8b2219 --- /dev/null +++ b/include/tvm/relay/error_reporter.h @@ -0,0 +1,104 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error_reporter.h + * \brief The set of errors raised by Relay. + */ +#ifndef TVM_RELAY_ERROR_REPORTER_H_ +#define TVM_RELAY_ERROR_REPORTER_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief An abstraction around how errors are stored and reported. + * Designed to be opaque to users, so we can support a robust and simpler + * error reporting mode, as well as a more complex mode. + * + * The first mode is the most accurate: we report a Relay error at a specific + * Span, and then render the error message directly against a textual representation + * of the program, highlighting the exact lines in which it occurs. This mode is not + * implemented in this PR and will not work. + * + * The second mode is a general-purpose mode, which attempts to annotate the program's + * textual format with errors. + * + * The final mode represents the old mode, if we report an error that has no span or + * expression, we will default to throwing an exception with a textual representation + * of the error and no indication of where it occured in the original program. + * + * The latter mode is not ideal, and the goal of the new error reporting machinery is + * to avoid ever reporting errors in this style. + */ +class ErrorReporter { + public: + ErrorReporter() : errors_(), node_to_error_() {} + + /*! \brief Report a tvm::relay::Error. + * + * This API is useful for reporting spanned errors. + * + * \param err The error to report. + */ + void Report(const Error& err) { + if (!err.sp.defined()) { + throw err; + } + + this->errors_.push_back(err); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error message to report. + */ + inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { + this->ReportAt(global, node, Error(err)); + } + + /*! \brief Report an error against a program, using the full program + * error reporting strategy. + * + * This error reporting method requires the global function in which + * to report an error, the expression to report the error on, + * and the error object. + * + * \param global The global function in which the expression is contained. + * \param node The expression or type to report the error at. + * \param err The error to report. + */ + void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); + + /*! \brief Render all reported errors and exit the program. + * + * This function should be used after executing a pass to render reported errors. + * + * It will build an error message from the set of errors, depending on the error + * reporting strategy. + * + * \param module The module to report errors on. + * \param use_color Controls whether to colorize the output. + */ + [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); + + private: + std::vector errors_; + std::unordered_map, NodeHash, NodeEqual> node_to_error_; + std::unordered_map node_to_gv_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ERROR_REPORTER_H_ diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8d302c09d959c..8585323d1628e 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -43,11 +43,15 @@ class ModuleNode : public RelayNode { /*! \brief A map from ids to all global functions. */ tvm::Map functions; + /*! \brief The entry function (i.e. "main"). */ + GlobalVar entry_func; + ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); v->Visit("global_var_map_", &global_var_map_); + v->Visit("entry_func", &entry_func); } TVM_DLL static Module make(tvm::Map global_funcs); @@ -111,6 +115,27 @@ class ModuleNode : public RelayNode { */ void Update(const Module& other); + /*! + * \brief Get the entry point of the module. + * + * \returns The entry point function, (i.e. main). + */ + Expr EntryPoint(); + + /*! \brief Construct a module from a standalone expression. + * + * Allows one to optionally pass a global function map as + * well. + * + * \param expr The expression to set as the entry point to the module. + * \param global_funcs The global function map. + * + * \returns A module with expr set as the entry point. + */ + static Module FromExpr( + const Expr& expr, + const tvm::Map& global_funcs = {}); + static constexpr const char* _type_key = "relay.Module"; TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); @@ -132,6 +157,7 @@ struct Module : public NodeRef { using ContainerType = ModuleNode; }; + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/error_reporter.cc b/src/relay/ir/error_reporter.cc new file mode 100644 index 0000000000000..b501806999608 --- /dev/null +++ b/src/relay/ir/error_reporter.cc @@ -0,0 +1,126 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error_reporter.h + * \brief The set of errors raised by Relay. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +template +using NodeMap = std::unordered_map; + +[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) { + // First we pick an error reporting strategy for each error. + // TODO(@jroesch): Spanned errors are currently not supported. + for (auto err : this->errors_) { + CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; + } + + NodeMap> error_maps; + + // Set control mode in order to produce colors; + if (use_color) { + rang::setControlMode(rang::control::Force); + } + + for (auto pair : this->node_to_gv_) { + auto node = pair.first; + auto global = Downcast(pair.second); + + auto has_errs = this->node_to_error_.find(node); + + CHECK(has_errs != this->node_to_error_.end()); + + const auto& error_indicies = has_errs->second; + + std::stringstream err_msg; + + err_msg << rang::fg::red; + for (auto index : error_indicies) { + err_msg << this->errors_[index].what() << "; "; + } + err_msg << rang::fg::reset; + + // Setup error map. + auto it = error_maps.find(global); + if (it != error_maps.end()) { + it->second.insert({ node, err_msg.str() }); + } else { + error_maps.insert({ global, { { node, err_msg.str() }}}); + } + } + + // Now we will construct the fully-annotated program to display to + // the user. + std::stringstream annotated_prog; + + // First we output a header for the errors. + annotated_prog << + rang::style::bold << std::endl << + "Error(s) have occurred. We have annotated the program with them:" + << std::endl << std::endl << rang::style::reset; + + // For each global function which contains errors, we will + // construct an annotated function. + for (auto pair : error_maps) { + auto global = pair.first; + auto err_map = pair.second; + auto func = module->Lookup(global); + + // We output the name of the function before displaying + // the annotated program. + annotated_prog << + rang::style::bold << + "In `" << global->name_hint << "`: " << + std::endl << + rang::style::reset; + + // We then call into the Relay printer to generate the program. + // + // The annotation callback will annotate the error messages + // contained in the map. + annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { + auto it = err_map.find(expr); + if (it != err_map.end()) { + return it->second; + } else { + return std::string(""); + } + }); + } + + auto msg = annotated_prog.str(); + + if (use_color) { + rang::setControlMode(rang::control::Auto); + } + + // Finally we report the error, currently we do so to LOG(FATAL), + // it may be good to instead report it to std::cout. + LOG(FATAL) << annotated_prog.str() << std::endl; + + exit(1); +} + +void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { + size_t index_to_insert = this->errors_.size(); + this->errors_.push_back(err); + auto it = this->node_to_error_.find(node); + if (it != this->node_to_error_.end()) { + it->second.push_back(index_to_insert); + } else { + this->node_to_error_.insert({ node, { index_to_insert }}); + } + this->node_to_gv_.insert({ node, global }); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index cbb0b77680043..02318658e4273 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map global_funcs) { << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } + + n->entry_func = GlobalVarNode::make("main"); return Module(n); } @@ -96,6 +98,25 @@ void ModuleNode::Update(const Module& mod) { } } +Expr ModuleNode::EntryPoint() { + return this->Lookup(this->entry_func); +} + +Module ModuleNode::FromExpr( + const Expr& expr, + const tvm::Map& global_funcs) { + auto mod = ModuleNode::make({}); + auto func_node = expr.as(); + Function func; + if (func_node) { + func = GetRef(func_node); + } else { + func = FunctionNode::make({}, expr, Type(), {}, {}); + } + mod->Add(mod->entry_func, func); + return mod; +} + TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index af4cc6607a44a..4e9d27f92fca2 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -21,6 +21,7 @@ */ #include +#include #include #include #include "type_solver.h" @@ -84,8 +85,8 @@ class TypeInferencer : private ExprFunctor { // constructors TypeInferencer() { } - explicit TypeInferencer(Module mod) - : mod_(mod) { + explicit TypeInferencer(Module mod, GlobalVar current_func) + : mod_(mod), current_func_(current_func), err_reporter() { } // inference the type of expr. @@ -96,6 +97,13 @@ class TypeInferencer : private ExprFunctor { class Resolver; // internal environment Module mod_; + + // The current function being type checked. + GlobalVar current_func_; + + // The error reporter. + ErrorReporter err_reporter; + // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -109,18 +117,21 @@ class TypeInferencer : private ExprFunctor { // relation function TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; - // Unify two types - Type Unify(const Type& t1, const Type& t2, const Span& span) { + + // Perform unification on two types and report the error at the expression + // or the span of the expression. + Type Unify(const Type& t1, const Type& t2, const Expr& expr) { // TODO(tqchen, jroesch): propagate span to solver try { return solver_.Unify(t1, t2); } catch (const dmlc::Error &e) { - LOG(FATAL) - << "Error unifying `" + this->ReportFatalError( + expr, + RELAY_ERROR("Error unifying `" << t1 << "` and `" << t2 - << "`: " << e.what(); + << "`: " << e.what())); return Type(); } } @@ -151,7 +162,7 @@ class TypeInferencer : private ExprFunctor { } // Lazily get type for expr - // will call visit to deduce it if it is not in the type_map_ + // expression, we will populate it now, and return the result. Type GetType(const Expr &expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { @@ -163,7 +174,19 @@ class TypeInferencer : private ExprFunctor { return ret; } - // Visitor logics + [[noreturn]] void ReportFatalError(const Expr& expr, const std::stringstream& err) { + CHECK(this->current_func_.defined()); + this->err_reporter.ReportAt(this->current_func_, expr, Error(err)); + this->err_reporter.RenderErrors(this->mod_); + } + + [[noreturn]] void ReportFatalError(const Expr& expr, const Error& err) { + CHECK(this->current_func_.defined()); + this->err_reporter.ReportAt(this->current_func_, expr, err); + this->err_reporter.RenderErrors(this->mod_); + } + + // Visitor Logic Type VisitExpr_(const VarNode* op) final { if (op->type_annotation.defined()) { return op->type_annotation; @@ -174,8 +197,13 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); - CHECK(mod_.defined()) - << "Cannot do type inference without a global variable"; + if (!mod_.defined()) { + this->ReportFatalError( + GetRef(op), + RELAY_ERROR( + "Cannot do type inference on global variables " \ + "without a module")); + } Expr e = mod_->Lookup(var); return e->checked_type(); } @@ -210,37 +238,39 @@ class TypeInferencer : private ExprFunctor { return op->op_type; } - Type VisitExpr_(const LetNode* op) final { + Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion - bool is_functional_literal = op->value.as() != nullptr; + bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + type_map_[let->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } - Type vtype = GetType(op->value); - if (op->var->type_annotation.defined()) { - vtype = Unify(vtype, op->var->type_annotation, op->span); + Type vtype = GetType(let->value); + if (let->var->type_annotation.defined()) { + vtype = Unify(vtype, let->var->type_annotation, GetRef(let)); } - CHECK(is_functional_literal || !type_map_.count(op->var)); + CHECK(is_functional_literal || !type_map_.count(let->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = vtype; - return GetType(op->body); + type_map_[let->var].checked_type = vtype; + return GetType(let->body); } - Type VisitExpr_(const IfNode* op) final { + Type VisitExpr_(const IfNode* ite) final { // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. - Type cond_type = this->GetType(op->cond); + Type cond_type = this->GetType(ite->cond); this->Unify(cond_type, TensorTypeNode::Scalar(tvm::Bool()), - op->cond->span); - Type checked_true = this->GetType(op->true_branch); - Type checked_false = this->GetType(op->false_branch); - return this->Unify(checked_true, checked_false, op->span); + ite->cond); + Type checked_true = this->GetType(ite->true_branch); + Type checked_false = this->GetType(ite->false_branch); + return this->Unify(checked_true, checked_false, GetRef(ite)); } - // Handle special case basic primitive operator, - // if successful return the return type + // This code is special-cased for primitive operators, + // which are registered in the style defined in src/relay/op/*. + // + // The result will be the return type of the operator. Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs) { @@ -304,16 +334,19 @@ class TypeInferencer : private ExprFunctor { auto* fn_ty_node = ftype.as(); auto* inc_ty_node = ftype.as(); - CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; + if (fn_ty_node == nullptr && inc_ty_node == nullptr) { + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("only expressions with function types can be called, found " + << ftype)); + } // incomplete type => it must be a function taking the arg types // with an unknown return type if (inc_ty_node != nullptr) { Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); - Type unified = this->Unify(ftype, func_type, call->span); + Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); } @@ -323,10 +356,16 @@ class TypeInferencer : private ExprFunctor { type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); } } - CHECK(type_args.size() == fn_ty_node->type_params.size()) - << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() - << "but got " << type_args.size(); + + if (type_args.size() != fn_ty_node->type_params.size()) { + this->ReportFatalError(GetRef(call), + RELAY_ERROR("Incorrect number of type args in " + << call->span << ": " + << "Expected " + << fn_ty_node->type_params.size() + << "but got " << type_args.size())); + } + FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); AddTypeArgs(GetRef(call), type_args); @@ -336,14 +375,20 @@ class TypeInferencer : private ExprFunctor { if (type_arity != number_of_args) { if (type_arity < number_of_args) { - LOG(FATAL) << "the function is provided too many arguments " << call->span; + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args)); } else { - LOG(FATAL) << "the function is provided too few arguments" << call->span; + this->ReportFatalError( + GetRef(call), + RELAY_ERROR("the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args)); } } for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]); } for (auto cs : fn_ty->type_constraints) { @@ -385,7 +430,7 @@ class TypeInferencer : private ExprFunctor { } Type rtype = GetType(f->body); if (f->ret_type.defined()) { - rtype = this->Unify(f->ret_type, rtype, f->span); + rtype = this->Unify(f->ret_type, rtype, GetRef(f)); } auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); @@ -445,6 +490,9 @@ class TypeInferencer::Resolver : public ExprMutator { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); + + // TODO(@jroesch): it would be nice if we would report resolution + // errors directly on the program. CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -549,10 +597,27 @@ Expr TypeInferencer::Infer(Expr expr) { } -Expr InferType(const Expr& expr, const Module& mod) { - auto e = TypeInferencer(mod).Infer(expr); - CHECK(WellFormed(e)); - return e; +Expr InferType(const Expr& expr, const Module& mod_ref) { + if (!mod_ref.defined()) { + Module mod = ModuleNode::FromExpr(expr); + // NB(@jroesch): By adding the expression to the module we will + // type check it anyway; afterwards we can just recover type + // from the type-checked function to avoid doing unnecessary work. + + Function e = mod->Lookup(mod->entry_func); + + // FromExpr wraps a naked expression as a function, we will unbox + // it here. + if (auto func = expr.as()) { + return e; + } else { + return e->body; + } + } else { + auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); + CHECK(WellFormed(e)); + return e; + } } Function InferType(const Function& func, @@ -561,7 +626,7 @@ Function InferType(const Function& func, Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); - Expr func_ret = TypeInferencer(mod).Infer(func_copy); + Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); mod->Remove(var); CHECK(WellFormed(func_ret)); return Downcast(func_ret); diff --git a/src/relay/util/CPPLINT.cfg b/src/relay/util/CPPLINT.cfg new file mode 100644 index 0000000000000..610884f9397e4 --- /dev/null +++ b/src/relay/util/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=rang.h diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py new file mode 100644 index 0000000000000..481c19482da08 --- /dev/null +++ b/tests/python/relay/test_error_reporting.py @@ -0,0 +1,27 @@ +import tvm +from tvm import relay + +def check_type_err(expr, msg): + try: + expr = relay.ir_pass.infer_type(expr) + assert False + except tvm.TVMError as err: + assert msg in str(err) + +def test_too_many_args(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x) + y = relay.var('y', shape=(10, 10)) + check_type_err( + f(x, y), + "the function is provided too many arguments expected 1, found 2;") + +def test_too_few_args(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(10, 10)) + f = relay.Function([x, y], x) + check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;") + +if __name__ == "__main__": + test_too_many_args() + test_too_few_args()