-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add descriptive whole module errors to Relay
This PR adds descriptive errors to Relay. It is now possible to annotate and visulize errors annotated at certain expressions in a relay::Module. Extract module changes from other branch Fix me Get first version working Improve color support Update error rendering, and comment about rang license Remove uncessary type checking from InferType Update more error reporting to use new machinery Update more error reporting Hacking Negate checks Remove debugging Address MK's comments Fix CPP lint Disable linting for rang Document new relay::Module APIs Add another test case Update include/tvm/relay/error.h Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/error.h Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/module.h Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/module.h Co-Authored-By: jroesch <[email protected]> Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch <[email protected]> Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/module.h Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/error.h Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/error.h Co-Authored-By: jroesch <[email protected]> Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch <[email protected]> Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/module.h Co-Authored-By: jroesch <[email protected]> Update src/relay/ir/module.cc Co-Authored-By: jroesch <[email protected]> Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch <[email protected]> Update src/relay/pass/type_infer.cc Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/error.h Co-Authored-By: jroesch <[email protected]> Update include/tvm/relay/error.h Co-Authored-By: jroesch <[email protected]> Apply suggestions from code review Co-Authored-By: jroesch <[email protected]> Fix comment in type_infer.cc Remove copy of relay::Module Fix lambda capture Refactor to remove error reporter from module Update include/tvm/relay/module.h Co-Authored-By: jroesch <[email protected]> Fix small issue Fix doc comment Fix doc comment typo Fix build error Fix implicit ctor not working Fix macro Remove duplicate macro Try and fix issue with stringstream Fix old type inference invariant WIP repairing from rebase Rebase fixup Another fix
- Loading branch information
Showing
11 changed files
with
448 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file error_reporter.h | ||
* \brief The set of errors raised by Relay. | ||
*/ | ||
#ifndef TVM_RELAY_ERROR_REPORTER_H_ | ||
#define TVM_RELAY_ERROR_REPORTER_H_ | ||
|
||
#include <tvm/relay/error.h> | ||
#include <tvm/relay/expr.h> | ||
#include <string> | ||
#include <sstream> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! \brief An abstraction around how errors are stored and reported. | ||
* Designed to be opaque to users, so we can support a robust and simpler | ||
* error reporting mode, as well as a more complex mode. | ||
* | ||
* The first mode is the most accurate: we report a Relay error at a specific | ||
* Span, and then render the error message directly against a textual representation | ||
* of the program, highlighting the exact lines in which it occurs. This mode is not | ||
* implemented in this PR and will not work. | ||
* | ||
* The second mode is a general-purpose mode, which attempts to annotate the program's | ||
* textual format with errors. | ||
* | ||
* The final mode represents the old mode, if we report an error that has no span or | ||
* expression, we will default to throwing an exception with a textual representation | ||
* of the error and no indication of where it occured in the original program. | ||
* | ||
* The latter mode is not ideal, and the goal of the new error reporting machinery is | ||
* to avoid ever reporting errors in this style. | ||
*/ | ||
class ErrorReporter { | ||
public: | ||
ErrorReporter() : errors_(), node_to_error_() {} | ||
|
||
/*! \brief Report a tvm::relay::Error. | ||
* | ||
* This API is useful for reporting spanned errors. | ||
* | ||
* \param err The error to report. | ||
*/ | ||
void Report(const Error& err) { | ||
if (!err.sp.defined()) { | ||
throw err; | ||
} | ||
|
||
this->errors_.push_back(err); | ||
} | ||
|
||
/*! \brief Report an error against a program, using the full program | ||
* error reporting strategy. | ||
* | ||
* This error reporting method requires the global function in which | ||
* to report an error, the expression to report the error on, | ||
* and the error object. | ||
* | ||
* \param global The global function in which the expression is contained. | ||
* \param node The expression or type to report the error at. | ||
* \param err The error message to report. | ||
*/ | ||
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { | ||
this->ReportAt(global, node, Error(err)); | ||
} | ||
|
||
/*! \brief Report an error against a program, using the full program | ||
* error reporting strategy. | ||
* | ||
* This error reporting method requires the global function in which | ||
* to report an error, the expression to report the error on, | ||
* and the error object. | ||
* | ||
* \param global The global function in which the expression is contained. | ||
* \param node The expression or type to report the error at. | ||
* \param err The error to report. | ||
*/ | ||
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); | ||
|
||
/*! \brief Render all reported errors and exit the program. | ||
* | ||
* This function should be used after executing a pass to render reported errors. | ||
* | ||
* It will build an error message from the set of errors, depending on the error | ||
* reporting strategy. | ||
* | ||
* \param module The module to report errors on. | ||
* \param use_color Controls whether to colorize the output. | ||
*/ | ||
[[noreturn]] void RenderErrors(const Module& module, bool use_color = true); | ||
|
||
private: | ||
std::vector<Error> errors_; | ||
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_; | ||
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_; | ||
}; | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAY_ERROR_REPORTER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file error_reporter.h | ||
* \brief The set of errors raised by Relay. | ||
*/ | ||
|
||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/module.h> | ||
#include <tvm/relay/error_reporter.h> | ||
#include <string> | ||
#include <vector> | ||
#include <rang.hpp> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
template<typename T, typename U> | ||
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>; | ||
|
||
[[noreturn]] void ErrorReporter::RenderErrors(const Module& module, bool use_color) { | ||
// First we pick an error reporting strategy for each error. | ||
// TODO(@jroesch): Spanned errors are currently not supported. | ||
for (auto err : this->errors_) { | ||
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; | ||
} | ||
|
||
NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps; | ||
|
||
// Set control mode in order to produce colors; | ||
if (use_color) { | ||
rang::setControlMode(rang::control::Force); | ||
} | ||
|
||
for (auto pair : this->node_to_gv_) { | ||
auto node = pair.first; | ||
auto global = Downcast<GlobalVar>(pair.second); | ||
|
||
auto has_errs = this->node_to_error_.find(node); | ||
|
||
CHECK(has_errs != this->node_to_error_.end()); | ||
|
||
const auto& error_indicies = has_errs->second; | ||
|
||
std::stringstream err_msg; | ||
|
||
err_msg << rang::fg::red; | ||
for (auto index : error_indicies) { | ||
err_msg << this->errors_[index].what() << "; "; | ||
} | ||
err_msg << rang::fg::reset; | ||
|
||
// Setup error map. | ||
auto it = error_maps.find(global); | ||
if (it != error_maps.end()) { | ||
it->second.insert({ node, err_msg.str() }); | ||
} else { | ||
error_maps.insert({ global, { { node, err_msg.str() }}}); | ||
} | ||
} | ||
|
||
// Now we will construct the fully-annotated program to display to | ||
// the user. | ||
std::stringstream annotated_prog; | ||
|
||
// First we output a header for the errors. | ||
annotated_prog << | ||
rang::style::bold << std::endl << | ||
"Error(s) have occurred. We have annotated the program with them:" | ||
<< std::endl << std::endl << rang::style::reset; | ||
|
||
// For each global function which contains errors, we will | ||
// construct an annotated function. | ||
for (auto pair : error_maps) { | ||
auto global = pair.first; | ||
auto err_map = pair.second; | ||
auto func = module->Lookup(global); | ||
|
||
// We output the name of the function before displaying | ||
// the annotated program. | ||
annotated_prog << | ||
rang::style::bold << | ||
"In `" << global->name_hint << "`: " << | ||
std::endl << | ||
rang::style::reset; | ||
|
||
// We then call into the Relay printer to generate the program. | ||
// | ||
// The annotation callback will annotate the error messages | ||
// contained in the map. | ||
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { | ||
auto it = err_map.find(expr); | ||
if (it != err_map.end()) { | ||
return it->second; | ||
} else { | ||
return std::string(""); | ||
} | ||
}); | ||
} | ||
|
||
auto msg = annotated_prog.str(); | ||
|
||
if (use_color) { | ||
rang::setControlMode(rang::control::Auto); | ||
} | ||
|
||
// Finally we report the error, currently we do so to LOG(FATAL), | ||
// it may be good to instead report it to std::cout. | ||
LOG(FATAL) << annotated_prog.str() << std::endl; | ||
|
||
exit(1); | ||
} | ||
|
||
void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { | ||
size_t index_to_insert = this->errors_.size(); | ||
this->errors_.push_back(err); | ||
auto it = this->node_to_error_.find(node); | ||
if (it != this->node_to_error_.end()) { | ||
it->second.push_back(index_to_insert); | ||
} else { | ||
this->node_to_error_.insert({ node, { index_to_insert }}); | ||
} | ||
this->node_to_gv_.insert({ node, global }); | ||
} | ||
|
||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.