Skip to content

Commit

Permalink
feat(//tests): Adding BERT to the test suite
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 5, 2022
1 parent 72c7b76 commit 7996a10
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 44 deletions.
10 changes: 6 additions & 4 deletions tests/cpp/cpp_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"

using PathAndInSize = std::tuple<std::string, std::vector<std::vector<int64_t>>, float>;
using PathAndInput = std::tuple<std::string, std::vector<std::vector<int64_t>>, std::vector<c10::ScalarType>, float>;

class CppAPITests : public testing::TestWithParam<PathAndInSize> {
class CppAPITests : public testing::TestWithParam<PathAndInput> {
public:
void SetUp() override {
PathAndInSize params = GetParam();
PathAndInput params = GetParam();
std::string path = std::get<0>(params);
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
Expand All @@ -21,7 +21,8 @@ class CppAPITests : public testing::TestWithParam<PathAndInSize> {
ASSERT_TRUE(false);
}
input_shapes = std::get<1>(params);
threshold = std::get<2>(params);
input_types = std::get<2>(params);
threshold = std::get<3>(params);
}

void TearDown() {
Expand All @@ -32,5 +33,6 @@ class CppAPITests : public testing::TestWithParam<PathAndInSize> {
protected:
torch::jit::script::Module mod;
std::vector<std::vector<int64_t>> input_shapes;
std::vector<c10::ScalarType> input_types;
float threshold;
};
49 changes: 36 additions & 13 deletions tests/cpp/test_compiled_modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,42 @@
TEST_P(CppAPITests, CompiledModuleIsClose) {
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
std::vector<torch_tensorrt::Input> shapes;
for (uint64_t i = 0; i < input_shapes.size(); i++) {
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
auto in_spec = torch_tensorrt::Input(input_shapes[i]);
in_spec.dtype = input_types[i];
shapes.push_back(in_spec);
std::cout << in_spec << std::endl;
}

torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, jit_inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());
if (jit_results_ivalues.isTuple()) {
auto tuple = jit_results_ivalues.toTuple();
for (auto t : tuple->elements()) {
jit_results.push_back(t.toTensor());
}
} else {
jit_results.push_back(jit_results_ivalues.toTensor());
}

auto spec = torch_tensorrt::ts::CompileSpec(shapes);
spec.truncate_long_and_double = true;

auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes);
auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
if (trt_results_ivalues.isTuple()) {
auto tuple = trt_results_ivalues.toTuple();
for (auto t : tuple->elements()) {
trt_results.push_back(t.toTensor());
}
} else {
trt_results.push_back(trt_results_ivalues.toTensor());
}

for (size_t i = 0; i < trt_results.size(); i++) {
ASSERT_TRUE(
Expand All @@ -30,13 +52,14 @@ INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3}),
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2})));
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}),
PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}),
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2})));

#endif
4 changes: 2 additions & 2 deletions tests/cpp/test_default_input_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) {
}

auto in = torch_tensorrt::Input(input_shapes[0]);
in.dtype = torch::kF32;
in.dtype = torch::kFloat;
auto spec = torch_tensorrt::ts::CompileSpec({in});
spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf);

Expand Down Expand Up @@ -116,4 +116,4 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
6 changes: 3 additions & 3 deletions tests/cpp/test_example_tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
TEST_P(CppAPITests, InputsFromTensors) {
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randn(in_shape, {at::kCUDA});
for (uint64_t i = 0; i < input_shapes.size(); i++) {
auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]);
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}
Expand All @@ -20,4 +20,4 @@ TEST_P(CppAPITests, InputsFromTensors) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5})));
24 changes: 12 additions & 12 deletions tests/cpp/test_modules_as_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
TEST_P(CppAPITests, ModuleAsEngineIsClose) {
std::vector<at::Tensor> inputs;
std::vector<torch::jit::IValue> inputs_ivalues;
for (auto in_shape : input_shapes) {
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
for (uint64_t i = 0; i < input_shapes.size(); i++) {
inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]));
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
}

Expand All @@ -21,8 +21,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) {
TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) {
std::vector<at::Tensor> inputs;
std::vector<torch::jit::IValue> inputs_ivalues;
for (auto in_shape : input_shapes) {
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
for (uint64_t i = 0; i < input_shapes.size(); i++) {
inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]));
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
}

Expand Down Expand Up @@ -57,13 +57,13 @@ INSTANTIATE_TEST_SUITE_P(
ModuleAsEngineForwardIsCloseSuite,
CppAPITests,
testing::Values(
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2})));
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2})));

#endif
6 changes: 3 additions & 3 deletions tests/cpp/test_multi_gpu_serde.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
TEST_P(CppAPITests, CompiledModuleIsClose) {
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
for (uint64_t i = 0; i < input_shapes.size(); i++) {
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}
Expand All @@ -31,4 +31,4 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5})));
12 changes: 6 additions & 6 deletions tests/cpp/test_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ std::vector<torch_tensorrt::Input> toInputRangesDynamic(std::vector<std::vector<
TEST_P(CppAPITests, SerializedModuleIsStillCorrect) {
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
for (uint64_t i = 0; i < input_shapes.size(); i++) {
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
post_serialized_inputs_ivalues.push_back(in.clone());
pre_serialized_inputs_ivalues.push_back(in.clone());
}
Expand Down Expand Up @@ -50,8 +50,8 @@ TEST_P(CppAPITests, SerializedModuleIsStillCorrect) {
TEST_P(CppAPITests, SerializedDynamicModuleIsStillCorrect) {
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
for (uint64_t i = 0; i < input_shapes.size(); i++) {
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
post_serialized_inputs_ivalues.push_back(in.clone());
pre_serialized_inputs_ivalues.push_back(in.clone());
}
Expand Down Expand Up @@ -81,5 +81,5 @@ INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(
PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}),
PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, 2e-5})));
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
PathAndInput({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, {at::kFloat}, 2e-5})));
2 changes: 1 addition & 1 deletion tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,4 @@ def forward(self, x):
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "bert_base_uncased_traced.jit..pt")
torch.jit.save(traced_model, "bert_base_uncased_traced.jit.pt")

0 comments on commit 7996a10

Please sign in to comment.