Skip to content

Commit

Permalink
Add descriptive whole module errors to Relay
Browse files Browse the repository at this point in the history
This PR adds descriptive errors to Relay.
It is now possible to annotate and visulize
errors annotated at certain expressions
in a relay::Module.

Extract module changes from other branch

Fix me

Get first version working

Improve color support

Update error rendering, and comment about rang license

Remove uncessary type checking from InferType

Update more error reporting to use new machinery

Update more error reporting

Hacking

Negate checks

Remove debugging

Address MK's comments

Fix CPP lint

Disable linting for rang

Document new relay::Module APIs

Add another test case

Update include/tvm/relay/error.h

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/error.h

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/module.h

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/module.h

Co-Authored-By: jroesch <[email protected]>

Update src/relay/pass/type_infer.cc

Co-Authored-By: jroesch <[email protected]>

Update src/relay/pass/type_infer.cc

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/module.h

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/error.h

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/error.h

Co-Authored-By: jroesch <[email protected]>

Update src/relay/pass/type_infer.cc

Co-Authored-By: jroesch <[email protected]>

Update src/relay/pass/type_infer.cc

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/module.h

Co-Authored-By: jroesch <[email protected]>

Update src/relay/ir/module.cc

Co-Authored-By: jroesch <[email protected]>

Update src/relay/pass/type_infer.cc

Co-Authored-By: jroesch <[email protected]>

Update src/relay/pass/type_infer.cc

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/error.h

Co-Authored-By: jroesch <[email protected]>

Update include/tvm/relay/error.h

Co-Authored-By: jroesch <[email protected]>

Apply suggestions from code review

Co-Authored-By: jroesch <[email protected]>

Fix comment in type_infer.cc

Remove copy of relay::Module

Fix lambda capture

Refactor to remove error reporter from module

Update include/tvm/relay/module.h

Co-Authored-By: jroesch <[email protected]>

Fix small issue

Fix doc comment

Fix doc comment typo

Fix build error

Fix implicit ctor not working

Fix macro

Remove duplicate macro

Try and fix issue with stringstream

Fix old type inference invariant

WIP repairing from rebase

Rebase fixup

Another fix
  • Loading branch information
jroesch committed Jan 16, 2019
1 parent 6783d37 commit 7a370ea
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 46 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "dlpack"]
path = 3rdparty/dlpack
url = https://github.com/dmlc/dlpack
[submodule "3rdparty/rang"]
path = 3rdparty/rang
url = https://github.com/agauniyal/rang
1 change: 1 addition & 0 deletions 3rdparty/rang
Submodule rang added at cabe04
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
include_directories("include")
include_directories("3rdparty/dlpack/include")
include_directories("3rdparty/dmlc-core/include")
include_directories("3rdparty/rang/include")
include_directories("3rdparty/compiler-rt")

# initial variables
Expand Down
29 changes: 28 additions & 1 deletion include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,40 @@
#define TVM_RELAY_ERROR_H_

#include <string>
#include <vector>
#include <sstream>
#include "./base.h"

namespace tvm {
namespace relay {

/*! \brief A wrapper around std::stringstream.
*
* This is designed to avoid platform specific
* issues compiling and using std::stringstream
* for error reporting.
*/
struct StringStream {
std::stringstream ss;

template<typename T>
StringStream& operator<<(const T& t) {
ss << t;
return *this;
}

std::string str() const {
return ss.str();
}
};

#define RELAY_ERROR(msg) (StringStream() << msg)

struct Error : public dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
Span sp;
explicit Error(const std::string &msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const StringStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
};

struct InternalError : public Error {
Expand Down
104 changes: 104 additions & 0 deletions include/tvm/relay/error_reporter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*!
* Copyright (c) 2018 by Contributors
* \file error_reporter.h
* \brief The set of errors raised by Relay.
*/
#ifndef TVM_RELAY_ERROR_REPORTER_H_
#define TVM_RELAY_ERROR_REPORTER_H_

#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <string>
#include <sstream>
#include <vector>

namespace tvm {
namespace relay {

/*! \brief An abstraction around how errors are stored and reported.
* Designed to be opaque to users, so we can support a robust and simpler
* error reporting mode, as well as a more complex mode.
*
* The first mode is the most accurate: we report a Relay error at a specific
* Span, and then render the error message directly against a textual representation
* of the program, highlighting the exact lines in which it occurs. This mode is not
* implemented in this PR and will not work.
*
* The second mode is a general-purpose mode, which attempts to annotate the program's
* textual format with errors.
*
* The final mode represents the old mode, if we report an error that has no span or
* expression, we will default to throwing an exception with a textual representation
* of the error and no indication of where it occured in the original program.
*
* The latter mode is not ideal, and the goal of the new error reporting machinery is
* to avoid ever reporting errors in this style.
*/
class ErrorReporter {
public:
ErrorReporter() : errors_(), node_to_error_() {}

/*! \brief Report a tvm::relay::Error.
*
* This API is useful for reporting spanned errors.
*
* \param err The error to report.
*/
void Report(const Error& err) {
if (!err.sp.defined()) {
throw err;
}

this->errors_.push_back(err);
}

/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
}

/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error to report.
*/
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);

/*! \brief Render all reported errors and exit the program.
*
* This function should be used after executing a pass to render reported errors.
*
* It will build an error message from the set of errors, depending on the error
* reporting strategy.
*
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
[[noreturn]] void RenderErrors(const Module& module, bool use_color = true);

private:
std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
};

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_ERROR_REPORTER_H_
26 changes: 26 additions & 0 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;

/*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func;

ModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
Expand Down Expand Up @@ -111,6 +115,27 @@ class ModuleNode : public RelayNode {
*/
void Update(const Module& other);

/*!
* \brief Get the entry point of the module.
*
* \returns The entry point function, (i.e. main).
*/
Expr EntryPoint();

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map as
* well.
*
* \param expr The expression to set as the entry point to the module.
* \param global_funcs The global function map.
*
* \returns A module with expr set as the entry point.
*/
static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});

static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);

Expand All @@ -132,6 +157,7 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode;
};


} // namespace relay
} // namespace tvm

