Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Arvind Sridhar <[email protected]>
  • Loading branch information
ArvindSridhar committed Sep 17, 2021
1 parent 2a6f999 commit 01c6952
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ filegroup(
name = "jit_models",
srcs = ["//tests/modules:resnet50_traced.jit.pt",
"//tests/modules:mobilenet_v2_traced.jit.pt",
"//tests/modules:conditional_scripted.jit.pt"]
"//tests/modules:conditional_scripted.jit.pt",
"//tests/modules:loop_fallback_eval_scripted.jit.pt",
"//tests/modules:loop_fallback_no_eval_scripted.jit.pt"]
)

partitioning_test(
Expand Down Expand Up @@ -46,6 +48,22 @@ cc_test(
]
)

cc_test(
name = "test_loop_fallback",
srcs = ["test_loop_fallback.cpp"],
deps = [
"//tests/util",
"//core",
"@googletest//:gtest_main",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
data = [
":jit_models"
]
)

cc_test(
name = "test_conditionals",
srcs = ["test_conditionals.cpp"],
Expand All @@ -70,6 +88,7 @@ test_suite(
":test_tensorrt_conversion",
":test_stitched_graph",
":test_fallback_graph_output",
":test_loop_fallback",
":test_conditionals"
]
)
62 changes: 62 additions & 0 deletions tests/core/partitioning/test_loop_fallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <string>
#include <unordered_set>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/script.h"

TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/loop_fallback_eval_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}

TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}
32 changes: 32 additions & 0 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,38 @@ def forward(self, x):
torch.jit.save(module_fallback_script_model, "module_fallback_scripted.jit.pt")


# Sample Looping Modules (for loop fallback testing)
class LoopFallbackEval(nn.Module):

def __init__(self):
super(LoopFallbackEval, self).__init__()

def forward(self, x):
add_list = torch.empty(0).to(x.device)
for i in range(x.shape[1]):
add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0)
return x + add_list


class LoopFallbackNoEval(nn.Module):

def __init__(self):
super(LoopFallbackNoEval, self).__init__()

def forward(self, x):
for _ in range(x.shape[1]):
x = x + torch.ones_like(x)
return x


loop_fallback_eval_model = LoopFallbackEval().eval().cuda()
loop_fallback_eval_script_model = torch.jit.script(loop_fallback_eval_model)
torch.jit.save(loop_fallback_eval_script_model, "loop_fallback_eval_scripted.jit.pt")
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
torch.jit.save(loop_fallback_no_eval_script_model, "loop_fallback_no_eval_scripted.jit.pt")


# Sample Conditional Model (for testing partitioning and fallback in conditionals)
class FallbackIf(torch.nn.Module):

Expand Down

0 comments on commit 01c6952

Please sign in to comment.