Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch search inside stack search #530

Merged
merged 10 commits into from
Apr 8, 2024
40 changes: 40 additions & 0 deletions src/kbmod/batch_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
class BatchSearchManager:
Copy link
Collaborator Author

@vlnistor vlnistor Mar 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main idea here was that I didn't want the user to have to call prepare_batch_search and finish_search manually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but something to consider as a future extension: In theory this could be extended further by making this a generator:

for batch_search in BatchSearchManager(stack_search, search_list, min_observations, x_size, y_size):
    batch_results.extend(batch_search.search_batch())

where the BatchSearchManager tracks the block sizes to search, current block locations, etc.

def __init__(self, stack_search, search_list, min_observations):
"""
Initialize the context manager with an instance of StackSearch,
the list of trajectories to search, and the minimum number of observations.

Parameters:
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
- stack_search: Instance of the StackSearch class.
- search_list: List of trajectories to search.
- min_observations: Minimum number of observations for the search.
"""
self.stack_search = stack_search
self.search_list = search_list
self.min_observations = min_observations

def __enter__(self):
"""
This method is called when entering the context managed by the `with` statement.
It prepares the batch search by calling `prepare_batch_search` on the StackSearch instance.
"""
# Initialize or prepare memory for the batch search.
self.stack_search.prepare_batch_search(self.search_list, self.min_observations)
# Return the object that should be used within the `with` block. Here, it's the StackSearch instance.
return self.stack_search

def __exit__(self, exc_type, exc_value, traceback):
"""
This method is called when exiting the context.
It cleans up resources by calling `finish_search` on the StackSearch instance.

Parameters:
- exc_type: The exception type if an exception was raised in the `with` block.
- exc_value: The exception value if an exception was raised.
- traceback: The traceback object if an exception was raised.
"""
# Clean up resources or delete initialized memory.
self.stack_search.finish_search()
# Returning False means any exception raised within the `with` block will be propagated.
# To suppress exceptions, return True.
return False
9 changes: 9 additions & 0 deletions src/kbmod/search/gpu_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ class GPUArray {
GPUArray(const GPUArray&) = delete;
GPUArray& operator=(GPUArray&) = delete;
GPUArray& operator=(const GPUArray&) = delete;
GPUArray& operator=(GPUArray&& other) noexcept {
if (this != &other) {
size = other.size;
memory_size = other.memory_size;
gpu_ptr = other.gpu_ptr;
other.gpu_ptr = nullptr;
}
return *this;
}

virtual ~GPUArray() {
if (gpu_ptr != nullptr) free_gpu_memory();
Expand Down
30 changes: 30 additions & 0 deletions src/kbmod/search/pydocs/stack_search_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,36 @@ static const auto DOC_StackSearch_get_results = R"doc(
``RunTimeError`` if start < 0 or count <= 0.
)doc";

static const auto DOC_StackSearch_prepare_batch_search = R"doc(
Prepare the search for a batch of trajectories.

Parameters
----------
search_list : `List`
A list of ``Trajectory`` objects to search.
min_observations : `int`
The minimum number of observations for a trajectory to be considered.
)doc";

static const auto DOC_StackSearch_search_batch = R"doc(
Perform a batch search of the trajectories in the list.

Returns
-------
results : `List`
A list of ``Trajectory`` search results
)doc";

static const auto DOC_StackSearch_finish_search = R"doc(
Clears memory used for the batch search.

This method should be called after a batch search is completed to ensure that any resources allocated during the search are properly freed.

Returns
-------
None
)doc";

static const auto DOC_StackSearch_set_results = R"doc(
Set the cached results. Used for testing.

