Skip to content

Commit

Permalink
Fix Inference Anomaly Caused by preprocess.cu on Linux #5
Browse files Browse the repository at this point in the history
  • Loading branch information
laugh12321 committed Feb 25, 2024
1 parent 42276e1 commit cdb5c35
Showing 1 changed file with 58 additions and 77 deletions.
135 changes: 58 additions & 77 deletions python/infer/preprocess.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,115 +12,96 @@ extern "C" __global__ void preprocess_kernel_fp16(uint8_t* src, int src_line_siz
int dx = _X, dy = _Y;
if (dx >= dst_width || dy >= dst_height) return;

// 使用寄存器存储临时变量
float src_x = d2s[0] * dx + d2s[1] * dy + d2s[2] + 0.5f;
float src_y = d2s[3] * dx + d2s[4] * dy + d2s[5] + 0.5f;
float c0 = fill_value, c1 = fill_value, c2 = fill_value;

if (dx >= 0 && dy >= 0 && dx < dst_width && dy < dst_height) {


// 计算权重
int x_low = max(0, min(static_cast<int>(floorf(src_x)), src_width - 1));
int y_low = max(0, min(static_cast<int>(floorf(src_y)), src_height - 1));
int x_high = min(src_width - 1, x_low + 1);
int y_high = min(src_height - 1, y_low + 1);

float ly = src_y - y_low;
float lx = src_x - x_low;
float hy = 1 - ly;
float hx = 1 - lx;

float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

// 直接访问像素值,不使用指针
uint8_t* v1 = src + y_low * src_line_size + x_low * 3;
uint8_t* v2 = src + y_low * src_line_size + x_high * 3;
uint8_t* v3 = src + y_high * src_line_size + x_low * 3;
uint8_t* v4 = src + y_high * src_line_size + x_high * 3;

// 权重和计算,不使用分支
c0 = fmaf(w1, v1[0], fmaf(w2, v2[0], fmaf(w3, v3[0], w4 * v4[0])));
c1 = fmaf(w1, v1[1], fmaf(w2, v2[1], fmaf(w3, v3[1], w4 * v4[1])));
c2 = fmaf(w1, v1[2], fmaf(w2, v2[2], fmaf(w3, v3[2], w4 * v4[2])));
}
// gbr -> rgb
float temp = c2;
c2 = c0;
c0 = temp;

// 归一化
int y_low = floorf(d2s[3] * dx + d2s[4] * dy + d2s[5] + 0.5f);
int x_low = floorf(d2s[0] * dx + d2s[1] * dy + d2s[2] + 0.5f);
int y_high = y_low + 1;
int x_high = x_low + 1;

int indices[4];
indices[0] = y_low * src_line_size + x_low * sizeof(uint8_t) * 3;
indices[1] = y_low * src_line_size + x_high * sizeof(uint8_t) * 3;
indices[2] = y_high * src_line_size + x_low * sizeof(uint8_t) * 3;
indices[3] = y_high * src_line_size + x_high * sizeof(uint8_t) * 3;

uchar3* v1 = reinterpret_cast<uchar3*>(src + indices[0]);
uchar3* v2 = reinterpret_cast<uchar3*>(src + indices[1]);
uchar3* v3 = reinterpret_cast<uchar3*>(src + indices[2]);
uchar3* v4 = reinterpret_cast<uchar3*>(src + indices[3]);

float ly = d2s[3] * dx + d2s[4] * dy + d2s[5] + 0.5f - y_low;
float lx = d2s[0] * dx + d2s[1] * dy + d2s[2] + 0.5f - x_low;
float hy = 1 - ly;
float hx = 1 - lx;
float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

// bgr -> rgb
c0 = w1 * v1->z + w2 * v2->z + w3 * v3->z + w4 * v4->z;
c1 = w1 * v1->y + w2 * v2->y + w3 * v3->y + w4 * v4->y;
c2 = w1 * v1->x + w2 * v2->x + w3 * v3->x + w4 * v4->x;

// 合并归一化操作
c0 /= 255.0f;
c1 /= 255.0f;
c2 /= 255.0f;

// rgbrgbrgb rrrgggbbb
// rgbrgbrgb -> rrrgggbbb
int area = dst_width * dst_height;
half* pdst_c0 = dst + dy * dst_width + dx;
half* pdst_c1 = pdst_c0 + area;
half* pdst_c2 = pdst_c1 + area;

*pdst_c0 = __float2half(c0);
*pdst_c1 = __float2half(c1);
*pdst_c2 = __float2half(c2);
}


