From 37d42393a048a241a70c247ef557ba606f35b4af Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 18 Sep 2024 10:25:36 -0400 Subject: [PATCH] Do batched curve lookup --- src/kbmod/run_search.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 642c452a9..77fcecff7 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -82,8 +82,6 @@ def load_and_filter_results(self, search, config): logger.info(f"Chunk Min. Likelihood = {results[-1].lh}") trj_batch = [] - psi_batch = [] - phi_batch = [] for i, trj in enumerate(results): # Stop as soon as we hit a result below our limit, because anything after # that is not guarrenteed to be valid due to potential on-GPU filtering. @@ -93,14 +91,15 @@ def load_and_filter_results(self, search, config): if trj.lh < max_lh: trj_batch.append(trj) - psi_batch.append(search.get_psi_curves(trj)) - phi_batch.append(search.get_phi_curves(trj)) total_count += 1 batch_size = len(trj_batch) logger.info(f"Extracted batch of {batch_size} results for total of {total_count}") if batch_size > 0: + psi_batch = search.get_psi_curves(trj_batch) + phi_batch = search.get_phi_curves(trj_batch) + result_batch = Results.from_trajectories(trj_batch, track_filtered=do_tracking) result_batch.add_psi_phi_data(psi_batch, phi_batch)