Expand Down
67 changes: 65 additions & 2 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ extern "C" void evaluateTrajectory(PsiPhiArrayMeta psi_phi_meta, void* psi_phi_v
// I'd imaging...
auto rs_logger = logging::getLogger("kbmod.search.run_search");

StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0) {
StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0), gpu_search_list(0) {
debug_info = false;
psi_phi_generated = false;

Expand Down Expand Up @@ -155,6 +155,26 @@ Trajectory StackSearch::search_linear_trajectory(short x, short y, float vx, flo
return result;
}

void StackSearch::finish_search(){
psi_phi_array.clear_from_gpu();
gpu_search_list.move_to_cpu();
}

void StackSearch::prepare_batch_search(std::vector<Trajectory>& search_list, int min_observations){
DebugTimer psi_phi_timer = DebugTimer("Creating psi/phi buffers", rs_logger);
prepare_psi_phi();
psi_phi_array.move_to_gpu();
psi_phi_timer.stop();


int num_to_search = search_list.size();
if (debug_info) std::cout << "Preparing to search " << num_to_search << " trajectories... \n" << std::flush;
gpu_search_list = TrajectoryList(search_list);
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
gpu_search_list.move_to_gpu();

params.min_observations = min_observations;
}

void StackSearch::search(std::vector<Trajectory>& search_list, int min_observations) {
DebugTimer core_timer = DebugTimer("core search", rs_logger);

Expand Down Expand Up @@ -212,6 +232,46 @@ void StackSearch::search(std::vector<Trajectory>& search_list, int min_observati
core_timer.stop();
}


std::vector<Trajectory> StackSearch::search_batch(){
if(!psi_phi_array.gpu_array_allocated()){
throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?");
}

DebugTimer core_timer = DebugTimer("Running batch search", rs_logger);
// Allocate a vector for the results and move it onto the GPU.
int search_width = params.x_start_max - params.x_start_min;
int search_height = params.y_start_max - params.y_start_min;
int num_search_pixels = search_width * search_height;
int max_results = num_search_pixels * RESULTS_PER_PIXEL;

if (debug_info) {
std::cout << "Searching X=[" << params.x_start_min << ", " << params.x_start_max << "]"
<< " Y=[" << params.y_start_min << ", " << params.y_start_max << "]\n";
std::cout << "Allocating space for " << max_results << " results.\n";
}
results.resize(max_results);
results.move_to_gpu();

// Do the actual search on the GPU.
DebugTimer search_timer = DebugTimer("Running search", rs_logger);
#ifdef HAVE_CUDA
deviceSearchFilter(psi_phi_array, params, gpu_search_list, results);
#else
throw std::runtime_error("Non-GPU search is not implemented.");
#endif
search_timer.stop();

results.move_to_cpu();
DebugTimer sort_timer = DebugTimer("Sorting results", rs_logger);
results.sort_by_likelihood();
sort_timer.stop();
core_timer.stop();

return results.get_batch(0, max_results);
}


std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi) {
prepare_psi_phi();

Expand Down Expand Up @@ -288,7 +348,10 @@ static void stack_search_bindings(py::module& m) {
.def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi)
.def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi)
.def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results)
.def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results);
.def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results)
.def("search_batch", &ks::search_batch, pydocs::DOC_StackSearch_search_batch)
.def("prepare_batch_search", &ks::prepare_batch_search, pydocs::DOC_StackSearch_prepare_batch_search)
.def("finish_search", &ks::finish_search, pydocs::DOC_StackSearch_finish_search);
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
}
#endif /* Py_PYTHON_H */

Expand Down
6 changes: 6 additions & 0 deletions src/kbmod/search/stack_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class StackSearch {
// The primary search functions
void evaluate_single_trajectory(Trajectory& trj);
Trajectory search_linear_trajectory(short x, short y, float vx, float vy);
void prepare_batch_search(std::vector<Trajectory>& search_list, int min_observations);
void finish_search();
std::vector<Trajectory> search_batch();
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
void search(std::vector<Trajectory>& search_list, int min_observations);

// Gets the vector of result trajectories from the grid search.
Expand Down Expand Up @@ -82,6 +85,9 @@ class StackSearch {

// Results from the grid search.
TrajectoryList results;

// Trajectories that are being searched.
TrajectoryList gpu_search_list;
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
};

} /* namespace search */
Expand Down
12 changes: 12 additions & 0 deletions src/kbmod/search/trajectory_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ TrajectoryList::TrajectoryList(int max_list_size) {
gpu_array.resize(max_size);
}

