From 28ee445db142568d7e8df578f80f1d1e3c9cb3d3 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 23 Mar 2020 23:57:52 -0700 Subject: [PATCH] feat(CheckMethodOperatorSupport): A new API which will check the graph to see if all operators are supported. Addresses #26. Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/compiler.cpp | 23 +++++++++- core/compiler.h | 2 + core/conversion/conversion.cpp | 45 +++++++++++++++---- core/conversion/conversion.h | 4 +- core/conversion/conversion_blacklist.cpp | 2 +- .../converters/NodeConverterRegistry.cpp | 15 +++---- .../evaluators/NodeEvaluatorRegistry.cpp | 4 -- cpp/api/include/trtorch/trtorch.h | 15 ++++++- cpp/api/src/trtorch.cpp | 9 +++- cpp/trtorchexec/main.cpp | 6 ++- 10 files changed, 95 insertions(+), 30 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 135faba03f..33e2f04bff 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -64,6 +64,28 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptrinputs(), params); + LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n"); + + // Is this necessary? + lowering::LowerBlock(g->block()); + + return conversion::VerifyConverterSupportForBlock(g->block()); +} + std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, conversion::ExtraInfo cfg) { @@ -87,7 +109,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, return std::move(engine); } -// TODO: Consider if there is a better way to deal with input size torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, conversion::ExtraInfo cfg) { // TODO: Should be doing a functional transform but need PR #31978 diff --git a/core/compiler.h b/core/compiler.h index 61daea2f5f..17ab1719db 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -6,9 +6,11 @@ namespace trtorch { namespace core { +bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name); std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, conversion::ExtraInfo cfg); + torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg); } // namespace core diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index b8e95c96d9..6ce3571012 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -11,7 +11,7 @@ namespace core { namespace conversion { // Defined in core/conversion/conversion_blacklist.cpp -bool isNodeConversionBlacklisted(torch::jit::Node* n); +bool isNodeConversionBlacklisted(const torch::jit::Node* n); bool OpSupported(const torch::jit::Node* n) { bool evalable = evaluators::shouldEvalAtConversionTime(n); @@ -19,7 +19,7 @@ bool OpSupported(const torch::jit::Node* n) { return evalable || convertable; } -c10::optional EvaluateNode(ConversionCtx* ctx, torch::jit::Node* n, int level=0, int limit=10) { +c10::optional EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level=0, int limit=10) { // Check to see if you can just go through and eval all of these AOT (saves the recursion) // Also probably a better way to deal with the two error cases; TRTORCH_CHECK(level < limit, "Failed to evaluate node: " << *n \ @@ -55,7 +55,7 @@ c10::optional EvaluateNode(ConversionCtx* ctx, torch::jit::N return eval; } -bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) { +bool AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { LOG_INFO(ctx->logger, "Adding Layer " << util::node_info(n) << " (ctx.AddLayer)"); converters::args node_args; @@ -114,11 +114,11 @@ bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) { } bool AddInputs(ConversionCtx* ctx, - at::ArrayRef inputs, + at::ArrayRef inputs, std::vector& input_dims) { auto type_lut = torch::jit::script::string_to_type_lut(); - std::vector input_tensors; + std::vector input_tensors; for (auto in : inputs) { // Disregarding inputs that are not tensors // @@ -163,7 +163,7 @@ bool AddInputs(ConversionCtx* ctx, return true; } -bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { +bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { for (auto out : outputs) { ctx->net->markOutput(*(ctx->value_tensor_map[out])); LOG_INFO(ctx->logger, @@ -178,7 +178,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) { } } -void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { +void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { LOG_INFO(ctx->logger, "Converting Block"); auto inputs = b->inputs(); @@ -188,7 +188,6 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu auto nodes = b->nodes(); for (const auto n : nodes) { - bool to_eval = evaluators::shouldEvalAtConversionTime(n); bool blacklisted = isNodeConversionBlacklisted(n); if (!to_eval && !blacklisted) { @@ -220,13 +219,41 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu // a serialized TensorRT engine that can be deserialized and run // Probably should consolidate these two functions -std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { +std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { ConversionCtx ctx(build_info.engine_settings); ConvertBlockToNetDef(&ctx, b, build_info, static_params); std::string engine = ctx.SerializeEngine(); return engine; } +bool VerifyConverterSupportForBlock(const torch::jit::Block* b) { + bool supported = true; + std::set unsupported_ops; + for (const auto n : b->nodes()) { + if (!OpSupported(n)) { + auto schema = n->maybeSchema(); + TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \ + << " (conversion.AddLayer)"); + std::stringstream ss; + ss << *schema; + unsupported_ops.insert(ss.str()); + supported = false; + } + } + + if (!supported) { + std::stringstream unsupported_msg; + unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl; + for (auto s : unsupported_ops) { + unsupported_msg << " - " << s << std::endl; + } + unsupported_msg << "You can either implement converters for these ops in your application or file a bug" << std::endl; + unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl; + LOG_ERROR(unsupported_msg.str()); + } + return supported; +} + } // namespace conversion } // namespace core } // namespace trtorch diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index 4a696a3d41..f053eced84 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -43,10 +43,12 @@ GraphParams get_named_params(c10::ArrayRef inputs, std::vect // Converts a already lowered block (blocks with no sub blocks) to // a serialized TensorRT engine that can be deserialized and run -std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params); +std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params); bool OpSupported(const torch::jit::Node* n); +bool VerifyConverterSupportForBlock(const torch::jit::Block* b); + } // namespace conversion } // namespace core } // namespace trtorch diff --git a/core/conversion/conversion_blacklist.cpp b/core/conversion/conversion_blacklist.cpp index 277978f188..c20ccc7db7 100644 --- a/core/conversion/conversion_blacklist.cpp +++ b/core/conversion/conversion_blacklist.cpp @@ -24,7 +24,7 @@ const std::unordered_set& get_non_convertable_nodes() { return nonconvertable_nodes; } -bool isNodeConversionBlacklisted(torch::jit::Node* n) { +bool isNodeConversionBlacklisted(const torch::jit::Node* n) { auto kind = n->kind(); auto convertableIt = get_non_convertable_nodes().find(kind.toQualString()); if (convertableIt == get_non_convertable_nodes().end()) { diff --git a/core/conversion/converters/NodeConverterRegistry.cpp b/core/conversion/converters/NodeConverterRegistry.cpp index f124ed6d52..15335a1a73 100644 --- a/core/conversion/converters/NodeConverterRegistry.cpp +++ b/core/conversion/converters/NodeConverterRegistry.cpp @@ -46,10 +46,6 @@ using ConverterLUT = std::unordered_map; class NodeConverterRegistry { public: bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) { - // NOTE: This is useful for people developing extentions to the conversion registry as is - // If you are working on the core conversion library and the conversion registry - // itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the - // registration of core converters during init, otherwise the messages will be masked LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature)); auto sym = torch::jit::Symbol::fromQualString(signature->name()); converter_lut_[sym] = std::move(converter); @@ -70,13 +66,12 @@ class NodeConverterRegistry { bool Convertable(const torch::jit::Node* n) { auto schema = n->maybeSchema(); if (schema) { - auto converter = GetConverter(schema); - if (converter) { - return true; + auto sym = torch::jit::Symbol::fromQualString(schema->name()); + auto iter = converter_lut_.find(sym); + if (iter == converter_lut_.end()) { + return false; } else { - LOG_DEBUG("Node has no registered converter: " << util::node_info(n) \ - << " (NodeConverterRegistry.Convertable)\nSchema: " << *schema); - return false; + return true; } } else { LOG_DEBUG("Unable to get schema for Node " << util::node_info(n) \ diff --git a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp index 35aab687d6..a810c44584 100644 --- a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp +++ b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp @@ -20,10 +20,6 @@ using EvaluatorLUT = std::unordered_map; class NodeEvaluatorRegistry { public: void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) { - // NOTE: This is useful for people developing extentions to the conversion registry as is - // If you are working on the core conversion library and the conversion registry - // itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the - // registration of core converters during init, otherwise the messages will be masked LOG_DEBUG("Registering evaluator for " << node_kind.toQualString()); evaluator_lut_[node_kind] = std::move(evaluator); } diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 1044bb31f4..8275b4b0a2 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -215,6 +215,19 @@ TRTORCH_API std::string get_build_info(); */ TRTORCH_API void dump_build_info(); +/** + * @brief Check to see if a module is fully supported by the compiler + * + * @param module: torch::jit::script::Module - Existing TorchScript module + * @param method_name: std::string - Name of method to compile + * + * Takes a module and a method name and checks if the method graph contains purely + * convertable operators + * + * Will print out a list of unsupported operators if the graph is unsupported + */ +TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name); + /** * @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT * @@ -239,5 +252,5 @@ TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Mo * and will convert selected method to a serialized TensorRT engine which can be run with * TensorRT */ -TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo info); +TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info); } // namespace trtorch diff --git a/cpp/api/src/trtorch.cpp b/cpp/api/src/trtorch.cpp index 1eef399845..562f4faa9f 100644 --- a/cpp/api/src/trtorch.cpp +++ b/cpp/api/src/trtorch.cpp @@ -10,12 +10,17 @@ namespace trtorch { // Defined in extra_info.cpp core::conversion::ExtraInfo to_internal_extra_info(ExtraInfo external); -std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, +bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, + std::string method_name) { + return core::CheckMethodOperatorSupport(module, method_name); +} + +std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info) { LOG_DEBUG(get_build_info()); // Want to export a much simpler (non TRT header dependent) API so doing the // type conversion here - return std::move(core::ConvertGraphToTRTEngine(mod, method_name, to_internal_extra_info(info))); + return std::move(core::ConvertGraphToTRTEngine(module, method_name, to_internal_extra_info(info))); } torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info) { diff --git a/cpp/trtorchexec/main.cpp b/cpp/trtorchexec/main.cpp index 157448678a..dec456e913 100644 --- a/cpp/trtorchexec/main.cpp +++ b/cpp/trtorchexec/main.cpp @@ -55,6 +55,11 @@ int main(int argc, const char* argv[]) { dims.push_back(v); } + if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) { + std::cerr << "Method is not currently supported by TRTorch" << std::endl; + return -1; + } + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims); std::ofstream out("/tmp/engine_converted_from_jit.trt"); out << engine; @@ -69,7 +74,6 @@ int main(int argc, const char* argv[]) { torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues); std::vector jit_results; jit_results.push_back(jit_results_ivalues.toTensor()); - auto trt_mod = trtorch::CompileGraph(mod, dims); torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);