Skip to content

Commit

Permalink
Address final comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Jan 25, 2019
1 parent 07d1674 commit c98bb22
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class ErrorReporter {
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
[[noreturn]] void RenderErrors(const Module& module, bool use_color = true);
void RenderErrors(const Module& module, bool use_color = true);

inline bool AnyErrors() {
return errors_.size() != 0;
Expand Down
8 changes: 6 additions & 2 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,15 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;

/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;

// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}

mutable NodeRef location;

static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
};
Expand Down
5 changes: 1 addition & 4 deletions src/relay/ir/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void RelayErrorStream::Raise() const {
template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;

[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) {
Expand Down Expand Up @@ -110,9 +110,6 @@ using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
// Finally we report the error, currently we do so to LOG(FATAL),
// it may be good to instead report it to std::cout.
LOG(FATAL) << annotated_prog.str() << std::endl;

// NB(@jroesch): this is to ensure that the function does not return.
exit(1);
}

void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return ret;
}

[[noreturn]] void ReportFatalError(const Expr& expr, const Error& err) {
void ReportFatalError(const Expr& expr, const Error& err) {
CHECK(this->current_func_.defined());
this->err_reporter.ReportAt(this->current_func_, expr, err);
this->err_reporter.RenderErrors(this->mod_);
Expand Down
9 changes: 8 additions & 1 deletion src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ class TypeSolver::Reporter : public TypeReporterNode {
return true;
}

TVM_DLL void SetLocation(const NodeRef& ref) final {
location = ref;
}

private:
/*! \brief The location to report unification errors at. */
mutable NodeRef location;

TypeSolver* solver_;
};

Expand Down Expand Up @@ -422,7 +429,7 @@ bool TypeSolver::Solve() {

// We need to set this in order to understand where unification
// errors generated by the error reporting are coming from.
reporter_->location = rnode->location;
reporter_->SetLocation(rnode->location);

try {
// Call the Type Relation's function.
Expand Down

0 comments on commit c98bb22

Please sign in to comment.