Skip to content

Commit

Permalink
Extend stamp filtering to use Results object
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Apr 19, 2024
1 parent 72ed621 commit 5fe9e62
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 1 deletion.
85 changes: 85 additions & 0 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def extract_search_parameters_from_config(config):
return params


# TODO remove once we full replace ResultList with Results
def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=1000000, debug=False):
"""Create the co-added postage stamps and filter them based on their statistical
properties. Results with stamps that are similar to a Gaussian are kept.
Expand Down Expand Up @@ -388,6 +389,90 @@ def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=100000
logger.debug("{:.2f}s elapsed".format(time.time() - start_time))


def get_coadds_and_filter_results(result_data, im_stack, stamp_params, chunk_size=1000000, debug=False):
"""Create the co-added postage stamps and filter them based on their statistical
properties. Results with stamps that are similar to a Gaussian are kept.
Parameters
----------
result_data : `Results`
The current set of results. Modified directly.
im_stack : `ImageStack`
The images from which to build the co-added stamps.
stamp_params : `StampParameters` or `SearchConfiguration`
The filtering parameters for the stamps.
chunk_size : `int`
How many stamps to load and filter at a time. Used to control memory.
debug : `bool`
Output verbose debugging messages.
"""
num_results = len(result_data)

if type(stamp_params) is SearchConfiguration:
stamp_params = extract_search_parameters_from_config(stamp_params)

if num_results <= 0:
logger.debug("Stamp Filtering : skipping, othing to filter.")
else:
logger.debug(f"Stamp filtering {num_results} results.")
logger.debug(f"Using filtering params: {stamp_params}")
logger.debug(f"Using chunksize = {chunk_size}")

trj_list = result_data.make_trajectory_list()
keep_row = [False] * num_results
stamps_to_keep = []

# Run the stamp creation and filtering in batches of chunk_size.
start_time = time.time()
start_idx = 0
while start_idx < num_results:
end_idx = min([start_idx + chunk_size, num_results])
slice_size = end_idx - start_idx

# Create a subslice of the results and the Boolean indices.
# Note that the sum stamp type does not filter out lc_index.
trj_slice = trj_list[start_idx:end_idx]
if stamp_params.stamp_type != StampType.STAMP_SUM and "index_valid" in result_data.colnames:
bool_slice = result_data["index_valid"][start_idx:end_idx]
else:
# Use all the indices for each trajectory.
bool_slice = [[True] * im_stack.img_count() for _ in range(slice_size)]

# Create and filter the results, using the GPU if there is one and enough
# trajectories to make it worthwhile.
stamps_slice = StampCreator.get_coadded_stamps(
im_stack,
trj_slice,
bool_slice,
stamp_params,
HAS_GPU and len(trj_slice) > 100,
)
# TODO: a way to avoid a copy here would be to do
# np.array([s.image for s in stamps], dtype=np.single, copy=False)
# but that could cause a problem with reference counting at the m
# moment. The real fix is to make the stamps return Image not
# RawImage and avoid reference to an private attribute and risking
# collecting RawImage but leaving a dangling ref to the attribute.
# That's a fix for another time so I'm leaving it as a copy here
for ind, stamp in enumerate(stamps_slice):
if stamp.width > 1:
stamps_to_keep.append(np.array(stamp.image))
keep_row[start_idx + ind] = True

# Move to the next chunk.
start_idx += chunk_size

# Do the actual filtering of results
result_data.filter_mask(keep_row, label="stamp_filter")

# Append the coadded stamps to the results. We do this after the filtering
# so we are not adding a jagged array.
result_data.table["stamp"] = np.array(stamps_to_keep)

logger.debug(f"Keeping {len(result_data)} results")
logger.debug("{:.2f}s elapsed".format(time.time() - start_time))


def append_all_stamps(result_data, im_stack, stamp_radius):
"""Get the stamps for the final results from a kbmod search. These are appended
onto the corresponding entries in a ResultList.
Expand Down
55 changes: 54 additions & 1 deletion tests/test_stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from kbmod.fake_data.fake_data_creator import add_fake_object, create_fake_times, FakeDataSet
from kbmod.filters.stamp_filters import *
from kbmod.result_list import *
from kbmod.results import Results
from kbmod.search import *


Expand Down Expand Up @@ -235,6 +236,58 @@ def test_get_coadds_and_filter(self):
self.assertIsNotNone(keep.results[0].stamp)
self.assertIsNotNone(keep.results[1].stamp)

@unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)")
def test_get_coadds_and_filter_results(self):
image_count = 10
fake_times = create_fake_times(image_count, 57130.2, 1, 0.01, 1)
ds = FakeDataSet(
25, # width
35, # height
fake_times, # time stamps
1.0, # noise level
0.5, # psf value
True, # Use a fixed seed for testing
)

# Insert a single fake object with known parameters.
trj = make_trajectory(8, 7, 2.0, 1.0, flux=250.0)
ds.insert_object(trj)

# Second Trajectory that isn't any good.
trj2 = make_trajectory(1, 1, 0.0, 0.0)

# Third Trajectory that is close to good, but offset.
trj3 = make_trajectory(trj.x + 2, trj.y + 2, trj.vx, trj.vy)

# Create a fourth Trajectory that is just close enough
trj4 = make_trajectory(trj.x + 1, trj.y + 1, trj.vx, trj.vy)

# Create the Results.
keep = Results([trj, trj2, trj3, trj4])
self.assertFalse("stamp" in keep.colnames)

# Create the stamp parameters we need.
config_dict = {
"center_thresh": 0.03,
"do_stamp_filter": True,
"mom_lims": [35.5, 35.5, 1.0, 1.0, 1.0],
"peak_offset": [1.5, 1.5],
"stamp_type": "cpp_mean",
"stamp_radius": 5,
}
config = SearchConfiguration.from_dict(config_dict)

# Do the filtering.
get_coadds_and_filter_results(keep, ds.stack, config, chunk_size=2, debug=False)

# The check that the correct indices and number of stamps are saved.
self.assertTrue("stamp" in keep.colnames)
self.assertEqual(len(keep), 2)
self.assertEqual(keep["x"][0], trj.x)
self.assertEqual(keep["x"][1], trj.x + 1)
self.assertEqual(keep["stamp"][0].shape, (11, 11))
self.assertEqual(keep["stamp"][1].shape, (11, 11))

def test_append_all_stamps(self):
image_count = 10
fake_times = create_fake_times(image_count, 57130.2, 1, 0.01, 1)
Expand Down Expand Up @@ -284,7 +337,7 @@ def test_append_all_stamps_results(self):

append_all_stamps(keep, ds.stack, 5)
self.assertTrue("all_stamps" in keep.colnames)
for i in range(len(trj_list)):
for i in range(len(keep)):
stamps_array = keep["all_stamps"][i]
self.assertEqual(stamps_array.shape[0], image_count)
self.assertEqual(stamps_array.shape[1], 11)
Expand Down

0 comments on commit 5fe9e62

Please sign in to comment.