Skip to content

Commit

Permalink
feat(//cpp/trtorchexec): TRTorch exec now supports checking correctness
Browse files Browse the repository at this point in the history
of multiple outputs

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 23, 2020
1 parent 8171f79 commit 80808b7
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions cpp/trtorchexec/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ int main(int argc, const char* argv[]) {
dims.push_back(v);
}

auto extra_info = trtorch::ExtraInfo(dims);
extra_info.workspace_size = 1 << 24;

std::cout << "Checking operator support" << std::endl;
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
return -1;
}

std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
out << engine;
out.close();
Expand All @@ -75,14 +78,28 @@ int main(int argc, const char* argv[]) {

torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());
if (jit_results_ivalues.isTensor()) {
jit_results.push_back(jit_results_ivalues.toTensor());
} else {
auto results = jit_results_ivalues.toTuple()->elements();
for (auto r : results) {
jit_results.push_back(r.toTensor());
}
}

std::cout << "Compiling graph as module" << std::endl;
auto trt_mod = trtorch::CompileGraph(mod, dims);
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
std::cout << "Running TRT module" << std::endl;
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
if (trt_results_ivalues.isTensor()) {
trt_results.push_back(trt_results_ivalues.toTensor());
} else {
auto results = trt_results_ivalues.toTuple()->elements();
for (auto r : results) {
trt_results.push_back(r.toTensor());
}
}

for (size_t i = 0; i < trt_results.size(); i++) {
almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]));
Expand Down

0 comments on commit 80808b7

Please sign in to comment.