From e75f93f7ff9c4d3d3012ceb50f52f15b829d22e1 Mon Sep 17 00:00:00 2001 From: Vlad Date: Sat, 23 Mar 2024 17:02:32 +0000 Subject: [PATCH 01/17] batch search inside stack search --- src/kbmod/search/pydocs/stack_search_docs.h | 28 +++++++++ src/kbmod/search/stack_search.cpp | 68 +++++++++++++++++++- src/kbmod/search/stack_search.h | 6 ++ tests/test_search.py | 70 +++++++++++++++++++++ 4 files changed, 169 insertions(+), 3 deletions(-) diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 9ddb520b3..f87b53c9c 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -177,6 +177,34 @@ static const auto DOC_StackSearch_get_results = R"doc( ``RunTimeError`` if start < 0 or count <= 0. )doc"; +static const auto DOC_StackSearch_prepare_batch_search = R"doc( + Prepare the search for a batch of trajectories. + + Parameters + ---------- + search_list : `List` + A list of ``Trajectory`` objects to search. + )doc"; + +static const auto DOC_StackSearch_search_batch = R"doc( + Perform a batch search of the trajectories in the list. + + Returns + ------- + results : `List` + A list of ``Trajectory`` search results + )doc"; + +static const auto DOC_StackSearch_finish_search = R"doc( + Clears memory used for the batch search. + + This method should be called after a batch search is completed to ensure that any resources allocated during the search are properly freed. + + Returns + ------- + None + )doc"; + static const auto DOC_StackSearch_set_results = R"doc( Set the cached results. Used for testing. diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index 95061f6a2..fdbf422ef 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -9,7 +9,7 @@ extern "C" void evaluateTrajectory(PsiPhiArrayMeta psi_phi_meta, void* psi_phi_v SearchParameters params, Trajectory* candidate); #endif -StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0) { +StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0), gpu_search_list(0) { debug_info = false; psi_phi_generated = false; @@ -149,6 +149,65 @@ Trajectory StackSearch::search_linear_trajectory(short x, short y, float vx, flo return result; } +void StackSearch::finish_search(){ + psi_phi_array.clear_from_gpu(); + gpu_search_list.move_to_cpu(); +} + +void StackSearch::prepare_batch_search(std::vector& search_list, int min_observations){ + DebugTimer psi_phi_timer = DebugTimer("Creating psi/phi buffers", debug_info); + prepare_psi_phi(); + psi_phi_array.move_to_gpu(); + psi_phi_timer.stop(); + + + int num_to_search = search_list.size(); + if (debug_info) std::cout << "Preparing to search " << num_to_search << " trajectories... \n" << std::flush; + gpu_search_list = TrajectoryList(search_list); + gpu_search_list.move_to_gpu(); + + params.min_observations = min_observations; +} + + +std::vector StackSearch::search_batch(){ + if(!psi_phi_array.gpu_array_allocated()){ + throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?"); + } + + DebugTimer core_timer = DebugTimer("Running batch search", debug_info); + // Allocate a vector for the results and move it onto the GPU. + int search_width = params.x_start_max - params.x_start_min; + int search_height = params.y_start_max - params.y_start_min; + int num_search_pixels = search_width * search_height; + int max_results = num_search_pixels * RESULTS_PER_PIXEL; + + if (debug_info) { + std::cout << "Searching X=[" << params.x_start_min << ", " << params.x_start_max << "]" + << " Y=[" << params.y_start_min << ", " << params.y_start_max << "]\n"; + std::cout << "Allocating space for " << max_results << " results.\n"; + } + results = TrajectoryList(max_results); + results.move_to_gpu(); + + // Do the actual search on the GPU. + DebugTimer search_timer = DebugTimer("Running search", debug_info); +#ifdef HAVE_CUDA + deviceSearchFilter(psi_phi_array, params, gpu_search_list, results); +#else + throw std::runtime_error("Non-GPU search is not implemented."); +#endif + search_timer.stop(); + + results.move_to_cpu(); + DebugTimer sort_timer = DebugTimer("Sorting results", debug_info); + results.sort_by_likelihood(); + sort_timer.stop(); + core_timer.stop(); + + return results.get_batch(0, max_results); +} + void StackSearch::search(std::vector& search_list, int min_observations) { DebugTimer core_timer = DebugTimer("Running core search", debug_info); @@ -173,7 +232,7 @@ void StackSearch::search(std::vector& search_list, int min_observati // Allocate space for the search list and move that to the GPU. int num_to_search = search_list.size(); if (debug_info) std::cout << "Searching " << num_to_search << " trajectories... \n" << std::flush; - TrajectoryList gpu_search_list = TrajectoryList(search_list); + gpu_search_list = TrajectoryList(search_list); gpu_search_list.move_to_gpu(); // Set the minimum number of observations. @@ -276,7 +335,10 @@ static void stack_search_bindings(py::module& m) { .def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi) .def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi) .def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results) - .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results); + .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results) + .def("search_batch", &ks::search_batch, pydocs::DOC_StackSearch_search_batch) + .def("prepare_batch_search", &ks::prepare_batch_search, pydocs::DOC_StackSearch_prepare_batch_search) + .def("finish_search", &ks::finish_search, pydocs::DOC_StackSearch_finish_search); } #endif /* Py_PYTHON_H */ diff --git a/src/kbmod/search/stack_search.h b/src/kbmod/search/stack_search.h index a1fb3f64e..eb6f0b780 100644 --- a/src/kbmod/search/stack_search.h +++ b/src/kbmod/search/stack_search.h @@ -48,6 +48,9 @@ class StackSearch { // The primary search functions void evaluate_single_trajectory(Trajectory& trj); Trajectory search_linear_trajectory(short x, short y, float vx, float vy); + void prepare_batch_search(std::vector& search_list, int min_observations); + void finish_search(); + std::vector search_batch(); void search(std::vector& search_list, int min_observations); // Gets the vector of result trajectories from the grid search. @@ -80,6 +83,9 @@ class StackSearch { // Results from the grid search. TrajectoryList results; + + // Trajectories that are being searched. + TrajectoryList gpu_search_list; }; } /* namespace search */ diff --git a/tests/test_search.py b/tests/test_search.py index 36cbeb94b..fa331a966 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -932,6 +932,76 @@ def test_coadd_filter_gpu(self): self.assertEqual(meanStamps[2].width, 1) self.assertEqual(meanStamps[2].height, 1) + @staticmethod + def result_hash(res): + return hash((res.x, res.y, res.vx, res.vy, res.lh, res.obs_count)) + + def test_search_batch(self): + width = 50 + height = 50 + results_per_pixel = 8 + min_observations = 2 + + # Simple average PSF + psf_data = np.zeros((5, 5), dtype=np.single) + psf_data[1:4, 1:4] = 0.1111111 + p = PSF(psf_data) + + # Create a stack with 10 20x20 images with random noise and times ranging from 0 to 1 + count = 10 + imlist = [make_fake_layered_image(width, height, 5.0, 25.0, n / count, p) for n in range(count)] + stack = ImageStack(imlist) + im_list = stack.get_images() + # Create a new list of LayeredImages with the added object. + new_im_list = [] + for im, time in zip(im_list, stack.build_zeroed_times()): + add_fake_object(im, 5.0 + (time * 8.0), 35.0 + (time * 0.0), 25000.0) + new_im_list.append(im) + + # Save these images in a new ImageStack and create a StackSearch object from them. + stack = ImageStack(new_im_list) + search = StackSearch(stack) + + # Sample generator + gen = KBMODV1Search( + 10, 5, 15, 10, -0.1, 0.1 + ) # velocity_steps, min_vel, max_vel, angle_steps, min_ang, max_ang, + candidates = [trj for trj in gen] + + # Peform complete in-memory search + search.search(candidates, min_observations) + total_results = width * height * results_per_pixel + # Need to filter as the fields are undefined otherwise + results = [ + result + for result in search.get_results(0, total_results) + if result.lh > -1 and result.obs_count >= min_observations + ] + + # Perform a batch search with the same images. + batch_search = StackSearch(stack) + batch_search.prepare_batch_search(candidates, min_observations) + batch_results = [] + for i in range(0, width, 5): + batch_search.set_start_bounds_x(i, i + 5) + for j in range(0, height, 5): + batch_search.set_start_bounds_y(j, j + 5) + batch_results.extend(batch_search.search_batch()) + + # Need to filter as the fields are undefined otherwise + batch_results = [ + result for result in batch_results if result.lh > -1 and result.obs_count >= min_observations + ] + + # Check that the results are the same. + results_hash_set = {test_search.result_hash(result) for result in results} + batch_results_hash_set = {test_search.result_hash(result) for result in batch_results} + + for res_hash in results_hash_set: + self.assertTrue(res_hash in batch_results_hash_set) + + batch_search.finish_search() + if __name__ == "__main__": unittest.main() From 666d2e56722e3f9a2ddba2bcaf4959a0a78a8831 Mon Sep 17 00:00:00 2001 From: Vlad Date: Sat, 23 Mar 2024 17:34:26 +0000 Subject: [PATCH 02/17] moved the batch search inside a context manager to remove the need to setup/clear resources --- src/kbmod/batch_search.py | 40 +++++++++++++++++++++++++++++++++++++++ tests/test_search.py | 36 +++++++++++++++++------------------ 2 files changed, 57 insertions(+), 19 deletions(-) create mode 100644 src/kbmod/batch_search.py diff --git a/src/kbmod/batch_search.py b/src/kbmod/batch_search.py new file mode 100644 index 000000000..92813ef08 --- /dev/null +++ b/src/kbmod/batch_search.py @@ -0,0 +1,40 @@ +class BatchSearchManager: + def __init__(self, stack_search, search_list, min_observations): + """ + Initialize the context manager with an instance of StackSearch, + the list of trajectories to search, and the minimum number of observations. + + Parameters: + - stack_search: Instance of the StackSearch class. + - search_list: List of trajectories to search. + - min_observations: Minimum number of observations for the search. + """ + self.stack_search = stack_search + self.search_list = search_list + self.min_observations = min_observations + + def __enter__(self): + """ + This method is called when entering the context managed by the `with` statement. + It prepares the batch search by calling `prepare_batch_search` on the StackSearch instance. + """ + # Initialize or prepare memory for the batch search. + self.stack_search.prepare_batch_search(self.search_list, self.min_observations) + # Return the object that should be used within the `with` block. Here, it's the StackSearch instance. + return self.stack_search + + def __exit__(self, exc_type, exc_value, traceback): + """ + This method is called when exiting the context. + It cleans up resources by calling `finish_search` on the StackSearch instance. + + Parameters: + - exc_type: The exception type if an exception was raised in the `with` block. + - exc_value: The exception value if an exception was raised. + - traceback: The traceback object if an exception was raised. + """ + # Clean up resources or delete initialized memory. + self.stack_search.finish_search() + # Returning False means any exception raised within the `with` block will be propagated. + # To suppress exceptions, return True. + return False diff --git a/tests/test_search.py b/tests/test_search.py index fa331a966..f28705401 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2,6 +2,7 @@ import numpy as np +from kbmod.batch_search import BatchSearchManager from kbmod.configuration import SearchConfiguration from kbmod.fake_data.fake_data_creator import add_fake_object, make_fake_layered_image, FakeDataSet from kbmod.run_search import SearchRunner @@ -978,29 +979,26 @@ def test_search_batch(self): if result.lh > -1 and result.obs_count >= min_observations ] - # Perform a batch search with the same images. - batch_search = StackSearch(stack) - batch_search.prepare_batch_search(candidates, min_observations) - batch_results = [] - for i in range(0, width, 5): - batch_search.set_start_bounds_x(i, i + 5) - for j in range(0, height, 5): - batch_search.set_start_bounds_y(j, j + 5) - batch_results.extend(batch_search.search_batch()) + with BatchSearchManager(StackSearch(stack), candidates, min_observations) as batch_search: - # Need to filter as the fields are undefined otherwise - batch_results = [ - result for result in batch_results if result.lh > -1 and result.obs_count >= min_observations - ] + batch_results = [] + for i in range(0, width, 5): + batch_search.set_start_bounds_x(i, i + 5) + for j in range(0, height, 5): + batch_search.set_start_bounds_y(j, j + 5) + batch_results.extend(batch_search.search_batch()) - # Check that the results are the same. - results_hash_set = {test_search.result_hash(result) for result in results} - batch_results_hash_set = {test_search.result_hash(result) for result in batch_results} + # Need to filter as the fields are undefined otherwise + batch_results = [ + result for result in batch_results if result.lh > -1 and result.obs_count >= min_observations + ] - for res_hash in results_hash_set: - self.assertTrue(res_hash in batch_results_hash_set) + # Check that the results are the same. + results_hash_set = {test_search.result_hash(result) for result in results} + batch_results_hash_set = {test_search.result_hash(result) for result in batch_results} - batch_search.finish_search() + for res_hash in results_hash_set: + self.assertTrue(res_hash in batch_results_hash_set) if __name__ == "__main__": From 88d328995a140fd077952f2ace7099b059acd1a5 Mon Sep 17 00:00:00 2001 From: Vlad Date: Sat, 23 Mar 2024 17:41:31 +0000 Subject: [PATCH 03/17] fixed to include 2nd argument in prepare_batch --- src/kbmod/search/pydocs/stack_search_docs.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index f87b53c9c..8b71b7f75 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -184,6 +184,8 @@ static const auto DOC_StackSearch_prepare_batch_search = R"doc( ---------- search_list : `List` A list of ``Trajectory`` objects to search. + min_observations : `int` + The minimum number of observations for a trajectory to be considered. )doc"; static const auto DOC_StackSearch_search_batch = R"doc( From 0b25c3215e12a8e7779fb3a998dbf8da58c23037 Mon Sep 17 00:00:00 2001 From: Vlad Date: Sat, 23 Mar 2024 20:36:57 +0000 Subject: [PATCH 04/17] resolving search_batch merge conflicts with main removed accidental git conflict merge conflicts #2 --- .vscode/settings.json | 76 ------------------------------- src/kbmod/search/stack_search.cpp | 59 +++--------------------- 2 files changed, 7 insertions(+), 128 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 19b8cd605..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,76 +0,0 @@ -{ - "files.associations": { - "array": "cpp", - "atomic": "cpp", - "bit": "cpp", - "*.tcc": "cpp", - "bitset": "cpp", - "cctype": "cpp", - "chrono": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "codecvt": "cpp", - "compare": "cpp", - "complex": "cpp", - "concepts": "cpp", - "condition_variable": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "deque": "cpp", - "forward_list": "cpp", - "list": "cpp", - "map": "cpp", - "set": "cpp", - "string": "cpp", - "unordered_map": "cpp", - "unordered_set": "cpp", - "vector": "cpp", - "exception": "cpp", - "algorithm": "cpp", - "functional": "cpp", - "iterator": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "numeric": "cpp", - "optional": "cpp", - "random": "cpp", - "ratio": "cpp", - "regex": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "hash_map": "cpp", - "fstream": "cpp", - "future": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "istream": "cpp", - "limits": "cpp", - "mutex": "cpp", - "new": "cpp", - "numbers": "cpp", - "ostream": "cpp", - "semaphore": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "stop_token": "cpp", - "streambuf": "cpp", - "thread": "cpp", - "cinttypes": "cpp", - "typeindex": "cpp", - "typeinfo": "cpp", - "valarray": "cpp", - "variant": "cpp" - } -} \ No newline at end of file diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index d0ff313a6..198170f8f 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -15,12 +15,6 @@ extern "C" void evaluateTrajectory(PsiPhiArrayMeta psi_phi_meta, void* psi_phi_v // I'd imaging... auto rs_logger = logging::getLogger("kbmod.search.run_search"); -// This logger is often used in this module so we might as well declare it -// global, but this would generally be a one-liner like: -// logging::getLogger("kbmod.search.run_search") -> level(msg) -// I'd imaging... -auto rs_logger = logging::getLogger("kbmod.search.run_search"); - StackSearch::StackSearch(ImageStack& imstack) : stack(imstack), results(0), gpu_search_list(0) { debug_info = false; psi_phi_generated = false; @@ -167,7 +161,7 @@ void StackSearch::finish_search(){ } void StackSearch::prepare_batch_search(std::vector& search_list, int min_observations){ - DebugTimer psi_phi_timer = DebugTimer("Creating psi/phi buffers", debug_info); + DebugTimer psi_phi_timer = DebugTimer("Creating psi/phi buffers", rs_logger); prepare_psi_phi(); psi_phi_array.move_to_gpu(); psi_phi_timer.stop(); @@ -175,51 +169,12 @@ void StackSearch::prepare_batch_search(std::vector& search_list, int int num_to_search = search_list.size(); if (debug_info) std::cout << "Preparing to search " << num_to_search << " trajectories... \n" << std::flush; - gpu_search_list = TrajectoryList(search_list); - gpu_search_list.move_to_gpu(); + // gpu_search_list = TrajectoryList(search_list); + // gpu_search_list.move_to_gpu(); params.min_observations = min_observations; } - -std::vector StackSearch::search_batch(){ - if(!psi_phi_array.gpu_array_allocated()){ - throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?"); - } - - DebugTimer core_timer = DebugTimer("Running batch search", debug_info); - // Allocate a vector for the results and move it onto the GPU. - int search_width = params.x_start_max - params.x_start_min; - int search_height = params.y_start_max - params.y_start_min; - int num_search_pixels = search_width * search_height; - int max_results = num_search_pixels * RESULTS_PER_PIXEL; - - if (debug_info) { - std::cout << "Searching X=[" << params.x_start_min << ", " << params.x_start_max << "]" - << " Y=[" << params.y_start_min << ", " << params.y_start_max << "]\n"; - std::cout << "Allocating space for " << max_results << " results.\n"; - } - results = TrajectoryList(max_results); - results.move_to_gpu(); - - // Do the actual search on the GPU. - DebugTimer search_timer = DebugTimer("Running search", debug_info); -#ifdef HAVE_CUDA - deviceSearchFilter(psi_phi_array, params, gpu_search_list, results); -#else - throw std::runtime_error("Non-GPU search is not implemented."); -#endif - search_timer.stop(); - - results.move_to_cpu(); - DebugTimer sort_timer = DebugTimer("Sorting results", debug_info); - results.sort_by_likelihood(); - sort_timer.stop(); - core_timer.stop(); - - return results.get_batch(0, max_results); -} - void StackSearch::search(std::vector& search_list, int min_observations) { DebugTimer core_timer = DebugTimer("core search", rs_logger); @@ -283,7 +238,7 @@ std::vector StackSearch::search_batch(){ throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?"); } - DebugTimer core_timer = DebugTimer("Running batch search", debug_info); + DebugTimer core_timer = DebugTimer("Running batch search", rs_logger); // Allocate a vector for the results and move it onto the GPU. int search_width = params.x_start_max - params.x_start_min; int search_height = params.y_start_max - params.y_start_min; @@ -295,11 +250,11 @@ std::vector StackSearch::search_batch(){ << " Y=[" << params.y_start_min << ", " << params.y_start_max << "]\n"; std::cout << "Allocating space for " << max_results << " results.\n"; } - results = TrajectoryList(max_results); + results.resize(max_results); results.move_to_gpu(); // Do the actual search on the GPU. - DebugTimer search_timer = DebugTimer("Running search", debug_info); + DebugTimer search_timer = DebugTimer("Running search", rs_logger); #ifdef HAVE_CUDA deviceSearchFilter(psi_phi_array, params, gpu_search_list, results); #else @@ -308,7 +263,7 @@ std::vector StackSearch::search_batch(){ search_timer.stop(); results.move_to_cpu(); - DebugTimer sort_timer = DebugTimer("Sorting results", debug_info); + DebugTimer sort_timer = DebugTimer("Sorting results", rs_logger); results.sort_by_likelihood(); sort_timer.stop(); core_timer.stop(); From cbcaa9cd4e2a820797ff5371bc3bc555053aa2dc Mon Sep 17 00:00:00 2001 From: Vlad Date: Sat, 23 Mar 2024 21:10:15 +0000 Subject: [PATCH 05/17] added move assignment operators for both TrajectoryList and GPUArray so as to be able to init gpu_search_list outside of the main search func --- src/kbmod/search/gpu_array.h | 9 +++++++++ src/kbmod/search/stack_search.cpp | 4 ++-- src/kbmod/search/trajectory_list.cpp | 12 ++++++++++++ src/kbmod/search/trajectory_list.h | 1 + 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/kbmod/search/gpu_array.h b/src/kbmod/search/gpu_array.h index 3b57590c2..74e1cfe97 100644 --- a/src/kbmod/search/gpu_array.h +++ b/src/kbmod/search/gpu_array.h @@ -50,6 +50,15 @@ class GPUArray { GPUArray(const GPUArray&) = delete; GPUArray& operator=(GPUArray&) = delete; GPUArray& operator=(const GPUArray&) = delete; + GPUArray& operator=(GPUArray&& other) noexcept { + if (this != &other) { + size = other.size; + memory_size = other.memory_size; + gpu_ptr = other.gpu_ptr; + other.gpu_ptr = nullptr; + } + return *this; + } virtual ~GPUArray() { if (gpu_ptr != nullptr) free_gpu_memory(); diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index 198170f8f..dc6cb19eb 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -169,8 +169,8 @@ void StackSearch::prepare_batch_search(std::vector& search_list, int int num_to_search = search_list.size(); if (debug_info) std::cout << "Preparing to search " << num_to_search << " trajectories... \n" << std::flush; - // gpu_search_list = TrajectoryList(search_list); - // gpu_search_list.move_to_gpu(); + gpu_search_list = TrajectoryList(search_list); + gpu_search_list.move_to_gpu(); params.min_observations = min_observations; } diff --git a/src/kbmod/search/trajectory_list.cpp b/src/kbmod/search/trajectory_list.cpp index 76a9557a1..4608f6986 100644 --- a/src/kbmod/search/trajectory_list.cpp +++ b/src/kbmod/search/trajectory_list.cpp @@ -22,6 +22,18 @@ TrajectoryList::TrajectoryList(int max_list_size) { gpu_array.resize(max_size); } +// Move assignment operator. +TrajectoryList& TrajectoryList::operator=(TrajectoryList&& other) noexcept { + if (this != &other) { + max_size = other.max_size; + data_on_gpu = other.data_on_gpu; + cpu_list = std::move(other.cpu_list); + gpu_array = std::move(other.gpu_array); + other.data_on_gpu = false; + } + return *this; +} + TrajectoryList::TrajectoryList(const std::vector &prev_list) { max_size = prev_list.size(); cpu_list = prev_list; // Do a full copy. diff --git a/src/kbmod/search/trajectory_list.h b/src/kbmod/search/trajectory_list.h index 5cdffd4a8..dbab3e28e 100644 --- a/src/kbmod/search/trajectory_list.h +++ b/src/kbmod/search/trajectory_list.h @@ -29,6 +29,7 @@ class TrajectoryList { TrajectoryList(const TrajectoryList&) = delete; TrajectoryList& operator=(TrajectoryList&) = delete; TrajectoryList& operator=(const TrajectoryList&) = delete; + TrajectoryList& operator=(TrajectoryList&&) noexcept; // --- Getter functions ---------------- inline int get_size() const { return max_size; } From cc5d52db3c374024bb1d3e0346974f2184bccfc4 Mon Sep 17 00:00:00 2001 From: Vlad Date: Mon, 25 Mar 2024 21:29:25 +0000 Subject: [PATCH 06/17] cleaned up comments --- src/kbmod/batch_search.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/kbmod/batch_search.py b/src/kbmod/batch_search.py index 92813ef08..f1fac4c8b 100644 --- a/src/kbmod/batch_search.py +++ b/src/kbmod/batch_search.py @@ -18,9 +18,8 @@ def __enter__(self): This method is called when entering the context managed by the `with` statement. It prepares the batch search by calling `prepare_batch_search` on the StackSearch instance. """ - # Initialize or prepare memory for the batch search. + # Prepare memory for the batch search. self.stack_search.prepare_batch_search(self.search_list, self.min_observations) - # Return the object that should be used within the `with` block. Here, it's the StackSearch instance. return self.stack_search def __exit__(self, exc_type, exc_value, traceback): @@ -33,8 +32,6 @@ def __exit__(self, exc_type, exc_value, traceback): - exc_value: The exception value if an exception was raised. - traceback: The traceback object if an exception was raised. """ - # Clean up resources or delete initialized memory. + # Clean up self.stack_search.finish_search() - # Returning False means any exception raised within the `with` block will be propagated. - # To suppress exceptions, return True. return False From 2106801e1d03f5bdc597e9e48565a784cc7f0706 Mon Sep 17 00:00:00 2001 From: Vlad Date: Wed, 3 Apr 2024 23:16:54 +0100 Subject: [PATCH 07/17] search_batch functionality extracted from search_all and search_single_batch + removed move assignment operator in favour of set_trajectories + updated search to search_all across the python code --- README.md | 2 +- notebooks/Kbmod_Reference.ipynb | 54 ++++++------ src/kbmod/batch_search.py | 30 +++---- src/kbmod/run_search.py | 2 +- src/kbmod/search/gpu_array.h | 9 -- src/kbmod/search/pydocs/stack_search_docs.h | 5 +- src/kbmod/search/stack_search.cpp | 96 ++++++--------------- src/kbmod/search/stack_search.h | 8 +- src/kbmod/search/trajectory_list.cpp | 12 --- src/kbmod/search/trajectory_list.h | 1 - tests/test_readme_example.py | 2 +- tests/test_search.py | 20 ++--- tests/test_search_encode.py | 2 +- tests/test_search_filter.py | 2 +- 14 files changed, 89 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index 1c74ca704..560c2a22a 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ candidates = [trj for trj in gen] # Do the actual search. search = kb.StackSearch(stack) -search.search( +search.search_all( strategy, 7, # The minimum number of observations ) diff --git a/notebooks/Kbmod_Reference.ipynb b/notebooks/Kbmod_Reference.ipynb index be6a67c9b..2a34d6226 100644 --- a/notebooks/Kbmod_Reference.ipynb +++ b/notebooks/Kbmod_Reference.ipynb @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -96,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -200,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -219,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -236,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -287,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -305,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -324,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -353,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -369,7 +369,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -394,7 +394,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -406,7 +406,7 @@ "candidates = [trj for trj in gen]\n", "print(f\"Created {len(candidates)} candidate trajectories per pixel.\")\n", "\n", - "search.search(candidates, 2)" + "search.search_all(candidates, 2)" ] }, { @@ -419,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -436,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -471,7 +471,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -486,7 +486,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -504,9 +504,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Jeremy's KBMOD", + "display_name": "kbmod_env", "language": "python", - "name": "kbmod_jk" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -518,7 +518,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.1" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/src/kbmod/batch_search.py b/src/kbmod/batch_search.py index f1fac4c8b..f34c33c74 100644 --- a/src/kbmod/batch_search.py +++ b/src/kbmod/batch_search.py @@ -1,36 +1,30 @@ class BatchSearchManager: def __init__(self, stack_search, search_list, min_observations): - """ - Initialize the context manager with an instance of StackSearch, - the list of trajectories to search, and the minimum number of observations. + """Manages a batch search over a list of Trajectory instances using a StackSearch instance. - Parameters: - - stack_search: Instance of the StackSearch class. - - search_list: List of trajectories to search. - - min_observations: Minimum number of observations for the search. + Parameters + ---------- + stack_search: `StackSearch` + StackSearch instance to use for the batch search. + search_list: `list[Trajectory]` + List of Trajectory instances to search along the stack. + min_observations: `int` + Minimum number of observations required to consider a candidate. """ self.stack_search = stack_search self.search_list = search_list self.min_observations = min_observations def __enter__(self): - """ - This method is called when entering the context managed by the `with` statement. - It prepares the batch search by calling `prepare_batch_search` on the StackSearch instance. - """ # Prepare memory for the batch search. - self.stack_search.prepare_batch_search(self.search_list, self.min_observations) + self.stack_search.prepare_search(self.search_list, self.min_observations) return self.stack_search - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *_): """ This method is called when exiting the context. It cleans up resources by calling `finish_search` on the StackSearch instance. - - Parameters: - - exc_type: The exception type if an exception was raised in the `with` block. - - exc_value: The exception value if an exception was raised. - - traceback: The traceback object if an exception was raised. + We return False to indicate that we do not want to suppress any exceptions that may have been raised. """ # Clean up self.stack_search.finish_search() diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 4812f258c..412ec267b 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -197,7 +197,7 @@ def do_gpu_search(self, config, stack, trj_generator): # Do the actual search. candidates = [trj for trj in trj_generator] - search.search(candidates, int(config["num_obs"])) + search.search_all(candidates, int(config["num_obs"])) search_timer.stop() # Load the results. diff --git a/src/kbmod/search/gpu_array.h b/src/kbmod/search/gpu_array.h index 74e1cfe97..3b57590c2 100644 --- a/src/kbmod/search/gpu_array.h +++ b/src/kbmod/search/gpu_array.h @@ -50,15 +50,6 @@ class GPUArray { GPUArray(const GPUArray&) = delete; GPUArray& operator=(GPUArray&) = delete; GPUArray& operator=(const GPUArray&) = delete; - GPUArray& operator=(GPUArray&& other) noexcept { - if (this != &other) { - size = other.size; - memory_size = other.memory_size; - gpu_ptr = other.gpu_ptr; - other.gpu_ptr = nullptr; - } - return *this; - } virtual ~GPUArray() { if (gpu_ptr != nullptr) free_gpu_memory(); diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 8b71b7f75..4ce0edb62 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -188,8 +188,9 @@ static const auto DOC_StackSearch_prepare_batch_search = R"doc( The minimum number of observations for a trajectory to be considered. )doc"; -static const auto DOC_StackSearch_search_batch = R"doc( - Perform a batch search of the trajectories in the list. +static const auto DOC_StackSearch_search_single_batch = R"doc( + Perform a search on the given trajectories for the current batch. + Batch is defined by the parameters set `set_start_bounds_x` & `set_start_bounds_y`. Returns ------- diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index dc6cb19eb..d3137a650 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -160,7 +160,7 @@ void StackSearch::finish_search(){ gpu_search_list.move_to_cpu(); } -void StackSearch::prepare_batch_search(std::vector& search_list, int min_observations){ +void StackSearch::prepare_search(std::vector& search_list, int min_observations){ DebugTimer psi_phi_timer = DebugTimer("Creating psi/phi buffers", rs_logger); prepare_psi_phi(); psi_phi_array.move_to_gpu(); @@ -169,25 +169,26 @@ void StackSearch::prepare_batch_search(std::vector& search_list, int int num_to_search = search_list.size(); if (debug_info) std::cout << "Preparing to search " << num_to_search << " trajectories... \n" << std::flush; - gpu_search_list = TrajectoryList(search_list); + gpu_search_list.set_trajectories(search_list); gpu_search_list.move_to_gpu(); params.min_observations = min_observations; } -void StackSearch::search(std::vector& search_list, int min_observations) { - DebugTimer core_timer = DebugTimer("core search", rs_logger); +void StackSearch::search_all(std::vector& search_list, int min_observations) { + prepare_search(search_list, min_observations); + search_batch(); + finish_search(); +} - DebugTimer psi_phi_timer = DebugTimer("creating psi/phi buffers", rs_logger); - prepare_psi_phi(); - psi_phi_array.move_to_gpu(); - psi_phi_timer.stop(); +void StackSearch::search_batch(){ + if(!psi_phi_array.gpu_array_allocated()){ + throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?"); + } + + DebugTimer core_timer = DebugTimer("Running batch search", rs_logger); + int max_results = extract_max_results(); - // Allocate a vector for the results and move it onto the GPU. - int search_width = params.x_start_max - params.x_start_min; - int search_height = params.y_start_max - params.y_start_min; - int num_search_pixels = search_width * search_height; - int max_results = num_search_pixels * RESULTS_PER_PIXEL; // staple C++ std::stringstream logmsg; logmsg << "Searching X=[" << params.x_start_min << ", " << params.x_start_max << "] " @@ -195,24 +196,12 @@ void StackSearch::search(std::vector& search_list, int min_observati << "Allocating space for " << max_results << " results."; rs_logger->info(logmsg.str()); + results.resize(max_results); results.move_to_gpu(); - // Allocate space for the search list and move that to the GPU. - int num_to_search = search_list.size(); - - logmsg.str(""); - logmsg << search_list.size() << " trajectories..."; - rs_logger->info(logmsg.str()); - - TrajectoryList gpu_search_list(search_list); - gpu_search_list.move_to_gpu(); - - // Set the minimum number of observations. - params.min_observations = min_observations; - - // Do the actual search on the GPU. - DebugTimer search_timer = DebugTimer("search execution", rs_logger); + // Do the actual search on the GPU. + DebugTimer search_timer = DebugTimer("Running search", rs_logger); #ifdef HAVE_CUDA deviceSearchFilter(psi_phi_array, params, gpu_search_list, results); #else @@ -220,58 +209,27 @@ void StackSearch::search(std::vector& search_list, int min_observati #endif search_timer.stop(); - // Move data back to CPU to unallocate GPU space (this will happen automatically - // for gpu_search_list when the object goes out of scope, but we do it explicitly here). - psi_phi_array.clear_from_gpu(); results.move_to_cpu(); - gpu_search_list.move_to_cpu(); - DebugTimer sort_timer = DebugTimer("Sorting results", rs_logger); results.sort_by_likelihood(); sort_timer.stop(); core_timer.stop(); } +std::vector StackSearch::search_single_batch(){ + int max_results = extract_max_results(); + search_batch(); + return results.get_batch(0, max_results); +} -std::vector StackSearch::search_batch(){ - if(!psi_phi_array.gpu_array_allocated()){ - throw std::runtime_error("PsiPhiArray array not allocated on GPU. Did you forget to call prepare_search?"); - } - DebugTimer core_timer = DebugTimer("Running batch search", rs_logger); - // Allocate a vector for the results and move it onto the GPU. +int StackSearch::extract_max_results(){ int search_width = params.x_start_max - params.x_start_min; int search_height = params.y_start_max - params.y_start_min; int num_search_pixels = search_width * search_height; - int max_results = num_search_pixels * RESULTS_PER_PIXEL; - - if (debug_info) { - std::cout << "Searching X=[" << params.x_start_min << ", " << params.x_start_max << "]" - << " Y=[" << params.y_start_min << ", " << params.y_start_max << "]\n"; - std::cout << "Allocating space for " << max_results << " results.\n"; - } - results.resize(max_results); - results.move_to_gpu(); - - // Do the actual search on the GPU. - DebugTimer search_timer = DebugTimer("Running search", rs_logger); -#ifdef HAVE_CUDA - deviceSearchFilter(psi_phi_array, params, gpu_search_list, results); -#else - throw std::runtime_error("Non-GPU search is not implemented."); -#endif - search_timer.stop(); - - results.move_to_cpu(); - DebugTimer sort_timer = DebugTimer("Sorting results", rs_logger); - results.sort_by_likelihood(); - sort_timer.stop(); - core_timer.stop(); - - return results.get_batch(0, max_results); + return num_search_pixels * RESULTS_PER_PIXEL; } - std::vector StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi) { prepare_psi_phi(); @@ -321,7 +279,7 @@ static void stack_search_bindings(py::module& m) { py::class_(m, "StackSearch", pydocs::DOC_StackSearch) .def(py::init()) - .def("search", &ks::search, pydocs::DOC_StackSearch_search) + .def("search_all", &ks::search_all, pydocs::DOC_StackSearch_search) .def("evaluate_single_trajectory", &ks::evaluate_single_trajectory, pydocs::DOC_StackSearch_evaluate_single_trajectory) .def("search_linear_trajectory", &ks::search_linear_trajectory, @@ -349,8 +307,8 @@ static void stack_search_bindings(py::module& m) { .def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi) .def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results) .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results) - .def("search_batch", &ks::search_batch, pydocs::DOC_StackSearch_search_batch) - .def("prepare_batch_search", &ks::prepare_batch_search, pydocs::DOC_StackSearch_prepare_batch_search) + .def("search_single_batch", &ks::search_single_batch, pydocs::DOC_StackSearch_search_single_batch) + .def("prepare_search", &ks::prepare_search, pydocs::DOC_StackSearch_prepare_batch_search) .def("finish_search", &ks::finish_search, pydocs::DOC_StackSearch_finish_search); } #endif /* Py_PYTHON_H */ diff --git a/src/kbmod/search/stack_search.h b/src/kbmod/search/stack_search.h index 20b77c263..befc038b9 100644 --- a/src/kbmod/search/stack_search.h +++ b/src/kbmod/search/stack_search.h @@ -29,6 +29,7 @@ using Point = indexing::Point; using Image = search::Image; class StackSearch { + int extract_max_results(); public: StackSearch(ImageStack& imstack); @@ -50,10 +51,11 @@ class StackSearch { // The primary search functions void evaluate_single_trajectory(Trajectory& trj); Trajectory search_linear_trajectory(short x, short y, float vx, float vy); - void prepare_batch_search(std::vector& search_list, int min_observations); + void prepare_search(std::vector& search_list, int min_observations); + std::vector search_single_batch(); + void search_batch(); + void search_all(std::vector& search_list, int min_observations); void finish_search(); - std::vector search_batch(); - void search(std::vector& search_list, int min_observations); // Gets the vector of result trajectories from the grid search. std::vector get_results(int start, int end); diff --git a/src/kbmod/search/trajectory_list.cpp b/src/kbmod/search/trajectory_list.cpp index 4608f6986..76a9557a1 100644 --- a/src/kbmod/search/trajectory_list.cpp +++ b/src/kbmod/search/trajectory_list.cpp @@ -22,18 +22,6 @@ TrajectoryList::TrajectoryList(int max_list_size) { gpu_array.resize(max_size); } -// Move assignment operator. -TrajectoryList& TrajectoryList::operator=(TrajectoryList&& other) noexcept { - if (this != &other) { - max_size = other.max_size; - data_on_gpu = other.data_on_gpu; - cpu_list = std::move(other.cpu_list); - gpu_array = std::move(other.gpu_array); - other.data_on_gpu = false; - } - return *this; -} - TrajectoryList::TrajectoryList(const std::vector &prev_list) { max_size = prev_list.size(); cpu_list = prev_list; // Do a full copy. diff --git a/src/kbmod/search/trajectory_list.h b/src/kbmod/search/trajectory_list.h index dbab3e28e..5cdffd4a8 100644 --- a/src/kbmod/search/trajectory_list.h +++ b/src/kbmod/search/trajectory_list.h @@ -29,7 +29,6 @@ class TrajectoryList { TrajectoryList(const TrajectoryList&) = delete; TrajectoryList& operator=(TrajectoryList&) = delete; TrajectoryList& operator=(const TrajectoryList&) = delete; - TrajectoryList& operator=(TrajectoryList&&) noexcept; // --- Getter functions ---------------- inline int get_size() const { return max_size; } diff --git a/tests/test_readme_example.py b/tests/test_readme_example.py index 0e2a704ad..62f19d24c 100644 --- a/tests/test_readme_example.py +++ b/tests/test_readme_example.py @@ -53,7 +53,7 @@ def test_make_and_copy(self): # Do the actual search. search = kb.StackSearch(stack) - search.search( + search.search_all( candidates, 7, # The minimum number of observations ) diff --git a/tests/test_search.py b/tests/test_search.py index f28705401..ebab5c6ac 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -211,7 +211,7 @@ def test_search_linear_trajectory(self): @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") def test_results(self): candidates = [trj for trj in self.trj_gen] - self.search.search(candidates, int(self.img_count / 2)) + self.search.search_all(candidates, int(self.img_count / 2)) results = self.search.get_results(0, 10) best = results[0] @@ -227,7 +227,7 @@ def test_results_extended_bounds(self): self.search.set_start_bounds_y(-10, self.dim_y + 10) candidates = [trj for trj in self.trj_gen] - self.search.search(candidates, int(self.img_count / 2)) + self.search.search_all(candidates, int(self.img_count / 2)) results = self.search.get_results(0, 10) best = results[0] @@ -243,7 +243,7 @@ def test_results_reduced_bounds(self): self.search.set_start_bounds_y(5, self.dim_y - 5) candidates = [trj for trj in self.trj_gen] - self.search.search(candidates, int(self.img_count / 2)) + self.search.search_all(candidates, int(self.img_count / 2)) results = self.search.get_results(0, 10) best = results[0] @@ -291,7 +291,7 @@ def test_results_off_chip(self): search.set_start_bounds_x(-10, self.dim_x + 10) search.set_start_bounds_y(-10, self.dim_y + 10) candidates = [trj for trj in self.trj_gen] - search.search(candidates, int(self.img_count / 2)) + search.search_all(candidates, int(self.img_count / 2)) # Check the results. results = search.get_results(0, 10) @@ -955,12 +955,12 @@ def test_search_batch(self): im_list = stack.get_images() # Create a new list of LayeredImages with the added object. new_im_list = [] - for im, time in zip(im_list, stack.build_zeroed_times()): + + for i in range(count): + im = stack.get_single_image(i) + time = stack.get_zeroed_time(i) add_fake_object(im, 5.0 + (time * 8.0), 35.0 + (time * 0.0), 25000.0) - new_im_list.append(im) - # Save these images in a new ImageStack and create a StackSearch object from them. - stack = ImageStack(new_im_list) search = StackSearch(stack) # Sample generator @@ -970,7 +970,7 @@ def test_search_batch(self): candidates = [trj for trj in gen] # Peform complete in-memory search - search.search(candidates, min_observations) + search.search_all(candidates, min_observations) total_results = width * height * results_per_pixel # Need to filter as the fields are undefined otherwise results = [ @@ -986,7 +986,7 @@ def test_search_batch(self): batch_search.set_start_bounds_x(i, i + 5) for j in range(0, height, 5): batch_search.set_start_bounds_y(j, j + 5) - batch_results.extend(batch_search.search_batch()) + batch_results.extend(batch_search.search_single_batch()) # Need to filter as the fields are undefined otherwise batch_results = [ diff --git a/tests/test_search_encode.py b/tests/test_search_encode.py index fc1dbd0e8..5de0aaaac 100644 --- a/tests/test_search_encode.py +++ b/tests/test_search_encode.py @@ -72,7 +72,7 @@ def test_different_encodings(self): search = StackSearch(self.stack) search.enable_gpu_encoding(encoding_bytes) candidates = [trj for trj in self.trj_gen] - search.search(candidates, int(self.img_count / 2)) + search.search_all(candidates, int(self.img_count / 2)) results = search.get_results(0, 10) best = results[0] diff --git a/tests/test_search_filter.py b/tests/test_search_filter.py index dbbbe99b6..ccbbb321c 100644 --- a/tests/test_search_filter.py +++ b/tests/test_search_filter.py @@ -66,7 +66,7 @@ def setUp(self): self.max_angle, ) candidates = [trj for trj in trj_gen] - self.search.search(candidates, int(self.img_count / 2)) + self.search.search_all(candidates, int(self.img_count / 2)) @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") def test_results(self): From 96e93234437db475b3b37b44957c7d257491b05b Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 4 Apr 2024 13:14:12 -0400 Subject: [PATCH 08/17] Update README.md --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 1c74ca704..3e2f7406d 100644 --- a/README.md +++ b/README.md @@ -158,3 +158,9 @@ print(results) ## License The software is open source and available under the BSD license. + +## Acknowledgements + +This project is supported by Schmidt Sciences. + +The team acknowledges support from the DIRAC Institute in the Department of Astronomy at the University of Washington. The DIRAC Institute is supported through generous gifts from the Charles and Lisa Simonyi Fund for Arts and Sciences, and the Washington Research Foundation. From 55a0449d297002f1c8591708bd8fd6fe9715bdd7 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 5 Apr 2024 09:21:32 -0400 Subject: [PATCH 09/17] Add helper functions for working with result tables --- notebooks/kbmod_results_and_filtering.ipynb | 67 ++++++++++++++++++++- src/kbmod/result_list.py | 17 ++++++ tests/test_result_list.py | 24 ++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) diff --git a/notebooks/kbmod_results_and_filtering.ipynb b/notebooks/kbmod_results_and_filtering.ipynb index af311af79..0de5b460c 100644 --- a/notebooks/kbmod_results_and_filtering.ipynb +++ b/notebooks/kbmod_results_and_filtering.ipynb @@ -292,7 +292,70 @@ "metadata": {}, "outputs": [], "source": [ - "results.to_table()" + "tbl = results.to_table()\n", + "tbl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Depending on what columns are set in the result table, we can visualize results directly. For example we can look at the co-added stamps for the first few results by accessing them directly from the `stamp` column." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 3)\n", + "\n", + "for i in range(3):\n", + " stamp = tbl[\"stamp\"][i].reshape([21, 21])\n", + " axs[i].imshow(stamp, cmap=\"gray\")\n", + " axs[i].set_title(f\"Codded stamp {i}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use the AstroPy table to do filtering and then propagate the results back to the result list via the `index` column. For example if we want to filter on trajectories starting in a box of x=[100, 200] and y=[200,300] we would use a mask on the astropy table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask = (\n", + " (tbl[\"trajectory_x\"] >= 100)\n", + " & (tbl[\"trajectory_x\"] <= 200)\n", + " & (tbl[\"trajectory_y\"] >= 200)\n", + " & (tbl[\"trajectory_x\"] <= 300)\n", + ")\n", + "tbl = tbl[mask]\n", + "tbl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then propagate the results back to the result list by using the `sync_table_indices()` function. This function will update *both* the `ResultList`'s entries and the table's index column so they are consistent. This is a short term work around until we rewrite `ResultList` as a full table itself." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results.sync_table_indices(tbl)\n", + "for res in results.results:\n", + " print(res.trajectory)" ] }, { @@ -319,7 +382,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.1" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/src/kbmod/result_list.py b/src/kbmod/result_list.py index 58beeed77..5b3e158c5 100644 --- a/src/kbmod/result_list.py +++ b/src/kbmod/result_list.py @@ -890,6 +890,20 @@ def revert_filter(self, label=None): return self + def sync_table_indices(self, table): + """Syncs the entries in the current list with those in a table version + of the results by filtering on the 'index' column. Rows that do not + appear in the table are removed from the list. The indices in the table + are updated to match the new ordering in the result list. + + Parameters + ---------- + table : `astropy.table.Table` + A table with the data as generated by to_table(). + """ + self.filter_results(table["index"], "Table filtered") + table["index"] = range(len(table)) + def to_table(self, filtered_label=None, append_times=False): """Extract the results into an astropy table. @@ -943,6 +957,9 @@ def to_table(self, filtered_label=None, append_times=False): if append_times: table_dict["all_times"].append(self._all_times) + # Append the index information + table_dict["index"] = [i for i in range(len(list_ref))] + return Table(table_dict) def write_table(self, filename, overwrite=True, keep_all_stamps=False): diff --git a/tests/test_result_list.py b/tests/test_result_list.py index f2ca052e3..b44cbcd2a 100644 --- a/tests/test_result_list.py +++ b/tests/test_result_list.py @@ -497,6 +497,7 @@ def test_to_from_table(self): self.assertEqual(len(table["phi_curve"][i]), self.num_times) self.assertEqual(len(table["pred_ra"][i]), self.num_times) self.assertEqual(len(table["pred_dec"][i]), self.num_times) + self.assertEqual(table["index"][i], i) for j in range(self.num_times): self.assertEqual(table["all_stamps"][i][j].shape, (10, 10)) @@ -531,6 +532,29 @@ def test_to_from_table(self): with self.assertRaises(KeyError): rs.to_table(filtered_label="test2") + def test_sync_table_indices(self): + """Check that we correctly sync the table data with an existing ResultList""" + rs = ResultList(self.times, track_filtered=True) + for i in range(10): + trj = make_trajectory(x=i, y=2 * i, vx=100.0 - i, vy=-i, obs_count=self.num_times - i) + row = ResultRow(trj, self.num_times) + rs.append_result(row) + table = rs.to_table() + self.assertEqual(len(table), 10) + + # Filter the table to specific indices. + inds_to_keep = [0, 1, 3, 7, 9] + table = table[inds_to_keep] + self.assertEqual(len(table), len(inds_to_keep)) + + # Sync with the ResultList and confirm both are updated. + rs.sync_table_indices(table) + self.assertEqual(len(rs), len(inds_to_keep)) + for i, row in enumerate(rs.results): + self.assertEqual(row.trajectory.x, inds_to_keep[i]) + self.assertEqual(table["trajectory_x"][i], inds_to_keep[i]) + self.assertEqual(table["index"][i], i) + def test_to_from_table_file(self): rs = ResultList(self.times, track_filtered=False) for i in range(10): From 3612d959a2762de43f808683aa696c07191905c3 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Fri, 5 Apr 2024 11:56:13 -0700 Subject: [PATCH 10/17] Fetch VDR Data from Butler For Region Search (#531) * Collect VDR data from butler * Prevent importing LSST for unit tests * Lint fix * Change table representation * lint fixes * Update doc strings * Configure max_workers * Simplify LSST mocking * [deploy_alpha] docs formatting * Remove unused import and fix indentation * lint fix --- src/kbmod/region_search.py | 344 ++++++++++++++++++++++++++++++++++++ tests/test_butlerstd.py | 153 +--------------- tests/test_region_search.py | 200 +++++++++++++++++++++ tests/utils/__init__.py | 1 + tests/utils/mock_butler.py | 273 ++++++++++++++++++++++++++++ tests/utils/mock_fits.py | 7 + 6 files changed, 826 insertions(+), 152 deletions(-) create mode 100644 src/kbmod/region_search.py create mode 100644 tests/test_region_search.py create mode 100644 tests/utils/mock_butler.py diff --git a/src/kbmod/region_search.py b/src/kbmod/region_search.py new file mode 100644 index 000000000..9c7781b8b --- /dev/null +++ b/src/kbmod/region_search.py @@ -0,0 +1,344 @@ +try: + import lsst.daf.butler as dafButler +except ImportError: + raise ImportError("LSST stack not found. Please install the LSST stack to use this module.") + +from concurrent.futures import ProcessPoolExecutor, as_completed + +from astropy.table import Table + + +def _chunked_data_ids(dataIds, chunk_size=200): + """Helper function to yield successive chunk_size chunks from dataIds.""" + for i in range(0, len(dataIds), chunk_size): + yield dataIds[i : i + chunk_size] + + +class RegionSearch: + """ + A class for searching through a dataset for data suitable for KBMOD processing, + + With a path to a butler repository, it provides helper methods for basic exploration of the data, + methods for retrieving data from the butler for search, and transformation of the data + into a KBMOD ImageCollection for further processing. + + Note that currently we store results from the butler in an Astropy Table. In the future, + we will likely want to use a database for faster performance and to handle processing of + datasets that are too large to fit in memory. + """ + + def __init__( + self, + repo_path, + collections, + dataset_types, + butler=None, + visit_info_str="Exposure.visitInfo", + max_workers=None, + fetch_data=False, + ): + """ + Parameters + ---------- + repo_path : `str` + The path to the LSST butler repository. + collections : `list[str]` + The list of desired collection names within the Butler repository` + dataset_types : `list[str]` + The list of desired dataset types within the Butler repository. + butler : `lsst.daf.butler.Butler`, optional + The Butler object to use for data access. If None, a new Butler object will be created from `repo_path`. + visit_info_str : `str` + The name used when querying the butler for VisitInfo for exposures. Default is "Exposure.visitInfo". + max_workers : `int`, optional + The maximum number of workers to use in parallel processing. Note that each parallel worker will instantiate its own Butler + objects. If not provided, parallel processing is disabled. + fetch_data: `bool`, optional + If True, fetch the VDR data when the object is created. Default is True. + """ + self.repo_path = repo_path + if butler is not None: + self.butler = butler + else: + self.butler = dafButler.Butler(self.repo_path) + + self.collections = collections + self.dataset_types = dataset_types + self.visit_info_str = visit_info_str + self.max_workers = max_workers + + # Create an empty table to store the VDR (Visit, Detector, Region) data from the butler. + self.vdr_data = Table() + if fetch_data: + # Fetch the VDR data from the butler + self.vdr_data = self.fetch_vdr_data() + + @staticmethod + def get_collection_names(butler=None, repo_path=None): + """ + Get the list of the names of available collections in a butler repository. + Parameters + ---------- + butler | repo_path : `lsst.daf.butler.Butler` | `str` + The Butler object or a path to the LSST butler repository from which to create a butler. + Returns + ------- + collections : `list[str]` + The list of the names of available collections in the butler repository. + """ + if butler is None: + if repo_path is None: + raise ValueError("Must specify one of repo_path or butler") + butler = dafButler.Butler(repo_path) + return butler.registry.queryCollections() + + @staticmethod + def get_dataset_type_freq(butler=None, repo_path=None, collections=None): + """ + Get the frequency of refs per dataset types across the given collections. + + Parameters + ---------- + butler | repo_path : `lsst.daf.butler.Butler` | str + The Butler object or a path to the LSST butler repository from which to create a butler. + collections : `list[str]`, optional + The names of collections from which we can querry the dataset type frequencies. If None, use all collections. + Returns + ------- + ref_freq : `dict` + A dictionary of frequency of refs per dataset type in the given collections. + """ + if butler is None: + if repo_path is None: + raise ValueError("Must specify one of repo_path or butler") + butler = dafButler.Butler(repo_path) + + # Iterate over all dataset types and count the frequency of refs associated with each + ref_freq = {} + for dt in butler.registry.queryDatasetTypes(): + refs = None + if collections: + refs = butler.registry.queryDatasets(dt, collections=collections) + else: + refs = butler.registry.queryDatasets(dt) + if refs is not None: + if dt.name not in ref_freq: + ref_freq[dt.name] = 0 + ref_freq[dt.name] += refs.count(exact=True, discard=True) + + return ref_freq + + def is_parallel(self): + """Returns True if parallel processing was requested.""" + return self.max_workers is not None + + def new_butler(self): + """Instantiates a new Butler object from the repo_path.""" + return dafButler.Butler(self.repo_path) + + def set_collections(self, collections): + """ + Set which collections to use when querying data from the butler. + + Parameters + ---------- + collections : `list[str]` + The list of desired collections to use for the region search. + """ + self.collections = collections + + def set_dataset_types(self, dataset_types): + """ + Set the desired dataset types to use when querying the butler. + """ + self.dataset_types = dataset_types + + def get_vdr_data(self): + """Returns the VDR data""" + return self.vdr_data + + def fetch_vdr_data(self, collections=None, dataset_types=None): + """ + Fetches the VDR (Visit, Detector, Region) data for the given collections and dataset types. + + VDRs are the regions of the detector that are covered by a visit. They contain what we need in terms of + regions hashes and unique dataIds. + + Parameters + ---------- + collections : `list[str]` + The names of the collection to get the dataset type stats for. If None, use self.collections. + dataset_types : `list[str]` + The names of the dataset types to get the dataset type stats for. If None, use self.dataset_types. + + Returns + ------- + vdr_data : `astropy.table.Table` + An Astropy Table containing the VDR data and associated URIs and RA/Dec center coordinates. + """ + if not collections: + if not self.collections: + raise ValueError("No collections specified") + collections = self.collections + + if not dataset_types: + if not self.dataset_types: + raise ValueError("No dataset types specified") + dataset_types = self.dataset_types + + vdr_dict = {"data_id": [], "region": [], "detector": [], "uri": [], "center_coord": []} + + for dt in dataset_types: + refs = self.butler.registry.queryDimensionRecords( + "visit_detector_region", datasets=dt, collections=collections + ) + for ref in refs: + vdr_dict["data_id"].append(ref.dataId) + vdr_dict["region"].append(ref.region) + vdr_dict["detector"].append(ref.detector) + vdr_dict["center_coord"].append(self.get_center_ra_dec(ref.region)) + + # Now that we have the initial VDR data ids, we can also fetch the associated URIs + vdr_dict["uri"] = self.get_uris(vdr_dict["data_id"]) + + # return as an Astropy Table + return Table(vdr_dict) + + def get_instruments(self, data_ids=None, first_instrument_only=False): + """ + Get the instruments for the given VDR data ids. + + Parameters + ---------- + data_ids : `iterable(dict)`, optional + A collection of VDR data IDs to get the instruments for. By default uses previously fetched data_ids + first_instrument_only : `bool`, optional + If True, return only the first instrument we find. Default is False. + + Returns + ------- + instruments : `list` + A list of instrument objects for the given data IDs. + """ + if data_ids is None: + data_ids = self.vdr_data["data_id"] + + instruments = [] + for data_id in data_ids: + instrument = self.butler.get(self.visit_info_str, dataId=data_id, collections=self.collections) + if first_instrument_only: + return [instrument] + instruments.append(instrument) + return instruments + + def _get_uris_serial(self, data_ids, dataset_types=None, collections=None, butler=None): + """Fetch URIs for a list of dataIds in serial fashion. + + Parameters + ---------- + data_ids : `iterable(dict)` + A collection of data Ids to fetch URIs for. + dataset_types : `list[str]` + The dataset types to use when fetching URIs. If None, use self.dataset_types. + collections : `list[str]` + The collections to use when fetching URIs. If None, use self.collections. + butler : `lsst.daf.butler.Butler`, optional + The Butler object to use for data access. If None, use self.butler. + + Returns + ------- + uris : `list[str]` + The list of URIs for the given data Ids. + """ + if butler is None: + butler = self.butler + if dataset_types is None: + if self.dataset_types is None: + raise ValueError("No dataset types specified") + dataset_types = self.dataset_types + if collections is None: + if self.collections is None: + raise ValueError("No collections specified") + collections = self.collections + + uris = [] + for data_id in data_ids: + try: + uri = self.butler.getURI(dataset_types[0], dataId=data_id, collections=collections) + uri = uri.geturl() # Convert to URL string + uris.append(uri) + except Exception as e: + print(f"Failed to retrieve path for dataId {data_id}: {e}") + return uris + + def get_uris(self, data_ids, dataset_types=None, collections=None): + """ + Get the URIs for the given dataIds. + + Parameters + ---------- + data_ids : `iterable(dict)` + A collection of data Ids to fetch URIs for. + dataset_types : `list[str]` + The dataset types to use when fetching URIs. If None, use self.dataset_types. + collections : `list[str]` + The collections to use when fetching URIs. If None, use self.collections. + + Returns + ------- + uris : `list[str]` + The list of URIs for the given data Ids. + """ + if dataset_types is None: + if self.dataset_types is None: + raise ValueError("No dataset types specified") + dataset_types = self.dataset_types + if collections is None: + if self.collections is None: + raise ValueError("No collections specified") + collections = self.collections + + if not self.is_parallel(): + return self._get_uris_serial(data_ids, dataset_types, collections) + + # Divide the data_ids into chunks to be processed in parallel + data_id_chunks = list(_chunked_data_ids(data_ids)) + + # Use a ProcessPoolExecutor to fetch URIs in parallel + uris = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = [ + executor.submit( + self._get_uris_serial, + chunk, + dataset_types=dataset_types, + collections=collections, + butler=self.new_butler(), + ) + for chunk in data_id_chunks + ] + for future in as_completed(futures): + uris.extend(future.result()) + + return uris + + def get_center_ra_dec(self, region): + """ + Get the center RA and Dec for the given region. + + Parameters + ---------- + region : `lsst::sphgeom::Region Class Reference` + The region for which to get the center RA and Dec. + + Returns + ------- + ra, dec : `float`, `float` + The center RA and Dec in degrees. + """ + # Note we get the 2D boundingBox (not the boundingBox3d) from a region. + # We then extract the RA and Dec from the center of the bounding box. + bbox_center = region.getBoundingBox().getCenter() + ra = bbox_center.getLon().asDegrees() + dec = bbox_center.getLat().asDegrees() + return ra, dec diff --git a/tests/test_butlerstd.py b/tests/test_butlerstd.py index e8b6bdfe8..f9f2307f1 100644 --- a/tests/test_butlerstd.py +++ b/tests/test_butlerstd.py @@ -8,7 +8,7 @@ from astropy.wcs import WCS import numpy as np -from utils import DECamImdiffFactory +from utils import DECamImdiffFactory, MockButler, Registry, Datastore, DatasetRef, DatasetId, dafButler from kbmod import PSF, Standardizer, StandardizerConfig from kbmod.standardizers import ButlerStandardizer, ButlerStandardizerConfig, KBMODV1Config @@ -18,157 +18,6 @@ FitsFactory = DECamImdiffFactory() -# Patch Rubin Middleware out of existence -class Registry: - def getDataset(self, ref): - return ref - - -class Datastore: - def __init__(self, root): - self.root = root - - -class DatasetRef: - def __init__(self, ref): - self.ref = ref - self.run = ref - - -class DatasetId: - def __init__(self, ref): - self.id = ref - self.ref = ref - self.run = ref - - -class MockButler: - """Mocked Vera C. Rubin Data Butler functionality sufficient to be used in - a ButlerStandardizer. - - The mocked .get method will return an mocked Exposure object with all the, - generally, expected attributes (info, visitInfo, image, variance, mask, - wcs). Most of these attributes are mocked such that they return an integer - id, which is then used in a FitsFactory to read out the serialized header - of some underlying real data. Particularly, we target DECam, such that - outputs of ButlerStandardizer and KBMODV1 are comparable. - - By default the mocked image arrays will contain the empty - `Butler.empty_arrat` but providing a callable `mock_images_f`, that takes - in a single mocked Exposure object, and assigns the: - * mocked.image.array - * mocked.variance.array - * mocked.mask.array - attributes can be used to customize the returned arrays. - """ - - def __init__(self, root, ref=None, mock_images_f=None): - self.datastore = Datastore(root) - self.registry = Registry() - self.current_ref = ref - self.mockImages = mock_images_f - - def getURI(self, ref, collections=None): - mocked = mock.Mock(name="ButlerURI") - mocked.geturl.return_value = f"file:/{self.datastore.root}" - return mocked - - def getDataset(self, datid): - return self.get(datid) - - def get(self, ref, collections=None): - # Butler.get gets a DatasetRef, but can take an DatasetRef or DatasetId - # DatasetId is type alias for UUID's, which are hex-strings when - # serialized. We short it to an integer, because We use an integer to - # read a particular file in FitsFactory. This means we got to cast - # all these different objects to int (somehow). Firstly, it's one of - # our mocks, dig out the value we really care about: - if isinstance(ref, (DatasetId, DatasetRef)): - ref = ref.ref - - # that value can be an int, a simple str(int) (used in testing only), - # a large hex UUID string, or a UUID object. Duck-type them to int - if isinstance(ref, uuid.UUID): - ref = ref.int - elif isinstance(ref, str): - try: - ref = uuid.UUID(ref).int - except (ValueError, AttributeError): - # likely a str(int) - pass - - # Cast to int to cover for all eventualities - ref = int(ref) - self.current_ref = ref - - # Finally we can proceed with mocking. Butler.get (the way we use it at - # least) returns an Exposure[F/I/...] object. Exposure is like our - # LayeredImage. We need to mock every attr, method and property that we - # call the standardizer. We shortcut the results to match the KBMODV1. - hdul = FitsFactory.get_fits(ref % FitsFactory.n_files, spoof_data=True) - prim = hdul["PRIMARY"].header - - mocked = mock.Mock( - name="Exposure", - spec_set=[ - "visitInfo", - "info", - "hasWcs", - "getWidth", - "getHeight", - "getFilter", - "image", - "variance", - "mask", - "wcs", - ], - ) - - # General metadata mocks - mocked.visitInfo.date.toAstropy.return_value = Time(hdul["PRIMARY"].header["DATE-AVG"], format="isot") - mocked.info.id = prim["EXPID"] - mocked.getWidth.return_value = hdul[1].header["NAXIS1"] - mocked.getHeight.return_value = hdul[1].header["NAXIS2"] - mocked.info.getFilter().physicalLabel = prim["FILTER"] - - # Rubin Sci. Pipes. return their own internal SkyWcs object. We mock a - # Header that'll work with ButlerStd instead. It works because in the - # STD we cast SkyWcs to dict-like thing, from which we make a WCS. What - # happens if SkyWcs changes though? - wcshdr = WCS(hdul[1].header).to_header(relax=True) - wcshdr["NAXIS1"] = hdul[1].header["NAXIS1"] - wcshdr["NAXIS2"] = hdul[1].header["NAXIS2"] - mocked.hasWcs.return_value = True - mocked.wcs.getFitsMetadata.return_value = wcshdr - - # Mocking the images consists of using the Factory default, then - # invoking any user specified method on the mocked exposure obj. - mocked.image.array = hdul["IMAGE"].data - mocked.variance.array = hdul["VARIANCE"].data - mocked.mask.array = hdul["MASK"].data - if self.mockImages is not None: - self.mockImages(mocked) - - # Same issue as with WCS, what if there's a change in definition of the - # mask plane? Note the change in definition of a flag to exponent only. - bit_flag_map = {} - for key, val in KBMODV1Config.bit_flag_map.items(): - bit_flag_map[key] = int(np.log2(val)) - mocked.mask.getMaskPlaneDict.return_value = bit_flag_map - - return mocked - - -class dafButler: - """Intercepts calls ``import lsst.daf.butler as dafButler`` and shortcuts - them to our mocks. - """ - - DatasetRef = DatasetRef - DatasetId = DatasetId - Butler = MockButler - - @mock.patch.dict( "sys.modules", { diff --git a/tests/test_region_search.py b/tests/test_region_search.py new file mode 100644 index 000000000..b89b48801 --- /dev/null +++ b/tests/test_region_search.py @@ -0,0 +1,200 @@ +import unittest + +# A path for our mock repository +MOCK_REPO_PATH = "far/far/away" + +from unittest import mock +from utils import DatasetRef, DatasetId, dafButler, MockButler + +with mock.patch.dict( + "sys.modules", + { + "lsst": mock.MagicMock(), # General mock for the LSST package import + "lsst.daf.butler": dafButler, + "lsst.daf.butler.core.DatasetRef": DatasetRef, + "lsst.daf.butler.core.DatasetId": DatasetId, + }, +): + from kbmod import region_search + + +class TestRegionSearch(unittest.TestCase): + """ + Test the region search functionality. + """ + + def setUp(self): + self.butler = MockButler(MOCK_REPO_PATH) + + # For the default collections and dataset types, we'll just use the first two of each + self.default_collections = self.butler.registry.queryCollections()[:2] + self.default_datasetTypes = [dt.name for dt in self.butler.registry.queryDatasetTypes()][:2] + + self.rs = region_search.RegionSearch( + MOCK_REPO_PATH, + self.default_collections, + self.default_datasetTypes, + butler=self.butler, + ) + + def test_init(self): + """ + Test that the region search object can be initialized. + """ + rs = region_search.RegionSearch(MOCK_REPO_PATH, [], [], butler=self.butler, fetch_data=False) + self.assertTrue(rs is not None) + self.assertEqual(0, len(rs.vdr_data)) # No data should be fetched + + def test_init_with_fetch(self): + """ + Test that the region search object can fetch data on initializaiton + """ + rs = region_search.RegionSearch( + MOCK_REPO_PATH, + self.default_collections, + self.default_datasetTypes, + butler=self.butler, + fetch_data=True, + ) + self.assertTrue(rs is not None) + + data = rs.fetch_vdr_data() + self.assertGreater(len(data), 0) + + # Verify that the appropraiate columns have been fetched + expected_columns = set(["data_id", "region", "detector", "uri", "center_coord"]) + # Compute the set of differing columns + diff_columns = set(expected_columns).symmetric_difference(data.keys()) + self.assertEqual(len(diff_columns), 0) + + def test_chunked_data_ids(self): + """ + Test the helper function for chunking data ids for parallel processing + """ + # Generate a list of random data_ids + data_ids = [str(i) for i in range(100)] + chunk_size = 10 + # Get all chunks from the generator + chunks = [id for id in region_search._chunked_data_ids(data_ids, chunk_size)] + + for i in range(len(chunks)): + chunk = chunks[i] + self.assertEqual(len(chunk), chunk_size) + for j in range(len(chunk)): + self.assertEqual(chunk[j], data_ids[i * chunk_size + j]) + + def test_get_collection_names(self): + """ + Test that the collection names are retrieved correctly. + """ + with self.assertRaises(ValueError): + region_search.RegionSearch.get_collection_names(butler=None, repo_path=None) + + self.assertGreater( + len( + region_search.RegionSearch.get_collection_names(butler=self.butler, repo_path=MOCK_REPO_PATH) + ), + 0, + ) + + def test_set_collections(self): + """ + Test that the desired collections are set correctly. + """ + collection_names = region_search.RegionSearch.get_collection_names( + butler=self.butler, repo_path=MOCK_REPO_PATH + ) + self.rs.set_collections(collection_names) + self.assertEqual(self.rs.collections, collection_names) + + def test_get_dataset_type_freq(self): + """ + Test that the dataset type frequency is retrieved correctly. + """ + freq = self.rs.get_dataset_type_freq(butler=self.butler, collections=self.default_collections) + self.assertTrue(len(freq) > 0) + for dataset_type in freq: + self.assertTrue(freq[dataset_type] > 0) + + def test_set_dataset_types(self): + """ + Test that the desired dataset types are correctly set. + """ + freq = self.rs.get_dataset_type_freq(butler=self.butler, collections=self.default_collections) + + self.assertGreater(len(freq), 0) + dataset_types = list(freq.keys())[0] + self.rs.set_dataset_types(dataset_types=dataset_types) + + self.assertEqual(self.rs.dataset_types, dataset_types) + + def test_fetch_vdr_data(self): + """ + Test that the VDR data is retrieved correctly. + """ + # Get the VDR data + vdr_data = self.rs.fetch_vdr_data() + self.assertTrue(len(vdr_data) > 0) + + # Verify that the appropraiate columns have been fetched + expected_columns = set(["data_id", "region", "detector", "uri", "center_coord"]) + # Compute the set of differing columns + diff_columns = set(expected_columns).symmetric_difference(vdr_data.keys()) + self.assertEqual(len(diff_columns), 0) + + def test_get_instruments(self): + """ + Test that the instruments are retrieved correctly. + """ + data_ids = self.rs.fetch_vdr_data()["data_id"] + # Get the instruments + first_instrument = self.rs.get_instruments(data_ids, first_instrument_only=True) + self.assertEqual(len(first_instrument), 1) + + # Now test the default where getting the first instrument is False. + instruments = self.rs.get_instruments(data_ids) + self.assertGreater(len(instruments), 1) + + def test_get_uris_serial(self): + """ + Test that the URIs are retrieved correctly in serial mode. + """ + data_ids = self.rs.fetch_vdr_data()["data_id"] + # Get the URIs + uris = self.rs.get_uris(data_ids) + self.assertTrue(len(uris) > 0) + + def test_get_uris_parallel(self): + """ + Test that the URIs are retrieved correctly in parallel mode. + """ + data_ids = self.rs.fetch_vdr_data()["data_id"] + # Get the URIs + + def func(repo_path): + return MockButler(repo_path) + + parallel_rs = region_search.RegionSearch( + MOCK_REPO_PATH, + self.default_collections, + self.default_datasetTypes, + butler=self.butler, + # TODO Turn on after fixing pickle issue for mocked objects + ) + + uris = parallel_rs.get_uris(data_ids) + self.assertTrue(len(uris) > 0) + + def test_get_center_ra_dec(self): + """ + Test that the center RA and Dec are retrieved correctly. + """ + region = self.rs.fetch_vdr_data()["region"][0] + + # Get the center RA and Dec + center_ra_dec = self.rs.get_center_ra_dec(region) + self.assertTrue(len(center_ra_dec) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 02cf2ed2c..99cd2dd2b 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,2 +1,3 @@ from .mock_fits import * +from .mock_butler import * from . import utils_for_tests diff --git a/tests/utils/mock_butler.py b/tests/utils/mock_butler.py new file mode 100644 index 000000000..7f1aa57cc --- /dev/null +++ b/tests/utils/mock_butler.py @@ -0,0 +1,273 @@ +import unittest +from unittest import mock + +# TODO remove unneeded imports +import os +import uuid +import tempfile +import unittest +from unittest import mock + +from kbmod.standardizers import KBMODV1Config + +from astropy.time import Time +from astropy.wcs import WCS +import numpy as np + +from .mock_fits import DECamImdiffFactory + +__all__ = [ + "MockButler", + "Registry", + "Datastore", + "DatasetRef", + "DatasetId", + "dafButler", +] + + +# Patch Rubin Middleware out of existence +class Datastore: + def __init__(self, root): + self.root = root + + +class DatasetType: + def __init__(self, name): + self.name = name + + +class DatasetRef: + def __init__(self, ref): + self.ref = ref + self.run = ref + self.dataId = ref + + +class DatasetId: + def __init__(self, ref): + self.id = ref + self.ref = ref + self.run = ref + + +class DatasetQueryResults: + def __init__(self, dataset_refs): + self.refs = dataset_refs + + def count(self, **kwargs): + return len(self.refs) + + +class Angle: + def __init__(self, value): + self.value = value + + def asDegrees(self): + return self.value + + +class LonLat: + def __init__(self, lon, lat): + self.lon = Angle(lon) + self.lat = Angle(lat) + + def getLon(self): + return self.lon + + def getLat(self): + return self.lat + + +class Box: + def __init__(self, x, y, width, height): + self.x = x + self.y = y + self.width = width + self.height = height + + def getCenter(self): + return LonLat(self.x + self.width / 2, self.y + self.height / 2) + + +class ConvexPolygon: + def __init__(self, vertices): + self.vertices = vertices + + def getBoundingBox(self): + x = min([v[0] for v in self.vertices]) + y = min([v[1] for v in self.vertices]) + width = max([v[0] for v in self.vertices]) - x + height = max([v[1] for v in self.vertices]) - y + return Box(x, y, width, height) + + +class DimensionRecord: + def __init__(self, dataId, region, detector): + self.dataId = dataId + self.region = region + self.detector = detector + + +class Registry: + def getDataset(self, ref): + return ref + + def queryDimensionRecords(self, type, **kwargs): + region1 = ConvexPolygon([(0, 0), (0, 1), (1, 1), (1, 0)]) + region2 = ConvexPolygon([(1, 1), (1, 3), (3, 3), (3, 1)]) + return [ + DimensionRecord("dataId1", region1, "detector_replace_me"), + DimensionRecord("dataId2", region2, "detector_replace_me"), + ] + + # Fix queryCollections + def queryCollections(self, **kwargs): + return ["replace_me", "replace_me2"] + + def queryDatasetTypes(self, **kwargs): + return [ + DatasetType("dataset_type_replace_me"), + DatasetType("dataset_type_replace_me2"), + DatasetType("dataset_type_replace_me3"), + ] + + def queryDatasets(self, dataset_type, **kwargs): + return DatasetQueryResults( + [ + DatasetRef("dataset_ref_replace_me"), + DatasetRef("dataset_ref_replace_me2"), + DatasetRef("dataset_ref_replace_me3"), + ] + ) + + +FitsFactory = DECamImdiffFactory() + + +class MockButler: + """Mocked Vera C. Rubin Data Butler functionality sufficient to be used in + a ButlerStandardizer. + + The mocked .get method will return an mocked Exposure object with all the, + generally, expected attributes (info, visitInfo, image, variance, mask, + wcs). Most of these attributes are mocked such that they return an integer + id, which is then used in a FitsFactory to read out the serialized header + of some underlying real data. Particularly, we target DECam, such that + outputs of ButlerStandardizer and KBMODV1 are comparable. + + By default the mocked image arrays will contain the empty + `Butler.empty_arrat` but providing a callable `mock_images_f`, that takes + in a single mocked Exposure object, and assigns the: + * mocked.image.array + * mocked.variance.array + * mocked.mask.array + attributes can be used to customize the returned arrays. + """ + + def __init__(self, root, ref=None, mock_images_f=None): + self.datastore = Datastore(root) + self.registry = Registry() + self.mockImages = mock_images_f + + def getURI(self, ref, dataId=None, collections=None): + mocked = mock.Mock(name="ButlerURI") + mocked.geturl.return_value = f"file:/{self.datastore.root}" + return mocked + + def getDataset(self, datid): + return self.get(datid) + + def get(self, ref, collections=None, dataId=None): + orig_ref = ref + + # Butler.get gets a DatasetRef, but can take an DatasetRef or DatasetId + # DatasetId is type alias for UUID's, which are hex-strings when + # serialized. We short it to an integer, because We use an integer to + # read a particular file in FitsFactory. This means we got to cast + # all these different objects to int (somehow). Firstly, it's one of + # our mocks, dig out the value we really care about: + if isinstance(ref, (DatasetId, DatasetRef)): + ref = ref.ref + + # that value can be an int, a simple str(int) (used in testing only), + # a large hex UUID string, or a UUID object. Duck-type them to int + if isinstance(ref, uuid.UUID): + ref = ref.int + elif isinstance(ref, str): + try: + ref = uuid.UUID(ref).int + except (ValueError, AttributeError): + # likely a str(int) + try: + ref = int(ref) + except ValueError: + ref = len(ref) + + # Finally we can proceed with mocking. Butler.get (the way we use it at + # least) returns an Exposure[F/I/...] object. Exposure is like our + # LayeredImage. We need to mock every attr, method and property that we + # call the standardizer. We shortcut the results to match the KBMODV1. + hdul = FitsFactory.get_fits(ref % FitsFactory.n_files, spoof_data=True) + prim = hdul["PRIMARY"].header + + mocked = mock.Mock( + name="Exposure", + spec_set=[ + "visitInfo", + "info", + "hasWcs", + "getWidth", + "getHeight", + "getFilter", + "image", + "variance", + "mask", + "wcs", + ], + ) + + # General metadata mocks + mocked.visitInfo.date.toAstropy.return_value = Time(hdul["PRIMARY"].header["DATE-AVG"], format="isot") + mocked.visitInfo.date.return_value = Time(hdul["PRIMARY"].header["DATE-AVG"], format="isot") + mocked.info.id = prim["EXPID"] + mocked.getWidth.return_value = hdul[1].header["NAXIS1"] + mocked.getHeight.return_value = hdul[1].header["NAXIS2"] + mocked.info.getFilter().physicalLabel = prim["FILTER"] + + # Rubin Sci. Pipes. return their own internal SkyWcs object. We mock a + # Header that'll work with ButlerStd instead. It works because in the + # STD we cast SkyWcs to dict-like thing, from which we make a WCS. What + # happens if SkyWcs changes though? + wcshdr = WCS(hdul[1].header).to_header(relax=True) + wcshdr["NAXIS1"] = hdul[1].header["NAXIS1"] + wcshdr["NAXIS2"] = hdul[1].header["NAXIS2"] + mocked.hasWcs.return_value = True + mocked.wcs.getFitsMetadata.return_value = wcshdr + + # Mocking the images consists of using the Factory default, then + # invoking any user specified method on the mocked exposure obj. + mocked.image.array = hdul["IMAGE"].data + mocked.variance.array = hdul["VARIANCE"].data + mocked.mask.array = hdul["MASK"].data + if self.mockImages is not None: + self.mockImages(mocked) + + # Same issue as with WCS, what if there's a change in definition of the + # mask plane? Note the change in definition of a flag to exponent only. + bit_flag_map = {} + for key, val in KBMODV1Config.bit_flag_map.items(): + bit_flag_map[key] = int(np.log2(val)) + mocked.mask.getMaskPlaneDict.return_value = bit_flag_map + + return mocked + + +class dafButler: + """Intercepts calls ``import lsst.daf.butler as dafButler`` and shortcuts + them to our mocks. + """ + + DatasetRef = DatasetRef + DatasetId = DatasetId + Butler = MockButler diff --git a/tests/utils/mock_fits.py b/tests/utils/mock_fits.py index ff4b4716d..305ee517f 100644 --- a/tests/utils/mock_fits.py +++ b/tests/utils/mock_fits.py @@ -7,8 +7,15 @@ from astropy.utils.exceptions import AstropyUserWarning from astropy.io.fits import HDUList, PrimaryHDU, CompImageHDU, BinTableHDU, Column +from unittest import mock + +from astropy.time import Time +from astropy.wcs import WCS + from .utils_for_tests import get_absolute_data_path +from kbmod.standardizers import KBMODV1Config +import uuid __all__ = [ "DECamImdiffFactory", From 9dc757aa452a7141fa4af78f6ee55056c259c0cd Mon Sep 17 00:00:00 2001 From: DinoBektesevic Date: Fri, 5 Apr 2024 13:18:26 -0700 Subject: [PATCH 11/17] Few ImageCollection and Standardizer bugfixes. - ImageCollection did not correctly track standardizers when a list was given to __getitem__ - ButlerStandardizer did not extract image dimensions when standardizing WCS - Rubin Middleware removed butler.datastore.root attribute for newer stack builds. --- src/kbmod/image_collection.py | 13 +++++++++---- src/kbmod/standardizers/__init__.py | 2 +- src/kbmod/standardizers/butler_standardizer.py | 18 +++++++++++++++++- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/kbmod/image_collection.py b/src/kbmod/image_collection.py index 3436869df..5906bfc16 100644 --- a/src/kbmod/image_collection.py +++ b/src/kbmod/image_collection.py @@ -318,7 +318,12 @@ def __getitem__(self, key): if isinstance(key, (int, str, np.integer)): return self.data[self._userColumns][key] elif isinstance(key, (list, np.ndarray, slice)): - return self.__class__(self.data[key], standardizers=self._standardizers[key]) + # current data table has standardizer idxs with respect to current + # list of standardizers. Sub-selecting them resets the count to 0 + meta = self.data[key] + stds = [self._standardizers[idx] for idx in meta["std_idx"]] + meta["std_idx"] = np.arange(len(stds)) + return self.__class__(meta, standardizers=stds) else: return self.data[key] @@ -502,14 +507,14 @@ def toImageStack(self): layeredImages = [img for std in self._standardizers for img in std.toLayeredImage()] return ImageStack(layeredImages) - def toWorkUnit(self, config): + def toWorkUnit(self, config=None): """Return an `~kbmod.WorkUnit` object for processing with KBMOD. Parameters ---------- - config : `~kbmod.SearchConfiguration` - Search configuration. + config : `~kbmod.SearchConfiguration` or None, optional + Search configuration. Default ``None``. Returns ------- diff --git a/src/kbmod/standardizers/__init__.py b/src/kbmod/standardizers/__init__.py index db45873d9..72e7c195c 100644 --- a/src/kbmod/standardizers/__init__.py +++ b/src/kbmod/standardizers/__init__.py @@ -1,3 +1,3 @@ -from .standardizer import * from .fits_standardizers import * from .butler_standardizer import * +from .standardizer import * diff --git a/src/kbmod/standardizers/butler_standardizer.py b/src/kbmod/standardizers/butler_standardizer.py index aff59386b..6201a6594 100644 --- a/src/kbmod/standardizers/butler_standardizer.py +++ b/src/kbmod/standardizers/butler_standardizer.py @@ -142,6 +142,14 @@ def resolveTarget(self, tgt): return False def __init__(self, id, butler, config=None, **kwargs): + # Somewhere around w_2024_ builds the datastore.root + # was removed as an attribute of the datastore, not sure + # it was ever replaced with anything back-compatible + try: + super().__init__(str(butler._datastore.root), config=config) + except: + super().__init__(butler.datastore.root, config=config) + super().__init__(butler.datastore.root, config=config) self.butler = butler @@ -310,8 +318,16 @@ def standardizePSF(self): ] def standardizeWCS(self): + wcs = None + if self.exp.hasWcs(): + meta = self.exp.wcs.getFitsMetadata() + # NAXIS values are required if we reproject + # so we must extract them if we can + meta["NAXIS1"] = self.exp.getWidth() + meta["NAXIS2"] = self.exp.getHeight() + wcs = WCS(meta) return [ - WCS(self.exp.wcs.getFitsMetadata()) if self.exp.hasWcs() else None, + wcs, ] def standardizeBBox(self): From 9b5226ff81dc3f2f39045ed286d96f7b543b2fb1 Mon Sep 17 00:00:00 2001 From: Vlad Date: Sun, 7 Apr 2024 22:06:29 +0100 Subject: [PATCH 12/17] renamed helper to compute_max_results and minor style changes --- src/kbmod/search/pydocs/stack_search_docs.h | 9 +++++++++ src/kbmod/search/stack_search.cpp | 9 +++++---- src/kbmod/search/stack_search.h | 3 +-- tests/test_search.py | 3 --- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 4ce0edb62..99b29a470 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -188,6 +188,15 @@ static const auto DOC_StackSearch_prepare_batch_search = R"doc( The minimum number of observations for a trajectory to be considered. )doc"; +static const auto DOC_StackSearch_compute_max_results = R"doc( + Compute the maximum number of results according to the x, y bounds and the RESULTS_PER_PIXEL constant + + Returns + ------- + max_results : `int` + The maximum number of results that a search will return according to the current bounds and the RESULTS_PER_PIXEL constant. + )doc"; + static const auto DOC_StackSearch_search_single_batch = R"doc( Perform a search on the given trajectories for the current batch. Batch is defined by the parameters set `set_start_bounds_x` & `set_start_bounds_y`. diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index d3137a650..6d3aeaa2a 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -187,7 +187,7 @@ void StackSearch::search_batch(){ } DebugTimer core_timer = DebugTimer("Running batch search", rs_logger); - int max_results = extract_max_results(); + int max_results = compute_max_results(); // staple C++ std::stringstream logmsg; @@ -200,7 +200,7 @@ void StackSearch::search_batch(){ results.resize(max_results); results.move_to_gpu(); - // Do the actual search on the GPU. + // Do the actual search on the GPU. DebugTimer search_timer = DebugTimer("Running search", rs_logger); #ifdef HAVE_CUDA deviceSearchFilter(psi_phi_array, params, gpu_search_list, results); @@ -217,13 +217,13 @@ void StackSearch::search_batch(){ } std::vector StackSearch::search_single_batch(){ - int max_results = extract_max_results(); + int max_results = compute_max_results(); search_batch(); return results.get_batch(0, max_results); } -int StackSearch::extract_max_results(){ +int StackSearch::compute_max_results(){ int search_width = params.x_start_max - params.x_start_min; int search_height = params.y_start_max - params.y_start_min; int num_search_pixels = search_width * search_height; @@ -307,6 +307,7 @@ static void stack_search_bindings(py::module& m) { .def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi) .def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results) .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results) + .def("compute_max_results", &ks::compute_max_results, pydocs::DOC_StackSearch_compute_max_results) .def("search_single_batch", &ks::search_single_batch, pydocs::DOC_StackSearch_search_single_batch) .def("prepare_search", &ks::prepare_search, pydocs::DOC_StackSearch_prepare_batch_search) .def("finish_search", &ks::finish_search, pydocs::DOC_StackSearch_finish_search); diff --git a/src/kbmod/search/stack_search.h b/src/kbmod/search/stack_search.h index befc038b9..5b7906f7d 100644 --- a/src/kbmod/search/stack_search.h +++ b/src/kbmod/search/stack_search.h @@ -29,10 +29,9 @@ using Point = indexing::Point; using Image = search::Image; class StackSearch { - int extract_max_results(); public: StackSearch(ImageStack& imstack); - + int compute_max_results(); int num_images() const { return stack.img_count(); } int get_image_width() const { return stack.get_width(); } int get_image_height() const { return stack.get_height(); } diff --git a/tests/test_search.py b/tests/test_search.py index ebab5c6ac..3eefe8837 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -952,9 +952,6 @@ def test_search_batch(self): count = 10 imlist = [make_fake_layered_image(width, height, 5.0, 25.0, n / count, p) for n in range(count)] stack = ImageStack(imlist) - im_list = stack.get_images() - # Create a new list of LayeredImages with the added object. - new_im_list = [] for i in range(count): im = stack.get_single_image(i) From 6180283f9c504080ebdc726d0d2da22a77496da5 Mon Sep 17 00:00:00 2001 From: Vlad Date: Sun, 7 Apr 2024 22:15:52 +0100 Subject: [PATCH 13/17] kbmod_reference notebook only change search.search -> search.search_all --- notebooks/Kbmod_Reference.ipynb | 52 ++++++++++++++++----------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/notebooks/Kbmod_Reference.ipynb b/notebooks/Kbmod_Reference.ipynb index 2a34d6226..6e2db6a78 100644 --- a/notebooks/Kbmod_Reference.ipynb +++ b/notebooks/Kbmod_Reference.ipynb @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -96,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -200,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -219,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -236,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -287,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -305,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -324,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -353,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -369,7 +369,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -394,7 +394,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -419,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -436,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -471,7 +471,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -486,7 +486,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -504,9 +504,9 @@ ], "metadata": { "kernelspec": { - "display_name": "kbmod_env", + "display_name": "Jeremy's KBMOD", "language": "python", - "name": "python3" + "name": "kbmod_jk" }, "language_info": { "codemirror_mode": { @@ -518,7 +518,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.1" } }, "nbformat": 4, From 5adc40acd424940b2ad9e3b26478a58490da7b6e Mon Sep 17 00:00:00 2001 From: Max West <110124344+maxwest-uw@users.noreply.github.com> Date: Mon, 8 Apr 2024 12:23:55 -0700 Subject: [PATCH 14/17] create fit_barycentric_wcs (#541) * add fit_barycentric_wcs * add dimmension check * remove scipy import * fix indents --- src/kbmod/reprojection_utils.py | 54 +++++++++++++++++++++++- tests/test_reprojection_utils.py | 72 +++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/src/kbmod/reprojection_utils.py b/src/kbmod/reprojection_utils.py index c118e5867..01c1b1840 100644 --- a/src/kbmod/reprojection_utils.py +++ b/src/kbmod/reprojection_utils.py @@ -1,7 +1,8 @@ import astropy.units as u import numpy as np from astropy import units as u -from astropy.coordinates import GCRS, ICRS +from astropy.coordinates import SkyCoord, GCRS, ICRS +from astropy.wcs.utils import fit_wcs_from_points from scipy.optimize import minimize @@ -55,3 +56,54 @@ def correct_parallax(coord, obstime, point_on_earth, guess_distance): ).transform_to(ICRS()) return answer + + +def fit_barycentric_wcs(original_wcs, width, height, distance, obstime, point_on_earth, npoints=10): + """Given a ICRS WCS and an object's distance from the Sun, + return a new WCS that has been corrected for parallax motion. + + Attributes + ---------- + original_wcs : `astropy.wcs.WCS` + The image's WCS. + width : `int` + The image's width (typically NAXIS1). + height : `int` + The image's height (typically NAXIS2). + distance : `float` + The distance of the object from the sun, in AU. + obstime : `astropy.time.Time` or `string` + The observation time. + point_on_earth : `astropy.coordinate.EarthLocation` + The location on Earth of the observation. + npoints : `int` + The number of randomly sampled points to use during the WCS fitting. + Typically, the more points the higher the accuracy. The four corners + of the image will always be included, so setting npoints = 0 will mean + just using the corners. + + Returns + ---------- + An `astropy.wcs.WCS` representing the original image in "Explicity Barycentric Distance" (EBD) + space, i.e. where the points have been corrected for parallax. + """ + sampled_x_points = np.array([0, 0, width, width]) + sampled_y_points = np.array([0, height, height, 0]) + if npoints > 0: + sampled_x_points = np.append(sampled_x_points, np.random.rand(npoints) * width) + sampled_y_points = np.append(sampled_y_points, np.random.rand(npoints) * height) + + sampled_ra, sampled_dec = original_wcs.all_pix2world(sampled_x_points, sampled_y_points, 0) + + sampled_coordinates = SkyCoord(sampled_ra, sampled_dec, unit="deg") + + ebd_corrected_points = [] + for coord in sampled_coordinates: + ebd_corrected_points.append(correct_parallax(coord, obstime, point_on_earth, distance)) + + ebd_corrected_points = SkyCoord(ebd_corrected_points) + xy = (sampled_x_points, sampled_y_points) + ebd_wcs = fit_wcs_from_points( + xy, ebd_corrected_points, proj_point="center", projection="TAN", sip_degree=3 + ) + return ebd_wcs diff --git a/tests/test_reprojection_utils.py b/tests/test_reprojection_utils.py index 9cecbad45..01cf6ab97 100644 --- a/tests/test_reprojection_utils.py +++ b/tests/test_reprojection_utils.py @@ -1,10 +1,12 @@ import unittest +import numpy as np import numpy.testing as npt from astropy.coordinates import EarthLocation, SkyCoord, solar_system_ephemeris from astropy.time import Time +from astropy.wcs import WCS -from kbmod.reprojection_utils import correct_parallax +from kbmod.reprojection_utils import correct_parallax, fit_barycentric_wcs class test_reprojection_utils(unittest.TestCase): @@ -45,3 +47,71 @@ def test_parallax_equinox(self): npt.assert_almost_equal(corrected_coord2.ra.value, expected_ra) npt.assert_almost_equal(corrected_coord2.dec.value, expected_dec) + + def test_fit_barycentric_wcs(self): + nx = 2046 + ny = 4094 + test_wcs = WCS(naxis=2) + test_wcs.pixel_shape = (ny, nx) + test_wcs.wcs.crpix = [nx / 2, ny / 2] + test_wcs.wcs.cdelt = np.array([-0.000055555555556, 0.000055555555556]) + test_wcs.wcs.crval = [346.9681342111, -6.482196848597] + test_wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + + x_points = np.array([247, 1252, 1052, 980, 420, 1954, 730, 1409, 1491, 803]) + + y_points = np.array([1530, 713, 3414, 3955, 1975, 123, 1456, 2008, 1413, 1756]) + + expected_ra = np.array( + [ + 346.69225567, + 346.63734563, + 346.64836252, + 346.65231188, + 346.68282256, + 346.59898412, + 346.66587788, + 346.62881986, + 346.6243199, + 346.66190162, + ] + ) + + expected_dec = np.array( + [ + -6.62151717, + -6.66580019, + -6.51929901, + -6.48995635, + -6.5973741, + -6.6977762, + -6.62551611, + -6.59555108, + -6.62782211, + -6.60924105, + ] + ) + + expected_sc = SkyCoord(ra=expected_ra, dec=expected_dec, unit="deg") + + time = "2021-08-24T20:59:06" + site = "ctio" + loc = EarthLocation.of_site(site) + distance = 41.1592725489203 + + corrected_wcs = fit_barycentric_wcs( + test_wcs, + nx, + ny, + distance, + time, + loc, + ) + + corrected_ra, corrected_dec = corrected_wcs.all_pix2world(x_points, y_points, 0) + corrected_sc = SkyCoord(corrected_ra, corrected_dec, unit="deg") + seps = expected_sc.separation(corrected_sc).arcsecond + + # assert we have sub-milliarcsecond precision + assert np.all(seps < 0.001) + assert corrected_wcs.array_shape == (ny, nx) From d8f16cacbbf617722ae79f2b89366e5b09c41dfc Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 9 Apr 2024 11:58:02 -0400 Subject: [PATCH 15/17] Add helper functions to update trajectory statistics. --- src/kbmod/trajectory_utils.py | 57 ++++++++++++++++++++++++++++++++++ tests/test_trajectory_utils.py | 32 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index 2604f2567..7d5c4a6e7 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -221,3 +221,60 @@ def trajectory_to_yaml(trj): "valid": trj.valid, } return dump(yaml_dict) + + +def update_trajectory_from_psi_phi(trj, psi_curve, phi_curve, index_valid=None, in_place=True): + """Update the trajectory's statistic information from a psi_curve and + phi_curve. Uses an optional index_valid mask (True/False) to mask out + pixels. + Parameters + ---------- + trj : `Trajectory` + The trajectory to update. + psi_curve : `numpy.ndarray` + The float psi values at each time step. + phi_curve : `numpy.ndarray` + The float phi values at each time step. + index_valid : `numpy.ndarray`, optional + An array of Booleans indicating whether the time step is valid. + in_place : `bool` + Update the input trajectory in-place. + Returns + ------- + result : `Trajectory` + The updated trajectory. May be the same as trj if in_place=True. + Raises + ------ + Raises a ValueError if the input arrays are not the same size. + """ + if len(psi_curve) != len(phi_curve): + raise ValueError("Mismatched psi and phi curve lengths.") + + # Compute the sums of the (masked) arrays. + if index_valid is None: + psi_sum = np.sum(psi_curve) + phi_sum = np.sum(phi_curve) + num_obs = len(psi_curve) + else: + if len(psi_curve) != len(index_valid): + raise ValueError("Mismatched psi/phi curve and index_valid lengths.") + psi_sum = np.sum(psi_curve[index_valid]) + phi_sum = np.sum(phi_curve[index_valid]) + num_obs = len(psi_curve[index_valid]) + + # Create a copy of the trajectory if we are not modifying in-place. + if in_place: + result = trj + else: + result = make_trajectory(x=trj.x, y=trj.y, vx=trj.vx, vy=trj.vy) + + # Update the statistics information (avoiding divide by zero). + if phi_sum <= 0.0: + result.lh = 0.0 + result.flux = 0.0 + else: + result.lh = psi_sum / np.sqrt(phi_sum) + result.flux = psi_sum / phi_sum + result.obs_count = num_obs + + return result diff --git a/tests/test_trajectory_utils.py b/tests/test_trajectory_utils.py index 4dd75797d..61b99e52e 100644 --- a/tests/test_trajectory_utils.py +++ b/tests/test_trajectory_utils.py @@ -112,6 +112,38 @@ def test_trajectory_yaml(self): self.assertEqual(new_trj.lh, 6.0) self.assertEqual(new_trj.obs_count, 7) + def test_update_trajectory_from_psi_phi(self): + trj = make_trajectory(x=0, y=10, vx=-1.0, vy=2.0) + self.assertEqual(trj.lh, 0.0) + self.assertEqual(trj.flux, 0.0) + self.assertEqual(trj.obs_count, 0) + + # Non-in-place update + psi = np.array([1.0, 1.1, 1.2, 1.3]) + phi = np.array([1.0, 1.0, 0.0, 2.0]) + trj2 = update_trajectory_from_psi_phi(trj, psi, phi, in_place=False) + self.assertEqual(trj2.obs_count, 4) + self.assertAlmostEqual(trj2.flux, 1.15) + self.assertAlmostEqual(trj2.lh, 2.3) + + # Original Trajectory is unchanged. + self.assertEqual(trj.lh, 0.0) + self.assertEqual(trj.flux, 0.0) + self.assertEqual(trj.obs_count, 0) + + # Check the original is modified in the in-place update. + trj3 = update_trajectory_from_psi_phi(trj, psi, phi, in_place=True) + self.assertEqual(trj.obs_count, 4) + self.assertAlmostEqual(trj.flux, 1.15) + self.assertAlmostEqual(trj.lh, 2.3) + + # Mark index 1 invalid. + index_valid = np.array([True, False, True, True]) + trj4 = update_trajectory_from_psi_phi(trj, psi, phi, index_valid=index_valid, in_place=False) + self.assertEqual(trj4.obs_count, 3) + self.assertAlmostEqual(trj4.flux, 1.1666667, delta=1e-5) + self.assertAlmostEqual(trj4.lh, 2.020725, delta=1e-5) + if __name__ == "__main__": unittest.main() From 9682119c59558037fc1b25681d4d8e345297ab73 Mon Sep 17 00:00:00 2001 From: Max West <110124344+maxwest-uw@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:11:50 -0700 Subject: [PATCH 16/17] add ability to set rng seed for fit_barycentric_wcs + add a consistency unit test (#554) --- src/kbmod/reprojection_utils.py | 12 ++++-- tests/test_reprojection_utils.py | 63 +++++++++++++++++++++----------- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/src/kbmod/reprojection_utils.py b/src/kbmod/reprojection_utils.py index 01c1b1840..b843987f7 100644 --- a/src/kbmod/reprojection_utils.py +++ b/src/kbmod/reprojection_utils.py @@ -58,7 +58,9 @@ def correct_parallax(coord, obstime, point_on_earth, guess_distance): return answer -def fit_barycentric_wcs(original_wcs, width, height, distance, obstime, point_on_earth, npoints=10): +def fit_barycentric_wcs( + original_wcs, width, height, distance, obstime, point_on_earth, npoints=10, seed=None +): """Given a ICRS WCS and an object's distance from the Sun, return a new WCS that has been corrected for parallax motion. @@ -81,17 +83,21 @@ def fit_barycentric_wcs(original_wcs, width, height, distance, obstime, point_on Typically, the more points the higher the accuracy. The four corners of the image will always be included, so setting npoints = 0 will mean just using the corners. + seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator} + the seed that `numpy.random.default_rng` will use. Returns ---------- An `astropy.wcs.WCS` representing the original image in "Explicity Barycentric Distance" (EBD) space, i.e. where the points have been corrected for parallax. """ + rng = np.random.default_rng(seed) + sampled_x_points = np.array([0, 0, width, width]) sampled_y_points = np.array([0, height, height, 0]) if npoints > 0: - sampled_x_points = np.append(sampled_x_points, np.random.rand(npoints) * width) - sampled_y_points = np.append(sampled_y_points, np.random.rand(npoints) * height) + sampled_x_points = np.append(sampled_x_points, rng.random(npoints) * width) + sampled_y_points = np.append(sampled_y_points, rng.random(npoints) * height) sampled_ra, sampled_dec = original_wcs.all_pix2world(sampled_x_points, sampled_y_points, 0) diff --git a/tests/test_reprojection_utils.py b/tests/test_reprojection_utils.py index 01cf6ab97..12c1a1357 100644 --- a/tests/test_reprojection_utils.py +++ b/tests/test_reprojection_utils.py @@ -10,6 +10,21 @@ class test_reprojection_utils(unittest.TestCase): + def setUp(self): + self.nx = 2046 + self.ny = 4094 + self.test_wcs = WCS(naxis=2) + self.test_wcs.pixel_shape = (self.ny, self.nx) + self.test_wcs.wcs.crpix = [self.nx / 2, self.ny / 2] + self.test_wcs.wcs.cdelt = np.array([-0.000055555555556, 0.000055555555556]) + self.test_wcs.wcs.crval = [346.9681342111, -6.482196848597] + self.test_wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + + self.time = "2021-08-24T20:59:06" + self.site = "ctio" + self.loc = EarthLocation.of_site(self.site) + self.distance = 41.1592725489203 + def test_parallax_equinox(self): icrs_ra1 = 88.74513571 icrs_dec1 = 23.43426475 @@ -49,17 +64,7 @@ def test_parallax_equinox(self): npt.assert_almost_equal(corrected_coord2.dec.value, expected_dec) def test_fit_barycentric_wcs(self): - nx = 2046 - ny = 4094 - test_wcs = WCS(naxis=2) - test_wcs.pixel_shape = (ny, nx) - test_wcs.wcs.crpix = [nx / 2, ny / 2] - test_wcs.wcs.cdelt = np.array([-0.000055555555556, 0.000055555555556]) - test_wcs.wcs.crval = [346.9681342111, -6.482196848597] - test_wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] - x_points = np.array([247, 1252, 1052, 980, 420, 1954, 730, 1409, 1491, 803]) - y_points = np.array([1530, 713, 3414, 3955, 1975, 123, 1456, 2008, 1413, 1756]) expected_ra = np.array( @@ -94,18 +99,13 @@ def test_fit_barycentric_wcs(self): expected_sc = SkyCoord(ra=expected_ra, dec=expected_dec, unit="deg") - time = "2021-08-24T20:59:06" - site = "ctio" - loc = EarthLocation.of_site(site) - distance = 41.1592725489203 - corrected_wcs = fit_barycentric_wcs( - test_wcs, - nx, - ny, - distance, - time, - loc, + self.test_wcs, + self.nx, + self.ny, + self.distance, + self.time, + self.loc, ) corrected_ra, corrected_dec = corrected_wcs.all_pix2world(x_points, y_points, 0) @@ -114,4 +114,23 @@ def test_fit_barycentric_wcs(self): # assert we have sub-milliarcsecond precision assert np.all(seps < 0.001) - assert corrected_wcs.array_shape == (ny, nx) + assert corrected_wcs.array_shape == (self.ny, self.nx) + + def test_fit_barycentric_wcs_consistency(self): + corrected_wcs = fit_barycentric_wcs( + self.test_wcs, self.nx, self.ny, self.distance, self.time, self.loc, seed=24601 + ) + + # crval consistency + npt.assert_almost_equal(corrected_wcs.wcs.crval[0], 346.6498731934591) + npt.assert_almost_equal(corrected_wcs.wcs.crval[1], -6.593449653602658) + + # crpix consistency + npt.assert_almost_equal(corrected_wcs.wcs.crpix[0], 1024.4630013095195) + npt.assert_almost_equal(corrected_wcs.wcs.crpix[1], 2047.9912979360922) + + # cd consistency + npt.assert_almost_equal(corrected_wcs.wcs.cd[0][0], -5.424296904025753e-05) + npt.assert_almost_equal(corrected_wcs.wcs.cd[0][1], 3.459611876675614e-08) + npt.assert_almost_equal(corrected_wcs.wcs.cd[1][0], 3.401472764249802e-08) + npt.assert_almost_equal(corrected_wcs.wcs.cd[1][1], 5.4242245855217796e-05) From 8324351ca59934570b93fb27bd814ec5effbd213 Mon Sep 17 00:00:00 2001 From: Max West <110124344+maxwest-uw@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:12:20 -0700 Subject: [PATCH 17/17] use proper WorkUnit per image wcs accessor (#555) --- notebooks/reprojection/reproject_demo.ipynb | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/notebooks/reprojection/reproject_demo.ipynb b/notebooks/reprojection/reproject_demo.ipynb index b266969d7..a832502cd 100644 --- a/notebooks/reprojection/reproject_demo.ipynb +++ b/notebooks/reprojection/reproject_demo.ipynb @@ -99,16 +99,16 @@ "original_images = wunit.im_stack.get_images()\n", "\n", "o_image0 = CCDData(original_images[0].get_science().image, unit=\"adu\")\n", - "o_image0.wcs = wunit.per_image_wcs[0]\n", + "o_image0.wcs = wunit.get_wcs(0)\n", "\n", "o_image1 = CCDData(original_images[1].get_science().image, unit=\"adu\")\n", - "o_image1.wcs = wunit.per_image_wcs[1]\n", + "o_image1.wcs = wunit.get_wcs(1)\n", "\n", "o_image2 = CCDData(original_images[2].get_science().image, unit=\"adu\")\n", - "o_image2.wcs = wunit.per_image_wcs[2]\n", + "o_image2.wcs = wunit.get_wcs(2)\n", "\n", "o_image3 = CCDData(original_images[3].get_science().image, unit=\"adu\")\n", - "o_image3.wcs = wunit.per_image_wcs[3]\n", + "o_image3.wcs = wunit.get_wcs(3)\n", "\n", "plot_images(wunit.get_all_obstimes(), o_image0, o_image1, o_image2, o_image3)" ] @@ -118,7 +118,7 @@ "id": "dbd3ff9f-62ef-42c8-820f-ca26214d4214", "metadata": {}, "source": [ - "A couple of important attributes to point ou:\n", + "A couple of important attributes to point out:\n", "- Each images has a different WCS. The center ra/dec value shifted up and to the right ~5 pixels in each successive images, except for the last one which is below image 3.\n", "- The `obstime` is increasing for each one, except for the last one which has the same obstime as image 3.\n", "- They all have a synthetic object in them, moving across the field of view. The last image has a presumambly different object.\n", @@ -139,7 +139,7 @@ "metadata": {}, "outputs": [], "source": [ - "common = wunit.per_image_wcs[0]\n", + "common = wunit.get_wcs(0)\n", "\n", "uwunit = reprojection.reproject_work_unit(wunit, common)" ] @@ -195,7 +195,7 @@ "original_img = wunit.im_stack.get_single_image(2)\n", "o_d = original_img.get_mask().image\n", "original_image2_mask = CCDData(o_d, unit=\"adu\")\n", - "original_image2_mask.wcs = wunit.per_image_wcs[2]\n", + "original_image2_mask.wcs = wunit.get_wcs(2)\n", "\n", "image2_mask = CCDData(images[2].get_mask().image, unit=\"adu\")\n", "image2_mask.wcs = uwunit.wcs\n", @@ -220,7 +220,7 @@ "kernelspec": { "display_name": "kbmod", "language": "python", - "name": "kbmod" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -232,7 +232,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.7" } }, "nbformat": 4,