extern "C" __global__ void preprocess_kernel_fp32(uint8_t* src, int src_line_size,
int src_width, int src_height, float* dst,
int dst_width, int dst_height,
uint8_t fill_value, const float* d2s) {
int dx = _X, dy = _Y;
if (dx >= dst_width || dy >= dst_height) return;

// 使用寄存器存储临时变量
float src_x = d2s[0] * dx + d2s[1] * dy + d2s[2] + 0.5f;
float src_y = d2s[3] * dx + d2s[4] * dy + d2s[5] + 0.5f;
float c0 = fill_value, c1 = fill_value, c2 = fill_value;

if (dx >= 0 && dy >= 0 && dx < dst_width && dy < dst_height) {


// 计算权重
int x_low = max(0, min(static_cast<int>(floorf(src_x)), src_width - 1));
int y_low = max(0, min(static_cast<int>(floorf(src_y)), src_height - 1));
int x_high = min(src_width - 1, x_low + 1);
int y_high = min(src_height - 1, y_low + 1);

float ly = src_y - y_low;
float lx = src_x - x_low;
float hy = 1 - ly;
float hx = 1 - lx;

float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

// 直接访问像素值,不使用指针
uint8_t* v1 = src + y_low * src_line_size + x_low * 3;
uint8_t* v2 = src + y_low * src_line_size + x_high * 3;
uint8_t* v3 = src + y_high * src_line_size + x_low * 3;
uint8_t* v4 = src + y_high * src_line_size + x_high * 3;

// 权重和计算,不使用分支
c0 = fmaf(w1, v1[0], fmaf(w2, v2[0], fmaf(w3, v3[0], w4 * v4[0])));
c1 = fmaf(w1, v1[1], fmaf(w2, v2[1], fmaf(w3, v3[1], w4 * v4[1])));
c2 = fmaf(w1, v1[2], fmaf(w2, v2[2], fmaf(w3, v3[2], w4 * v4[2])));
}
// gbr -> rgb
float temp = c2;
c2 = c0;
c0 = temp;

// 归一化
int y_low = floorf(d2s[3] * dx + d2s[4] * dy + d2s[5] + 0.5f);
int x_low = floorf(d2s[0] * dx + d2s[1] * dy + d2s[2] + 0.5f);
int y_high = y_low + 1;
int x_high = x_low + 1;

int indices[4];
indices[0] = y_low * src_line_size + x_low * sizeof(uint8_t) * 3;
indices[1] = y_low * src_line_size + x_high * sizeof(uint8_t) * 3;
indices[2] = y_high * src_line_size + x_low * sizeof(uint8_t) * 3;
indices[3] = y_high * src_line_size + x_high * sizeof(uint8_t) * 3;

uchar3* v1 = reinterpret_cast<uchar3*>(src + indices[0]);
uchar3* v2 = reinterpret_cast<uchar3*>(src + indices[1]);
uchar3* v3 = reinterpret_cast<uchar3*>(src + indices[2]);
uchar3* v4 = reinterpret_cast<uchar3*>(src + indices[3]);

float ly = d2s[3] * dx + d2s[4] * dy + d2s[5] + 0.5f - y_low;
float lx = d2s[0] * dx + d2s[1] * dy + d2s[2] + 0.5f - x_low;
float hy = 1 - ly;
float hx = 1 - lx;
float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

// bgr -> rgb
c0 = w1 * v1->z + w2 * v2->z + w3 * v3->z + w4 * v4->z;
c1 = w1 * v1->y + w2 * v2->y + w3 * v3->y + w4 * v4->y;
c2 = w1 * v1->x + w2 * v2->x + w3 * v3->x + w4 * v4->x;

// 合并归一化操作
c0 /= 255.0f;
c1 /= 255.0f;
c2 /= 255.0f;

// rgbrgbrgb rrrgggbbb
// rgbrgbrgb -> rrrgggbbb
int area = dst_width * dst_height;
float* pdst_c0 = dst + dy * dst_width + dx;
float* pdst_c1 = pdst_c0 + area;
float* pdst_c2 = pdst_c1 + area;

*pdst_c0 = c0;
*pdst_c1 = c1;
*pdst_c2 = c2;
Expand Down

0 comments on commit cdb5c35

Please sign in to comment.