diff --git a/src/kbmod/search/psf.cpp b/src/kbmod/search/psf.cpp index fa2735cbb..ee25f75a3 100644 --- a/src/kbmod/search/psf.cpp +++ b/src/kbmod/search/psf.cpp @@ -87,7 +87,12 @@ PSF& PSF::operator=(PSF&& other) { void PSF::calc_sum() { sum = 0.0; - for (auto& i : kernel) sum += i; + for (auto& i : kernel) { + if (std::isnan(i) || std::isinf(i)) { + throw std::runtime_error("Invalid value in PSF kernel (NaN or inf)"); + } + sum += i; + } } void PSF::square_psf() { diff --git a/src/kbmod/search/psf.h b/src/kbmod/search/psf.h index 37b7a4d26..2dc96a17b 100644 --- a/src/kbmod/search/psf.h +++ b/src/kbmod/search/psf.h @@ -31,18 +31,20 @@ class PSF { float get_std() const { return width; } float get_sum() const { return sum; } float get_value(int x, int y) const { return kernel[y * dim + x]; } - int get_dim() const { return dim; } + int get_dim() const { return dim; } // Length of one side of the kernel. int get_radius() const { return radius; } int get_size() const { return kernel.size(); } const std::vector& get_kernel() const { return kernel; }; float* data() { return kernel.data(); } // Computation functions. - void calc_sum(); void square_psf(); std::string print(); private: + // Validates the PSF array and computes the sum. + void calc_sum(); + std::vector kernel; float width; float sum; diff --git a/src/kbmod/search/pydocs/psf_docs.h b/src/kbmod/search/pydocs/psf_docs.h index a58fb07f4..22e18cabd 100644 --- a/src/kbmod/search/pydocs/psf_docs.h +++ b/src/kbmod/search/pydocs/psf_docs.h @@ -23,7 +23,8 @@ static const auto DOC_PSF = R"doc( Raises ------ - Raises a ``RuntimeError`` when given an invalid stdev. + Raises a ``RuntimeError`` when given an invalid stdev or an array + containing invalid entries, such as NaN or infinity. )doc"; static const auto DOC_PSF_set_array = R"doc( @@ -49,11 +50,11 @@ static const auto DOC_PSF_get_sum = R"doc( ")doc"; static const auto DOC_PSF_get_dim = R"doc( - "Returns the PSF kernel dimensions. + "Returns the PSF kernel dimension D where the kernel is a D by D array. ")doc"; static const auto DOC_PSF_get_radius = R"doc( - "Returns the radius of the PSF + "Returns the radius of the PSF. ")doc"; static const auto DOC_PSF_get_size = R"doc( diff --git a/tests/test_psf.py b/tests/test_psf.py index a7eba0ea9..bb37d9157 100644 --- a/tests/test_psf.py +++ b/tests/test_psf.py @@ -1,3 +1,5 @@ +import math +import numpy as np import unittest from kbmod.search import PSF @@ -19,10 +21,24 @@ def test_make_noop(self): self.assertEqual(len(kernel0), 1) self.assertEqual(kernel0[0], 1.0) - def test_make_invalud(self): + def test_make_invalid(self): # Raise an error if creating a PSF with a negative stdev. self.assertRaises(RuntimeError, PSF, -1.0) + def test_make_from_array(self): + arr = np.full((3, 3), 1.0 / 9.0) + psf_arr = PSF(arr) + self.assertEqual(psf_arr.get_size(), 9) + self.assertEqual(psf_arr.get_dim(), 3) + + # We get an error if we include a NaN. + arr[0][0] = math.nan + self.assertRaises(RuntimeError, PSF, arr) + + # We get an error if we include a inf. + arr[0][0] = math.inf + self.assertRaises(RuntimeError, PSF, arr) + def test_to_string(self): result = self.psf_list[0].__str__() self.assertGreater(len(result), 1)