diff --git a/tests/cpp/cpp_api_test.h b/tests/cpp/cpp_api_test.h index 2291f814cd..3addfbc2ed 100644 --- a/tests/cpp/cpp_api_test.h +++ b/tests/cpp/cpp_api_test.h @@ -6,12 +6,12 @@ #include "torch/script.h" #include "torch_tensorrt/torch_tensorrt.h" -using PathAndInSize = std::tuple>, float>; +using PathAndInput = std::tuple>, std::vector, float>; -class CppAPITests : public testing::TestWithParam { +class CppAPITests : public testing::TestWithParam { public: void SetUp() override { - PathAndInSize params = GetParam(); + PathAndInput params = GetParam(); std::string path = std::get<0>(params); try { // Deserialize the ScriptModule from a file using torch::jit::load(). @@ -21,7 +21,8 @@ class CppAPITests : public testing::TestWithParam { ASSERT_TRUE(false); } input_shapes = std::get<1>(params); - threshold = std::get<2>(params); + input_types = std::get<2>(params); + threshold = std::get<3>(params); } void TearDown() { @@ -32,5 +33,6 @@ class CppAPITests : public testing::TestWithParam { protected: torch::jit::script::Module mod; std::vector> input_shapes; + std::vector input_types; float threshold; }; diff --git a/tests/cpp/test_compiled_modules.cpp b/tests/cpp/test_compiled_modules.cpp index c61a8f76f1..595dd7044f 100644 --- a/tests/cpp/test_compiled_modules.cpp +++ b/tests/cpp/test_compiled_modules.cpp @@ -3,20 +3,42 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + std::vector shapes; + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); + auto in_spec = torch_tensorrt::Input(input_shapes[i]); + in_spec.dtype = input_types[i]; + shapes.push_back(in_spec); + std::cout << in_spec << std::endl; } torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, jit_inputs_ivalues); std::vector jit_results; - jit_results.push_back(jit_results_ivalues.toTensor()); + if (jit_results_ivalues.isTuple()) { + auto tuple = jit_results_ivalues.toTuple(); + for (auto t : tuple->elements()) { + jit_results.push_back(t.toTensor()); + } + } else { + jit_results.push_back(jit_results_ivalues.toTensor()); + } + + auto spec = torch_tensorrt::ts::CompileSpec(shapes); + spec.truncate_long_and_double = true; - auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes); + auto trt_mod = torch_tensorrt::ts::compile(mod, spec); torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); std::vector trt_results; - trt_results.push_back(trt_results_ivalues.toTensor()); + if (trt_results_ivalues.isTuple()) { + auto tuple = trt_results_ivalues.toTuple(); + for (auto t : tuple->elements()) { + trt_results.push_back(t.toTensor()); + } + } else { + trt_results.push_back(trt_results_ivalues.toTensor()); + } for (size_t i = 0; i < trt_results.size(); i++) { ASSERT_TRUE( @@ -30,13 +52,14 @@ INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3}), - PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2}))); + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}), + PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}), + PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2}))); #endif diff --git a/tests/cpp/test_default_input_types.cpp b/tests/cpp/test_default_input_types.cpp index 63904c7416..752f51eecb 100644 --- a/tests/cpp/test_default_input_types.cpp +++ b/tests/cpp/test_default_input_types.cpp @@ -78,7 +78,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) { } auto in = torch_tensorrt::Input(input_shapes[0]); - in.dtype = torch::kF32; + in.dtype = torch::kFloat; auto spec = torch_tensorrt::ts::CompileSpec({in}); spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf); @@ -116,4 +116,4 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) { INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); + testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5}))); diff --git a/tests/cpp/test_example_tensors.cpp b/tests/cpp/test_example_tensors.cpp index fc77d9e4d4..6561cd16a0 100644 --- a/tests/cpp/test_example_tensors.cpp +++ b/tests/cpp/test_example_tensors.cpp @@ -3,8 +3,8 @@ TEST_P(CppAPITests, InputsFromTensors) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randn(in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); } @@ -20,4 +20,4 @@ TEST_P(CppAPITests, InputsFromTensors) { INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); + testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}))); diff --git a/tests/cpp/test_modules_as_engines.cpp b/tests/cpp/test_modules_as_engines.cpp index ab4ccc1ae7..c77919c8b9 100644 --- a/tests/cpp/test_modules_as_engines.cpp +++ b/tests/cpp/test_modules_as_engines.cpp @@ -4,8 +4,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) { std::vector inputs; std::vector inputs_ivalues; - for (auto in_shape : input_shapes) { - inputs.push_back(at::randint(5, in_shape, {at::kCUDA})); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i])); inputs_ivalues.push_back(inputs[inputs.size() - 1].clone()); } @@ -21,8 +21,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) { TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) { std::vector inputs; std::vector inputs_ivalues; - for (auto in_shape : input_shapes) { - inputs.push_back(at::randint(5, in_shape, {at::kCUDA})); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i])); inputs_ivalues.push_back(inputs[inputs.size() - 1].clone()); } @@ -57,13 +57,13 @@ INSTANTIATE_TEST_SUITE_P( ModuleAsEngineForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2}))); + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2}))); #endif diff --git a/tests/cpp/test_multi_gpu_serde.cpp b/tests/cpp/test_multi_gpu_serde.cpp index 2356583fa3..366c287c32 100644 --- a/tests/cpp/test_multi_gpu_serde.cpp +++ b/tests/cpp/test_multi_gpu_serde.cpp @@ -4,8 +4,8 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); } @@ -31,4 +31,4 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); \ No newline at end of file + testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}))); \ No newline at end of file diff --git a/tests/cpp/test_serialization.cpp b/tests/cpp/test_serialization.cpp index 877e42e6ab..0086500be5 100644 --- a/tests/cpp/test_serialization.cpp +++ b/tests/cpp/test_serialization.cpp @@ -21,8 +21,8 @@ std::vector toInputRangesDynamic(std::vector post_serialized_inputs_ivalues; std::vector pre_serialized_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); post_serialized_inputs_ivalues.push_back(in.clone()); pre_serialized_inputs_ivalues.push_back(in.clone()); } @@ -50,8 +50,8 @@ TEST_P(CppAPITests, SerializedModuleIsStillCorrect) { TEST_P(CppAPITests, SerializedDynamicModuleIsStillCorrect) { std::vector post_serialized_inputs_ivalues; std::vector pre_serialized_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); post_serialized_inputs_ivalues.push_back(in.clone()); pre_serialized_inputs_ivalues.push_back(in.clone()); } @@ -81,5 +81,5 @@ INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, 2e-5}))); + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, {at::kFloat}, 2e-5}))); diff --git a/tests/modules/hub.py b/tests/modules/hub.py index fa4b7892ef..2664517f4f 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -217,4 +217,4 @@ def forward(self, x): model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) -torch.jit.save(traced_model, "bert_base_uncased_traced.jit..pt") \ No newline at end of file +torch.jit.save(traced_model, "bert_base_uncased_traced.jit.pt") \ No newline at end of file