Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add generic & informative Relay error reporting #2408

Merged
merged 13 commits into from
Jan 25, 2019
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "dlpack"]
path = 3rdparty/dlpack
url = https://github.com/dmlc/dlpack
[submodule "3rdparty/rang"]
path = 3rdparty/rang
url = https://github.com/agauniyal/rang
1 change: 1 addition & 0 deletions 3rdparty/rang
Submodule rang added at cabe04
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
include_directories("include")
include_directories("3rdparty/dlpack/include")
include_directories("3rdparty/dmlc-core/include")
include_directories("3rdparty/rang/include")
include_directories("3rdparty/compiler-rt")

# initial variables
Expand Down
127 changes: 118 additions & 9 deletions include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,134 @@
#define TVM_RELAY_ERROR_H_

#include <string>
#include <vector>
#include <sstream>
#include "./base.h"
#include "./expr.h"
#include "./module.h"

namespace tvm {
namespace relay {

struct Error : public dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};
#define RELAY_ERROR(msg) (RelayErrorStream() << msg)

// Forward declaratio for RelayErrorStream.
struct Error;

/*! \brief A wrapper around std::stringstream.
*
* This is designed to avoid platform specific
* issues compiling and using std::stringstream
* for error reporting.
*/
struct RelayErrorStream {
std::stringstream ss;

template<typename T>
RelayErrorStream& operator<<(const T& t) {
ss << t;
return *this;
}

struct InternalError : public Error {
explicit InternalError(const std::string &msg) : Error(msg) {}
std::string str() const {
return ss.str();
}

void Raise() const;
};

struct FatalTypeError : public Error {
explicit FatalTypeError(const std::string &s) : Error(s) {}
struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
};

struct TypecheckerError : public Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
/*! \brief An abstraction around how errors are stored and reported.
* Designed to be opaque to users, so we can support a robust and simpler
* error reporting mode, as well as a more complex mode.
*
* The first mode is the most accurate: we report a Relay error at a specific
* Span, and then render the error message directly against a textual representation
* of the program, highlighting the exact lines in which it occurs. This mode is not
* implemented in this PR and will not work.
*
* The second mode is a general-purpose mode, which attempts to annotate the program's
* textual format with errors.
*
* The final mode represents the old mode, if we report an error that has no span or
* expression, we will default to throwing an exception with a textual representation
* of the error and no indication of where it occured in the original program.
*
* The latter mode is not ideal, and the goal of the new error reporting machinery is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misread this doc comment initially, as only describing two modes. It should be "The final mode" here. My bad.

* to avoid ever reporting errors in this style.
*/
class ErrorReporter {
public:
ErrorReporter() : errors_(), node_to_error_() {}
jroesch marked this conversation as resolved.
Show resolved Hide resolved

/*! \brief Report a tvm::relay::Error.
*
* This API is useful for reporting spanned errors.
*
* \param err The error to report.
*/
void Report(const Error& err) {
if (!err.sp.defined()) {
throw err;
}

this->errors_.push_back(err);
}

/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
}

/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error to report.
*/
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);

/*! \brief Render all reported errors and exit the program.
*
* This function should be used after executing a pass to render reported errors.
*
* It will build an error message from the set of errors, depending on the error
* reporting strategy.
*
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
void RenderErrors(const Module& module, bool use_color = true);

inline bool AnyErrors() {
return errors_.size() != 0;
}

private:
std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
};

} // namespace relay
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;

/*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func;
jroesch marked this conversation as resolved.
Show resolved Hide resolved

ModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
Expand Down Expand Up @@ -111,6 +115,20 @@ class ModuleNode : public RelayNode {
*/
void Update(const Module& other);

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map as
* well.
*
* \param expr The expression to set as the entry point to the module.
* \param global_funcs The global function map.
*
* \returns A module with expr set as the entry point.
*/
static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});

static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);

Expand All @@ -132,6 +150,7 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode;
};


} // namespace relay
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <string>

Expand Down
6 changes: 6 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;

/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;

// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}

Expand Down
128 changes: 128 additions & 0 deletions src/relay/ir/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*!
* Copyright (c) 2018 by Contributors
* \file error_reporter.h
* \brief The set of errors raised by Relay.
*/

#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/error.h>
#include <string>
#include <vector>
#include <rang.hpp>

namespace tvm {
namespace relay {

void RelayErrorStream::Raise() const {
throw Error(*this);
}

template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;

void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) {
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
}

NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;

// Set control mode in order to produce colors;
if (use_color) {
rang::setControlMode(rang::control::Force);
}

for (auto pair : this->node_to_gv_) {
auto node = pair.first;
auto global = Downcast<GlobalVar>(pair.second);

auto has_errs = this->node_to_error_.find(node);

CHECK(has_errs != this->node_to_error_.end());

const auto& error_indicies = has_errs->second;

std::stringstream err_msg;

err_msg << rang::fg::red;
for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; ";
}
err_msg << rang::fg::reset;

// Setup error map.
auto it = error_maps.find(global);
if (it != error_maps.end()) {
it->second.insert({ node, err_msg.str() });
} else {
error_maps.insert({ global, { { node, err_msg.str() }}});
}
}

// Now we will construct the fully-annotated program to display to
// the user.
std::stringstream annotated_prog;

// First we output a header for the errors.
annotated_prog <<
rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:"
<< std::endl << std::endl << rang::style::reset;

// For each global function which contains errors, we will
// construct an annotated function.
for (auto pair : error_maps) {
auto global = pair.first;
auto err_map = pair.second;
auto func = module->Lookup(global);

// We output the name of the function before displaying
// the annotated program.
annotated_prog <<
rang::style::bold <<
"In `" << global->name_hint << "`: " <<
std::endl <<
rang::style::reset;

// We then call into the Relay printer to generate the program.
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
return it->second;
} else {
return std::string("");
}
});
}

auto msg = annotated_prog.str();

if (use_color) {
rang::setControlMode(rang::control::Auto);
}

// Finally we report the error, currently we do so to LOG(FATAL),
// it may be good to instead report it to std::cout.
LOG(FATAL) << annotated_prog.str() << std::endl;
}

void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
size_t index_to_insert = this->errors_.size();
this->errors_.push_back(err);
auto it = this->node_to_error_.find(node);
if (it != this->node_to_error_.end()) {
it->second.push_back(index_to_insert);
} else {
this->node_to_error_.insert({ node, { index_to_insert }});
}
this->node_to_gv_.insert({ node, global });
}

} // namespace relay
} // namespace tvm
17 changes: 17 additions & 0 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}

n->entry_func = GlobalVarNode::make("main");
return Module(n);
}

Expand Down Expand Up @@ -96,6 +98,21 @@ void ModuleNode::Update(const Module& mod) {
}
}

Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
auto mod = ModuleNode::make(global_funcs);
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
}
mod->Add(mod->entry_func, func);
return mod;
}

TVM_REGISTER_NODE_TYPE(ModuleNode);

TVM_REGISTER_API("relay._make.Module")
Expand Down
Loading