-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//tests): New optional accuracy tests to check INT8 and FP16
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
- Loading branch information
1 parent
b989c7f
commit df74136
Showing
13 changed files
with
508 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
filegroup( | ||
name = "jit_models", | ||
srcs = glob(["**/*.jit.pt"]) | ||
) | ||
|
||
test_suite( | ||
name = "test_accuracy", | ||
tests = [ | ||
":test_int8_accuracy", | ||
":test_fp16_accuracy", | ||
":test_fp32_accuracy", | ||
] | ||
) | ||
|
||
cc_test( | ||
name = "test_int8_accuracy", | ||
srcs = ["test_int8_accuracy.cpp"], | ||
deps = [ | ||
":accuracy_test", | ||
"//tests/accuracy/datasets:cifar10" | ||
], | ||
data = [ | ||
":jit_models", | ||
] | ||
) | ||
|
||
cc_test( | ||
name = "test_fp16_accuracy", | ||
srcs = ["test_fp16_accuracy.cpp"], | ||
deps = [ | ||
":accuracy_test", | ||
"//tests/accuracy/datasets:cifar10" | ||
], | ||
data = [ | ||
":jit_models", | ||
] | ||
) | ||
|
||
cc_test( | ||
name = "test_fp32_accuracy", | ||
srcs = ["test_fp32_accuracy.cpp"], | ||
deps = [ | ||
":accuracy_test", | ||
"//tests/accuracy/datasets:cifar10" | ||
], | ||
data = [ | ||
":jit_models", | ||
] | ||
) | ||
|
||
cc_binary( | ||
name = "test", | ||
srcs = ["test.cpp"], | ||
deps = [ | ||
":accuracy_test", | ||
"//tests/accuracy/datasets:cifar10" | ||
], | ||
data = [ | ||
":jit_models", | ||
] | ||
) | ||
|
||
|
||
cc_library( | ||
name = "accuracy_test", | ||
hdrs = ["accuracy_test.h"], | ||
deps = [ | ||
"//cpp/api:trtorch", | ||
"//tests/util", | ||
"@libtorch//:libtorch", | ||
"@googletest//:gtest_main", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include <utility> | ||
#include "torch/script.h" | ||
#include "gtest/gtest.h" | ||
#include "tests/util/util.h" | ||
#include "trtorch/trtorch.h" | ||
#include "c10/cuda/CUDACachingAllocator.h" | ||
|
||
// TODO: Extend this to support other datasets | ||
class AccuracyTests | ||
: public testing::TestWithParam<std::string> { | ||
public: | ||
void SetUp() override { | ||
auto params = GetParam(); | ||
auto module_path = params; | ||
try { | ||
// Deserialize the ScriptModule from a file using torch::jit::load(). | ||
mod = torch::jit::load(module_path); | ||
} | ||
catch (const c10::Error& e) { | ||
std::cerr << "error loading the model\n"; | ||
return; | ||
} | ||
} | ||
|
||
void TearDown() { | ||
c10::cuda::CUDACachingAllocator::emptyCache(); | ||
} | ||
protected: | ||
torch::jit::script::Module mod; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
cc_library( | ||
name = "cifar10", | ||
hdrs = [ | ||
"cifar10.h" | ||
], | ||
srcs = [ | ||
"cifar10.cpp" | ||
], | ||
deps = [ | ||
"@libtorch//:libtorch" | ||
], | ||
data = [ | ||
":cifar10_data" | ||
] | ||
|
||
) | ||
|
||
filegroup( | ||
name = "cifar10_data", | ||
srcs = glob(["data/cifar-10-batches-bin/**/*.bin"]) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#include "tests/accuracy/datasets/cifar10.h" | ||
|
||
#include "torch/torch.h" | ||
#include "torch/data/example.h" | ||
#include "torch/types.h" | ||
|
||
#include <iostream> | ||
#include <cstddef> | ||
#include <fstream> | ||
#include <string> | ||
#include <vector> | ||
#include <utility> | ||
#include <sstream> | ||
#include <memory> | ||
|
||
namespace datasets { | ||
namespace { | ||
constexpr const char* kTrainFilenamePrefix = "data_batch_"; | ||
constexpr const uint32_t kNumTrainFiles = 5; | ||
constexpr const char* kTestFilename = "test_batch.bin"; | ||
constexpr const size_t kLabelSize = 1; // B | ||
constexpr const size_t kImageSize = 3072; // B | ||
constexpr const size_t kImageDim = 32; | ||
constexpr const size_t kImageChannels = 3; | ||
constexpr const size_t kBatchSize = 10000; | ||
|
||
std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) { | ||
std::ifstream batch; | ||
batch.open(path, std::ios::in|std::ios::binary|std::ios::ate); | ||
|
||
auto file_size = batch.tellg(); | ||
std::unique_ptr<char[]> buf(new char[file_size]); | ||
|
||
batch.seekg(0, std::ios::beg); | ||
batch.read(buf.get(), file_size); | ||
batch.close(); | ||
|
||
std::vector<uint8_t> labels; | ||
std::vector<torch::Tensor> images; | ||
labels.reserve(kBatchSize); | ||
images.reserve(kBatchSize); | ||
|
||
for (size_t i = 0; i < kBatchSize; i++) { | ||
uint8_t label = buf[i * (kImageSize + kLabelSize)]; | ||
std::vector<uint8_t> image; | ||
image.reserve(kImageSize); | ||
std::copy(&buf[i * (kImageSize + kLabelSize) + 1], &buf[i * (kImageSize + kLabelSize) + kImageSize], std::back_inserter(image)); | ||
labels.push_back(label); | ||
auto image_tensor = torch::from_blob(image.data(), | ||
{kImageChannels, kImageDim, kImageDim}, | ||
torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); | ||
images.push_back(image_tensor); | ||
} | ||
|
||
auto labels_tensor = torch::from_blob(labels.data(), | ||
{kBatchSize}, | ||
torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); | ||
assert(labels_tensor.size(0) == kBatchSize); | ||
|
||
auto images_tensor = torch::stack(images); | ||
assert(images_tensor.size(0) == kBatchSize); | ||
|
||
return std::make_pair(images_tensor, labels_tensor); | ||
} | ||
|
||
std::pair<torch::Tensor, torch::Tensor> read_train_data(const std::string& root) { | ||
std::vector<torch::Tensor> images, targets; | ||
for(uint32_t i = 1; i <= 5; i++) { | ||
std::stringstream ss; | ||
ss << root << '/' << kTrainFilenamePrefix << i << ".bin"; | ||
auto batch = read_batch(ss.str()); | ||
images.push_back(batch.first); | ||
targets.push_back(batch.second); | ||
} | ||
|
||
torch::Tensor image_tensor = std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);}); | ||
torch::Tensor target_tensor = std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) {return torch::cat({a, b}, 0);}); | ||
|
||
return std::make_pair(image_tensor, target_tensor); | ||
} | ||
|
||
std::pair<torch::Tensor, torch::Tensor> read_test_data(const std::string& root) { | ||
std::stringstream ss; | ||
ss << root << '/' << kTestFilename; | ||
return read_batch(ss.str()); | ||
} | ||
} | ||
|
||
CIFAR10::CIFAR10(const std::string& root, Mode mode) | ||
: mode_(mode) { | ||
|
||
std::pair<torch::Tensor, torch::Tensor> data; | ||
if (mode_ == Mode::kTrain) { | ||
data = read_train_data(root); | ||
} else { | ||
data = read_test_data(root); | ||
} | ||
|
||
images_ = std::move(data.first); | ||
targets_ = std::move(data.second); | ||
assert(images_.sizes()[0] == images_.sizes()[0]); | ||
} | ||
|
||
torch::data::Example<> CIFAR10::get(size_t index) { | ||
return {images_[index], targets_[index]}; | ||
} | ||
|
||
c10::optional<size_t> CIFAR10::size() const { | ||
return images_.size(0); | ||
} | ||
|
||
bool CIFAR10::is_train() const noexcept { | ||
return mode_ == Mode::kTrain; | ||
} | ||
|
||
const torch::Tensor& CIFAR10::images() const { | ||
return images_; | ||
} | ||
|
||
const torch::Tensor& CIFAR10::targets() const { | ||
return targets_; | ||
} | ||
|
||
CIFAR10&& CIFAR10::use_subset(int64_t new_size) { | ||
assert(new_size <= images_.sizes()[0]); | ||
images_ = images_.slice(0, 0, new_size); | ||
targets_ = targets_.slice(0, 0, new_size); | ||
return std::move(*this); | ||
} | ||
|
||
} // namespace datasets | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#pragma once | ||
|
||
#include "torch/data/datasets/base.h" | ||
#include "torch/data/example.h" | ||
#include "torch/types.h" | ||
|
||
#include <cstddef> | ||
#include <string> | ||
|
||
namespace datasets { | ||
// The CIFAR10 Dataset | ||
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> { | ||
public: | ||
// The mode in which the dataset is loaded | ||
enum class Mode { kTrain, kTest }; | ||
|
||
// Loads CIFAR10 from un-tarred file | ||
// Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz | ||
// Root path should be the directory that contains the content of tarball | ||
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); | ||
|
||
// Returns the pair at index in the dataset | ||
torch::data::Example<> get(size_t index) override; | ||
|
||
// The size of the dataset | ||
c10::optional<size_t> size() const override; | ||
|
||
// The mode the dataset is in | ||
bool is_train() const noexcept; | ||
|
||
// Returns all images stacked into a single tensor | ||
const torch::Tensor& images() const; | ||
|
||
// Returns all targets stacked into a single tensor | ||
const torch::Tensor& targets() const; | ||
|
||
// Trims the dataset to the first n pairs | ||
CIFAR10&& use_subset(int64_t new_size); | ||
|
||
|
||
private: | ||
Mode mode_; | ||
torch::Tensor images_, targets_; | ||
}; | ||
} // namespace datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#include "accuracy_test.h" | ||
#include "datasets/cifar10.h" | ||
#include "torch/torch.h" | ||
|
||
TEST_P(AccuracyTests, FP16AccuracyIsClose) { | ||
auto eval_dataset = datasets::CIFAR10("tests/accuracy/datasets/data/cifar-10-batches-bin/", datasets::CIFAR10::Mode::kTest) | ||
.use_subset(3200) | ||
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, | ||
{0.2023, 0.1994, 0.2010})) | ||
.map(torch::data::transforms::Stack<>()); | ||
auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions() | ||
.batch_size(32) | ||
.workers(2)); | ||
|
||
// Check the FP32 accuracy in JIT | ||
torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA}); | ||
for (auto batch : *eval_dataloader) { | ||
auto images = batch.data.to(torch::kCUDA); | ||
auto targets = batch.target.to(torch::kCUDA); | ||
|
||
auto outputs = mod.forward({images}); | ||
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); | ||
|
||
jit_total += targets.sizes()[0]; | ||
jit_correct += torch::sum(torch::eq(predictions, targets)); | ||
} | ||
torch::Tensor jit_accuracy = jit_correct / jit_total; | ||
|
||
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}}; | ||
auto extra_info = trtorch::ExtraInfo({input_shape}); | ||
extra_info.op_precision = torch::kF16; | ||
|
||
auto trt_mod = trtorch::CompileGraph(mod, extra_info); | ||
|
||
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); | ||
for (auto batch : *eval_dataloader) { | ||
auto images = batch.data.to(torch::kCUDA).to(torch::kF16); | ||
auto targets = batch.target.to(torch::kCUDA).to(torch::kF16); | ||
|
||
auto outputs = trt_mod.forward({images}); | ||
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); | ||
predictions = predictions.reshape(predictions.sizes()[0]); | ||
|
||
trt_total += targets.sizes()[0]; | ||
trt_correct += torch::sum(torch::eq(predictions, targets)); | ||
} | ||
|
||
torch::Tensor trt_accuracy = trt_correct / trt_total; | ||
|
||
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); | ||
} | ||
|
||
|
||
INSTANTIATE_TEST_SUITE_P(FP16AccuracyIsCloseSuite, | ||
AccuracyTests, | ||
testing::Values("tests/accuracy/vgg16_cifar10.jit.pt")); |
Oops, something went wrong.