Skip to content

Commit

Permalink
chore: improve some minor code problems
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Apr 16, 2021
1 parent f722035 commit c1934c1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
4 changes: 4 additions & 0 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo

int trt_engine_id = 1;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
// add global graph's input to old_to_new_g mapping
for (auto input : g->inputs()) {
util::getOrAddInputForValue(input, new_g, old_to_new_g);
}
for (auto& seg_block : segmented_blocks) {
LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n");
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
Expand Down
4 changes: 4 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ at::ScalarType toATenDType(nvinfer1::DataType t);
nvinfer1::DataType toTRTDataType(at::ScalarType t);
c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype);
c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g);
torch::jit::Value* getOrAddInputForValue(
torch::jit::Value* old_value,
std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new);
torch::jit::Node* cloneNode(
torch::jit::Node* node,
std::shared_ptr<torch::jit::Graph>& graph,
Expand Down
14 changes: 13 additions & 1 deletion tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,20 @@ partitioning_test(
name = "test_stitched_graph",
)

partitioning_test(
cc_test(
name = "test_fallback_graph_output",
srcs = ["test_fallback_graph_output.cpp"],
deps = [
"//tests/util",
"//core",
"@googletest//:gtest_main",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
data = [
":jit_models"
]
)

test_suite(
Expand Down
3 changes: 0 additions & 3 deletions tests/core/partitioning/partitioning_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,5 @@ def partitioning_test(name, visibility=None):
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
data = [
":jit_models"
],
timeout="short"
)

0 comments on commit c1934c1

Please sign in to comment.