Skip to content

Commit

Permalink
feat(//core/quantization): skeleton of INT8 PTQ calibrator
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 3, 2020
1 parent aef6003 commit dd443a6
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 11 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ experiments/
py/build/
py/tmp/
py/.eggs
.vscode/
.vscode/
.DS_Store
._DS_Store
15 changes: 10 additions & 5 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
<< "\n Max Workspace Size: " << s.workspace_size \
<< "\n Device Type: " << s.device \
<< "\n Engine Capability: " << s.capability;
<< "\n Engine Capability: " << s.capability \
<< "\n Calibrator Created: " << s.calibrator ? true : false;
return os;
}

Expand All @@ -36,13 +37,17 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)

switch(settings.op_precision) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does support FP16");
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
input_type = nvinfer1::DataType::kHALF;
break;
// case nvinfer1::DataType::kINT8:
// cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
// input_type = nvinfer1::DataType::kFLOAT;
// break;
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
input_type = nvinfer1::DataType::kINT8;
// If the calibrator is nullptr then TRT will use default quantization
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
default:
input_type = nvinfer1::DataType::kFLOAT;
Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct BuilderSettings {
bool allow_gpu_fallback = true;
nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU;
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
nvinfer1::IInt8Calibrator* calibrator = nullptr;
uint64_t num_min_timing_iters = 2;
uint64_t num_avg_timing_iters = 1;
uint64_t workspace_size = 0;
Expand Down
Empty file added core/quantization/BUILD
Empty file.
64 changes: 64 additions & 0 deletions core/quantization/TRTEntropyCalibrator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "core/util/prelude.h"
#include "core/quantization/quantization.h"

namespace trtorch {
namespace core {
namespace quantization {

Int8CalibratorImpl::Int8CalibratorImpl(QuantizationSettings&& settings)
: dataset_(std::move(settings.calibration_dataset),
cache_file_path_(settings.calibration_cache_file),
use_cache_(settings.use_cache) {
buffers_.reserve(dataset_.size);

}

int Int8CalibratorImpl::GetBatchSize() const {

}

bool Int8CalibratorImpl::GetBatch(void* bindings[], const char* names[], int num_bindings) {
if (!is_next_batch) {
return false;
}

for (size_t i = 0; i < num_bindings; i++) {
auto batch = next_binding_batch(names[i]);
batch = batch.to(at::kCUDA).contiguous();
bindings[i] = batch.data_ptr();
}
return true;
}

const void* Int8CalibratorImpl::ReadCalibrationCache(size_t& length) {
cache_.clear();
std::ifstream cache_file(cache_file_path_, std::ios::binary);
cache_file >> std::noskipws;
if (use_cache && cache_file.good()) {
std::copy(std::istream_iterator<char>(input),
std::istream_iterator<char>(),
std::back_inserter(cache_));
}
cache_size_ = cache_.size();
return cache_size ? cache_.data() : nullptr;
}

void Int8CalibratorImpl::WriteCalibrationCache(const void* cache, size_t length) {
std::ofstream cache_file(cache_file_path_, std::ios::binary);
cache_file.write(reinterpret_cast<const char*>(cache_), cache_size_);
}

nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings) {
auto calibrator_impl = Int8CalibratorImpl(settings);
switch(settings.calibrator_type) {
case CalibratorKind::kMinMax:
return TRTInt8MinMaxCalibrator(std::move(calibrator_impl));
case CalibratorKind::kEntropy:
default:
return TRTInt8EntropyCalibrator(std::move(calibrator_impl));
}
}

} // quantization
} // core
} // trtorch
69 changes: 69 additions & 0 deletions core/quantization/quantization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once
#include "ATen/tensor.h"
#include "NvInfer.h"

namespace trtorch {
namespace core {
namespace quantization {

enum class CalibratorKind {
kEntropy,
kMinMax,
}

in conveter or whatever
in order given std::vector<at::Tensor> -> map<input_name, at::Tensor>

struct QuantizationSettings {
CalibratorKind calibrator_type = CalibratorKind::kEntropy;
const std::string& calibration_cache_file = "";
bool use_cache = false;
std::unordered_map<std::string, at::Tensor> calibration_dataset;
};

class CalibrationBatchStream {

};

class Int8CalibratorImpl {
public:
TRTInt8CalibratorImpl(QuantizationSettings& settings);
int GetBatchSize() const;
bool GetBatch(void* bindings[], const char* names[], int num_bindings);
const void* ReadCalibrationCache(size_t& length);
void WriteCalibrationCache(const void* cache, size_t length);
private:
std::unordered_map<std::string, at::Tensor> dataset_;
const std::string& cache_file_path_;
std::vector<char> cache_;
bool use_cache_;
size_t cache_size_ = 0;
};

class TRTInt8EntropyCalibrator : nvinfer1::IInt8EntropyCalibrator2 {
public:
TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {}
int getBatchSize() const override {return impl_.GetBatchSize();}
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)};
const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)};
void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)};
private:
Int8CalibratorImpl impl_;
};

