Skip to content

Commit

Permalink
fix: Fix modules_as_engines test case to use trt_mod instead of pyt_mod
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Oct 9, 2021
1 parent a38f0c7 commit 282e98a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/cpp/test_modules_as_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

auto forward_graph = mod.get_method("forward");
std::vector<c10::ArrayRef<int64_t>> input_ranges;
for (auto in : inputs) {
input_ranges.push_back(in.sizes());
Expand All @@ -43,7 +42,7 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
auto trt_mod = trtorch::EmbedEngineInNewModule(engine, compile_spec.device);

torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());

Expand All @@ -61,4 +60,4 @@ INSTANTIATE_TEST_SUITE_P(
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3})));

0 comments on commit 282e98a

Please sign in to comment.