Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch search inside stack search #530

Merged
merged 10 commits into from
Apr 8, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ candidates = [trj for trj in gen]

# Do the actual search.
search = kb.StackSearch(stack)
search.search(
search.search_all(
strategy,
7, # The minimum number of observations
)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Kbmod_Reference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@
"candidates = [trj for trj in gen]\n",
"print(f\"Created {len(candidates)} candidate trajectories per pixel.\")\n",
"\n",
"search.search(candidates, 2)"
"search.search_all(candidates, 2)"
]
},
{
Expand Down
31 changes: 31 additions & 0 deletions src/kbmod/batch_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
class BatchSearchManager:
Copy link
Collaborator Author

@vlnistor vlnistor Mar 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main idea here was that I didn't want the user to have to call prepare_batch_search and finish_search manually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but something to consider as a future extension: In theory this could be extended further by making this a generator:

for batch_search in BatchSearchManager(stack_search, search_list, min_observations, x_size, y_size):
    batch_results.extend(batch_search.search_batch())

where the BatchSearchManager tracks the block sizes to search, current block locations, etc.

def __init__(self, stack_search, search_list, min_observations):
"""Manages a batch search over a list of Trajectory instances using a StackSearch instance.

Parameters
----------
stack_search: `StackSearch`
StackSearch instance to use for the batch search.
search_list: `list[Trajectory]`
List of Trajectory instances to search along the stack.
min_observations: `int`
Minimum number of observations required to consider a candidate.
"""
self.stack_search = stack_search
self.search_list = search_list
self.min_observations = min_observations

def __enter__(self):
# Prepare memory for the batch search.
self.stack_search.prepare_search(self.search_list, self.min_observations)
return self.stack_search

def __exit__(self, *_):
"""
This method is called when exiting the context.
It cleans up resources by calling `finish_search` on the StackSearch instance.
We return False to indicate that we do not want to suppress any exceptions that may have been raised.
"""
# Clean up
self.stack_search.finish_search()
return False
2 changes: 1 addition & 1 deletion src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def do_gpu_search(self, config, stack, trj_generator):

# Do the actual search.
candidates = [trj for trj in trj_generator]
search.search(candidates, int(config["num_obs"]))
search.search_all(candidates, int(config["num_obs"]))
search_timer.stop()

# Load the results.
Expand Down
40 changes: 40 additions & 0 deletions src/kbmod/search/pydocs/stack_search_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,46 @@ static const auto DOC_StackSearch_get_results = R"doc(
``RunTimeError`` if start < 0 or count <= 0.
)doc";

static const auto DOC_StackSearch_prepare_batch_search = R"doc(
Prepare the search for a batch of trajectories.

Parameters
----------
search_list : `List`
A list of ``Trajectory`` objects to search.
min_observations : `int`
The minimum number of observations for a trajectory to be considered.
)doc";

static const auto DOC_StackSearch_compute_max_results = R"doc(
Compute the maximum number of results according to the x, y bounds and the RESULTS_PER_PIXEL constant

Returns
-------
max_results : `int`
The maximum number of results that a search will return according to the current bounds and the RESULTS_PER_PIXEL constant.
)doc";

static const auto DOC_StackSearch_search_single_batch = R"doc(
Perform a search on the given trajectories for the current batch.
Batch is defined by the parameters set `set_start_bounds_x` & `set_start_bounds_y`.

Returns
-------
results : `List`
A list of ``Trajectory`` search results
)doc";

static const auto DOC_StackSearch_finish_search = R"doc(
Clears memory used for the batch search.

This method should be called after a batch search is completed to ensure that any resources allocated during the search are properly freed.

Returns
-------
None
)doc";

static const auto DOC_StackSearch_set_results = R"doc(
Set the cached results. Used for testing.

