Skip to content

Commit

Permalink
refactor!: Changing the C++ api to be snake case
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This changes the C++ API ::ts
APIs to be snake case and for CompileModules to
become just compile

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Oct 22, 2021
1 parent 4d606bc commit f34e230
Show file tree
Hide file tree
Showing 24 changed files with 46 additions and 45 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ examples/int8/ptq/ptq
examples/int8/qat/qat
examples/int8/training/vgg16/data/*
examples/int8/datasets/data/*
env/**/*
env/**/*
bazel-Torch-TensorRT-Preview
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ More Information / System Architecture:

## Building a docker container for Torch-TensorRT Preview

We provide a `Dockerfile` in `docker/` directory. We build `Torch-TensorRT` on top of a `Pytorch NGC container` which provide basic dependencies (like CUDA, CUDNN, CUBLAS, TensorRT, Pytorch and others) The dependency libraries in the container can be found in the <a href="https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html">release notes</a>.
We provide a `Dockerfile` in `docker/` directory. We build `Torch-TensorRT` on top of a `Pytorch NGC container` which provide basic dependencies (like CUDA, CUDNN, CUBLAS, TensorRT, Pytorch and others) The dependency libraries in the container can be found in the <a href="https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html">release notes</a>.

Please follow this instruction to build a Docker container.

