Skip to content

Commit

Permalink
tests: use IRParser in test_tensorrt_conversion and test_stitched_graph
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Apr 16, 2021
1 parent 437670e commit f722035
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 48 deletions.
3 changes: 1 addition & 2 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ void AddSegmentedBlockToGraph(
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
}

LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
return;
}

Expand All @@ -187,7 +186,6 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
if (method.name().rfind("_", 0)) {
auto new_g = std::make_shared<torch::jit::Graph>();
auto graph_and_parameters = lowering::Lower(mod, method.name());
// LOG_INFO(*(method.graph()) << "Original graph\n");

auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
Expand All @@ -204,6 +202,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
int trt_engine_id = 1;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
for (auto& seg_block : segmented_blocks) {
LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n");
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
std::vector<ir::InputRange> input_ranges;
for (auto& shape : seg_block.in_shape()) {
Expand Down
124 changes: 102 additions & 22 deletions tests/core/partitioning/test_stitched_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "core/compiler.h"
#include "core/util/trt_util.h"
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/constants.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/script.h"

bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
Expand All @@ -22,39 +24,117 @@ bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
return true;
}

TEST(Partitioning, StitchResNet50SegmentedBlockCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
TEST(Partitioning, StitchSequentialModelSegmentedBlockCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
%b1 : Float(32),
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
%b2 : Float(16),
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
%b3 : Float(8)):
%2 : int[] = prim::Constant[value=[1, 1]]()
%3 : int = prim::Constant[value=1]()
%10 : bool = prim::Constant[value=0]()
%11 : int[] = prim::Constant[value=[0, 0]]()
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
%13 : Tensor = aten::relu(%12)
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
%15 : Tensor = aten::log_sigmoid(%14)
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
return (%16))IR";

auto parsed_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, parsed_g.get());

auto g = std::make_shared<torch::jit::Graph>();
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}, {8, 16, 3, 3}, {8}};
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
for (size_t i = 0; i < all_shapes.size(); ++i) {
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
torch::jit::IValue cur_val = in.clone();
auto new_val = g->insertConstant(cur_val);
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
}
for (auto node : parsed_g->nodes()) {
if (node->kind() == torch::jit::prim::Constant)
continue;
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
}
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);

std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
std::vector<trtorch::core::ir::InputRange> input_ranges;
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
torch::jit::script::Module mod(c10::QualifiedName("module"));

auto self = g->insertInput(0, "self_1");
self->setType(mod.type());
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
mod.type()->addMethod(cur_method);
cur_method->setSchema(schema);

torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
auto g = new_mod.get_method("forward").graph();
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
auto fallback_g = new_mod.get_method("forward").graph();
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g));
}

TEST(Partitioning, StitchMobileNetSegmentedBlockCorrectlyEdge) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
TEST(Partitioning, StitchBranchModelSegmentedBlockCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
%2 : Float(32),
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
%4 : Float(16)):
%5 : int[] = prim::Constant[value=[0, 0]]()
%6 : int[] = prim::Constant[value=[2, 2]]()
%7 : bool = prim::Constant[value=0]()
%8 : int[] = prim::Constant[value=[1, 1]]()
%9 : int = prim::Constant[value=1]()
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
%12: Tensor = aten::log_sigmoid(%10)
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
%14 : Tensor = aten::relu(%11)
%15 : Tensor = aten::add(%13, %14, %9)
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
return (%16))IR";

auto parsed_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, parsed_g.get());

auto g = std::make_shared<torch::jit::Graph>();
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}};
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
for (size_t i = 0; i < all_shapes.size(); ++i) {
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
torch::jit::IValue cur_val = in.clone();
auto new_val = g->insertConstant(cur_val);
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
}
for (auto node : parsed_g->nodes()) {
if (node->kind() == torch::jit::prim::Constant)
continue;
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
}
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);

std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
std::vector<trtorch::core::ir::InputRange> input_ranges;
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
torch::jit::script::Module mod(c10::QualifiedName("module"));

auto self = g->insertInput(0, "self_1");
self->setType(mod.type());
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
mod.type()->addMethod(cur_method);
cur_method->setSchema(schema);

torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
auto g = new_mod.get_method("forward").graph();
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(g));
auto fallback_g = new_mod.get_method("forward").graph();
ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g));
}
127 changes: 103 additions & 24 deletions tests/core/partitioning/test_tensorrt_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,119 @@ int count_trt_engines(std::shared_ptr<torch::jit::Graph> g) {
return count;
}

