diff --git a/core/compiler.cpp b/core/compiler.cpp index 7f5e9e3076..0c3bb35957 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -183,7 +183,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo std::vector> graphs; for (const torch::jit::script::Method& method : mod.get_methods()) { // Don't convert hidden methods - if (method.name().rfind("_", 0)) { + if (method.name().compare("forward")==0) { auto new_g = std::make_shared(); auto graph_and_parameters = lowering::Lower(mod, method.name()); @@ -257,7 +257,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C std::vector> graphs; for (const torch::jit::script::Method& method : mod.get_methods()) { // Don't convert hidden methods - if (method.name().rfind("_", 0)) { + // + if (method.name().compare("forward")==0) { auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg); auto new_g = std::make_shared(); AddEngineToGraph(new_mod, new_g, engine);