diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 530ac5aa3..2ccbf73c2 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -157,11 +157,38 @@ static const auto DOC_StackSearch_prepare_psi_phi = R"doc( )doc"; static const auto DOC_StackSearch_get_results = R"doc( - todo + Get a batch of cached results. + + Parameters + ---------- + start : `int` + The starting index of the results to retrieve. Returns + an empty list if start is past the end of the cache. + count : `int` + The maximum number of results to retrieve. Returns fewer + results if there are not enough in the cache. + + Returns + ------- + results : `List` + A list of ``Trajectory`` objects for the cached results. + + Raises + ------ + ``RunTimeError`` if start < 0 or count <= 0. )doc"; static const auto DOC_StackSearch_set_results = R"doc( - todo + Set the cached results. Used for testing. + + Parameters + ---------- + new_results : `List` + The list of results to store. + )doc"; + +static const auto DOC_StackSearch_clear_results = R"doc( + Clear the cached results. )doc"; static const auto DOC_StackSearch_evaluate_single_trajectory = R"doc( diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index be3f32224..ca91644ad 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -256,15 +256,18 @@ void StackSearch::sort_results() { } std::vector StackSearch::get_results(int start, int count) { + if (start < 0) throw std::runtime_error("start must be 0 or greater"); + if (count <= 0) throw std::runtime_error("count must be greater than 0"); + if (start + count >= results.size()) { count = results.size() - start; } - if (start < 0) throw std::runtime_error("start must be 0 or greater"); return std::vector(results.begin() + start, results.begin() + start + count); } // This function is used only for testing by injecting known result trajectories. void StackSearch::set_results(const std::vector& new_results) { results = new_results; } +void StackSearch::clear_results() { results.clear(); } #ifdef Py_PYTHON_H static void stack_search_bindings(py::module& m) { @@ -303,7 +306,8 @@ static void stack_search_bindings(py::module& m) { .def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi) .def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi) .def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results) - .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results); + .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results) + .def("clear_results", &ks::clear_results, pydocs::DOC_StackSearch_clear_results); } #endif /* Py_PYTHON_H */ diff --git a/src/kbmod/search/stack_search.h b/src/kbmod/search/stack_search.h index 9d59342df..9c492dad5 100644 --- a/src/kbmod/search/stack_search.h +++ b/src/kbmod/search/stack_search.h @@ -63,6 +63,7 @@ class StackSearch { // Helper functions for testing void set_results(const std::vector& new_results); + void clear_results(); virtual ~StackSearch(){}; diff --git a/tests/test_search.py b/tests/test_search.py index 8e6985d39..a60b7457e 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -89,6 +89,38 @@ def setUp(self): self.params.m02_limit = 35.5 self.params.m20_limit = 35.5 + def test_set_get_results(self): + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 0) + + trjs = [make_trajectory(i, i, 0.0, 0.0) for i in range(10)] + self.search.set_results(trjs) + + # Check that we extract them all. + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 10) + for i in range(10): + self.assertEqual(results[i].x, i) + + # Check that we can run past the end of the results. + results = self.search.get_results(0, 100) + self.assertEqual(len(results), 10) + + # Check that we can pull a subset. + results = self.search.get_results(2, 2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].x, 2) + self.assertEqual(results[1].x, 3) + + # Check invalid settings + self.assertRaises(RuntimeError, self.search.get_results, -1, 5) + self.assertRaises(RuntimeError, self.search.get_results, 0, 0) + + # Check that clear works. + self.search.clear_results() + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 0) + @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") def test_evaluate_single_trajectory(self): test_trj = make_trajectory(