TEST(Partitioning, ConvertResNet50SegmentedBlockCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/resnet50_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
TEST(Partitioning, ConvertSequentialModelSegmentedBlockCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
%b1 : Float(32),
%w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
%b2 : Float(16),
%w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
%b3 : Float(8)):
%2 : int[] = prim::Constant[value=[1, 1]]()
%3 : int = prim::Constant[value=1]()
%10 : bool = prim::Constant[value=0]()
%11 : int[] = prim::Constant[value=[0, 0]]()
%12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
%13 : Tensor = aten::relu(%12)
%14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
%15 : Tensor = aten::log_sigmoid(%14)
%16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
return (%16))IR";

auto parsed_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, parsed_g.get());

auto g = std::make_shared<torch::jit::Graph>();
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}, {8, 16, 3, 3}, {8}};
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
for (size_t i = 0; i < all_shapes.size(); ++i) {
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
torch::jit::IValue cur_val = in.clone();
auto new_val = g->insertConstant(cur_val);
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
}
for (auto node : parsed_g->nodes()) {
if (node->kind() == torch::jit::prim::Constant)
continue;
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
}
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);

std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
std::vector<trtorch::core::ir::InputRange> input_ranges;
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
torch::jit::script::Module mod(c10::QualifiedName("module"));

auto self = g->insertInput(0, "self_1");
self->setType(mod.type());
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
mod.type()->addMethod(cur_method);
cur_method->setSchema(schema);

torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
auto g = new_mod.get_method("forward").graph();
int count = count_trt_engines(g);
ASSERT_TRUE(count == 17);
auto fallback_g = new_mod.get_method("forward").graph();
int count = count_trt_engines(fallback_g);
ASSERT_TRUE(count == 2);
}

TEST(Partitioning, ConvertMobileNetSegmentedBlockCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/mobilenet_v2_traced.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
TEST(Partitioning, ConvertBranchModelSegmentedBlockCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
%2 : Float(32),
%3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
%4 : Float(16)):
%5 : int[] = prim::Constant[value=[0, 0]]()
%6 : int[] = prim::Constant[value=[2, 2]]()
%7 : bool = prim::Constant[value=0]()
%8 : int[] = prim::Constant[value=[1, 1]]()
%9 : int = prim::Constant[value=1]()
%10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
%11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
%12: Tensor = aten::log_sigmoid(%10)
%13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
%14 : Tensor = aten::relu(%11)
%15 : Tensor = aten::add(%13, %14, %9)
%16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
return (%16))IR";

auto parsed_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, parsed_g.get());

auto g = std::make_shared<torch::jit::Graph>();
std::vector<std::vector<int64_t>> all_shapes{{32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}};
std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
for (size_t i = 0; i < all_shapes.size(); ++i) {
auto in = at::randint(5, all_shapes[i], {at::kCUDA});
torch::jit::IValue cur_val = in.clone();
auto new_val = g->insertConstant(cur_val);
tensor_to_constant[parsed_g->inputs()[i + 1]] = new_val;
}
for (auto node : parsed_g->nodes()) {
if (node->kind() == torch::jit::prim::Constant)
continue;
trtorch::core::util::cloneNode(node, g, tensor_to_constant);
}
g->registerOutput(tensor_to_constant[parsed_g->outputs()[0]]);

std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({1, 3, 224, 224})};
std::vector<trtorch::core::ir::InputRange> input_ranges;
input_ranges.push_back(trtorch::core::ir::InputRange({3, 3, 16, 16}));
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;
cfg.partition_info.forced_fallback_operators.push_back("aten::add");
torch::jit::script::Module mod(c10::QualifiedName("module"));

auto self = g->insertInput(0, "self_1");
self->setType(mod.type());
auto cur_method = mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
auto schema = trtorch::core::util::GenerateGraphSchema(cur_method->name(), g);
mod.type()->addMethod(cur_method);
cur_method->setSchema(schema);

torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
auto g = new_mod.get_method("forward").graph();
int count = count_trt_engines(g);
ASSERT_TRUE(count == 11);
auto fallback_g = new_mod.get_method("forward").graph();
int count = count_trt_engines(fallback_g);
ASSERT_TRUE(count == 2);
}

0 comments on commit f722035

Please sign in to comment.