Skip to content

Commit

Permalink
[Relay] Add generic & informative Relay error reporting (apache#2408)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and AWS Neo committed Feb 20, 2019
1 parent 1ad84f8 commit 01ba738
Show file tree
Hide file tree
Showing 14 changed files with 537 additions and 86 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "dlpack"]
path = 3rdparty/dlpack
url = https://github.com/dmlc/dlpack
[submodule "3rdparty/rang"]
path = 3rdparty/rang
url = https://github.com/agauniyal/rang
1 change: 1 addition & 0 deletions 3rdparty/rang
Submodule rang added at cabe04
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
include_directories("include")
include_directories("3rdparty/dlpack/include")
include_directories("3rdparty/dmlc-core/include")
include_directories("3rdparty/rang/include")
include_directories("3rdparty/compiler-rt")

# initial variables
Expand Down
127 changes: 118 additions & 9 deletions include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,134 @@
#define TVM_RELAY_ERROR_H_

#include <string>
#include <vector>
#include <sstream>
#include "./base.h"
#include "./expr.h"
#include "./module.h"

namespace tvm {
namespace relay {

struct Error : public dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};
#define RELAY_ERROR(msg) (RelayErrorStream() << msg)

// Forward declaratio for RelayErrorStream.
struct Error;

/*! \brief A wrapper around std::stringstream.
*
* This is designed to avoid platform specific
* issues compiling and using std::stringstream
* for error reporting.
*/
struct RelayErrorStream {
std::stringstream ss;

template<typename T>
RelayErrorStream& operator<<(const T& t) {
ss << t;
return *this;
}

struct InternalError : public Error {
explicit InternalError(const std::string &msg) : Error(msg) {}
std::string str() const {
return ss.str();
}

void Raise() const;
};

struct FatalTypeError : public Error {
explicit FatalTypeError(const std::string &s) : Error(s) {}
struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
};

struct TypecheckerError : public Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
/*! \brief An abstraction around how errors are stored and reported.
* Designed to be opaque to users, so we can support a robust and simpler
* error reporting mode, as well as a more complex mode.
*
* The first mode is the most accurate: we report a Relay error at a specific
* Span, and then render the error message directly against a textual representation
* of the program, highlighting the exact lines in which it occurs. This mode is not
* implemented in this PR and will not work.
*
* The second mode is a general-purpose mode, which attempts to annotate the program's
* textual format with errors.
*
* The final mode represents the old mode, if we report an error that has no span or
* expression, we will default to throwing an exception with a textual representation
* of the error and no indication of where it occured in the original program.
*
* The latter mode is not ideal, and the goal of the new error reporting machinery is
* to avoid ever reporting errors in this style.
*/
class ErrorReporter {
public:
ErrorReporter() : errors_(), node_to_error_() {}

/*! \brief Report a tvm::relay::Error.
*
* This API is useful for reporting spanned errors.
*
* \param err The error to report.
*/
void Report(const Error& err) {
if (!err.sp.defined()) {
throw err;
}

this->errors_.push_back(err);
}

/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
}

/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error to report.
*/
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);

/*! \brief Render all reported errors and exit the program.
*
* This function should be used after executing a pass to render reported errors.
*
* It will build an error message from the set of errors, depending on the error
* reporting strategy.
*
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
void RenderErrors(const Module& module, bool use_color = true);

inline bool AnyErrors() {
return errors_.size() != 0;
}

private:
std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
};

} // namespace relay
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;

/*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func;

ModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
Expand Down Expand Up @@ -111,6 +115,20 @@ class ModuleNode : public RelayNode {
*/
void Update(const Module& other);

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map as
* well.
*
* \param expr The expression to set as the entry point to the module.
* \param global_funcs The global function map.
*
* \returns A module with expr set as the entry point.
*/
static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});

static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);

Expand All @@ -132,6 +150,7 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode;
};


} // namespace relay
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <string>

Expand Down
6 changes: 6 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;

/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;

// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}

Expand Down
128 changes: 128 additions & 0 deletions src/relay/ir/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*!
* Copyright (c) 2018 by Contributors
* \file error_reporter.h
* \brief The set of errors raised by Relay.
*/

#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/error.h>
#include <string>
#include <vector>
#include <rang.hpp>

namespace tvm {
namespace relay {

void RelayErrorStream::Raise() const {
throw Error(*this);
}

template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;

void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) {
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
}

NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;

// Set control mode in order to produce colors;
if (use_color) {
rang::setControlMode(rang::control::Force);
}

for (auto pair : this->node_to_gv_) {
auto node = pair.first;
auto global = Downcast<GlobalVar>(pair.second);

auto has_errs = this->node_to_error_.find(node);

CHECK(has_errs != this->node_to_error_.end());

const auto& error_indicies = has_errs->second;

std::stringstream err_msg;

err_msg << rang::fg::red;
for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; ";
}
err_msg << rang::fg::reset;

// Setup error map.
auto it = error_maps.find(global);
if (it != error_maps.end()) {
it->second.insert({ node, err_msg.str() });
} else {
error_maps.insert({ global, { { node, err_msg.str() }}});
}
}

// Now we will construct the fully-annotated program to display to
// the user.
std::stringstream annotated_prog;

// First we output a header for the errors.
annotated_prog <<
rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:"
<< std::endl << std::endl << rang::style::reset;

// For each global function which contains errors, we will
// construct an annotated function.
for (auto pair : error_maps) {
auto global = pair.first;
auto err_map = pair.second;
auto func = module->Lookup(global);

// We output the name of the function before displaying
// the annotated program.
annotated_prog <<
rang::style::bold <<
"In `" << global->name_hint << "`: " <<
std::endl <<
rang::style::reset;

// We then call into the Relay printer to generate the program.
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
return it->second;
} else {
return std::string("");
}
});
}

auto msg = annotated_prog.str();

if (use_color) {
rang::setControlMode(rang::control::Auto);
}

// Finally we report the error, currently we do so to LOG(FATAL),
// it may be good to instead report it to std::cout.
LOG(FATAL) << annotated_prog.str() << std::endl;
}

void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
size_t index_to_insert = this->errors_.size();
this->errors_.push_back(err);
auto it = this->node_to_error_.find(node);
if (it != this->node_to_error_.end()) {
it->second.push_back(index_to_insert);
} else {
this->node_to_error_.insert({ node, { index_to_insert }});
}
this->node_to_gv_.insert({ node, global });
}

} // namespace relay
} // namespace tvm
17 changes: 17 additions & 0 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}

n->entry_func = GlobalVarNode::make("main");
return Module(n);
}

Expand Down Expand Up @@ -96,6 +98,21 @@ void ModuleNode::Update(const Module& mod) {
}
}

Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
auto mod = ModuleNode::make(global_funcs);
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
}
mod->Add(mod->entry_func, func);
return mod;
}

TVM_REGISTER_NODE_TYPE(ModuleNode);

TVM_REGISTER_API("relay._make.Module")
Expand Down
Loading

0 comments on commit 01ba738

Please sign in to comment.