Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New API to display if a method can be converted to TensorRT #27

Merged
merged 1 commit into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
return;
}

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
std::string method_name) {
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);

auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");

// Is this necessary?
lowering::LowerBlock(g->block());

return conversion::VerifyConverterSupportForBlock(g->block());
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name,
conversion::ExtraInfo cfg) {
Expand All @@ -87,7 +109,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
return std::move(engine);
}

// TODO: Consider if there is a better way to deal with input size
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
conversion::ExtraInfo cfg) {
// TODO: Should be doing a functional transform but need PR #31978
Expand Down
2 changes: 2 additions & 0 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

namespace trtorch {
namespace core {
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name, conversion::ExtraInfo cfg);

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg);

} // namespace core
Expand Down
45 changes: 36 additions & 9 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ namespace core {
namespace conversion {

// Defined in core/conversion/conversion_blacklist.cpp
bool isNodeConversionBlacklisted(torch::jit::Node* n);
bool isNodeConversionBlacklisted(const torch::jit::Node* n);

bool OpSupported(const torch::jit::Node* n) {
bool evalable = evaluators::shouldEvalAtConversionTime(n);
bool convertable = converters::node_is_convertable(n);
return evalable || convertable;
}

c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, torch::jit::Node* n, int level=0, int limit=10) {
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level=0, int limit=10) {
// Check to see if you can just go through and eval all of these AOT (saves the recursion)
// Also probably a better way to deal with the two error cases;
TRTORCH_CHECK(level < limit, "Failed to evaluate node: " << *n \
Expand Down Expand Up @@ -55,7 +55,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, torch::jit::N
return eval;
}

bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) {
bool AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
LOG_INFO(ctx->logger,
"Adding Layer " << util::node_info(n) << " (ctx.AddLayer)");
converters::args node_args;
Expand Down Expand Up @@ -114,11 +114,11 @@ bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) {
}

bool AddInputs(ConversionCtx* ctx,
at::ArrayRef<torch::jit::Value*> inputs,
at::ArrayRef<const torch::jit::Value*> inputs,
std::vector<InputRange>& input_dims) {

auto type_lut = torch::jit::script::string_to_type_lut();
std::vector<torch::jit::Value*> input_tensors;
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
//
Expand Down Expand Up @@ -163,7 +163,7 @@ bool AddInputs(ConversionCtx* ctx,
return true;
}

bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef<torch::jit::Value*> outputs) {
bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
for (auto out : outputs) {
ctx->net->markOutput(*(ctx->value_tensor_map[out]));
LOG_INFO(ctx->logger,
Expand All @@ -178,7 +178,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
LOG_INFO(ctx->logger, "Converting Block");

auto inputs = b->inputs();
Expand All @@ -188,7 +188,6 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu
auto nodes = b->nodes();

for (const auto n : nodes) {

bool to_eval = evaluators::shouldEvalAtConversionTime(n);
bool blacklisted = isNodeConversionBlacklisted(n);
if (!to_eval && !blacklisted) {
Expand Down Expand Up @@ -220,13 +219,41 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu
// a serialized TensorRT engine that can be deserialized and run

// Probably should consolidate these two functions
std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
ConversionCtx ctx(build_info.engine_settings);
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
std::string engine = ctx.SerializeEngine();
return engine;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
bool supported = true;
std::set<std::string> unsupported_ops;
for (const auto n : b->nodes()) {
if (!OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.AddLayer)");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
supported = false;
}
}

if (!supported) {
std::stringstream unsupported_msg;
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl;
for (auto s : unsupported_ops) {
unsupported_msg << " - " << s << std::endl;
}
unsupported_msg << "You can either implement converters for these ops in your application or file a bug" << std::endl;
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
LOG_ERROR(unsupported_msg.str());
}
return supported;
}

} // namespace conversion
} // namespace core
} // namespace trtorch
4 changes: 3 additions & 1 deletion core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect

// Converts a already lowered block (blocks with no sub blocks) to
// a serialized TensorRT engine that can be deserialized and run
std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);