Expand All @@ -41,7 +41,7 @@ auto compile_settings = torch_tensorrt::ts::CompileSpec({input});
// FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
// Compile module
auto trt_mod = torch_tensorrt::ts::CompileModule(ts_mod, compile_settings);
auto trt_mod = torch_tensorrt::ts::compile(ts_mod, compile_settings);
// Run like normal
auto results = trt_mod.forward({in_tensor});
// Save module for later
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ torchtrt::ts::CompileSpec cfg(input_sizes);
cfg.torch_fallback = torchtrt::CompileSpec::TorchFallback(true);
cfg.torch_fallback.min_block_size = 2;
cfg.torch_fallback.forced_fallback_ops.push_back("aten::relu");
auto trt_mod = torchtrt::ts::CompileModule(mod, cfg);
auto trt_mod = torchtrt::ts::compile(mod, cfg);
auto out = trt_mod.forward({in});
```
6 changes: 3 additions & 3 deletions cpp/bin/torchtrtc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ int main(int argc, char** argv) {
// Instead of compiling, just embed engine in a PyTorch module
if (embed_engine) {
std::string serialized_engine = read_buf(real_input_path);
auto trt_mod = torchtrt::ts::EmbedEngineInNewModule(serialized_engine, compile_settings.device);
auto trt_mod = torchtrt::ts::embed_engine_in_new_module(serialized_engine, compile_settings.device);
trt_mod.save(real_output_path);
return 0;
}
Expand All @@ -622,12 +622,12 @@ int main(int argc, char** argv) {
}

if (save_engine) {
auto engine = torchtrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_settings);
auto engine = torchtrt::ts::convert_method_to_trt_engine(mod, "forward", compile_settings);
std::ofstream out(real_output_path);
out << engine;
out.close();
} else {
auto trt_mod = torchtrt::ts::CompileModule(mod, compile_settings);
auto trt_mod = torchtrt::ts::compile(mod, compile_settings);

if (!no_threshold_check &&
(compile_settings.enabled_precisions.size() == 1 &&
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ struct TORCHTRT_API CompileSpec {
*
* @returns bool: Method is supported by Torch-TensorRT.TorchScript
*/
TORCHTRT_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);
TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);

/**
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand All @@ -717,7 +717,7 @@ TORCHTRT_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, s
*
* @return: A new module trageting a TensorRT engine
*/
TORCHTRT_API torch::jit::Module CompileModule(const torch::jit::Module& module, CompileSpec info);
TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);

/**
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
Expand All @@ -733,7 +733,7 @@ TORCHTRT_API torch::jit::Module CompileModule(const torch::jit::Module& module,
* @return: std::string: Serialized TensorRT engine equivilant to the method
* graph
*/
TORCHTRT_API std::string ConvertMethodToTRTEngine(
TORCHTRT_API std::string convert_method_to_trt_engine(
const torch::jit::Module& module,
std::string method_name,
CompileSpec info);
Expand All @@ -751,6 +751,6 @@ TORCHTRT_API std::string ConvertMethodToTRTEngine(
*
* @return: A new module trageting a TensorRT engine
*/
TORCHTRT_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine, Device device);
TORCHTRT_API torch::jit::Module embed_engine_in_new_module(const std::string& engine, Device device);
} // namespace torchscript
} // namespace torch_tensorrt
8 changes: 4 additions & 4 deletions cpp/src/torch_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ namespace torchscript {
// Defined in compile_spec.cpp
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external);

bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) {
return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name);
}

std::string ConvertMethodToTRTEngine(
std::string convert_method_to_trt_engine(
const torch::jit::script::Module& module,
std::string method_name,
CompileSpec info) {
Expand All @@ -26,14 +26,14 @@ std::string ConvertMethodToTRTEngine(
return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info));
}

torch::jit::script::Module CompileModule(const torch::jit::script::Module& module, CompileSpec info) {
torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) {
LOG_DEBUG(get_build_info());
// Want to export a much simpler (non TRT header dependent) API so doing the
// type conversion here
return torch_tensorrt::core::CompileGraph(module, to_internal_compile_spec(info));
}

torch::jit::Module EmbedEngineInNewModule(const std::string& engine, Device device) {
torch::jit::Module embed_engine_in_new_module(const std::string& engine, Device device) {
return torch_tensorrt::core::EmbedEngineInNewModule(engine, to_internal_cuda_device(device));
}

Expand Down
4 changes: 2 additions & 2 deletions examples/benchmark/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ int main(int argc, const char* argv[]) {
compile_spec.enabled_precisions.insert(torch::kF16);
#endif

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

#ifdef SAVE_ENGINE
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_spec);
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", compile_spec);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
out << engine;
out.close();
Expand Down
4 changes: 2 additions & 2 deletions examples/int8/ptq/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M

#ifdef SAVE_ENGINE
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_spec);
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", compile_spec);
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
out << engine;
out.close();
#endif

std::cout << "Compiling and quantizing module" << std::endl;
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
return std::move(trt_mod);
}

Expand Down
4 changes: 2 additions & 2 deletions examples/int8/qat/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ torch::jit::Module compile_int8_qat_model(const std::string& data_dir, torch::ji

#ifdef SAVE_ENGINE
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", compile_spec);
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", compile_spec);
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
out << engine;
out.close();
#endif

std::cout << "Compiling and quantizing module" << std::endl;
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);
return std::move(trt_mod);
}

Expand Down
2 changes: 1 addition & 1 deletion tests/accuracy/test_dla_fp16_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TEST_P(AccuracyTests, DLAFP16AccuracyIsClose) {
compile_spec.device.allow_gpu_fallback = true;
compile_spec.workspace_size = 1 << 28;

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
for (auto batch : *eval_dataloader) {
Expand Down
2 changes: 1 addition & 1 deletion tests/accuracy/test_dla_int8_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ TEST_P(AccuracyTests, DLAINT8AccuracyIsClose) {
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

// Compile Graph
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

// Check the INT8 accuracy in TRT
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
Expand Down
2 changes: 1 addition & 1 deletion tests/accuracy/test_fp16_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
compile_spec.enabled_precisions.insert(torch::kF16);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
for (auto batch : *eval_dataloader) {
Expand Down
2 changes: 1 addition & 1 deletion tests/accuracy/test_fp32_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TEST_P(AccuracyTests, FP32AccuracyIsClose) {
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
compile_spec.enabled_precisions.insert(torch::kF32);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
for (auto batch : *eval_dataloader) {
Expand Down
2 changes: 1 addition & 1 deletion tests/accuracy/test_int8_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TEST_P(AccuracyTests, INT8AccuracyIsClose) {
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

// Compile Graph
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

// Check the INT8 accuracy in TRT
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_compiled_modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, input_shapes);
auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand Down
10 changes: 5 additions & 5 deletions tests/cpp/test_default_input_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TEST_P(CppAPITests, InputsUseDefaultFP32) {
auto spec = torch_tensorrt::ts::CompileSpec({in});
spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand All @@ -38,7 +38,7 @@ TEST_P(CppAPITests, InputsUseDefaultFP16) {

mod.to(torch::kHalf);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand All @@ -60,7 +60,7 @@ TEST_P(CppAPITests, InputsUseDefaultFP16WithoutFP16Enabled) {

mod.to(torch::kHalf);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand All @@ -84,7 +84,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) {

mod.to(torch::kHalf);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand All @@ -106,7 +106,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
auto spec = torch_tensorrt::ts::CompileSpec({in});
spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf);

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_example_tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TEST_P(CppAPITests, InputsFromTensors) {

auto spec = torch_tensorrt::ts::CompileSpec({trt_inputs_ivalues[0].toTensor()});

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, spec);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_module_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
cfg.torch_executed_modules.push_back("torchvision.models.resnet.BasicBlock");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, cfg);
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}
Expand Down Expand Up @@ -54,7 +54,7 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
cfg.torch_executed_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, cfg);
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);

auto g = trt_mod.get_method("forward").graph();
auto nodes = g->block()->nodes();
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_modules_as_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
cudaGetDevice(&device_id);
compile_spec.device.device_type = torch_tensorrt::Device::DeviceType::kGPU;
compile_spec.device.gpu_id = device_id;
auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", input_ranges);
auto trt_mod = torch_tensorrt::ts::EmbedEngineInNewModule(engine, compile_spec.device);
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", input_ranges);
auto trt_mod = torch_tensorrt::ts::embed_engine_in_new_module(engine, compile_spec.device);

torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, inputs_ivalues);
std::vector<at::Tensor> trt_results;
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_multi_gpu_serde.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

auto trt_mod = torch_tensorrt::ts::CompileModule(mod, input_shapes);
auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes);

// Deliberately changing the device ID. torch_tensorrt runtime should correct the Device ID internally
torch_tensorrt::set_device(1);
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_multiple_registered_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ TEST(CppAPITest, CanRunMultipleEngines) {
std::vector<at::Tensor> jit2_results;
jit2_results.push_back(jit2_results_ivalues.toTensor());

auto trt_mod1 = torch_tensorrt::ts::CompileModule(mod1, input_shapes);
auto trt_mod1 = torch_tensorrt::ts::compile(mod1, input_shapes);
torch::jit::IValue trt1_results_ivalues =
torch_tensorrt::tests::util::RunModuleForward(trt_mod1, trt1_inputs_ivalues);
std::vector<at::Tensor> trt1_results;
trt1_results.push_back(trt1_results_ivalues.toTensor());

auto trt_mod2 = torch_tensorrt::ts::CompileModule(mod2, input_shapes);
auto trt_mod2 = torch_tensorrt::ts::compile(mod2, input_shapes);
torch::jit::IValue trt2_results_ivalues =
torch_tensorrt::tests::util::RunModuleForward(trt_mod2, trt2_inputs_ivalues);
std::vector<at::Tensor> trt2_results;
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_runtime_thread_safety.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ TEST(CppAPITests, RuntimeThreadSafety) {
// FP32 execution
compile_settings.enabled_precisions = {torch::kFloat};
compile_settings.strict_types = true;
auto trt_mod = torch_tensorrt::ts::CompileModule(mod, compile_settings);
std::cout << "torch_tensorrt::ts::CompileModule" << std::endl;
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_settings);
std::cout << "torch_tensorrt::ts::compile" << std::endl;

int num_threads = 10;
std::vector<torch::jit::IValue> out_vec(num_threads), trt_out_vec(num_threads);
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TEST_P(CppAPITests, SerializedModuleIsStillCorrect) {
pre_serialized_inputs_ivalues.push_back(in.clone());
}

auto pre_serialized_mod = torch_tensorrt::ts::CompileModule(mod, input_shapes);
auto pre_serialized_mod = torch_tensorrt::ts::compile(mod, input_shapes);
torch::jit::IValue pre_serialized_results_ivalues =
torch_tensorrt::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
std::vector<at::Tensor> pre_serialized_results;
Expand Down Expand Up @@ -57,7 +57,7 @@ TEST_P(CppAPITests, SerializedDynamicModuleIsStillCorrect) {
}

auto pre_serialized_mod =
torch_tensorrt::ts::CompileModule(mod, torch_tensorrt::ts::CompileSpec(toInputRangesDynamic(input_shapes)));
torch_tensorrt::ts::compile(mod, torch_tensorrt::ts::CompileSpec(toInputRangesDynamic(input_shapes)));
torch::jit::IValue pre_serialized_results_ivalues =
torch_tensorrt::tests::util::RunModuleForward(pre_serialized_mod, pre_serialized_inputs_ivalues);
std::vector<at::Tensor> pre_serialized_results;
Expand Down
2 changes: 1 addition & 1 deletion tests/util/run_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::vector<at::Tensor> RunModuleForwardAsEngine(torch::jit::Module& mod, std::v
input_ranges.push_back(in.sizes());
}

auto engine = torch_tensorrt::ts::ConvertMethodToTRTEngine(mod, "forward", input_ranges);
auto engine = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", input_ranges);
return RunEngine(engine, inputs);
}

Expand Down

0 comments on commit f34e230

Please sign in to comment.