Skip to content

Commit

Permalink
docs: Clean up testing and documentation
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 4, 2020
1 parent a9f33e4 commit cd6b1b9
Show file tree
Hide file tree
Showing 16 changed files with 52 additions and 24 deletions.
2 changes: 1 addition & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pkg_tar(
"//core/conversion/evaluators:include",
"//core/execution:include",
"//core/lowering:include",
"//core/lowering/irfusers:include",
"//core/lowering/passes:include",
"//core/util:include",
"//core/util/logging:include"
],
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ Thanks for wanting to contribute! There are two main ways to handle supporting a
You can register a converter for your op using the `NodeConverterRegistry` inside your application.

## Known Limitations

- You cannot use both Adaptive Pooling in PyTorch and also use TRTorch Dynamic input shape

## Structure of the repo

| Component | Description |
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
ExtraInfo cfg) {
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
// torch::jit::script::Module new_mod = mod.clone();
//torch::jit::script::Module new_mod = mod.clone();
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
for (const torch::jit::script::Method& method : mod.get_methods()) {
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::UnpackAddMM(g);
passes::ExpandLogSoftmax(g);
passes::UnpackLogSoftmax(g);
//passes::RemoveDimExeception(g);
//irfusers::UnpackBatchNorm(g);
torch::jit::EliminateDeadCode(g);
Expand Down
10 changes: 5 additions & 5 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ cc_library(
"passes.h",
],
srcs = [
"exception_elimination.cpp",
"fuse_flatten_linear.cpp",
"expand_log_softmax.cpp",
"remove_dropout.cpp",
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
"exception_elimination.cpp",
"unpack_addmm.cpp"
"unpack_log_softmax.cpp",
],
deps = [
"//core/util:prelude",
Expand All @@ -23,7 +23,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")

pkg_tar(
name = "include",
package_dir = "core/lowering/irfusers/",
srcs = ["irfusers.h"],
package_dir = "core/lowering/passes/",
srcs = ["passes.h"],
)

1 change: 0 additions & 1 deletion core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ struct ExceptionOrPassPatternElimination {
: graph_(std::move(graph)) {}

void run() {
LOG_GRAPH("Pre exeception or pass elimination: " << *graph_);
findExceptionOrPassNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);
Expand Down
3 changes: 3 additions & 0 deletions core/lowering/passes/fuse_flatten_linear.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
Expand Down Expand Up @@ -38,6 +40,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
flatten_linear_bias_none_pattern, fused_linear_bias_none);
flatten_linear_bias_none_to_linear.runOnGraph(graph);
LOG_GRAPH("Post flatten linear: " << *graph);
}

} // namespace passes
Expand Down
4 changes: 2 additions & 2 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace lowering {
namespace passes {

void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

} // namespace irfusers
Expand Down
3 changes: 3 additions & 0 deletions core/lowering/passes/remove_dropout.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
Expand All @@ -20,6 +22,7 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
remove_dropout.RegisterRewritePattern(
dropout_pattern, no_dropout_pattern);
remove_dropout.runOnGraph(graph);
LOG_GRAPH("Post remove dropout: " << *graph);
}

} // namespace passes
Expand Down
3 changes: 3 additions & 0 deletions core/lowering/passes/unpack_addmm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
Expand All @@ -23,6 +25,7 @@ void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::SubgraphRewriter unpack_addmm;
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
unpack_addmm.runOnGraph(graph);
LOG_GRAPH("Post unpack addmm: " << *graph);
}


