Skip to content

Commit

Permalink
feat: show pytorch code of unsupported operators
Browse files Browse the repository at this point in the history
Signed-off-by: lamhoangtung <[email protected]>
  • Loading branch information
lamhoangtung committed Jul 1, 2021
1 parent bdaacf1 commit 2ee2a84
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
11 changes: 7 additions & 4 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
return engine;
}

std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
std::set<std::string> unsupported_ops;
std::set<std::pair<std::string, std::string>> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
std::set<std::pair<std::string, std::string>> unsupported_ops;
for (const auto n : b->nodes()) {
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
auto schema = n->maybeSchema();
Expand All @@ -438,7 +438,9 @@ std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
"Unable to get schema for Node " << util::node_info(n) << " (conversion.VerifyCoverterSupportForBlock)");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
std::string pytorch_code = trtorch::core::util::GetPyTorchSourceCode(n);
auto current_unsupported_op = std::make_pair(ss.str(), pytorch_code);
unsupported_ops.insert(current_unsupported_op);
}
for (const auto sub_b : n->blocks()) {
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
Expand Down Expand Up @@ -480,7 +482,8 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:"
<< std::endl;
for (auto s : unsupported_ops) {
unsupported_msg << " - " << s << std::endl;
unsupported_msg << " - " << s.first << std::endl;
unsupported_msg << " Related PyTorch code:" << std::endl << s.second << std::endl;
}
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
<< std::endl;
Expand Down
5 changes: 5 additions & 0 deletions core/util/jit_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ inline c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::sha
return c10::FunctionSchema(method_name, method_name, args, returns);
}

inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
std::string source_code = n->sourceRange().str();
return source_code;
}

} // namespace util
} // namespace core
} // namespace trtorch

0 comments on commit 2ee2a84

Please sign in to comment.