From 946cb3c19527434773038991db3ffa68c781bf9f Mon Sep 17 00:00:00 2001 From: qili93 Date: Mon, 27 May 2024 18:09:53 +0800 Subject: [PATCH 1/2] Fix Histogram kernel to check range error --- paddle/phi/kernels/gpu/histogram_kernel.cu | 27 +++++++++++++++++++--- test/legacy_test/test_histogram_op.py | 23 ++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index aa10aea35f867a..3cad1c7d30ef92 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -35,8 +35,8 @@ __device__ static IndexType GetBin(T input_value, T min_value, T max_value, int64_t nbins) { - IndexType bin = static_cast((input_value - min_value) * nbins / - (max_value - min_value)); + IndexType bin = static_cast((input_value - min_value) * nbins / + (max_value - min_value)); IndexType output_index = bin < nbins - 1 ? bin : nbins - 1; return output_index; } @@ -151,7 +151,7 @@ void HistogramKernel(const Context& dev_ctx, min_max.Resize({2 * block_num}); auto* min_block_ptr = dev_ctx.template Alloc(&min_max); auto* max_block_ptr = min_block_ptr + block_num; - if (output_min == output_max) { + if (min == max) { KernelMinMax<< min_max_vec; + phi::TensorToVector(min_max, dev_ctx, &min_max_vec); + output_min = min_max_vec[0]; + output_max = min_max_vec[1]; + + // check if out of range + double range = + static_cast(output_max) - static_cast(output_min); + PADDLE_ENFORCE_LT( + range, + static_cast(std::numeric_limits::max()), + phi::errors::InvalidArgument( + "The range of max - min is out of range for target type, " + "current kernel type is %s, the range should less than %f " + "but now min is %f, max is %f.", + typeid(T).name(), + std::numeric_limits::max(), + output_min, + output_max)); + PADDLE_ENFORCE_EQ((std::isinf(static_cast(output_min)) || std::isnan(static_cast(output_max)) || std::isinf(static_cast(output_min)) || diff --git a/test/legacy_test/test_histogram_op.py b/test/legacy_test/test_histogram_op.py index 06d7bec5450876..4f47888b83dc2f 100644 --- a/test/legacy_test/test_histogram_op.py +++ b/test/legacy_test/test_histogram_op.py @@ -117,6 +117,29 @@ def net_func(): with self.assertRaises(TypeError): self.run_network(net_func) + @test_with_pir_api + def test_input_range_error(self): + """Test range of input is out of bound""" + + def net_func(): + input_value = paddle.to_tensor( + [ + -7095538316670326452, + -6102192280439741006, + 2040176985344715288, + -6276983991026997920, + -6570715756420355710, + -5998045007776667296, + -6763099356862306438, + 3166073479842736625, + ], + dtype=paddle.int64, + ) + paddle.histogram(input=input_value, bins=1, min=0, max=0) + + with self.assertRaises(ValueError): + self.run_network(net_func) + @test_with_pir_api def test_type_errors(self): with paddle.static.program_guard(paddle.static.Program()): From 7da3a7480709c828171d84cb9f11923df9beb129 Mon Sep 17 00:00:00 2001 From: qili93 Date: Tue, 28 May 2024 10:33:41 +0800 Subject: [PATCH 2/2] fix windows openblas ci fail --- paddle/phi/kernels/cpu/histogram_kernel.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/paddle/phi/kernels/cpu/histogram_kernel.cc b/paddle/phi/kernels/cpu/histogram_kernel.cc index 030dee9908b31c..eba05f4b0810a2 100644 --- a/paddle/phi/kernels/cpu/histogram_kernel.cc +++ b/paddle/phi/kernels/cpu/histogram_kernel.cc @@ -51,6 +51,21 @@ void HistogramKernel(const Context& dev_ctx, output_max = output_max + 1; } + // check if out of range + double range = + static_cast(output_max) - static_cast(output_min); + PADDLE_ENFORCE_LT( + range, + static_cast(std::numeric_limits::max()), + phi::errors::InvalidArgument( + "The range of max - min is out of range for target type, " + "current kernel type is %s, the range should less than %f " + "but now min is %f, max is %f.", + typeid(T).name(), + std::numeric_limits::max(), + output_min, + output_max)); + PADDLE_ENFORCE_EQ((std::isinf(static_cast(output_min)) || std::isnan(static_cast(output_max)) || std::isinf(static_cast(output_min)) ||