Expand Down
82 changes: 52 additions & 30 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ extern "C" void evaluateTrajectory(PsiPhiArrayMeta psi_phi_meta, void* psi_phi_v
// I'd imaging...
auto rs_logger = logging::getLogger("kbmod.search.run_search");

StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0) {
StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0), gpu_search_list(0) {
debug_info = false;
psi_phi_generated = false;

Expand Down Expand Up @@ -155,63 +155,81 @@ Trajectory StackSearch::search_linear_trajectory(short x, short y, float vx, flo
return result;
}

void StackSearch::search(std::vector<Trajectory>& search_list, int min_observations) {
DebugTimer core_timer = DebugTimer("core search", rs_logger);
void StackSearch::finish_search(){
psi_phi_array.clear_from_gpu();
gpu_search_list.move_to_cpu();
}

DebugTimer psi_phi_timer = DebugTimer("creating psi/phi buffers", rs_logger);
void StackSearch::prepare_search(std::vector<Trajectory>& search_list, int min_observations){
DebugTimer psi_phi_timer = DebugTimer("Creating psi/phi buffers", rs_logger);
prepare_psi_phi();
psi_phi_array.move_to_gpu();
psi_phi_timer.stop();

// Allocate a vector for the results and move it onto the GPU.
int search_width = params.x_start_max - params.x_start_min;
int search_height = params.y_start_max - params.y_start_min;
int num_search_pixels = search_width * search_height;
int max_results = num_search_pixels * RESULTS_PER_PIXEL;

int num_to_search = search_list.size();
if (debug_info) std::cout << "Preparing to search " << num_to_search << " trajectories... \n" << std::flush;
gpu_search_list.set_trajectories(search_list);
gpu_search_list.move_to_gpu();

params.min_observations = min_observations;
}

void StackSearch::search_all(std::vector<Trajectory>& search_list, int min_observations) {
prepare_search(search_list, min_observations);
search_batch();
finish_search();
}

void StackSearch::search_batch(){
if(!psi_phi_array.gpu_array_allocated()){
throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?");
}

DebugTimer core_timer = DebugTimer("Running batch search", rs_logger);
int max_results = compute_max_results();

// staple C++
std::stringstream logmsg;
logmsg << "Searching X=[" << params.x_start_min << ", " << params.x_start_max << "] "
<< "Y=[" << params.y_start_min << ", " << params.y_start_max << "]\n"
<< "Allocating space for " << max_results << " results.";
rs_logger->info(logmsg.str());


results.resize(max_results);
results.move_to_gpu();

// Allocate space for the search list and move that to the GPU.
int num_to_search = search_list.size();

logmsg.str("");
logmsg << search_list.size() << " trajectories...";
rs_logger->info(logmsg.str());

TrajectoryList gpu_search_list(search_list);
gpu_search_list.move_to_gpu();

// Set the minimum number of observations.
params.min_observations = min_observations;

// Do the actual search on the GPU.
DebugTimer search_timer = DebugTimer("search execution", rs_logger);
DebugTimer search_timer = DebugTimer("Running search", rs_logger);
#ifdef HAVE_CUDA
deviceSearchFilter(psi_phi_array, params, gpu_search_list, results);
#else
throw std::runtime_error("Non-GPU search is not implemented.");
#endif
search_timer.stop();

// Move data back to CPU to unallocate GPU space (this will happen automatically
// for gpu_search_list when the object goes out of scope, but we do it explicitly here).
psi_phi_array.clear_from_gpu();
results.move_to_cpu();
gpu_search_list.move_to_cpu();

DebugTimer sort_timer = DebugTimer("Sorting results", rs_logger);
results.sort_by_likelihood();
sort_timer.stop();
core_timer.stop();
}

std::vector<Trajectory> StackSearch::search_single_batch(){
int max_results = compute_max_results();
search_batch();
return results.get_batch(0, max_results);
}


int StackSearch::compute_max_results(){
int search_width = params.x_start_max - params.x_start_min;
int search_height = params.y_start_max - params.y_start_min;
int num_search_pixels = search_width * search_height;
return num_search_pixels * RESULTS_PER_PIXEL;
}

std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi) {
prepare_psi_phi();

Expand Down Expand Up @@ -261,7 +279,7 @@ static void stack_search_bindings(py::module& m) {

py::class_<ks>(m, "StackSearch", pydocs::DOC_StackSearch)
.def(py::init<is&>())
.def("search", &ks::search, pydocs::DOC_StackSearch_search)
.def("search_all", &ks::search_all, pydocs::DOC_StackSearch_search)
.def("evaluate_single_trajectory", &ks::evaluate_single_trajectory,
pydocs::DOC_StackSearch_evaluate_single_trajectory)
.def("search_linear_trajectory", &ks::search_linear_trajectory,
Expand All @@ -288,7 +306,11 @@ static void stack_search_bindings(py::module& m) {
.def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi)
.def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi)
.def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results)
.def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results);
.def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results)
.def("compute_max_results", &ks::compute_max_results, pydocs::DOC_StackSearch_compute_max_results)
.def("search_single_batch", &ks::search_single_batch, pydocs::DOC_StackSearch_search_single_batch)
.def("prepare_search", &ks::prepare_search, pydocs::DOC_StackSearch_prepare_batch_search)
.def("finish_search", &ks::finish_search, pydocs::DOC_StackSearch_finish_search);
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
}
#endif /* Py_PYTHON_H */

Expand Down
11 changes: 9 additions & 2 deletions src/kbmod/search/stack_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using Image = search::Image;
class StackSearch {
public:
StackSearch(ImageStack& imstack);

int compute_max_results();
int num_images() const { return stack.img_count(); }
int get_image_width() const { return stack.get_width(); }
int get_image_height() const { return stack.get_height(); }
Expand All @@ -50,7 +50,11 @@ class StackSearch {
// The primary search functions
void evaluate_single_trajectory(Trajectory& trj);
Trajectory search_linear_trajectory(short x, short y, float vx, float vy);
void search(std::vector<Trajectory>& search_list, int min_observations);
void prepare_search(std::vector<Trajectory>& search_list, int min_observations);
std::vector<Trajectory> search_single_batch();
void search_batch();
void search_all(std::vector<Trajectory>& search_list, int min_observations);
void finish_search();

// Gets the vector of result trajectories from the grid search.
std::vector<Trajectory> get_results(int start, int end);
Expand Down Expand Up @@ -82,6 +86,9 @@ class StackSearch {

// Results from the grid search.
TrajectoryList results;

// Trajectories that are being searched.
TrajectoryList gpu_search_list;
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
};

} /* namespace search */
Expand Down
2 changes: 1 addition & 1 deletion tests/test_readme_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_make_and_copy(self):

# Do the actual search.
search = kb.StackSearch(stack)
search.search(
search.search_all(
candidates,
7, # The minimum number of observations
)
Expand Down
Loading
Loading