Expand Down
126 changes: 126 additions & 0 deletions src/relay/ir/error_reporter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*!
* Copyright (c) 2018 by Contributors
* \file error_reporter.h
* \brief The set of errors raised by Relay.
*/

#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/error_reporter.h>
#include <string>
#include <vector>
#include <rang.hpp>

namespace tvm {
namespace relay {

template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;

[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) {
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
}

NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;

// Set control mode in order to produce colors;
if (use_color) {
rang::setControlMode(rang::control::Force);
}

for (auto pair : this->node_to_gv_) {
auto node = pair.first;
auto global = Downcast<GlobalVar>(pair.second);

auto has_errs = this->node_to_error_.find(node);

CHECK(has_errs != this->node_to_error_.end());

const auto& error_indicies = has_errs->second;

std::stringstream err_msg;

err_msg << rang::fg::red;
for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; ";
}
err_msg << rang::fg::reset;

// Setup error map.
auto it = error_maps.find(global);
if (it != error_maps.end()) {
it->second.insert({ node, err_msg.str() });
} else {
error_maps.insert({ global, { { node, err_msg.str() }}});
}
}

// Now we will construct the fully-annotated program to display to
// the user.
std::stringstream annotated_prog;

// First we output a header for the errors.
annotated_prog <<
rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:"
<< std::endl << std::endl << rang::style::reset;

// For each global function which contains errors, we will
// construct an annotated function.
for (auto pair : error_maps) {
auto global = pair.first;
auto err_map = pair.second;
auto func = module->Lookup(global);

// We output the name of the function before displaying
// the annotated program.
annotated_prog <<
rang::style::bold <<
"In `" << global->name_hint << "`: " <<
std::endl <<
rang::style::reset;

// We then call into the Relay printer to generate the program.
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
return it->second;
} else {
return std::string("");
}
});
}

auto msg = annotated_prog.str();

if (use_color) {
rang::setControlMode(rang::control::Auto);
}

// Finally we report the error, currently we do so to LOG(FATAL),
// it may be good to instead report it to std::cout.
LOG(FATAL) << annotated_prog.str() << std::endl;

exit(1);
}

void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
size_t index_to_insert = this->errors_.size();
this->errors_.push_back(err);
auto it = this->node_to_error_.find(node);
if (it != this->node_to_error_.end()) {
it->second.push_back(index_to_insert);
} else {
this->node_to_error_.insert({ node, { index_to_insert }});
}
this->node_to_gv_.insert({ node, global });
}

} // namespace relay
} // namespace tvm
21 changes: 21 additions & 0 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}

n->entry_func = GlobalVarNode::make("main");
return Module(n);
}

Expand Down Expand Up @@ -96,6 +98,25 @@ void ModuleNode::Update(const Module& mod) {
}
}

Expr ModuleNode::EntryPoint() {
return this->Lookup(this->entry_func);
}

Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
auto mod = ModuleNode::make({});
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
}
mod->Add(mod->entry_func, func);
return mod;
}

TVM_REGISTER_NODE_TYPE(ModuleNode);

TVM_REGISTER_API("relay._make.Module")
Expand Down
Loading

0 comments on commit 7a370ea

Please sign in to comment.