diff --git a/README.md b/README.md index 1f383d8..8aac4d9 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV mIOUs and fps on cityscapes val set: | none | ss | ssc | msf | mscf | fps(fp16/fp32) | link | |------|:--:|:---:|:---:|:----:|:---:|:----:| -| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 68/23 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) | -| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 59/21 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) | +| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 78/25 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) | +| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 67/26 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) | mIOUs on cocostuff val2017 set: | none | ss | ssc | msf | mscf | link | diff --git a/tensorrt/CMakeLists.txt b/tensorrt/CMakeLists.txt index 1ccace7..1336e3a 100644 --- a/tensorrt/CMakeLists.txt +++ b/tensorrt/CMakeLists.txt @@ -1,11 +1,13 @@ -CMAKE_MINIMUM_REQUIRED(VERSION 2.8) +CMAKE_MINIMUM_REQUIRED(VERSION 3.17) PROJECT(segment) -set(CMAKE_CXX_FLAGS "-std=c++14 -O1") +set(CMAKE_CXX_FLAGS "-std=c++14 -O2") +set(CMAKE_NVCC_FLAGS "-std=c++14 -O2") link_directories(/usr/local/cuda/lib64) +link_directories(${PROJECT_SOURCE_DIR}/build) # include_directories(/root/build/TensorRT-8.2.5.1/include) # link_directories(/root/build/TensorRT-8.2.5.1/lib) @@ -17,7 +19,8 @@ add_executable(segment segment.cpp trt_dep.cpp) target_include_directories( segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS}) target_link_libraries( - segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser + segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser -lkernels ${CUDA_LIBRARIES} - ${OpenCV_LIBRARIES} - ) + ${OpenCV_LIBRARIES}) + +cuda_add_library(kernels STATIC kernels.cu) diff --git a/tensorrt/README.md b/tensorrt/README.md index f35e04b..a0227db 100644 --- a/tensorrt/README.md +++ b/tensorrt/README.md @@ -5,7 +5,7 @@ Firstly, We should export our trained model to onnx model: ``` $ cd BiSeNet/ -$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx +$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx --aux-mode eval ``` **NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here. diff --git a/tensorrt/kernels.cu b/tensorrt/kernels.cu new file mode 100644 index 0000000..b9bf4de --- /dev/null +++ b/tensorrt/kernels.cu @@ -0,0 +1,158 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "NvInfer.h" + + + +#define BLOCKSIZE 512 + +#define ivpair thrust::pair + + +template +__forceinline__ __device__ void reduce_max(ivpair* sdata, int blocksize, int tid) { + __syncthreads(); + for (int s{blocksize / 2}; s > 0; s >>= 1) { + if (tid < s) { + if (sdata[tid].first < sdata[tid + s].first) { + sdata[tid] = sdata[tid + s]; + } + } + __syncthreads(); + } +} + + +template +__global__ void arg_max_depth(const int n_size, + const int dimsize, const int m_size, + const scalar_t *inten, + int *oten) { + extern __shared__ __align__(sizeof(ivpair)) unsigned char sdata_raw[]; + ivpair *sdata = reinterpret_cast(sdata_raw); + sdata = sdata + blockDim.x * threadIdx.y; + + int sample_offset = gridDim.x * blockDim.y; + int bid = threadIdx.y + blockIdx.x * blockDim.y; + int samplesize = n_size * m_size; + + for (int i{bid}; i < samplesize; i += sample_offset) { + int n_idx = i / m_size; + int m_idx = i % m_size; + + /// NOTE: This is not memory-safe when dimsize < blockDim.x + int idx = n_idx * dimsize * m_size + threadIdx.x * m_size + m_idx; + ivpair maxp = thrust::make_pair(inten[idx], threadIdx.x); + int j = threadIdx.x + blockDim.x; + for (; j < dimsize; j += blockDim.x) { + idx += blockDim.x * m_size; + scalar_t val = inten[idx]; + if (val > maxp.first) { + maxp = thrust::make_pair(val, j); + } + } + sdata[threadIdx.x] = maxp; + __syncthreads(); + reduce_max(sdata, blockDim.x, threadIdx.x); + + idx = n_idx * m_size + m_idx; + oten[idx] = sdata[0].second; + } +} + + +template +__global__ void arg_max_spatial(const int n_size, + const int dimsize, const int m_size, + const scalar_t *inten, + int *oten) { + + int sample_offset = gridDim.x * blockDim.x; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int samplesize = n_size * m_size; + + for (int i{tid}; i < samplesize; i += sample_offset) { + int n_idx = i / m_size; + int m_idx = i % m_size; + + // obtain max + int idx = n_idx * dimsize * m_size + m_idx; + scalar_t max_val = inten[idx]; + int res = 0; + for (int j{1}; j < dimsize; ++j) { + idx += m_size; + scalar_t val = inten[idx]; + if (val > max_val) { + max_val = val; + res = j; + } + } + idx = n_idx * m_size + m_idx; + oten[idx] = res; + } +} + + +void argMaxFunc(const void *inten, + void *oten, const int n_size, + const int dimsize, const int m_size, + cudaStream_t* stream) { + if (inten == nullptr or oten == nullptr) std::abort(); + + int samplesize = n_size * m_size; + int shm_size = 0; + dim3 grid, block; + + if (dimsize <= 256) { + int blockx, gridx; + cudaOccupancyMaxPotentialBlockSize(&gridx, &blockx, + arg_max_spatial, 0, samplesize); + gridx = std::min(4096, gridx << 2); + block.x = blockx; grid.x = gridx; + + if (stream == nullptr) { + arg_max_spatial<<>>( + n_size, dimsize, m_size, + reinterpret_cast(inten), + reinterpret_cast(oten)); + } else { + arg_max_spatial<<>>( + n_size, dimsize, m_size, + reinterpret_cast(inten), + reinterpret_cast(oten)); + } + + } else { + int blockx, blocky, gridx; + shm_size = (sizeof(float) + sizeof(int)) * BLOCKSIZE; + int block_lmt = std::min(BLOCKSIZE, dimsize); + blockx = 32; + while (blockx <= block_lmt) blockx = (blockx << 1); + blockx = (blockx >> 1); // must make sure dimsize > blockx + blocky = BLOCKSIZE / blockx; + gridx = std::min(4096, samplesize / blocky); + block.x = blockx; block.y = blocky; grid.x = gridx; + + if (stream == nullptr) { + arg_max_depth<<>>( + n_size, dimsize, m_size, + reinterpret_cast(inten), + reinterpret_cast(oten)); + } else { + arg_max_depth<<>>( + n_size, dimsize, m_size, + reinterpret_cast(inten), + reinterpret_cast(oten)); + } + } + + +} + diff --git a/tensorrt/kernels.hpp b/tensorrt/kernels.hpp new file mode 100644 index 0000000..db05f6e --- /dev/null +++ b/tensorrt/kernels.hpp @@ -0,0 +1,13 @@ +#ifndef _KERNELS_HPP_ +#define _KERNELS_HPP_ + +#include +#include + + +void argMaxFunc(const void *inten, + void *oten, const int n_size, + const int dimsize, const int m_size, + cudaStream_t* stream); + +#endif diff --git a/tensorrt/segment.cpp b/tensorrt/segment.cpp index 47f86d2..2d00ce7 100644 --- a/tensorrt/segment.cpp +++ b/tensorrt/segment.cpp @@ -102,7 +102,7 @@ void run_with_trt(vector args) { Dims3 o_dims = static_cast( engine->getBindingDimensions(engine->getBindingIndex("preds"))); const int iH{i_dims.d[2]}, iW{i_dims.d[3]}; - const int oH{o_dims.d[1]}, oW{o_dims.d[2]}; + const int oH{o_dims.d[2]}, oW{o_dims.d[3]}; // prepare image and resize Mat im = cv::imread(args[2]); @@ -150,13 +150,13 @@ void run_with_trt(vector args) { ptr[1] = color_map[res[idx]][1]; ptr[2] = color_map[res[idx]][2]; ptr += 3; - ++ idx; + ++idx; } } // resize back and save if ((orgH != oH) || orgW != oW) { - cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_NEAREST); + cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_CUBIC); } cv::imwrite(args[3], pred); diff --git a/tensorrt/segment.py b/tensorrt/segment.py index 44af4f4..32e34d4 100644 --- a/tensorrt/segment.py +++ b/tensorrt/segment.py @@ -143,10 +143,10 @@ def main(): cuda.memcpy_dtoh_async(h_output, d_output, stream) stream.synchronize() - out = palette[h_outputs[0]] - outshape = engine.get_binding_shape(1) - H, W = outshape[1], outshape[2] - out = out.reshape(H, W, 3) + oshape = engine.get_binding_shape(1) + pred = np.argmax(h_outputs[0].reshape(oshape), axis=1) + out = palette[pred] + out = out.reshape(*oshape[2:], 3) out = cv2.resize(out, (orgW, orgH)) cv2.imwrite(args.outpth, out) diff --git a/tensorrt/trt_dep.cpp b/tensorrt/trt_dep.cpp index 355f4af..905bed9 100644 --- a/tensorrt/trt_dep.cpp +++ b/tensorrt/trt_dep.cpp @@ -4,10 +4,12 @@ #include #include #include +#include #include #include #include "trt_dep.hpp" +#include "kernels.hpp" using nvinfer1::IHostMemory; @@ -43,7 +45,7 @@ TrtSharedEnginePtr shared_engine_ptr(ICudaEngine* ptr) { TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) { unsigned int maxBatchSize{1}; - int memory_limit = 1U << 30; // 1G + long memory_limit = 1UL << 32; // 4G auto builder = TrtUnqPtr(nvinfer1::createInferBuilder(gLogger)); if (!builder) { @@ -86,7 +88,8 @@ TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) { } auto output = network->getOutput(0); - output->setType(nvinfer1::DataType::kINT32); + // output->setType(nvinfer1::DataType::kINT32); + output->setType(nvinfer1::DataType::kFLOAT); cout << " start to build \n"; CudaStreamUnqPtr stream(new cudaStream_t); @@ -154,7 +157,7 @@ TrtSharedEnginePtr deserialize(string serpth) { auto runtime = TrtUnqPtr(nvinfer1::createInferRuntime(gLogger)); TrtSharedEnginePtr engine = shared_engine_ptr( - runtime->deserializeCudaEngine((void*)&buf[0], mdsize, nullptr)); + runtime->deserializeCudaEngine((void*)&buf[0], mdsize)); return engine; } @@ -163,10 +166,12 @@ vector infer_with_engine(TrtSharedEnginePtr engine, vector& data) { Dims3 out_dims = static_cast( engine->getBindingDimensions(engine->getBindingIndex("preds"))); - const int batchsize{1}, H{out_dims.d[1]}, W{out_dims.d[2]}; + const int batchsize{1}, H{out_dims.d[2]}, W{out_dims.d[3]}; + const int n_classes{out_dims.d[1]}; const int in_size{static_cast(data.size())}; + const int logits_size{batchsize * n_classes * H * W}; const int out_size{batchsize * H * W}; - vector buffs(2); + vector buffs(3); vector res(out_size); auto context = TrtUnqPtr(engine->createExecutionContext()); @@ -181,7 +186,12 @@ vector infer_with_engine(TrtSharedEnginePtr engine, vector& data) { cout << "allocate memory failed\n"; std::abort(); } - state = cudaMalloc(&buffs[1], out_size * sizeof(int)); + state = cudaMalloc(&buffs[1], logits_size * sizeof(float)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + state = cudaMalloc(&buffs[2], out_size * sizeof(int)); if (state) { cout << "allocate memory failed\n"; std::abort(); @@ -199,19 +209,24 @@ vector infer_with_engine(TrtSharedEnginePtr engine, vector& data) { cout << "transmit to device failed\n"; std::abort(); } + context->enqueueV2(&buffs[0], *stream, nullptr); // context->enqueue(1, &buffs[0], stream, nullptr); + argMaxFunc(buffs[1], buffs[2], batchsize, n_classes, H * W, stream.get()); + state = cudaMemcpyAsync( - &res[0], buffs[1], out_size * sizeof(int), + &res[0], buffs[2], out_size * sizeof(int), cudaMemcpyDeviceToHost, *stream); if (state) { cout << "transmit to host failed \n"; std::abort(); } + cudaStreamSynchronize(*stream); cudaFree(buffs[0]); cudaFree(buffs[1]); + cudaFree(buffs[2]); return res; } @@ -222,10 +237,13 @@ void test_fps_with_engine(TrtSharedEnginePtr engine) { engine->getBindingDimensions(engine->getBindingIndex("input_image"))); Dims3 out_dims = static_cast( engine->getBindingDimensions(engine->getBindingIndex("preds"))); + const int batchsize{1}; - const int oH{out_dims.d[1]}, oW{out_dims.d[2]}; + const int oH{out_dims.d[2]}, oW{out_dims.d[3]}; + const int n_classes{out_dims.d[1]}; const int iH{in_dims.d[2]}, iW{in_dims.d[3]}; const int in_size{batchsize * 3 * iH * iW}; + const int logits_size{batchsize * n_classes * oH * oW}; const int out_size{batchsize * oH * oW}; auto context = TrtUnqPtr(engine->createExecutionContext()); @@ -234,14 +252,19 @@ void test_fps_with_engine(TrtSharedEnginePtr engine) { std::abort(); } - vector buffs(2); + vector buffs(3); cudaError_t state; state = cudaMalloc(&buffs[0], in_size * sizeof(float)); if (state) { cout << "allocate memory failed\n"; std::abort(); } - state = cudaMalloc(&buffs[1], out_size * sizeof(int)); + state = cudaMalloc(&buffs[1], logits_size * sizeof(float)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + state = cudaMalloc(&buffs[2], out_size * sizeof(int)); if (state) { cout << "allocate memory failed\n"; std::abort(); @@ -253,6 +276,7 @@ void test_fps_with_engine(TrtSharedEnginePtr engine) { for (int i{0}; i < n_loops; ++i) { // context->execute(1, &buffs[0]); context->executeV2(&buffs[0]); + argMaxFunc(buffs[1], buffs[2], batchsize, n_classes, oH * oW, nullptr); } auto end = std::chrono::steady_clock::now(); double duration = std::chrono::duration(end - start).count(); @@ -261,7 +285,9 @@ void test_fps_with_engine(TrtSharedEnginePtr engine) { << duration << "s" << endl; cout << "fps is: " << static_cast(n_loops) / duration << endl; + cudaFree(buffs[0]); cudaFree(buffs[1]); + cudaFree(buffs[2]); }