From df741364a21b30e5d5990f088d2cb238825bfed3 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 24 Apr 2020 13:06:44 -0700 Subject: [PATCH] feat(//tests): New optional accuracy tests to check INT8 and FP16 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .gitignore | 1 + tests/BUILD | 9 +- tests/accuracy/BUILD | 73 ++++++++++ tests/accuracy/accuracy_test.h | 30 +++++ tests/accuracy/datasets/BUILD | 23 ++++ tests/accuracy/datasets/cifar10.cpp | 132 +++++++++++++++++++ tests/accuracy/datasets/cifar10.h | 45 +++++++ tests/accuracy/test_fp16_accuracy.cpp | 56 ++++++++ tests/accuracy/test_fp32_accuracy.cpp | 56 ++++++++ tests/accuracy/test_int8_accuracy.cpp | 83 ++++++++++++ tests/modules/BUILD | 13 -- tests/modules/test_fp16_compiled_modules.cpp | 41 ------ tests/util/run_forward.cpp | 2 +- 13 files changed, 508 insertions(+), 56 deletions(-) create mode 100644 tests/accuracy/BUILD create mode 100644 tests/accuracy/accuracy_test.h create mode 100644 tests/accuracy/datasets/BUILD create mode 100644 tests/accuracy/datasets/cifar10.cpp create mode 100644 tests/accuracy/datasets/cifar10.h create mode 100644 tests/accuracy/test_fp16_accuracy.cpp create mode 100644 tests/accuracy/test_fp32_accuracy.cpp create mode 100644 tests/accuracy/test_int8_accuracy.cpp delete mode 100644 tests/modules/test_fp16_compiled_modules.cpp diff --git a/.gitignore b/.gitignore index 37dd6a63ea..2fc1b55b80 100644 --- a/.gitignore +++ b/.gitignore @@ -21,5 +21,6 @@ py/.eggs cpp/ptq/training/vgg16/data/* *.bin cpp/ptq/datasets/data/ +tests/accuracy/datasets/data/* ._.DS_Store *.tar.gz diff --git a/tests/BUILD b/tests/BUILD index 1006bec047..66967d02c7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -5,4 +5,11 @@ test_suite( "//tests/modules:test_modules" ], ) - + +test_suite( + name = "required_and_optional_tests", + tests = [ + ":tests", + "//tests/accuracy:test_accuracy" + ] +) \ No newline at end of file diff --git a/tests/accuracy/BUILD b/tests/accuracy/BUILD new file mode 100644 index 0000000000..713f886ef2 --- /dev/null +++ b/tests/accuracy/BUILD @@ -0,0 +1,73 @@ +filegroup( + name = "jit_models", + srcs = glob(["**/*.jit.pt"]) +) + +test_suite( + name = "test_accuracy", + tests = [ + ":test_int8_accuracy", + ":test_fp16_accuracy", + ":test_fp32_accuracy", + ] +) + +cc_test( + name = "test_int8_accuracy", + srcs = ["test_int8_accuracy.cpp"], + deps = [ + ":accuracy_test", + "//tests/accuracy/datasets:cifar10" + ], + data = [ + ":jit_models", + ] +) + +cc_test( + name = "test_fp16_accuracy", + srcs = ["test_fp16_accuracy.cpp"], + deps = [ + ":accuracy_test", + "//tests/accuracy/datasets:cifar10" + ], + data = [ + ":jit_models", + ] +) + +cc_test( + name = "test_fp32_accuracy", + srcs = ["test_fp32_accuracy.cpp"], + deps = [ + ":accuracy_test", + "//tests/accuracy/datasets:cifar10" + ], + data = [ + ":jit_models", + ] +) + +cc_binary( + name = "test", + srcs = ["test.cpp"], + deps = [ + ":accuracy_test", + "//tests/accuracy/datasets:cifar10" + ], + data = [ + ":jit_models", + ] +) + + +cc_library( + name = "accuracy_test", + hdrs = ["accuracy_test.h"], + deps = [ + "//cpp/api:trtorch", + "//tests/util", + "@libtorch//:libtorch", + "@googletest//:gtest_main", + ], +) diff --git a/tests/accuracy/accuracy_test.h b/tests/accuracy/accuracy_test.h new file mode 100644 index 0000000000..229608de6d --- /dev/null +++ b/tests/accuracy/accuracy_test.h @@ -0,0 +1,30 @@ +#include +#include "torch/script.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "trtorch/trtorch.h" +#include "c10/cuda/CUDACachingAllocator.h" + +// TODO: Extend this to support other datasets +class AccuracyTests + : public testing::TestWithParam { +public: + void SetUp() override { + auto params = GetParam(); + auto module_path = params; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(module_path); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return; + } + } + + void TearDown() { + c10::cuda::CUDACachingAllocator::emptyCache(); + } +protected: + torch::jit::script::Module mod; +}; diff --git a/tests/accuracy/datasets/BUILD b/tests/accuracy/datasets/BUILD new file mode 100644 index 0000000000..23b7134bab --- /dev/null +++ b/tests/accuracy/datasets/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cifar10", + hdrs = [ + "cifar10.h" + ], + srcs = [ + "cifar10.cpp" + ], + deps = [ + "@libtorch//:libtorch" + ], + data = [ + ":cifar10_data" + ] + +) + +filegroup( + name = "cifar10_data", + srcs = glob(["data/cifar-10-batches-bin/**/*.bin"]) +) \ No newline at end of file diff --git a/tests/accuracy/datasets/cifar10.cpp b/tests/accuracy/datasets/cifar10.cpp new file mode 100644 index 0000000000..87b6868c6e --- /dev/null +++ b/tests/accuracy/datasets/cifar10.cpp @@ -0,0 +1,132 @@ +#include "tests/accuracy/datasets/cifar10.h" + +#include "torch/torch.h" +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace datasets { +namespace { +constexpr const char* kTrainFilenamePrefix = "data_batch_"; +constexpr const uint32_t kNumTrainFiles = 5; +constexpr const char* kTestFilename = "test_batch.bin"; +constexpr const size_t kLabelSize = 1; // B +constexpr const size_t kImageSize = 3072; // B +constexpr const size_t kImageDim = 32; +constexpr const size_t kImageChannels = 3; +constexpr const size_t kBatchSize = 10000; + +std::pair read_batch(const std::string& path) { + std::ifstream batch; + batch.open(path, std::ios::in|std::ios::binary|std::ios::ate); + + auto file_size = batch.tellg(); + std::unique_ptr buf(new char[file_size]); + + batch.seekg(0, std::ios::beg); + batch.read(buf.get(), file_size); + batch.close(); + + std::vector labels; + std::vector images; + labels.reserve(kBatchSize); + images.reserve(kBatchSize); + + for (size_t i = 0; i < kBatchSize; i++) { + uint8_t label = buf[i * (kImageSize + kLabelSize)]; + std::vector image; + image.reserve(kImageSize); + std::copy(&buf[i * (kImageSize + kLabelSize) + 1], &buf[i * (kImageSize + kLabelSize) + kImageSize], std::back_inserter(image)); + labels.push_back(label); + auto image_tensor = torch::from_blob(image.data(), + {kImageChannels, kImageDim, kImageDim}, + torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); + images.push_back(image_tensor); + } + + auto labels_tensor = torch::from_blob(labels.data(), + {kBatchSize}, + torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); + assert(labels_tensor.size(0) == kBatchSize); + + auto images_tensor = torch::stack(images); + assert(images_tensor.size(0) == kBatchSize); + + return std::make_pair(images_tensor, labels_tensor); +} + +std::pair read_train_data(const std::string& root) { + std::vector images, targets; + for(uint32_t i = 1; i <= 5; i++) { + std::stringstream ss; + ss << root << '/' << kTrainFilenamePrefix << i << ".bin"; + auto batch = read_batch(ss.str()); + images.push_back(batch.first); + targets.push_back(batch.second); + } + + torch::Tensor image_tensor = std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);}); + torch::Tensor target_tensor = std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);}); + + return std::make_pair(image_tensor, target_tensor); +} + +std::pair read_test_data(const std::string& root) { + std::stringstream ss; + ss << root << '/' << kTestFilename; + return read_batch(ss.str()); +} +} + +CIFAR10::CIFAR10(const std::string& root, Mode mode) + : mode_(mode) { + + std::pair data; + if (mode_ == Mode::kTrain) { + data = read_train_data(root); + } else { + data = read_test_data(root); + } + + images_ = std::move(data.first); + targets_ = std::move(data.second); + assert(images_.sizes()[0] == images_.sizes()[0]); +} + +torch::data::Example<> CIFAR10::get(size_t index) { + return {images_[index], targets_[index]}; +} + +c10::optional CIFAR10::size() const { + return images_.size(0); +} + +bool CIFAR10::is_train() const noexcept { + return mode_ == Mode::kTrain; +} + +const torch::Tensor& CIFAR10::images() const { + return images_; +} + +const torch::Tensor& CIFAR10::targets() const { + return targets_; +} + +CIFAR10&& CIFAR10::use_subset(int64_t new_size) { + assert(new_size <= images_.sizes()[0]); + images_ = images_.slice(0, 0, new_size); + targets_ = targets_.slice(0, 0, new_size); + return std::move(*this); +} + +} // namespace datasets + diff --git a/tests/accuracy/datasets/cifar10.h b/tests/accuracy/datasets/cifar10.h new file mode 100644 index 0000000000..ce2bccebb6 --- /dev/null +++ b/tests/accuracy/datasets/cifar10.h @@ -0,0 +1,45 @@ +#pragma once + +#include "torch/data/datasets/base.h" +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include + +namespace datasets { +// The CIFAR10 Dataset +class CIFAR10 : public torch::data::datasets::Dataset { +public: + // The mode in which the dataset is loaded + enum class Mode { kTrain, kTest }; + + // Loads CIFAR10 from un-tarred file + // Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz + // Root path should be the directory that contains the content of tarball + explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); + + // Returns the pair at index in the dataset + torch::data::Example<> get(size_t index) override; + + // The size of the dataset + c10::optional size() const override; + + // The mode the dataset is in + bool is_train() const noexcept; + + // Returns all images stacked into a single tensor + const torch::Tensor& images() const; + + // Returns all targets stacked into a single tensor + const torch::Tensor& targets() const; + + // Trims the dataset to the first n pairs + CIFAR10&& use_subset(int64_t new_size); + + +private: + Mode mode_; + torch::Tensor images_, targets_; +}; +} // namespace datasets diff --git a/tests/accuracy/test_fp16_accuracy.cpp b/tests/accuracy/test_fp16_accuracy.cpp new file mode 100644 index 0000000000..7ebcc8b0fb --- /dev/null +++ b/tests/accuracy/test_fp16_accuracy.cpp @@ -0,0 +1,56 @@ +#include "accuracy_test.h" +#include "datasets/cifar10.h" +#include "torch/torch.h" + +TEST_P(AccuracyTests, FP16AccuracyIsClose) { + auto eval_dataset = datasets::CIFAR10("tests/accuracy/datasets/data/cifar-10-batches-bin/", datasets::CIFAR10::Mode::kTest) + .use_subset(3200) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions() + .batch_size(32) + .workers(2)); + + // Check the FP32 accuracy in JIT + torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA}); + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + jit_total += targets.sizes()[0]; + jit_correct += torch::sum(torch::eq(predictions, targets)); + } + torch::Tensor jit_accuracy = jit_correct / jit_total; + + std::vector> input_shape = {{32, 3, 32, 32}}; + auto extra_info = trtorch::ExtraInfo({input_shape}); + extra_info.op_precision = torch::kF16; + + auto trt_mod = trtorch::CompileGraph(mod, extra_info); + + torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA).to(torch::kF16); + auto targets = batch.target.to(torch::kCUDA).to(torch::kF16); + + auto outputs = trt_mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + predictions = predictions.reshape(predictions.sizes()[0]); + + trt_total += targets.sizes()[0]; + trt_correct += torch::sum(torch::eq(predictions, targets)); + } + + torch::Tensor trt_accuracy = trt_correct / trt_total; + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); +} + + +INSTANTIATE_TEST_SUITE_P(FP16AccuracyIsCloseSuite, + AccuracyTests, + testing::Values("tests/accuracy/vgg16_cifar10.jit.pt")); diff --git a/tests/accuracy/test_fp32_accuracy.cpp b/tests/accuracy/test_fp32_accuracy.cpp new file mode 100644 index 0000000000..b014340e82 --- /dev/null +++ b/tests/accuracy/test_fp32_accuracy.cpp @@ -0,0 +1,56 @@ +#include "accuracy_test.h" +#include "datasets/cifar10.h" +#include "torch/torch.h" + +TEST_P(AccuracyTests, FP16AccuracyIsClose) { + auto eval_dataset = datasets::CIFAR10("tests/accuracy/datasets/data/cifar-10-batches-bin/", datasets::CIFAR10::Mode::kTest) + .use_subset(3200) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions() + .batch_size(32) + .workers(2)); + + // Check the FP32 accuracy in JIT + torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA}); + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + jit_total += targets.sizes()[0]; + jit_correct += torch::sum(torch::eq(predictions, targets)); + } + torch::Tensor jit_accuracy = jit_correct / jit_total; + + std::vector> input_shape = {{32, 3, 32, 32}}; + auto extra_info = trtorch::ExtraInfo({input_shape}); + extra_info.op_precision = torch::kF32; + + auto trt_mod = trtorch::CompileGraph(mod, extra_info); + + torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = trt_mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + predictions = predictions.reshape(predictions.sizes()[0]); + + trt_total += targets.sizes()[0]; + trt_correct += torch::sum(torch::eq(predictions, targets)); + } + + torch::Tensor trt_accuracy = trt_correct / trt_total; + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); +} + + +INSTANTIATE_TEST_SUITE_P(FP16AccuracyIsCloseSuite, + AccuracyTests, + testing::Values("tests/accuracy/vgg16_cifar10.jit.pt")); diff --git a/tests/accuracy/test_int8_accuracy.cpp b/tests/accuracy/test_int8_accuracy.cpp new file mode 100644 index 0000000000..07d399c96d --- /dev/null +++ b/tests/accuracy/test_int8_accuracy.cpp @@ -0,0 +1,83 @@ +#include "accuracy_test.h" +#include "datasets/cifar10.h" +#include "torch/torch.h" + +TEST_P(AccuracyTests, FP16AccuracyIsClose) { + auto calibration_dataset = datasets::CIFAR10("tests/accuracy/datasets/data/cifar-10-batches-bin/", datasets::CIFAR10::Mode::kTest) + .use_subset(320) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), torch::data::DataLoaderOptions() + .batch_size(32) + .workers(2)); + + std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache"; + + auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); + //auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file); + + + std::vector> input_shape = {{32, 3, 32, 32}}; + // Configure settings for compilation + auto extra_info = trtorch::ExtraInfo({input_shape}); + // Set operating precision to INT8 + extra_info.op_precision = torch::kI8; + // Use the TensorRT Entropy Calibrator + extra_info.ptq_calibrator = calibrator; + // Set max batch size for the engine + extra_info.max_batch_size = 32; + // Set a larger workspace + extra_info.workspace_size = 1 << 28; + + mod.eval(); + + // Dataloader moved into calibrator so need another for inference + auto eval_dataset = datasets::CIFAR10("tests/accuracy/datasets/data/cifar-10-batches-bin/", datasets::CIFAR10::Mode::kTest) + .use_subset(3200) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions() + .batch_size(32) + .workers(2)); + + // Check the FP32 accuracy in JIT + torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA}); + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + jit_total += targets.sizes()[0]; + jit_correct += torch::sum(torch::eq(predictions, targets)); + } + torch::Tensor jit_accuracy = jit_correct / jit_total; + + // Compile Graph + auto trt_mod = trtorch::CompileGraph(mod, extra_info); + + // Check the INT8 accuracy in TRT + torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = trt_mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + predictions = predictions.reshape(predictions.sizes()[0]); + + trt_total += targets.sizes()[0]; + trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + } + torch::Tensor trt_accuracy = trt_correct / trt_total; + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); +} + + +INSTANTIATE_TEST_SUITE_P(FP16AccuracyIsCloseSuite, + AccuracyTests, + testing::Values("tests/accuracy/vgg16_cifar10.jit.pt")); \ No newline at end of file diff --git a/tests/modules/BUILD b/tests/modules/BUILD index 1bbc677bb1..b9e3545ce2 100644 --- a/tests/modules/BUILD +++ b/tests/modules/BUILD @@ -8,7 +8,6 @@ test_suite( tests = [ ":test_modules_as_engines", ":test_compiled_modules", - ":test_fp16_compiled_modules", ":test_multiple_registered_engines" ] ) @@ -24,7 +23,6 @@ cc_test( ] ) - cc_test( name = "test_modules_as_engines", srcs = ["test_modules_as_engines.cpp"], @@ -36,17 +34,6 @@ cc_test( ] ) -cc_test( - name = "test_fp16_compiled_modules", - srcs = ["test_fp16_compiled_modules.cpp"], - deps = [ - ":module_test" - ], - data = [ - ":jit_models" - ] -) - cc_test( name = "test_compiled_modules", srcs = ["test_compiled_modules.cpp"], diff --git a/tests/modules/test_fp16_compiled_modules.cpp b/tests/modules/test_fp16_compiled_modules.cpp deleted file mode 100644 index 0828197b79..0000000000 --- a/tests/modules/test_fp16_compiled_modules.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "module_test.h" - -TEST_P(ModuleTests, FP16CompiledModuleIsClose) { - std::vector jit_inputs_ivalues; - std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); - in = in.to(torch::kF16); - jit_inputs_ivalues.push_back(in.clone()); - trt_inputs_ivalues.push_back(in.clone()); - } - - auto extra_info = trtorch::ExtraInfo({input_shapes}); - extra_info.op_precision = torch::kF16; - extra_info.strict_types = true; - - auto trt_mod = trtorch::CompileGraph(mod, extra_info); - torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); - std::vector trt_results; - trt_results.push_back(trt_results_ivalues.toTensor()); - - mod.to(torch::kF16); - torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, jit_inputs_ivalues); - std::vector jit_results; - jit_results.push_back(jit_results_ivalues.toTensor()); - - for (size_t i = 0; i < trt_results.size(); i++) { - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 2e-5)); - } -} - - -INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite, - ModuleTests, - testing::Values( - PathAndInSize({"tests/modules/resnet18.jit.pt", - {{1,3,224,224}}}), - PathAndInSize({"tests/modules/resnet50.jit.pt", - {{1,3,224,224}}}), - PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", - {{1,3,224,224}}}))); diff --git a/tests/util/run_forward.cpp b/tests/util/run_forward.cpp index e552cdc104..f8f22b900d 100644 --- a/tests/util/run_forward.cpp +++ b/tests/util/run_forward.cpp @@ -18,7 +18,7 @@ std::vector RunModuleForwardAsEngine(torch::jit::script::Module& mod for (auto in : inputs) { input_ranges.push_back(in.sizes()); } - + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges); return RunEngine(engine, inputs); }