forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay] Add generic & informative Relay error reporting (apache#2408)
- Loading branch information
1 parent
5e2f354
commit c78a067
Showing
14 changed files
with
537 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file error_reporter.h | ||
* \brief The set of errors raised by Relay. | ||
*/ | ||
|
||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/module.h> | ||
#include <tvm/relay/error.h> | ||
#include <string> | ||
#include <vector> | ||
#include <rang.hpp> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
void RelayErrorStream::Raise() const { | ||
throw Error(*this); | ||
} | ||
|
||
template<typename T, typename U> | ||
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>; | ||
|
||
void ErrorReporter::RenderErrors(const Module& module, bool use_color) { | ||
// First we pick an error reporting strategy for each error. | ||
// TODO(@jroesch): Spanned errors are currently not supported. | ||
for (auto err : this->errors_) { | ||
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; | ||
} | ||
|
||
NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps; | ||
|
||
// Set control mode in order to produce colors; | ||
if (use_color) { | ||
rang::setControlMode(rang::control::Force); | ||
} | ||
|
||
for (auto pair : this->node_to_gv_) { | ||
auto node = pair.first; | ||
auto global = Downcast<GlobalVar>(pair.second); | ||
|
||
auto has_errs = this->node_to_error_.find(node); | ||
|
||
CHECK(has_errs != this->node_to_error_.end()); | ||
|
||
const auto& error_indicies = has_errs->second; | ||
|
||
std::stringstream err_msg; | ||
|
||
err_msg << rang::fg::red; | ||
for (auto index : error_indicies) { | ||
err_msg << this->errors_[index].what() << "; "; | ||
} | ||
err_msg << rang::fg::reset; | ||
|
||
// Setup error map. | ||
auto it = error_maps.find(global); | ||
if (it != error_maps.end()) { | ||
it->second.insert({ node, err_msg.str() }); | ||
} else { | ||
error_maps.insert({ global, { { node, err_msg.str() }}}); | ||
} | ||
} | ||
|
||
// Now we will construct the fully-annotated program to display to | ||
// the user. | ||
std::stringstream annotated_prog; | ||
|
||
// First we output a header for the errors. | ||
annotated_prog << | ||
rang::style::bold << std::endl << | ||
"Error(s) have occurred. We have annotated the program with them:" | ||
<< std::endl << std::endl << rang::style::reset; | ||
|
||
// For each global function which contains errors, we will | ||
// construct an annotated function. | ||
for (auto pair : error_maps) { | ||
auto global = pair.first; | ||
auto err_map = pair.second; | ||
auto func = module->Lookup(global); | ||
|
||
// We output the name of the function before displaying | ||
// the annotated program. | ||
annotated_prog << | ||
rang::style::bold << | ||
"In `" << global->name_hint << "`: " << | ||
std::endl << | ||
rang::style::reset; | ||
|
||
// We then call into the Relay printer to generate the program. | ||
// | ||
// The annotation callback will annotate the error messages | ||
// contained in the map. | ||
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { | ||
auto it = err_map.find(expr); | ||
if (it != err_map.end()) { | ||
return it->second; | ||
} else { | ||
return std::string(""); | ||
} | ||
}); | ||
} | ||
|
||
auto msg = annotated_prog.str(); | ||
|
||
if (use_color) { | ||
rang::setControlMode(rang::control::Auto); | ||
} | ||
|
||
// Finally we report the error, currently we do so to LOG(FATAL), | ||
// it may be good to instead report it to std::cout. | ||
LOG(FATAL) << annotated_prog.str() << std::endl; | ||
} | ||
|
||
void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { | ||
size_t index_to_insert = this->errors_.size(); | ||
this->errors_.push_back(err); | ||
auto it = this->node_to_error_.find(node); | ||
if (it != this->node_to_error_.end()) { | ||
it->second.push_back(index_to_insert); | ||
} else { | ||
this->node_to_error_.insert({ node, { index_to_insert }}); | ||
} | ||
this->node_to_gv_.insert({ node, global }); | ||
} | ||
|
||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.