// Move assignment operator.
TrajectoryList& TrajectoryList::operator=(TrajectoryList&& other) noexcept {
if (this != &other) {
max_size = other.max_size;
data_on_gpu = other.data_on_gpu;
cpu_list = std::move(other.cpu_list);
gpu_array = std::move(other.gpu_array);
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
other.data_on_gpu = false;
}
return *this;
}

TrajectoryList::TrajectoryList(const std::vector<Trajectory> &prev_list) {
max_size = prev_list.size();
cpu_list = prev_list; // Do a full copy.
Expand Down
1 change: 1 addition & 0 deletions src/kbmod/search/trajectory_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TrajectoryList {
TrajectoryList(const TrajectoryList&) = delete;
TrajectoryList& operator=(TrajectoryList&) = delete;
TrajectoryList& operator=(const TrajectoryList&) = delete;
TrajectoryList& operator=(TrajectoryList&&) noexcept;
vlnistor marked this conversation as resolved.
Show resolved Hide resolved

// --- Getter functions ----------------
inline int get_size() const { return max_size; }
Expand Down
68 changes: 68 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from kbmod.batch_search import BatchSearchManager
from kbmod.configuration import SearchConfiguration
from kbmod.fake_data.fake_data_creator import add_fake_object, make_fake_layered_image, FakeDataSet
from kbmod.run_search import SearchRunner
Expand Down Expand Up @@ -932,6 +933,73 @@ def test_coadd_filter_gpu(self):
self.assertEqual(meanStamps[2].width, 1)
self.assertEqual(meanStamps[2].height, 1)

@staticmethod
def result_hash(res):
return hash((res.x, res.y, res.vx, res.vy, res.lh, res.obs_count))

def test_search_batch(self):
width = 50
height = 50
results_per_pixel = 8
min_observations = 2

# Simple average PSF
psf_data = np.zeros((5, 5), dtype=np.single)
psf_data[1:4, 1:4] = 0.1111111
p = PSF(psf_data)

# Create a stack with 10 20x20 images with random noise and times ranging from 0 to 1
count = 10
imlist = [make_fake_layered_image(width, height, 5.0, 25.0, n / count, p) for n in range(count)]
stack = ImageStack(imlist)
im_list = stack.get_images()
# Create a new list of LayeredImages with the added object.
new_im_list = []
vlnistor marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove new_im_list

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed now

for im, time in zip(im_list, stack.build_zeroed_times()):
add_fake_object(im, 5.0 + (time * 8.0), 35.0 + (time * 0.0), 25000.0)
new_im_list.append(im)

# Save these images in a new ImageStack and create a StackSearch object from them.
stack = ImageStack(new_im_list)
search = StackSearch(stack)

# Sample generator
gen = KBMODV1Search(
10, 5, 15, 10, -0.1, 0.1
) # velocity_steps, min_vel, max_vel, angle_steps, min_ang, max_ang,
candidates = [trj for trj in gen]

# Peform complete in-memory search
search.search(candidates, min_observations)
total_results = width * height * results_per_pixel
# Need to filter as the fields are undefined otherwise
results = [
result
for result in search.get_results(0, total_results)
if result.lh > -1 and result.obs_count >= min_observations
]

with BatchSearchManager(StackSearch(stack), candidates, min_observations) as batch_search:

batch_results = []
for i in range(0, width, 5):
batch_search.set_start_bounds_x(i, i + 5)
for j in range(0, height, 5):
batch_search.set_start_bounds_y(j, j + 5)
batch_results.extend(batch_search.search_batch())

# Need to filter as the fields are undefined otherwise
batch_results = [
result for result in batch_results if result.lh > -1 and result.obs_count >= min_observations
]

# Check that the results are the same.
results_hash_set = {test_search.result_hash(result) for result in results}
batch_results_hash_set = {test_search.result_hash(result) for result in batch_results}

for res_hash in results_hash_set:
self.assertTrue(res_hash in batch_results_hash_set)


if __name__ == "__main__":
unittest.main()
Loading