Skip to content

Commit

Permalink
add dims check for nms_kernel (PaddlePaddle#49993)
Browse files Browse the repository at this point in the history
  • Loading branch information
RedContritio authored and pangengzheng committed Feb 2, 2023
1 parent 18756fa commit a951644
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
12 changes: 12 additions & 0 deletions paddle/phi/kernels/cpu/nms_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes,
float threshold,
DenseTensor* output) {
PADDLE_ENFORCE_EQ(
boxes.dims().size(),
2,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));

PADDLE_ENFORCE_EQ(
boxes.dims()[1],
4,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));

int64_t num_boxes = boxes.dims()[0];
DenseTensor output_tmp;
output_tmp.Resize(phi::make_ddim({num_boxes}));
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/kernels/gpu/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes,
float threshold,
DenseTensor* output) {
PADDLE_ENFORCE_EQ(
boxes.dims().size(),
2,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));

PADDLE_ENFORCE_EQ(
boxes.dims()[1],
4,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));

const int64_t num_boxes = boxes.dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
dim3 block(threadsPerBlock);
Expand Down

0 comments on commit a951644

Please sign in to comment.