Skip to content

Commit

Permalink
feat(tests/util): added RunGraphEngineDynamic to handle dynamic input…
Browse files Browse the repository at this point in the history
… sized tensors

Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer committed Jun 18, 2020
1 parent 98c797d commit 9458f21
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/util/run_graph_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include "core/conversion/conversion.h"
#include "cuda_runtime_api.h"

#include <vector>
#include <math.h>

namespace trtorch {
namespace tests {
namespace util {
Expand All @@ -18,6 +21,34 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
return std::move(a);
}

std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten) {
std::vector<core::conversion::InputRange> a;

for (auto i : ten) {
auto opt = core::util::toVec(i.sizes());

std::vector<int64_t> min_range(opt);
std::vector<int64_t> max_range(opt);

min_range[0] = ceil(opt[0]/2.0);
max_range[0] = 2*opt[0];

// for (int64_t each : min_range) {
// std::cout << each << std::endl;
// }
// for (int64_t each : opt) {
// std::cout << each << std::endl;
// }
// for (int64_t each : max_range) {
// std::cout << each << std::endl;
// }

a.push_back(core::conversion::InputRange(min_range, opt, max_range));
}

return std::move(a);
}

std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs) {
auto rt = nvinfer1::createInferRuntime(core::util::logging::get_logger());
auto engine = rt->deserializeCudaEngine(eng.c_str(), eng.size());
Expand Down Expand Up @@ -71,6 +102,17 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
return RunEngine(eng, inputs);
}

std::vector<at::Tensor> RunGraphEngineDynamic(std::shared_ptr<torch::jit::Graph>& g,
core::conversion::GraphParams& named_params,
std::vector<at::Tensor> inputs) {
LOG_DEBUG("Running TRT version");
auto in = toInputRangesDynamic(inputs);
auto info = core::conversion::ConversionInfo(in);
info.engine_settings.workspace_size = 1 << 20;
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
return RunEngine(eng, inputs);
}

} // namespace util
} // namespace tests
} // namespace trtorch
6 changes: 6 additions & 0 deletions tests/util/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
core::conversion::GraphParams& named_params,
std::vector<at::Tensor> inputs);

// Runs an arbitrary JIT graph with dynamic input sizes by converting it to TensorRT
// and running inference and returns results
std::vector<at::Tensor> RunGraphEngineDynamic(std::shared_ptr<torch::jit::Graph>& g,
core::conversion::GraphParams& named_params,
std::vector<at::Tensor> inputs);

// Run the forward method of a module and return results
torch::jit::IValue RunModuleForward(torch::jit::Module& mod,
std::vector<torch::jit::IValue> inputs);
Expand Down

0 comments on commit 9458f21

Please sign in to comment.