Skip to content

Commit

Permalink
write argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Aug 11, 2022
1 parent 6b9e2cf commit c4e57f5
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 25 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV
mIOUs and fps on cityscapes val set:
| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link |
|------|:--:|:---:|:---:|:----:|:---:|:----:|
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 68/23 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 59/21 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 78/25 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 67/26 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |

mIOUs on cocostuff val2017 set:
| none | ss | ssc | msf | mscf | link |
Expand Down
13 changes: 8 additions & 5 deletions tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
CMAKE_MINIMUM_REQUIRED(VERSION 2.8)
CMAKE_MINIMUM_REQUIRED(VERSION 3.17)

PROJECT(segment)

set(CMAKE_CXX_FLAGS "-std=c++14 -O1")
set(CMAKE_CXX_FLAGS "-std=c++14 -O2")
set(CMAKE_NVCC_FLAGS "-std=c++14 -O2")


link_directories(/usr/local/cuda/lib64)
link_directories(${PROJECT_SOURCE_DIR}/build)
# include_directories(/root/build/TensorRT-8.2.5.1/include)
# link_directories(/root/build/TensorRT-8.2.5.1/lib)

Expand All @@ -17,7 +19,8 @@ add_executable(segment segment.cpp trt_dep.cpp)
target_include_directories(
segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
target_link_libraries(
segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser
segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser -lkernels
${CUDA_LIBRARIES}
${OpenCV_LIBRARIES}
)
${OpenCV_LIBRARIES})

cuda_add_library(kernels STATIC kernels.cu)
2 changes: 1 addition & 1 deletion tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Firstly, We should export our trained model to onnx model:
```
$ cd BiSeNet/
$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx
$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx --aux-mode eval
```

**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here.
Expand Down
158 changes: 158 additions & 0 deletions tensorrt/kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@

#include <iostream>
#include <functional>
#include <algorithm>
#include <cfloat>
#include <thrust/pair.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "NvInfer.h"



#define BLOCKSIZE 512

#define ivpair thrust::pair<scalar_t, int>


template<typename scalar_t>
__forceinline__ __device__ void reduce_max(ivpair* sdata, int blocksize, int tid) {
__syncthreads();
for (int s{blocksize / 2}; s > 0; s >>= 1) {
if (tid < s) {
if (sdata[tid].first < sdata[tid + s].first) {
sdata[tid] = sdata[tid + s];
}
}
__syncthreads();
}
}


template<typename scalar_t>
__global__ void arg_max_depth(const int n_size,
const int dimsize, const int m_size,
const scalar_t *inten,
int *oten) {
extern __shared__ __align__(sizeof(ivpair)) unsigned char sdata_raw[];
ivpair *sdata = reinterpret_cast<ivpair*>(sdata_raw);
sdata = sdata + blockDim.x * threadIdx.y;

int sample_offset = gridDim.x * blockDim.y;
int bid = threadIdx.y + blockIdx.x * blockDim.y;
int samplesize = n_size * m_size;

for (int i{bid}; i < samplesize; i += sample_offset) {
int n_idx = i / m_size;
int m_idx = i % m_size;

/// NOTE: This is not memory-safe when dimsize < blockDim.x
int idx = n_idx * dimsize * m_size + threadIdx.x * m_size + m_idx;
ivpair maxp = thrust::make_pair(inten[idx], threadIdx.x);
int j = threadIdx.x + blockDim.x;
for (; j < dimsize; j += blockDim.x) {
idx += blockDim.x * m_size;
scalar_t val = inten[idx];
if (val > maxp.first) {
maxp = thrust::make_pair(val, j);
}
}
sdata[threadIdx.x] = maxp;
__syncthreads();
reduce_max(sdata, blockDim.x, threadIdx.x);

idx = n_idx * m_size + m_idx;
oten[idx] = sdata[0].second;
}
}


template<typename scalar_t>
__global__ void arg_max_spatial(const int n_size,
const int dimsize, const int m_size,
const scalar_t *inten,
int *oten) {

int sample_offset = gridDim.x * blockDim.x;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int samplesize = n_size * m_size;

for (int i{tid}; i < samplesize; i += sample_offset) {
int n_idx = i / m_size;
int m_idx = i % m_size;

// obtain max
int idx = n_idx * dimsize * m_size + m_idx;
scalar_t max_val = inten[idx];
int res = 0;
for (int j{1}; j < dimsize; ++j) {
idx += m_size;
scalar_t val = inten[idx];
if (val > max_val) {
max_val = val;
res = j;
}
}
idx = n_idx * m_size + m_idx;
oten[idx] = res;
}
}


