diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 4bfd153cca49..0451a9826cde 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -125,7 +125,7 @@ class ErrorReporter { * \param module The module to report errors on. * \param use_color Controls whether to colorize the output. */ - [[noreturn]] void RenderErrors(const Module& module, bool use_color = true); + void RenderErrors(const Module& module, bool use_color = true); inline bool AnyErrors() { return errors_.size() != 0; diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e5f1b50764ef..f3bcf2c0a1d9 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -295,11 +295,15 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; + /*! + * \brief Set the location at which to report unification errors. + * \param ref The program node to report the error. + */ + TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} - mutable NodeRef location; - static constexpr const char* _type_key = "relay.TypeReporter"; TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); }; diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index d56321286338..24f8d1c49b6b 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -21,7 +21,7 @@ void RelayErrorStream::Raise() const { template using NodeMap = std::unordered_map; -[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) { +void ErrorReporter::RenderErrors(const Module& module, bool use_color) { // First we pick an error reporting strategy for each error. // TODO(@jroesch): Spanned errors are currently not supported. for (auto err : this->errors_) { @@ -110,9 +110,6 @@ using NodeMap = std::unordered_map; // Finally we report the error, currently we do so to LOG(FATAL), // it may be good to instead report it to std::cout. LOG(FATAL) << annotated_prog.str() << std::endl; - - // NB(@jroesch): this is to ensure that the function does not return. - exit(1); } void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 20b6e9ac730c..3135715f7691 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -173,7 +173,7 @@ class TypeInferencer : private ExprFunctor { return ret; } - [[noreturn]] void ReportFatalError(const Expr& expr, const Error& err) { + void ReportFatalError(const Expr& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 63bd012a813e..dafcaf56015a 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -35,7 +35,14 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } + TVM_DLL void SetLocation(const NodeRef& ref) final { + location = ref; + } + private: + /*! \brief The location to report unification errors at. */ + mutable NodeRef location; + TypeSolver* solver_; }; @@ -422,7 +429,7 @@ bool TypeSolver::Solve() { // We need to set this in order to understand where unification // errors generated by the error reporting are coming from. - reporter_->location = rnode->location; + reporter_->SetLocation(rnode->location); try { // Call the Type Relation's function.