From c1934c1c3f9fa48cfe2428ead8848f7b3199206f Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 16 Apr 2021 16:57:27 -0500 Subject: [PATCH] chore: improve some minor code problems Signed-off-by: Bo Wang --- core/compiler.cpp | 4 ++++ core/util/trt_util.h | 4 ++++ tests/core/partitioning/BUILD | 14 +++++++++++++- tests/core/partitioning/partitioning_test.bzl | 3 --- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 44e9cfd8b6..48cfd05556 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -201,6 +201,10 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo int trt_engine_id = 1; std::unordered_map old_to_new_g; + // add global graph's input to old_to_new_g mapping + for (auto input : g->inputs()) { + util::getOrAddInputForValue(input, new_g, old_to_new_g); + } for (auto& seg_block : segmented_blocks) { LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n"); if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index f90e6d19fd..61ea59c127 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -110,6 +110,10 @@ at::ScalarType toATenDType(nvinfer1::DataType t); nvinfer1::DataType toTRTDataType(at::ScalarType t); c10::optional toTRTDataType(caffe2::TypeMeta dtype); c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr& g); +torch::jit::Value* getOrAddInputForValue( + torch::jit::Value* old_value, + std::shared_ptr& graph, + std::unordered_map& old_to_new); torch::jit::Node* cloneNode( torch::jit::Node* node, std::shared_ptr& graph, diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 4fc51c2bbd..34fdc8c921 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -29,8 +29,20 @@ partitioning_test( name = "test_stitched_graph", ) -partitioning_test( +cc_test( name = "test_fallback_graph_output", + srcs = ["test_fallback_graph_output.cpp"], + deps = [ + "//tests/util", + "//core", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + data = [ + ":jit_models" + ] ) test_suite( diff --git a/tests/core/partitioning/partitioning_test.bzl b/tests/core/partitioning/partitioning_test.bzl index 55d1625e56..2c11a38b08 100644 --- a/tests/core/partitioning/partitioning_test.bzl +++ b/tests/core/partitioning/partitioning_test.bzl @@ -11,8 +11,5 @@ def partitioning_test(name, visibility=None): ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], }), - data = [ - ":jit_models" - ], timeout="short" )