Skip to content

Commit

Permalink
add int8 inference for trt
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Oct 7, 2022
1 parent 7a6f123 commit 55b877b
Show file tree
Hide file tree
Showing 13 changed files with 459 additions and 50 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV


mIOUs and fps on cityscapes val set:
| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link |
| none | ss | ssc | msf | mscf | fps(fp32/fp16/int8) | link |
|------|:--:|:---:|:---:|:----:|:---:|:----:|
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 78/25 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 67/26 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 25/78/141 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 26/67/95 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |

mIOUs on cocostuff val2017 set:
| none | ss | ssc | msf | mscf | link |
Expand Down
4 changes: 3 additions & 1 deletion tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ link_directories(/usr/local/cuda/lib64)
link_directories(${PROJECT_SOURCE_DIR}/build)
# include_directories(/root/build/TensorRT-8.2.5.1/include)
# link_directories(/root/build/TensorRT-8.2.5.1/lib)
include_directories(/data/zzy/trt_install/TensorRT-8.2.5.1/include)
link_directories(/data/zzy/trt_install/TensorRT-8.2.5.1/lib)


find_package(CUDA REQUIRED)
find_package(OpenCV REQUIRED)

cuda_add_library(kernels STATIC kernels.cu)

add_executable(segment segment.cpp trt_dep.cpp)
add_executable(segment segment.cpp trt_dep.cpp read_img.cpp)
target_include_directories(
segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
target_link_libraries(
Expand Down
12 changes: 11 additions & 1 deletion tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,22 @@ This would generate a `./segment` in the `tensorrt/build` directory.


#### 3. Convert onnx to tensorrt model
If you can successfully compile the source code, you can parse the onnx model to tensorrt model like this:
If you can successfully compile the source code, you can parse the onnx model to tensorrt model with one of the following commands.
For fp32, command is:
```
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt
```
If your gpu support acceleration with fp16 inferenece, you can add a `--fp16` option to in this step:
```
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16
```
Building an int8 engine is also supported. Firstly, you should make sure your gpu support int8 inference, or you model will not be faster than fp16/fp32. Then you should prepare certain amount of images for int8 calibration. In this example, I use train set of cityscapes for calibration. The command is like this:
```
$ calibrate_int8 # delete this if exists
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --int8 /path/to/BiSeNet/datasets/cityscapes /path/to/BiSeNet/datasets/cityscapes/train.txt
```
With the above commands, we will have an tensorrt engine named `saved_model.trt` generated.

Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the args in above command.


Expand Down Expand Up @@ -74,6 +82,8 @@ Likewise, you do not need to worry about this anymore with version newer than 7.

4. On my platform, after compiling with tensorrt, the model size of bisenetv1 is 29Mb(fp16) and 128Mb(fp32), and the size of bisenetv2 is 16Mb(fp16) and 42Mb(fp32). However, the fps of bisenetv1 is 68(fp16) and 23(fp32), while the fps of bisenetv2 is 59(fp16) and 21(fp32). It is obvious that bisenetv2 has fewer parameters than bisenetv1, but the speed is otherwise. I am not sure whether it is because tensorrt has worse optimization strategy in some ops used in bisenetv2(such as depthwise convolution) or because of the limitation of the gpu on different ops. Please tell me if you have better idea on this.

5. int8 mode is not always greatly faster than fp16 mode. For example, I tested with bisenetv1-cityscapes and tensorrt 8.2.5.1. With v100 gpu and driver 515.65, the fp16/int8 fps is 185.89/186.85, while with t4 gpu and driver 450.80, it is 78.77/142.31.


### Using python

Expand Down
148 changes: 148 additions & 0 deletions tensorrt/batch_stream.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

#ifndef BATCH_STREAM_HPP
#define BATCH_STREAM_HPP


#include <string>
#include <sstream>
#include <fstream>
#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <opencv2/opencv.hpp>

#include "NvInfer.h"
#include "read_img.hpp"

using nvinfer1::Dims;
using nvinfer1::Dims3;
using nvinfer1::Dims4;


class IBatchStream
{
public:
virtual void reset(int firstBatch) = 0;
virtual bool next() = 0;
virtual void skip(int skipCount) = 0;
virtual float* getBatch() = 0;
virtual int getBatchesRead() const = 0;
virtual int getBatchSize() const = 0;
virtual nvinfer1::Dims4 getDims() const = 0;
};


class BatchStream : public IBatchStream
{
public:
BatchStream(int batchSize, int maxBatches, Dims indim,
const std::string& dataRoot,
const std::string& dataFile)
: mBatchSize{batchSize}
, mMaxBatches{maxBatches}
{
mDims = Dims3(indim.d[1], indim.d[2], indim.d[3]);

readDataFile(dataFile, dataRoot);
mSampleSize = std::accumulate(
mDims.d, mDims.d + mDims.nbDims, 1, std::multiplies<int64_t>()) * sizeof(float);
mData.resize(mSampleSize * mBatchSize);
}

void reset(int firstBatch) override
{
cout << "mBatchCount: " << mBatchCount << endl;
mBatchCount = firstBatch;
}

bool next() override
{
if (mBatchCount >= mMaxBatches)
{
return false;
}
++mBatchCount;
return true;
}

void skip(int skipCount) override
{
mBatchCount += skipCount;
}

float* getBatch() override
{
int offset = mBatchCount * mBatchSize;
for (int i{0}; i < mBatchSize; ++i) {
int ind = offset + i;
read_data(mPaths[ind], &mData[i * mSampleSize], mDims.d[1], mDims.d[2]);
}
return mData.data();
}

int getBatchesRead() const override
{
return mBatchCount;
}

int getBatchSize() const override
{
return mBatchSize;
}

nvinfer1::Dims4 getDims() const override
{
return Dims4{mBatchSize, mDims.d[0], mDims.d[1], mDims.d[2]};
}

private:
void readDataFile(const std::string& dataFilePath, const std::string& dataRootPath)
{
std::ifstream file(dataFilePath, std::ios::in);
if (!file.is_open()) {
cout << "file open failed: " << dataFilePath << endl;
std::abort();
}
std::stringstream ss;
file >> ss.rdbuf();
file.close();

std::string impth;
int n_imgs = 0;
while (std::getline(ss, impth)) ++n_imgs;
ss.clear(); ss.seekg(0, std::ios::beg);
if (n_imgs <= 0) {
cout << "ann file is empty, cannot read image paths for int8 calibration: "
<< dataFilePath << endl;
std::abort();
}

mPaths.resize(n_imgs);
for (int i{0}; i < n_imgs; ++i) {
std::getline(ss, impth, ',');
mPaths[i] = dataRootPath + "/" + impth;
std::getline(ss, impth);
}
if (mMaxBatches < 0) {
mMaxBatches = n_imgs / mBatchSize - 1;
}
if (mMaxBatches <= 0) {
cout << "must have at least 1 batch for calibration\n";
std::abort();
}
cout << "mMaxBatches = " << mMaxBatches << endl;
}


int mBatchSize{0};
int mBatchCount{0};
int mMaxBatches{0};
Dims3 mDims{};
std::vector<string> mPaths;
std::vector<float> mData;
int mSampleSize{0};
};


#endif
160 changes: 160 additions & 0 deletions tensorrt/entropy_calibrator.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ENTROPY_CALIBRATOR_HPP
#define ENTROPY_CALIBRATOR_HPP

#include <algorithm>
#include <numeric>
#include <iterator>
#include "NvInfer.h"

//! \class EntropyCalibratorImpl
//!
//! \brief Implements common functionality for Entropy calibrators.
//!
template <typename TBatchStream>
class EntropyCalibratorImpl
{
public:
EntropyCalibratorImpl(
TBatchStream stream, int firstBatch, std::string cal_table_name, const char* inputBlobName, bool readCache = true)
: mStream{stream}
, mCalibrationTableName(cal_table_name)
, mInputBlobName(inputBlobName)
, mReadCache(readCache)
{
nvinfer1::Dims4 dims = mStream.getDims();
mInputCount = std::accumulate(
dims.d, dims.d + dims.nbDims, 1, std::multiplies<int64_t>());
cout << "dims.nbDims: " << dims.nbDims << endl;
for (int i{0}; i < dims.nbDims; ++i) {
cout << dims.d[i] << ", ";
}
cout << endl;

cudaError_t state;
state = cudaMalloc(&mDeviceInput, mInputCount * sizeof(float));
if (state) {
cout << "allocate memory failed\n";
std::abort();
}
cout << "mInputCount: " << mInputCount << endl;
mStream.reset(firstBatch);
}

virtual ~EntropyCalibratorImpl()
{
cudaError_t state;
state = cudaFree(mDeviceInput);
if (state) {
cout << "free memory failed\n";
std::abort();
}
}

int getBatchSize() const
{
return mStream.getBatchSize();
}

bool getBatch(void* bindings[], const char* names[], int nbBindings)
{
if (!mStream.next())
{
return false;
}
cudaError_t state;
state = cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice);
if (state) {
cout << "allocate memory failed\n";
std::abort();
}
assert(!strcmp(names[0], mInputBlobName));
bindings[0] = mDeviceInput;
return true;
}

const void* readCalibrationCache(size_t& length)
{
mCalibrationCache.clear();
std::ifstream input(mCalibrationTableName, std::ios::binary);
input >> std::noskipws;
if (mReadCache && input.good())
{
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
std::back_inserter(mCalibrationCache));
}
length = mCalibrationCache.size();
return length ? mCalibrationCache.data() : nullptr;
}

void writeCalibrationCache(const void* cache, size_t length)
{
std::ofstream output(mCalibrationTableName, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
}

private:
TBatchStream mStream;
size_t mInputCount;
std::string mCalibrationTableName;
const char* mInputBlobName;
bool mReadCache{true};
void* mDeviceInput{nullptr};
std::vector<char> mCalibrationCache;
};

//! \class Int8EntropyCalibrator2
//!
//! \brief Implements Entropy calibrator 2.
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
//!
template <typename TBatchStream>
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
{
public:
Int8EntropyCalibrator2(
TBatchStream stream, int firstBatch, const char* networkName, const char* inputBlobName, bool readCache = true)
: mImpl(stream, firstBatch, networkName, inputBlobName, readCache)
{
}

int getBatchSize() const noexcept override
{
return mImpl.getBatchSize();
}

bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override
{
return mImpl.getBatch(bindings, names, nbBindings);
}

const void* readCalibrationCache(size_t& length) noexcept override
{
return mImpl.readCalibrationCache(length);
}

void writeCalibrationCache(const void* cache, size_t length) noexcept override
{
mImpl.writeCalibrationCache(cache, length);
}

private:
EntropyCalibratorImpl<TBatchStream> mImpl;
};

#endif // ENTROPY_CALIBRATOR_H
Loading

0 comments on commit 55b877b

Please sign in to comment.