void argMaxFunc(const void *inten,
void *oten, const int n_size,
const int dimsize, const int m_size,
cudaStream_t* stream) {
if (inten == nullptr or oten == nullptr) std::abort();

int samplesize = n_size * m_size;
int shm_size = 0;
dim3 grid, block;

if (dimsize <= 256) {
int blockx, gridx;
cudaOccupancyMaxPotentialBlockSize(&gridx, &blockx,
arg_max_spatial<float>, 0, samplesize);
gridx = std::min(4096, gridx << 2);
block.x = blockx; grid.x = gridx;

if (stream == nullptr) {
arg_max_spatial<float><<<grid, block, shm_size>>>(
n_size, dimsize, m_size,
reinterpret_cast<const float*>(inten),
reinterpret_cast<int*>(oten));
} else {
arg_max_spatial<float><<<grid, block, shm_size, *stream>>>(
n_size, dimsize, m_size,
reinterpret_cast<const float*>(inten),
reinterpret_cast<int*>(oten));
}

} else {
int blockx, blocky, gridx;
shm_size = (sizeof(float) + sizeof(int)) * BLOCKSIZE;
int block_lmt = std::min(BLOCKSIZE, dimsize);
blockx = 32;
while (blockx <= block_lmt) blockx = (blockx << 1);
blockx = (blockx >> 1); // must make sure dimsize > blockx
blocky = BLOCKSIZE / blockx;
gridx = std::min(4096, samplesize / blocky);
block.x = blockx; block.y = blocky; grid.x = gridx;

if (stream == nullptr) {
arg_max_depth<float><<<grid, block, shm_size>>>(
n_size, dimsize, m_size,
reinterpret_cast<const float*>(inten),
reinterpret_cast<int*>(oten));
} else {
arg_max_depth<float><<<grid, block, shm_size, *stream>>>(
n_size, dimsize, m_size,
reinterpret_cast<const float*>(inten),
reinterpret_cast<int*>(oten));
}
}


}

13 changes: 13 additions & 0 deletions tensorrt/kernels.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef _KERNELS_HPP_
#define _KERNELS_HPP_

#include <cuda.h>
#include <cuda_runtime.h>


void argMaxFunc(const void *inten,
void *oten, const int n_size,
const int dimsize, const int m_size,
cudaStream_t* stream);

#endif
6 changes: 3 additions & 3 deletions tensorrt/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void run_with_trt(vector<string> args) {
Dims3 o_dims = static_cast<Dims3&&>(
engine->getBindingDimensions(engine->getBindingIndex("preds")));
const int iH{i_dims.d[2]}, iW{i_dims.d[3]};
const int oH{o_dims.d[1]}, oW{o_dims.d[2]};
const int oH{o_dims.d[2]}, oW{o_dims.d[3]};

// prepare image and resize
Mat im = cv::imread(args[2]);
Expand Down Expand Up @@ -150,13 +150,13 @@ void run_with_trt(vector<string> args) {
ptr[1] = color_map[res[idx]][1];
ptr[2] = color_map[res[idx]][2];
ptr += 3;
++ idx;
++idx;
}
}

// resize back and save
if ((orgH != oH) || orgW != oW) {
cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_NEAREST);
cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_CUBIC);
}
cv::imwrite(args[3], pred);

Expand Down
8 changes: 4 additions & 4 deletions tensorrt/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def main():
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()

out = palette[h_outputs[0]]
outshape = engine.get_binding_shape(1)
H, W = outshape[1], outshape[2]
out = out.reshape(H, W, 3)
oshape = engine.get_binding_shape(1)
pred = np.argmax(h_outputs[0].reshape(oshape), axis=1)
out = palette[pred]
out = out.reshape(*oshape[2:], 3)
out = cv2.resize(out, (orgW, orgH))
cv2.imwrite(args.outpth, out)

Expand Down
Loading

0 comments on commit c4e57f5

Please sign in to comment.