Skip to content

Commit

Permalink
feat(CheckMethodOperatorSupport): A new API which will check the graph
Browse files Browse the repository at this point in the history
to see if all operators are supported. Addresses #26.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 24, 2020
1 parent 3da4947 commit 28ee445
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 30 deletions.
23 changes: 22 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
return;
}

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
std::string method_name) {
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);

auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");

// Is this necessary?
lowering::LowerBlock(g->block());

return conversion::VerifyConverterSupportForBlock(g->block());
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name,
conversion::ExtraInfo cfg) {
Expand All @@ -87,7 +109,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
return std::move(engine);
}

// TODO: Consider if there is a better way to deal with input size
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
conversion::ExtraInfo cfg) {
// TODO: Should be doing a functional transform but need PR #31978
Expand Down
2 changes: 2 additions & 0 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

namespace trtorch {
namespace core {
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name, conversion::ExtraInfo cfg);

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg);

} // namespace core
Expand Down
45 changes: 36 additions & 9 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ namespace core {
namespace conversion {

// Defined in core/conversion/conversion_blacklist.cpp
bool isNodeConversionBlacklisted(torch::jit::Node* n);
bool isNodeConversionBlacklisted(const torch::jit::Node* n);

bool OpSupported(const torch::jit::Node* n) {
bool evalable = evaluators::shouldEvalAtConversionTime(n);
bool convertable = converters::node_is_convertable(n);
return evalable || convertable;
}

c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, torch::jit::Node* n, int level=0, int limit=10) {
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level=0, int limit=10) {
// Check to see if you can just go through and eval all of these AOT (saves the recursion)
// Also probably a better way to deal with the two error cases;
TRTORCH_CHECK(level < limit, "Failed to evaluate node: " << *n \
Expand Down Expand Up @@ -55,7 +55,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, torch::jit::N
return eval;
}

bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) {
bool AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
LOG_INFO(ctx->logger,
"Adding Layer " << util::node_info(n) << " (ctx.AddLayer)");
converters::args node_args;
Expand Down Expand Up @@ -114,11 +114,11 @@ bool AddLayer(ConversionCtx* ctx, torch::jit::Node* n) {
}

bool AddInputs(ConversionCtx* ctx,
at::ArrayRef<torch::jit::Value*> inputs,
at::ArrayRef<const torch::jit::Value*> inputs,
std::vector<InputRange>& input_dims) {

auto type_lut = torch::jit::script::string_to_type_lut();
std::vector<torch::jit::Value*> input_tensors;
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
//
Expand Down Expand Up @@ -163,7 +163,7 @@ bool AddInputs(ConversionCtx* ctx,
return true;
}

bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef<torch::jit::Value*> outputs) {
bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
for (auto out : outputs) {
ctx->net->markOutput(*(ctx->value_tensor_map[out]));
LOG_INFO(ctx->logger,
Expand All @@ -178,7 +178,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
LOG_INFO(ctx->logger, "Converting Block");

auto inputs = b->inputs();
Expand All @@ -188,7 +188,6 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu
auto nodes = b->nodes();

for (const auto n : nodes) {

bool to_eval = evaluators::shouldEvalAtConversionTime(n);
bool blacklisted = isNodeConversionBlacklisted(n);
if (!to_eval && !blacklisted) {
Expand Down Expand Up @@ -220,13 +219,41 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, torch::jit::Block* b, ExtraInfo bu
// a serialized TensorRT engine that can be deserialized and run

// Probably should consolidate these two functions
std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
ConversionCtx ctx(build_info.engine_settings);
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
std::string engine = ctx.SerializeEngine();
return engine;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
bool supported = true;
std::set<std::string> unsupported_ops;
for (const auto n : b->nodes()) {
if (!OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.AddLayer)");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
supported = false;
}
}

if (!supported) {
std::stringstream unsupported_msg;
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl;
for (auto s : unsupported_ops) {
unsupported_msg << " - " << s << std::endl;
}
unsupported_msg << "You can either implement converters for these ops in your application or file a bug" << std::endl;
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
LOG_ERROR(unsupported_msg.str());
}
return supported;
}

} // namespace conversion
} // namespace core
} // namespace trtorch
4 changes: 3 additions & 1 deletion core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect

// Converts a already lowered block (blocks with no sub blocks) to
// a serialized TensorRT engine that can be deserialized and run
std::string ConvertBlockToEngine(torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);

bool OpSupported(const torch::jit::Node* n);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b);

} // namespace conversion
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/conversion/conversion_blacklist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
return nonconvertable_nodes;
}

