Skip to content

Commit

Permalink
Merge pull request #460 from dirac-institute/stamps
Browse files Browse the repository at this point in the history
Improvements and reorganization of stamp generation
  • Loading branch information
jeremykubica authored Feb 7, 2024
2 parents 4ed4643 + 5cc08bd commit 3d0fe64
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 288 deletions.
1 change: 0 additions & 1 deletion data/demo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ mom_lims:
- 37.5
- 37.5
- 1.5
- 1.5
- 1.0
- 1.0
num_cores: 1
Expand Down
3 changes: 3 additions & 0 deletions docs/source/user_manual/search_params.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ This document serves to provide a quick overview of the existing parameters and
| | | Can be use used in addition to |
| | | outputting individual result files. |
+------------------------+-----------------------------+----------------------------------------+
| ``save_all_stamps`` | True | Save the individual stamps for each |
| | | result and timestep. |
+------------------------+-----------------------------+----------------------------------------+
| ``sigmaG_lims`` | [25, 75] | The percentiles to use in sigmaG |
| | | filtering, if |
| | | ``filter_type= clipped_sigmaG``. |
Expand Down
143 changes: 0 additions & 143 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,149 +121,6 @@ def load_and_filter_results(
res_num += chunk_size
return keep

def get_all_stamps(self, result_list, search, stamp_radius):
"""Get the stamps for the final results from a kbmod search.
Parameters
----------
result_list : `ResultList`
The values from trajectories. The stamps are inserted into this data structure.
search : `kbmod.StackSearch`
The search object
stamp_radius : int
The radius of the stamps to create.
"""
stamp_edge = stamp_radius * 2 + 1
for row in result_list.results:
stamps = kb.StampCreator.get_stamps(search.get_imagestack(), row.trajectory, stamp_radius)
# TODO: a way to avoid a copy here would be to do
# np.array([s.image for s in stamps], dtype=np.single, copy=False)
# but that could cause a problem with reference counting at the m
# moment. The real fix is to make the stamps return Image not
# RawImage, return the Image and avoid a reference to a private
# attribute. This risks collecting RawImage but leaving a dangling
# ref to its private field. That's a fix for another time.
row.all_stamps = np.array([stamp.image for stamp in stamps])

def apply_stamp_filter(
self,
result_list,
search,
center_thresh=0.03,
peak_offset=[2.0, 2.0],
mom_lims=[35.5, 35.5, 1.0, 0.25, 0.25],
chunk_size=1000000,
stamp_type="sum",
stamp_radius=10,
):
"""This function filters result postage stamps based on their Gaussian
Moments. Results with stamps that are similar to a Gaussian are kept.
Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly by
the filtering.
search : `kbmod.StackSearch`
The search object.
center_thresh : float
The fraction of the total flux that must be contained in a single
central pixel.
peak_offset : list of floats
How far the brightest pixel in the stamp can be from the central
pixel.
mom_lims : list of floats
The maximum limit of the xx, yy, xy, x, and y central moments of
the stamp.
chunk_size : int
How many stamps to load and filter at a time.
stamp_type : string
Which method to use to generate stamps.
One of 'median', 'cpp_median', 'mean', 'cpp_mean', or 'sum'.
stamp_radius : int
The radius of the stamp.
"""
# Set the stamp creation and filtering parameters.
params = kb.StampParameters()
params.radius = stamp_radius
params.do_filtering = True
params.center_thresh = center_thresh
params.peak_offset_x = peak_offset[0]
params.peak_offset_y = peak_offset[1]
params.m20_limit = mom_lims[0]
params.m02_limit = mom_lims[1]
params.m11_limit = mom_lims[2]
params.m10_limit = mom_lims[3]
params.m01_limit = mom_lims[4]

if stamp_type == "cpp_median" or stamp_type == "median":
params.stamp_type = kb.StampType.STAMP_MEDIAN
elif stamp_type == "cpp_mean" or stamp_type == "mean":
params.stamp_type = kb.StampType.STAMP_MEAN
else:
params.stamp_type = kb.StampType.STAMP_SUM

# Save some useful helper data.
num_times = search.get_num_images()
all_valid_inds = []

# Run the stamp creation and filtering in batches of chunk_size.
print("---------------------------------------")
print("Applying Stamp Filtering")
print("---------------------------------------", flush=True)
start_time = time.time()
start_idx = 0
if result_list.num_results() <= 0:
print("Skipping. Nothing to filter.")
return

print("Stamp filtering %i results" % result_list.num_results())
while start_idx < result_list.num_results():
end_idx = min([start_idx + chunk_size, result_list.num_results()])

# Create a subslice of the results and the Boolean indices.
# Note that the sum stamp type does not filter out lc_index.
inds_to_use = [i for i in range(start_idx, end_idx)]
trj_slice = [result_list.results[i].trajectory for i in inds_to_use]
if params.stamp_type != kb.StampType.STAMP_SUM:
bool_slice = [result_list.results[i].valid_indices_as_booleans() for i in inds_to_use]
else:
# For the sum stamp, use all the indices for each trajectory.
all_true = [True] * num_times
bool_slice = [all_true for _ in inds_to_use]

# Create and filter the results, using the GPU if there is one and enough
# trajectories to make it worthwhile.
stamps_slice = kb.StampCreator.get_coadded_stamps(
search.get_imagestack(),
trj_slice,
bool_slice,
params,
kb.HAS_GPU and len(trj_slice) > 100,
)
# TODO: a way to avoid a copy here would be to do
# np.array([s.image for s in stamps], dtype=np.single, copy=False)
# but that could cause a problem with reference counting at the m
# moment. The real fix is to make the stamps return Image not
# RawImage and avoid reference to an private attribute and risking
# collecting RawImage but leaving a dangling ref to the attribute.
# That's a fix for another time so I'm leaving it as a copy here
for ind, stamp in enumerate(stamps_slice):
if stamp.width > 1:
result_list.results[ind + start_idx].stamp = np.array(stamp.image)
all_valid_inds.append(ind + start_idx)

# Move to the next chunk.
start_idx += chunk_size

# Do the actual filtering of results
result_list.filter_results(all_valid_inds)
print("Keeping %i results" % result_list.num_results(), flush=True)

end_time = time.time()
time_elapsed = end_time - start_time
print("{:.2f}s elapsed".format(time_elapsed))

def apply_clustering(self, result_list, cluster_params):
"""This function clusters results that have similar trajectories.
Expand Down
1 change: 1 addition & 0 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self):
"repeated_flag_keys": default_repeated_flag_keys,
"res_filepath": None,
"result_filename": None,
"save_all_stamps": True,
"sigmaG_lims": [25, 75],
"stamp_radius": 10,
"stamp_type": "sum",
Expand Down
170 changes: 169 additions & 1 deletion src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
"""

import abc
import numpy as np
import time

from kbmod.configuration import SearchConfiguration
from kbmod.result_list import ResultRow
from kbmod.search import KB_NO_DATA, RawImage
from kbmod.search import (
HAS_GPU,
KB_NO_DATA,
ImageStack,
RawImage,
StampCreator,
StampParameters,
StampType,
)


class BaseStampFilter(abc.ABC):
Expand Down Expand Up @@ -237,3 +248,160 @@ def keep_row(self, row: ResultRow):
"""
image = RawImage(row.stamp)
return image.center_is_local_max(self.flux_thresh, self.local_max)


def extract_search_parameters_from_config(config):
"""Create an initialized StampParameters object from the configuration settings
while doing some validity checking.
Parameters
----------
config : `SearchConfiguration`
The configuration object.
Returns
-------
params : `StampParameters`
The StampParameters object with all fields set.
Raises
------
Raises a ``ValueError`` if parameter validation fails.
Raises a ``KeyError`` if a required parameter is not found.
"""
params = StampParameters()

# Construction parameters
params.radius = config["stamp_radius"]
if params.radius < 0:
raise ValueError(f"Invalid stamp radius {params.radius}")

stamp_type = config["stamp_type"]
if stamp_type == "cpp_median" or stamp_type == "median":
params.stamp_type = StampType.STAMP_MEDIAN
elif stamp_type == "cpp_mean" or stamp_type == "mean":
params.stamp_type = StampType.STAMP_MEAN
elif stamp_type == "cpp_sum" or stamp_type == "sum":
params.stamp_type = StampType.STAMP_SUM
else:
raise ValueError(f"Unrecognized stamp type: {stamp_type}")

# Filtering parameters (with validity checking)
params.do_filtering = config["do_stamp_filter"]
params.center_thresh = config["center_thresh"]

peak_offset = config["peak_offset"]
if len(peak_offset) != 2:
raise ValueError(f"Expected length 2 list for peak_offset. Found {peak_offset}")
params.peak_offset_x = peak_offset[0]
params.peak_offset_y = peak_offset[1]

mom_lims = config["mom_lims"]
if len(mom_lims) != 5:
raise ValueError(f"Expected length 5 list for mom_lims. Found {mom_lims}")
params.m20_limit = mom_lims[0]
params.m02_limit = mom_lims[1]
params.m11_limit = mom_lims[2]
params.m10_limit = mom_lims[3]
params.m01_limit = mom_lims[4]

return params


def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=1000000, debug=False):
"""Create the co-added postage stamps and filter them based on their statistical
properties. Results with stamps that are similar to a Gaussian are kept.
Parameters
----------
result_list : `ResultList`
The current set of results. Modified directly.
im_stack : `ImageStack`
The images from which to build the co-added stamps.
stamp_params : `StampParameters` or `SearchConfiguration`
The filtering parameters for the stamps.
chunk_size : `int`
How many stamps to load and filter at a time. Used to control memory.
debug : `bool`
Output verbose debugging messages.
"""
if type(stamp_params) is SearchConfiguration:
stamp_params = extract_search_parameters_from_config(stamp_params)

if debug:
print("---------------------------------------")
print("Applying Stamp Filtering")
print("---------------------------------------")
if result_list.num_results() <= 0:
print("Skipping. Nothing to filter.")
else:
print(f"Stamp filtering {result_list.num_results()} results.")
print(stamp_params)
print(f"Using chunksize = {chunk_size}")

# Run the stamp creation and filtering in batches of chunk_size.
start_time = time.time()
start_idx = 0
all_valid_inds = []
while start_idx < result_list.num_results():
end_idx = min([start_idx + chunk_size, result_list.num_results()])

# Create a subslice of the results and the Boolean indices.
# Note that the sum stamp type does not filter out lc_index.
inds_to_use = [i for i in range(start_idx, end_idx)]
trj_slice = [result_list.results[i].trajectory for i in inds_to_use]
if stamp_params.stamp_type != StampType.STAMP_SUM:
bool_slice = [result_list.results[i].valid_indices_as_booleans() for i in inds_to_use]
else:
# For the sum stamp, use all the indices for each trajectory.
all_true = [True] * im_stack.img_count()
bool_slice = [all_true for _ in inds_to_use]

# Create and filter the results, using the GPU if there is one and enough
# trajectories to make it worthwhile.
stamps_slice = StampCreator.get_coadded_stamps(
im_stack,
trj_slice,
bool_slice,
stamp_params,
HAS_GPU and len(trj_slice) > 100,
)
# TODO: a way to avoid a copy here would be to do
# np.array([s.image for s in stamps], dtype=np.single, copy=False)
# but that could cause a problem with reference counting at the m
# moment. The real fix is to make the stamps return Image not
# RawImage and avoid reference to an private attribute and risking
# collecting RawImage but leaving a dangling ref to the attribute.
# That's a fix for another time so I'm leaving it as a copy here
for ind, stamp in enumerate(stamps_slice):
if stamp.width > 1:
result_list.results[ind + start_idx].stamp = np.array(stamp.image)
all_valid_inds.append(ind + start_idx)

# Move to the next chunk.
start_idx += chunk_size

# Do the actual filtering of results
result_list.filter_results(all_valid_inds)
if debug:
print("Keeping %i results" % result_list.num_results(), flush=True)
time_elapsed = time.time() - start_time
print("{:.2f}s elapsed".format(time_elapsed))


def append_all_stamps(result_list, im_stack, stamp_radius):
"""Get the stamps for the final results from a kbmod search. These are appended
onto the corresponding entries in a ResultList.
Parameters
----------
result_list : `ResultList`
The current set of results. Modified directly.
im_stack : `ImageStack`
The stack of images.
stamp_radius : `int`
The radius of the stamps to create.
"""
for row in result_list.results:
stamps = StampCreator.get_stamps(im_stack, row.trajectory, stamp_radius)
row.all_stamps = np.array([stamp.image for stamp in stamps])
Loading

0 comments on commit 3d0fe64

Please sign in to comment.