Skip to content

Commit

Permalink
fix: Restrict TRTorch to compile only forward methods
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jun 16, 2021
1 parent 930d582 commit 9f006d5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
for (const torch::jit::script::Method& method : mod.get_methods()) {
// Don't convert hidden methods
if (method.name().rfind("_", 0)) {
if (method.name().compare("forward")==0) {
auto new_g = std::make_shared<torch::jit::Graph>();
auto graph_and_parameters = lowering::Lower(mod, method.name());

Expand Down Expand Up @@ -257,7 +257,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
for (const torch::jit::script::Method& method : mod.get_methods()) {
// Don't convert hidden methods
if (method.name().rfind("_", 0)) {
//
if (method.name().compare("forward")==0) {
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
Expand Down

0 comments on commit 9f006d5

Please sign in to comment.