Skip to content

Commit

Permalink
Merge branch 'main' into rawimage_time
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Sep 20, 2024
2 parents b6ffa34 + 2008221 commit 698c5a7
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 148 deletions.
17 changes: 9 additions & 8 deletions src/kbmod/reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from kbmod import is_interactive
from kbmod.search import KB_NO_DATA, PSF, ImageStack, LayeredImage, RawImage
from kbmod.work_unit import WorkUnit
from kbmod.tqdm_utils import TQDMUtils
from kbmod.wcs_utils import append_wcs_to_hdu_header
from astropy.io import fits
import os
Expand All @@ -17,6 +16,7 @@

# The number of executors to use in the parallel reprojecting function.
MAX_PROCESSES = 8
_DEFAULT_TQDM_BAR = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]"


def reproject_image(image, original_wcs, common_wcs):
Expand All @@ -34,7 +34,7 @@ def reproject_image(image, original_wcs, common_wcs):
The WCS to reproject all the images into.
Returns
----------
-------
new_image : `numpy.ndarray`
The image data reprojected with a common `astropy.wcs.WCS`.
footprint : `numpy.ndarray`
Expand Down Expand Up @@ -115,7 +115,7 @@ def reproject_work_unit(
displayed or hidden.
Returns
----------
-------
A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case
where `write_output` is set to True.
"""
Expand Down Expand Up @@ -186,8 +186,9 @@ def _reproject_work_unit(
The base filename where output will be written if `write_output` is set to True.
disable_show_progress : `bool`
Whether or not to disable the `tqdm` show_progress bar.
Returns
----------
-------
A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case
where `write_output` is set to True.
"""
Expand All @@ -197,7 +198,7 @@ def _reproject_work_unit(
stack = ImageStack()
for obstime_index, o_i in tqdm(
enumerate(zip(unique_obstimes, unique_obstime_indices)),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
bar_format=_DEFAULT_TQDM_BAR,
desc="Reprojecting",
disable=not show_progress,
):
Expand Down Expand Up @@ -344,7 +345,7 @@ def _reproject_work_unit_in_parallel(
Whether or not to enable the `tqdm` show_progress bar.
Returns
----------
-------
A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case
where `write_output` is set to True.
"""
Expand Down Expand Up @@ -406,7 +407,7 @@ def _reproject_work_unit_in_parallel(
tqdm(
concurrent.futures.as_completed(future_reprojections),
total=len(future_reprojections),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
bar_format=_DEFAULT_TQDM_BAR,
desc="Reprojecting",
disable=not show_progress,
)
Expand Down Expand Up @@ -538,7 +539,7 @@ def reproject_lazy_work_unit(
tqdm(
concurrent.futures.as_completed(future_reprojections),
total=len(future_reprojections),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
bar_format=_DEFAULT_TQDM_BAR,
desc="Reprojecting",
disable=not show_progress,
)
Expand Down
7 changes: 3 additions & 4 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def load_and_filter_results(self, search, config):
logger.info(f"Chunk Min. Likelihood = {results[-1].lh}")

trj_batch = []
psi_batch = []
phi_batch = []
for i, trj in enumerate(results):
# Stop as soon as we hit a result below our limit, because anything after
# that is not guarrenteed to be valid due to potential on-GPU filtering.
Expand All @@ -93,14 +91,15 @@ def load_and_filter_results(self, search, config):

if trj.lh < max_lh:
trj_batch.append(trj)
psi_batch.append(search.get_psi_curves(trj))
phi_batch.append(search.get_phi_curves(trj))
total_count += 1

batch_size = len(trj_batch)
logger.info(f"Extracted batch of {batch_size} results for total of {total_count}")

if batch_size > 0:
psi_batch = search.get_psi_curves(trj_batch)
phi_batch = search.get_phi_curves(trj_batch)

result_batch = Results.from_trajectories(trj_batch, track_filtered=do_tracking)
result_batch.add_psi_phi_data(psi_batch, phi_batch)

Expand Down
8 changes: 4 additions & 4 deletions src/kbmod/search/image_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ void ImageStack::set_single_image(int index, LayeredImage& img, bool force_move)
assert_sizes_equal(img.get_height(), height, "ImageStack image height");

if (force_move) {
images[index] = img;
images[index] = img;
} else {
images[index] = std::move(img);
images[index] = std::move(img);
}
}

Expand Down Expand Up @@ -186,8 +186,8 @@ static void image_stack_bindings(py::module& m) {
.def("get_single_image", &is::get_single_image, py::return_value_policy::reference_internal,
pydocs::DOC_ImageStack_get_single_image)
.def("set_single_image", &is::set_single_image, py::arg("index"), py::arg("img"),
py::arg("force_move")=false, pydocs::DOC_ImageStack_set_single_image)
.def("append_image", &is::append_image, py::arg("img"), py::arg("force_move")=false,
py::arg("force_move") = false, pydocs::DOC_ImageStack_set_single_image)
.def("append_image", &is::append_image, py::arg("img"), py::arg("force_move") = false,
pydocs::DOC_ImageStack_append_image)
.def("get_obstime", &is::get_obstime, pydocs::DOC_ImageStack_get_obstime)
.def("get_zeroed_time", &is::get_zeroed_time, pydocs::DOC_ImageStack_get_zeroed_time)
Expand Down
4 changes: 2 additions & 2 deletions src/kbmod/search/image_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class ImageStack {

// Functions for setting or appending a single LayeredImage. If force_move is true,
// then the code uses move semantics and destroys the input object.
void set_single_image(int index, LayeredImage& img, bool force_move=false);
void append_image(LayeredImage& img, bool force_move=false);
void set_single_image(int index, LayeredImage& img, bool force_move = false);
void append_image(LayeredImage& img, bool force_move = false);

// Functions for getting or using times.
double get_obstime(int index) const;
Expand Down
20 changes: 10 additions & 10 deletions src/kbmod/search/pydocs/stack_search_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ static const auto DOC_StackSearch_get_psi_curves = R"doc(
Parameters
----------
trj : `kb.Trajectory`
The input trajectory.
trj : `kb.Trajectory` or `list` of `kb.Trajectory`
The input trajectory or trajectories.
Returns
-------
result : `list` of `float`
result : `list` of `float` or `list` of `list` of `float`
The psi values at each time step with NO_DATA replaced by 0.0.
)doc";

Expand All @@ -144,12 +144,12 @@ static const auto DOC_StackSearch_get_phi_curves = R"doc(
Parameters
----------
trj : `kb.Trajectory`
The input trajectory.
trj : `kb.Trajectory` or `list` of `kb.Trajectory`
The input trajectory or trajectories.
Returns
-------
result : `list` of `float`
result : `list` of `float` or `list` of `list` of `float`
The phi values at each time step with NO_DATA replaced by 0.0.
)doc";

Expand Down Expand Up @@ -249,8 +249,8 @@ static const auto DOC_StackSearch_evaluate_single_trajectory = R"doc(
Performs the evaluation of a single Trajectory object. Modifies the object
in-place.
Note
----
Notes
-----
Runs on the CPU, but requires CUDA compiler.
Parameters
Expand All @@ -262,8 +262,8 @@ static const auto DOC_StackSearch_evaluate_single_trajectory = R"doc(
static const auto DOC_StackSearch_search_linear_trajectory = R"doc(
Performs the evaluation of a linear trajectory in pixel space.
Note
----
Notes
-----
Runs on the CPU, but requires CUDA compiler.
Parameters
Expand Down
34 changes: 27 additions & 7 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ uint64_t StackSearch::compute_max_results() {
return num_search_pixels * params.results_per_pixel;
}

std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi) {
std::vector<float> StackSearch::extract_psi_or_phi_curve(const Trajectory& trj, bool extract_psi) {
prepare_psi_phi();

const unsigned int num_times = stack.img_count();
Expand All @@ -256,17 +256,33 @@ std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e
return result;
}

std::vector<float> StackSearch::get_psi_curves(Trajectory& trj) {
std::vector<std::vector<float> > StackSearch::get_psi_curves(const std::vector<Trajectory>& trajectories) {
std::vector<std::vector<float> > all_results;
for (const auto& trj : trajectories) {
all_results.push_back(extract_psi_or_phi_curve(trj, true));
}
return all_results;
}

std::vector<float> StackSearch::get_psi_curves(const Trajectory& trj) {
return extract_psi_or_phi_curve(trj, true);
}

std::vector<float> StackSearch::get_phi_curves(Trajectory& trj) {
std::vector<std::vector<float> > StackSearch::get_phi_curves(const std::vector<Trajectory>& trajectories) {
std::vector<std::vector<float> > all_results;
for (const auto& trj : trajectories) {
all_results.push_back(extract_psi_or_phi_curve(trj, false));
}
return all_results;
}

std::vector<float> StackSearch::get_phi_curves(const Trajectory& trj) {
return extract_psi_or_phi_curve(trj, false);
}

std::vector<Trajectory> StackSearch::get_results(uint64_t start, uint64_t count) {
rs_logger->debug("Reading results [" + std::to_string(start) + ", " +
std::to_string(start + count) + ")");
rs_logger->debug("Reading results [" + std::to_string(start) + ", " + std::to_string(start + count) +
")");
return results.get_batch(start, count);
}

Expand Down Expand Up @@ -306,10 +322,14 @@ static void stack_search_bindings(py::module& m) {
.def("get_imagestack", &ks::get_imagestack, py::return_value_policy::reference_internal,
pydocs::DOC_StackSearch_get_imagestack)
// For testings
.def("get_psi_curves", (std::vector<float>(ks::*)(tj&)) & ks::get_psi_curves,
.def("get_psi_curves", (std::vector<float>(ks::*)(const tj&)) & ks::get_psi_curves,
pydocs::DOC_StackSearch_get_psi_curves)
.def("get_phi_curves", (std::vector<float>(ks::*)(tj&)) & ks::get_phi_curves,
.def("get_phi_curves", (std::vector<float>(ks::*)(const tj&)) & ks::get_phi_curves,
pydocs::DOC_StackSearch_get_phi_curves)
.def("get_psi_curves",
(std::vector<std::vector<float> >(ks::*)(const std::vector<tj>&)) & ks::get_psi_curves)
.def("get_phi_curves",
(std::vector<std::vector<float> >(ks::*)(const std::vector<tj>&)) & ks::get_phi_curves)
.def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi)
.def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi)
.def("get_number_total_results", &ks::get_number_total_results,
Expand Down
8 changes: 5 additions & 3 deletions src/kbmod/search/stack_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class StackSearch {
std::vector<Trajectory> get_results(uint64_t start, uint64_t count);

// Getters for the Psi and Phi data.
std::vector<float> get_psi_curves(Trajectory& t);
std::vector<float> get_phi_curves(Trajectory& t);
std::vector<float> get_psi_curves(const Trajectory& t);
std::vector<float> get_phi_curves(const Trajectory& t);
std::vector<std::vector<float> > get_psi_curves(const std::vector<Trajectory>& trajectories);
std::vector<std::vector<float> > get_phi_curves(const std::vector<Trajectory>& trajectories);

// Helper functions for computing Psi and Phi
void prepare_psi_phi();
Expand All @@ -73,7 +75,7 @@ class StackSearch {
virtual ~StackSearch(){};

protected:
std::vector<float> extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi);
std::vector<float> extract_psi_or_phi_curve(const Trajectory& trj, bool extract_psi);

// Core data and search parameters. Note the StackSearch does not own
// the ImageStack and it must exist for the duration of the object's life.
Expand Down
4 changes: 0 additions & 4 deletions src/kbmod/tqdm_utils.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/kbmod/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def calc_ecliptic_angle(wcs, center_pixel=(1000, 2000), step=12):
with the image axes. Used to transform the specified search angles,
with respect to the ecliptic, to search angles within the image.
Note
----
Notes
-----
It is not neccessary to calculate this angle for each image in an
image set if they have all been warped to a common WCS.
"""
Expand Down Expand Up @@ -336,7 +336,7 @@ def extract_wcs_from_hdu_header(header):
The header from which to read the data.
Returns
--------
-------
curr_wcs : `astropy.wcs.WCS`
The WCS or None if it does not exist.
"""
Expand Down
8 changes: 5 additions & 3 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
wcs_to_dict,
)
from kbmod.reprojection_utils import invert_correct_parallax
from kbmod.tqdm_utils import TQDMUtils


_DEFAULT_WORKUNIT_TQDM_BAR = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]"


logger = Logging.getLogger(__name__)
Expand Down Expand Up @@ -311,7 +313,7 @@ def from_fits(cls, filename, show_progress=None):
# Read in all the image files.
for i in tqdm(
range(num_images),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
bar_format=_DEFAULT_WORKUNIT_TQDM_BAR,
desc="Loading images",
disable=not show_progress,
):
Expand Down Expand Up @@ -340,7 +342,7 @@ def from_fits(cls, filename, show_progress=None):
constituent_images = []
for i in tqdm(
range(n_constituents),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
bar_format=_DEFAULT_WORKUNIT_TQDM_BAR,
desc="Loading WCS",
disable=not show_progress,
):
Expand Down
Loading

0 comments on commit 698c5a7

Please sign in to comment.