bool isNodeConversionBlacklisted(torch::jit::Node* n) {
bool isNodeConversionBlacklisted(const torch::jit::Node* n) {
auto kind = n->kind();
auto convertableIt = get_non_convertable_nodes().find(kind.toQualString());
if (convertableIt == get_non_convertable_nodes().end()) {
Expand Down
15 changes: 5 additions & 10 deletions core/conversion/converters/NodeConverterRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ using ConverterLUT = std::unordered_map<torch::jit::Symbol, OpConverter>;
class NodeConverterRegistry {
public:
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
// NOTE: This is useful for people developing extentions to the conversion registry as is
// If you are working on the core conversion library and the conversion registry
// itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the
// registration of core converters during init, otherwise the messages will be masked
LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature));
auto sym = torch::jit::Symbol::fromQualString(signature->name());
converter_lut_[sym] = std::move(converter);
Expand All @@ -70,13 +66,12 @@ class NodeConverterRegistry {
bool Convertable(const torch::jit::Node* n) {
auto schema = n->maybeSchema();
if (schema) {
auto converter = GetConverter(schema);
if (converter) {
return true;
auto sym = torch::jit::Symbol::fromQualString(schema->name());
auto iter = converter_lut_.find(sym);
if (iter == converter_lut_.end()) {
return false;
} else {
LOG_DEBUG("Node has no registered converter: " << util::node_info(n) \
<< " (NodeConverterRegistry.Convertable)\nSchema: " << *schema);
return false;
return true;
}
} else {
LOG_DEBUG("Unable to get schema for Node " << util::node_info(n) \
Expand Down
4 changes: 0 additions & 4 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, NodeEvaluator>;
class NodeEvaluatorRegistry {
public:
void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) {
// NOTE: This is useful for people developing extentions to the conversion registry as is
// If you are working on the core conversion library and the conversion registry
// itself, it might helpful to set -DDEBUG_MSGS when you compile so you can watch the
// registration of core converters during init, otherwise the messages will be masked
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
evaluator_lut_[node_kind] = std::move(evaluator);
}
Expand Down
15 changes: 14 additions & 1 deletion cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,19 @@ TRTORCH_API std::string get_build_info();
*/
TRTORCH_API void dump_build_info();

/**
* @brief Check to see if a module is fully supported by the compiler
*
* @param module: torch::jit::script::Module - Existing TorchScript module
* @param method_name: std::string - Name of method to compile
*
* Takes a module and a method name and checks if the method graph contains purely
* convertable operators
*
* Will print out a list of unsupported operators if the graph is unsupported
*/
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name);

/**
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
*
Expand All @@ -239,5 +252,5 @@ TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Mo
* and will convert selected method to a serialized TensorRT engine which can be run with
* TensorRT
*/
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo info);
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
} // namespace trtorch
9 changes: 7 additions & 2 deletions cpp/api/src/trtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ namespace trtorch {
// Defined in extra_info.cpp
core::conversion::ExtraInfo to_internal_extra_info(ExtraInfo external);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module,
std::string method_name) {
return core::CheckMethodOperatorSupport(module, method_name);
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module,
std::string method_name, ExtraInfo info) {
LOG_DEBUG(get_build_info());
// Want to export a much simpler (non TRT header dependent) API so doing the
// type conversion here
return std::move(core::ConvertGraphToTRTEngine(mod, method_name, to_internal_extra_info(info)));
return std::move(core::ConvertGraphToTRTEngine(module, method_name, to_internal_extra_info(info)));
}

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info) {
Expand Down
6 changes: 5 additions & 1 deletion cpp/trtorchexec/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ int main(int argc, const char* argv[]) {
dims.push_back(v);
}

if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
return -1;
}

auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
out << engine;
Expand All @@ -69,7 +74,6 @@ int main(int argc, const char* argv[]) {
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());


auto trt_mod = trtorch::CompileGraph(mod, dims);
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
Expand Down

0 comments on commit 28ee445

Please sign in to comment.