From 9f006d5e66cd686fdb6e93747c8493029ef06fd9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 15 Jun 2021 18:33:28 -0700 Subject: [PATCH] fix: Restrict TRTorch to compile only forward methods Signed-off-by: Dheeraj Peri --- core/compiler.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 7f5e9e3076..0c3bb35957 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -183,7 +183,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo std::vector> graphs; for (const torch::jit::script::Method& method : mod.get_methods()) { // Don't convert hidden methods - if (method.name().rfind("_", 0)) { + if (method.name().compare("forward")==0) { auto new_g = std::make_shared(); auto graph_and_parameters = lowering::Lower(mod, method.name()); @@ -257,7 +257,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C std::vector> graphs; for (const torch::jit::script::Method& method : mod.get_methods()) { // Don't convert hidden methods - if (method.name().rfind("_", 0)) { + // + if (method.name().compare("forward")==0) { auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg); auto new_g = std::make_shared(); AddEngineToGraph(new_mod, new_g, engine);