Expand Down
3 changes: 3 additions & 0 deletions core/lowering/passes/unpack_batch_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
Expand Down Expand Up @@ -39,6 +41,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::SubgraphRewriter unpack_batch_norm;
unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern);
unpack_batch_norm.runOnGraph(graph);
LOG_GRAPH("Post unpack batchnorm: " << *graph);
}
} // Namespace passes
} // namespace lowering
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
// Its easier for TensorRT if we seperate softmax and log
// There might need to be a reshape inserted see:
// https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593
Expand Down Expand Up @@ -43,6 +45,7 @@ void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
logsoftmax_none_to_softmax_log_none.RegisterRewritePattern(
logsoftmax_none_pattern, softmax_log_none_pattern);
logsoftmax_none_to_softmax_log_none.runOnGraph(graph);
LOG_GRAPH("Post unpack logsoftmax: " << *graph);
}

} // namespace passes
Expand Down
2 changes: 1 addition & 1 deletion core/util/logging/TRTorchLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ namespace {
TRTorchLogger& get_global_logger() {
#ifndef NDEBUG
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ",
LogLevel::kGRAPH,
LogLevel::kDEBUG,
true);
#else
static TRTorchLogger global_logger("[TRTorch] - ",
Expand Down
19 changes: 10 additions & 9 deletions cpp/api/include/trtorch/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ namespace logging {
* Emum for setting message severity
*/
enum Level {
kINTERNAL_ERROR,
kERROR,
kWARNING,
kINFO,
kDEBUG,
kINTERNAL_ERROR, // Only print messages for internal errors
kERROR, // Print all internal errors and errors (default)
kWARNING, // Print warnings and errors
kINFO, // Print all info, warnings and errors
kDEBUG, // Print all debug info, info, warnings and errors
kGRAPH, // Print everything including the intermediate graphs of the lowering phase
};

// Are these ones necessary for the user?
Expand All @@ -35,7 +36,7 @@ TRTORCH_API void set_reportable_log_level(Level lvl);
TRTORCH_API void set_is_colored_output_on(bool colored_output_on);

/**
* @brief Get the current reportable log level
* @brief Get the current reportable log level
*/
TRTORCH_API Level get_reportable_log_level();

Expand All @@ -45,10 +46,10 @@ TRTORCH_API Level get_reportable_log_level();
TRTORCH_API bool get_is_colored_output_on();

/**
* @brief Adds a message to the global log
* @brief Adds a message to the global log
*
* @param lvl: trtorch::logging::Level - Severity of the message
* @param msg: std::string - Message to be logged
* @param lvl: trtorch::logging::Level - Severity of the message
* @param msg: std::string - Message to be logged
*/
// Dont know if we want this?
TRTORCH_API void log(Level lvl, std::string msg);
Expand Down
9 changes: 7 additions & 2 deletions cpp/api/src/logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace logging {
std::string get_logging_prefix() {
return core::util::logging::get_logger().get_logging_prefix();
}

void set_logging_prefix(std::string prefix) {
core::util::logging::get_logger().set_logging_prefix(prefix);
}
Expand All @@ -27,6 +27,9 @@ void set_reportable_log_level(Level lvl) {
case Level::kINFO:
log_lvl = core::util::logging::LogLevel::kINFO;
break;
case Level::kGRAPH:
log_lvl = core::util::logging::LogLevel::kGRAPH;
break;
case Level::kDEBUG:
default:
log_lvl = core::util::logging::LogLevel::kDEBUG;
Expand All @@ -50,12 +53,14 @@ Level get_reportable_log_level() {
return Level::kWARNING;
case core::util::logging::LogLevel::kINFO:
return Level::kINFO;
case core::util::logging::LogLevel::kGRAPH:
return Level::kGRAPH;
case core::util::logging::LogLevel::kDEBUG:
default:
return Level::kDEBUG;
}
}

bool get_is_colored_output_on() {
return core::util::logging::get_logger().get_is_colored_output_on();
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/trtorchexec/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ int main(int argc, const char* argv[]) {
dims.push_back(v);
}

std::cout << "Checking operator support" << std::endl;
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
return -1;
}

std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
out << engine;
Expand All @@ -75,7 +77,9 @@ int main(int argc, const char* argv[]) {
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

std::cout << "Compiling graph as module" << std::endl;
auto trt_mod = trtorch::CompileGraph(mod, dims);
std::cout << "Running TRT module" << std::endl;
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand Down

0 comments on commit cd6b1b9

Please sign in to comment.