Skip to content

Commit

Permalink
Merge branch 'main' into cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Dec 15, 2023
2 parents dc66b3f + 9d54136 commit 429c01c
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 106 deletions.
6 changes: 3 additions & 3 deletions src/kbmod/search/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ namespace py = pybind11;
#include "image_stack.cpp"
#include "stack_search.cpp"
#include "stamp_creator.cpp"
#include "filtering.cpp"
#include "kernel_testing_helpers.cpp"
#include "psi_phi_array.cpp"


PYBIND11_MODULE(search, m) {
m.attr("KB_NO_DATA") = pybind11::float_(search::NO_DATA);
m.attr("HAS_GPU") = pybind11::bool_(search::HAVE_GPU);
Expand Down Expand Up @@ -45,7 +46,6 @@ PYBIND11_MODULE(search, m) {
m.def("create_median_image", &search::create_median_image);
m.def("create_summed_image", &search::create_summed_image);
m.def("create_mean_image", &search::create_mean_image);
// Functions from filtering.cpp
// Functions from kernel_testing_helpers.cpp
m.def("sigmag_filtered_indices", &search::sigmaGFilteredIndices);
m.def("calculate_likelihood_psi_phi", &search::calculateLikelihoodFromPsiPhi);
}
65 changes: 0 additions & 65 deletions src/kbmod/search/filtering.cpp

This file was deleted.

15 changes: 0 additions & 15 deletions src/kbmod/search/filtering.h

This file was deleted.

43 changes: 43 additions & 0 deletions src/kbmod/search/kernel_testing_helpers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Helper functions for testing functions in the .cu files from Python. */

#include <vector>

namespace search {
#ifdef HAVE_CUDA
/* The filter_kenerls.cu functions. */
extern "C" void SigmaGFilteredIndicesCU(float *values, int num_values, float sgl0, float sgl1, float sg_coeff,
float width, int *idx_array, int *min_keep_idx, int *max_keep_idx);
#endif

/* Used for testing SigmaGFilteredIndicesCU for python
*
* Return the list of indices from the values array such that those elements
* pass the sigmaG filtering defined by percentiles [sgl0, sgl1] with coefficient
* sigma_g_coeff and a multiplicative factor of width.
*
* The vector values is passed by value to create a local copy which will be modified by
* SigmaGFilteredIndicesCU.
*/
std::vector<int> sigmaGFilteredIndices(std::vector<float> values, float sgl0, float sgl1, float sigma_g_coeff,
float width) {
int num_values = values.size();
std::vector<int> idx_array(num_values, 0);
int min_keep_idx = 0;
int max_keep_idx = num_values - 1;

#ifdef HAVE_CUDA
SigmaGFilteredIndicesCU(values.data(), num_values, sgl0, sgl1, sigma_g_coeff, width, idx_array.data(),
&min_keep_idx, &max_keep_idx);
#else
throw std::runtime_error("Non-GPU SigmaGFilteredIndicesCU is not implemented.");
#endif

// Copy the result into a vector and return it.
std::vector<int> result;
for (int i = min_keep_idx; i <= max_keep_idx; ++i) {
result.push_back(idx_array[i]);
}
return result;
}

} /* namespace search */
23 changes: 0 additions & 23 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,29 +98,6 @@ def test_sigmag_filtered_indices_three_outliers(self):
valid = i != 13 and i != 14 and i != 27
self.assertEqual(i in inds, valid)

def test_calculate_likelihood_psiphi(self):
# make sure that the calculate_likelihood_psi_phi works.
psi_values = [1.0 for _ in range(20)]
phi_values = [1.0 for _ in range(20)]

lh = calculate_likelihood_psi_phi(psi_values, phi_values)

self.assertEqual(lh, 4.47213595499958)

def test_calculate_likelihood_psiphi_zero_or_negative_phi(self):
# make sure that the calculate_likelihood_psi_phi works
# properly when phi values are less than or equal to zero.
psi_values = [1.0 for _ in range(20)]
phi_values = [-1.0 for _ in range(20)]

# test negatives
lh = calculate_likelihood_psi_phi(psi_values, phi_values)
self.assertEqual(lh, 0.0)

# test zero
lh = calculate_likelihood_psi_phi([1.0], [0.0])
self.assertEqual(lh, 0.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 429c01c

Please sign in to comment.