bool OpSupported(const torch::jit::Node* n);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b);

} // namespace conversion
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/conversion/conversion_blacklist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
return nonconvertable_nodes;
}

bool isNodeConversionBlacklisted(torch::jit::Node* n) {
bool isNodeConversionBlacklisted(const torch::jit::Node* n) {
auto kind = n->kind();
auto convertableIt = get_non_convertable_nodes().find(kind.toQualString());
if (convertableIt == get_non_convertable_nodes().end()) {
Expand Down
15 changes: 5 additions & 10 deletions core/conversion/converters/NodeConverterRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ using ConverterLUT = std::unordered_map<torch::jit::Symbol, OpConverter>;
class NodeConverterRegistry {
public:
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
// NOTE: This is useful for people developing extentions to the conversion registry as is
// If you are working on the core conversion library and the conversion registry
// itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the
// registration of core converters during init, otherwise the messages will be masked
LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature));
auto sym = torch::jit::Symbol::fromQualString(signature->name());
converter_lut_[sym] = std::move(converter);
Expand All @@ -70,13 +66,12 @@ class NodeConverterRegistry {
bool Convertable(const torch::jit::Node* n) {
auto schema = n->maybeSchema();
if (schema) {
auto converter = GetConverter(schema);
if (converter) {
return true;
auto sym = torch::jit::Symbol::fromQualString(schema->name());
auto iter = converter_lut_.find(sym);
if (iter == converter_lut_.end()) {
return false;
} else {
LOG_DEBUG("Node has no registered converter: " << util::node_info(n) \
<< " (NodeConverterRegistry.Convertable)\nSchema: " << *schema);
return false;
return true;
}
} else {
LOG_DEBUG("Unable to get schema for Node " << util::node_info(n) \
Expand Down
4 changes: 0 additions & 4 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, NodeEvaluator>;
class NodeEvaluatorRegistry {
public:
void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) {
// NOTE: This is useful for people developing extentions to the conversion registry as is
// If you are working on the core conversion library and the conversion registry
// itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the
// registration of core converters during init, otherwise the messages will be masked
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
evaluator_lut_[node_kind] = std::move(evaluator);
}
Expand Down
15 changes: 14 additions & 1 deletion cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,19 @@ TRTORCH_API std::string get_build_info();
*/
TRTORCH_API void dump_build_info();

/**
* @brief Check to see if a module is fully supported by the compiler
*
* @param module: torch::jit::script::Module - Existing TorchScript module
* @param method_name: std::string - Name of method to compile
*
* Takes a module and a method name and checks if the method graph contains purely
* convertable operators
*
* Will print out a list of unsupported operators if the graph is unsupported
*/
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name);

/**
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
*
Expand All @@ -239,5 +252,5 @@ TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Mo
* and will convert selected method to a serialized TensorRT engine which can be run with
* TensorRT
*/
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo info);
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
} // namespace trtorch
9 changes: 7 additions & 2 deletions cpp/api/src/trtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ namespace trtorch {
// Defined in extra_info.cpp
core::conversion::ExtraInfo to_internal_extra_info(ExtraInfo external);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module,
std::string method_name) {
return core::CheckMethodOperatorSupport(module, method_name);
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module,
std::string method_name, ExtraInfo info) {
LOG_DEBUG(get_build_info());
// Want to export a much simpler (non TRT header dependent) API so doing the
// type conversion here
return std::move(core::ConvertGraphToTRTEngine(mod, method_name, to_internal_extra_info(info)));
return std::move(core::ConvertGraphToTRTEngine(module, method_name, to_internal_extra_info(info)));
}

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info) {
Expand Down
6 changes: 5 additions & 1 deletion cpp/trtorchexec/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ int main(int argc, const char* argv[]) {
dims.push_back(v);
}

if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
return -1;
}

auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
out << engine;
Expand All @@ -69,7 +74,6 @@ int main(int argc, const char* argv[]) {
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());


auto trt_mod = trtorch::CompileGraph(mod, dims);
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
Expand Down