class TRTInt8MinMaxCalibrator : nvinfer1::IInt8MinMaxCalibrator {
public:
TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {}
int getBatchSize() const override {return impl_.GetBatchSize();}
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)};
const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)};
void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)};
private:
Int8CalibratorImpl impl_;
};

nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings);

} // quantization
} // core
} // trtorch
12 changes: 12 additions & 0 deletions cpp/ptq/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package(default_visibility = ["//visibility:public"])

cc_binary(
name = "ptq",
srcs = [
"main.cpp"
],
deps = [
"@libtorch//:libtorch",
"//cpp/api:trtorch"
],
)
21 changes: 21 additions & 0 deletions cpp/ptq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ptq

This is a short example application that shows how to use TRTorch to perform post-training quantization for a module.

## Compilation

``` shell
bazel build //cpp/ptq --cxxopt="-DNDEBUG"
```

If you want insight into what is going under the hood or need debug symbols

``` shell
bazel build //cpp/ptq --compilation_mode=dbg
```

## Usage

``` shell
ptq
```
36 changes: 36 additions & 0 deletions cpp/ptq/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "torch/script.h"
#include "torch/csrc/api/include/torch/data/datasets/mnist.h"
#include "trtorch/trtorch.h"

#include <iostream>
#include <sstream>
#include <memory>

int main(int argc, const char* argv[]) {
if (argc < 3) {
std::cerr << "usage: ptq <path-to-module> <path-to-mnist>\n";
return -1;
}

torch::jit::script::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}

const std::string data_dir = std::string(argv[2]);
auto calibration_dataset = torch::data::datasets::MNIST(data_dir, torch::data::datasets::MNIST::Mode::kTest)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());
auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), torch::data::DataLoaderOptions()
.batch_size(32)
.workers(1))

for (auto batch : batched_calibration_dataset) {
std::cout << batch.data().sizes() << std::endl;
}
}
10 changes: 5 additions & 5 deletions cpp/trtorchexec/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
}
std::cout << "Max Difference: " << diff.abs().max().item<float>() << std::endl;
return diff.abs().max().item<float>() <= 2e-6 * maxValue;
return diff.abs().max().item<float>() <= 2e-5 * maxValue;
}

bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
Expand All @@ -25,8 +25,8 @@ int main(int argc, const char* argv[]) {
<< " trtorchexec <path-to-exported-script-module> <min-input-size> <opt-input-size> <max-input-size>\n";
return -1;
}


torch::jit::script::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
Expand All @@ -38,7 +38,7 @@ int main(int argc, const char* argv[]) {
}

mod.to(at::kCUDA);

std::vector<std::vector<int64_t>> dims;
for (int i = 2; i < argc; i++) {
auto arg = std::string(argv[i]);
Expand Down Expand Up @@ -74,7 +74,7 @@ int main(int argc, const char* argv[]) {
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

auto trt_mod = trtorch::CompileGraph(mod, dims);
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
Expand Down

0 comments on commit dd443a6

Please sign in to comment.