Skip to content

Commit

Permalink
Merge pull request #736 from dirac-institute/save_wcs
Browse files Browse the repository at this point in the history
Save the WCS as a column in the results table
  • Loading branch information
jeremykubica authored Nov 5, 2024
2 parents df931a2 + c7e4d56 commit f8e979b
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 28 deletions.
5 changes: 2 additions & 3 deletions src/kbmod/reprojection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def correct_parallax(
use_bounds=False,
):
"""Calculate the parallax corrected postions for a given object at a given
time, observation location on Earth, and user defined distance from th
e Sun.
time, observation location on Earth, and user defined distance from the Sun.
By default, this function will use the geometric solution for objects beyond 1au.
If the distance is less than 1au, the function will use the scipy minimizer
to find the best geocentric distance.
Expand Down Expand Up @@ -66,7 +66,6 @@ def correct_parallax(
ICRS, and the best fit geocentric distance (float).
"""

if use_minimizer or heliocentric_distance < 1.02:
return correct_parallax_with_minimizer(
coord, obstime, point_on_earth, heliocentric_distance, geocentric_distance, method, use_bounds
Expand Down
57 changes: 39 additions & 18 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from kbmod.trajectory_utils import trajectory_from_np_object
from kbmod.search import Trajectory
from kbmod.wcs_utils import deserialize_wcs, serialize_wcs


logger = logging.getLogger(__name__)
Expand All @@ -28,6 +29,9 @@ class Results:
----------
table : `astropy.table.Table`
The stored results data.
wcs : `astropy.wcs.WCS`
A global WCS for all the results. This is optional and primarily used when saving
the results to a file so as to preserve the WCS for future analysis.
track_filtered : `bool`
Whether to track (save) the filtered trajectories. This will use
more memory and is recommended only for analysis.
Expand All @@ -53,16 +57,21 @@ class Results:
]
_required_col_names = set([rq_col[0] for rq_col in required_cols])

def __init__(self, data=None, track_filtered=False):
def __init__(self, data=None, track_filtered=False, wcs=None):
"""Create a ResultTable class.
Parameters
----------
data : `dict`, `astropy.table.Table`
The data for the results table.
track_filtered : `bool`
Whether to track (save) the filtered trajectories. This will use
more memory and is recommended only for analysis.
wcs : `astropy.wcs.WCS`, optional
A gloabl WCS for the results.
"""
self.wcs = wcs

# Set up information to track which row is filtered at which round.
self.track_filtered = track_filtered
self.filtered = {}
Expand Down Expand Up @@ -179,7 +188,14 @@ def read_table(cls, filename, track_filtered=False):
if not Path(filename).is_file():
raise FileNotFoundError(f"File {filename} not found.")
data = Table.read(filename)
return Results(data, track_filtered=track_filtered)

# Check if we have stored a global WCS.
if "wcs" in data.meta:
wcs = deserialize_wcs(data.meta["wcs"])
else:
wcs = None

return Results(data, track_filtered=track_filtered, wcs=wcs)

def remove_column(self, colname):
"""Remove a column from the results table.
Expand Down Expand Up @@ -600,7 +616,7 @@ def revert_filter(self, label=None, add_column=None):

return self

def write_table(self, filename, overwrite=True, cols_to_drop=[]):
def write_table(self, filename, overwrite=True, cols_to_drop=()):
"""Write the unfiltered results to a single (ecsv) file.
Parameter
Expand All @@ -609,26 +625,28 @@ def write_table(self, filename, overwrite=True, cols_to_drop=[]):
The name of the result file.
overwrite : `bool`
Overwrite the file if it already exists. [default: True]
cols_to_drop : `list`
A list of columns to drop (to save space). [default: []]
cols_to_drop : `tuple`
A tuple of columns to drop (to save space). [default: ()]
"""
logger.info(f"Saving results to {filename}")

if len(cols_to_drop) > 0:
# Make a copy so we can modify the table
write_table = self.table.copy()
# Make a copy so we can modify the table
write_table = self.table.copy()

for col in cols_to_drop:
if col in write_table.colnames:
if col in self._required_col_names:
logger.debug(f"Unable to drop required column {col} for write.")
else:
write_table.remove_column(col)
# Drop the columns we need to drop.
for col in cols_to_drop:
if col in write_table.colnames:
if col in self._required_col_names:
logger.debug(f"Unable to drop required column {col} for write.")
else:
write_table.remove_column(col)

# Write out the table.
write_table.write(filename, overwrite=overwrite)
else:
self.table.write(filename, overwrite=overwrite)
# Add global meta data that we can retrieve.
if self.wcs is not None:
write_table.meta["wcs"] = serialize_wcs(self.wcs)

# Write out the table.
write_table.write(filename, overwrite=overwrite)

def write_trajectory_file(self, filename, overwrite=True):
"""Save the trajectories to a numpy file.
Expand Down Expand Up @@ -693,6 +711,8 @@ def write_column(self, colname, filename):
logger.info(f"Writing {colname} column data to {filename}")
if colname not in self.table.colnames:
raise KeyError(f"Column {colname} missing from data.")

# Save the column.
data = np.array(self.table[colname])
np.save(filename, data, allow_pickle=False)

Expand Down Expand Up @@ -722,6 +742,7 @@ def load_column(self, filename, colname):
raise ValueError(
f"Error loading {filename}: expected {len(self.table)} entries, but found {len(data)}."
)

self.table[colname] = data

def write_filtered_stats(self, filename):
Expand Down
10 changes: 8 additions & 2 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .results import Results
from .trajectory_generator import create_trajectory_generator
from .wcs_utils import wcs_to_dict
from .work_unit import WorkUnit


Expand Down Expand Up @@ -186,7 +187,7 @@ def do_gpu_search(self, config, stack, trj_generator):
keep = self.load_and_filter_results(search, config)
return keep

def run_search(self, config, stack, trj_generator=None):
def run_search(self, config, stack, trj_generator=None, wcs=None):
"""This function serves as the highest-level python interface for starting
a KBMOD search given an ImageStack and SearchConfiguration.
Expand All @@ -199,6 +200,8 @@ def run_search(self, config, stack, trj_generator=None):
trj_generator : `TrajectoryGenerator`, optional
The object to generate the candidate trajectories for each pixel.
If None uses the default EclipticCenteredSearch
wcs : `astropy.wcs.WCS`, optional
A global WCS for all images in the search.
Returns
-------
Expand Down Expand Up @@ -247,6 +250,9 @@ def run_search(self, config, stack, trj_generator=None):
if config["save_all_stamps"]:
append_all_stamps(keep, stack, config["stamp_radius"])

# Append the WCS information if it is provided. This will be saved with the results.
keep.table.wcs = wcs

logger.info(f"Found {len(keep)} potential trajectories.")

# Save the results in as an ecsv file and/or a legacy text file.
Expand Down Expand Up @@ -280,4 +286,4 @@ def run_search_from_work_unit(self, work):
trj_generator = create_trajectory_generator(work.config, work_unit=work)

# Run the search.
return self.run_search(work.config, work.im_stack, trj_generator=trj_generator)
return self.run_search(work.config, work.im_stack, trj_generator=trj_generator, wcs=work.wcs)
55 changes: 50 additions & 5 deletions src/kbmod/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import astropy.coordinates
import astropy.units
import astropy.wcs
import json
import numpy


Expand Down Expand Up @@ -436,14 +437,52 @@ def wcs_to_dict(wcs):
result : `dict`
A dictionary containing the WCS header information.
"""
result = {}
if wcs is not None:
wcs_header = wcs.to_header()
for key in wcs_header:
result[key] = wcs_header[key]
result = dict(wcs.to_header(relax=True))
if wcs.pixel_shape is not None:
header["NAXIS1"], header["NAXIS2"] = wcs.pixel_shape
elif wcs.array_shape is not None:
header["NAXIS2"], header["NAXIS1"] = wcs.array_shape
return result


def serialize_wcs(wcs):
"""Convert a WCS into a JSON string.
Parameters
----------
wcs : `astropy.wcs.WCS`
The WCS to convert.
Returns
-------
wcs_str : `str`
The serialized WCS.
"""
# Since AstroPy's WCS does not output NAXIS, we need to manually add those.
header = wcs.to_header(relax=True)
header["NAXIS1"], header["NAXIS2"] = wcs.pixel_shape
return json.dumps(dict(header))


def deserialize_wcs(wcs_str):
"""Convert a JSON string into a WCS object.
Parameters
----------
wcs_str : `str`
The serialized WCS.
Returns
-------
wcs : `astropy.wcs.WCS`
The resulting WCS.
"""
wcs_dict = json.loads(wcs_str)
wcs = astropy.wcs.WCS(wcs_dict)
wcs.pixel_shape = (wcs_dict["NAXIS1"], wcs_dict["NAXIS2"])
return wcs


def make_fake_wcs_info(center_ra, center_dec, height, width, deg_per_pixel=None):
"""Create a fake WCS dictionary given basic information. This is not a realistic
WCS in terms of astronomy, but can provide a place holder for many tests.
Expand Down Expand Up @@ -548,4 +587,10 @@ def wcs_fits_equal(wcs_a, wcs_b):
if header_a[key] != header_b[key]:
return False

# Check that we correctly kept the shape of the matrix.
if wcs_a.array_shape != wcs_b.array_shape:
return False
if wcs_a.pixel_shape != wcs_b.pixel_shape:
return False

return True
10 changes: 10 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from kbmod.results import Results
from kbmod.search import Trajectory
from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal


class test_results(unittest.TestCase):
Expand Down Expand Up @@ -388,6 +389,10 @@ def test_to_from_table_file(self):
table.table["other"] = [i for i in range(max_save)]
self.assertEqual(len(table), max_save)

# Create a fake WCS to use for serialization tests.
fake_wcs = make_fake_wcs(25.0, -7.5, 800, 600, deg_per_pixel=0.01)
table.wcs = fake_wcs

# Test read/write to file.
with tempfile.TemporaryDirectory() as dir_name:
file_path = os.path.join(dir_name, "results.ecsv")
Expand All @@ -402,6 +407,11 @@ def test_to_from_table_file(self):
for col in ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count", "other"]:
self.assertTrue(np.allclose(table[col], table2[col]))

# Check that we reloaded the WCS's, including the correct shape.
self.assertIsNotNone(table2.wcs)
self.assertTrue(wcs_fits_equal(table2.wcs, fake_wcs))
self.assertEqual(table2.wcs.pixel_shape, fake_wcs.pixel_shape)

# Cannot overwrite with it set to False
with self.assertRaises(OSError):
table.write_table(file_path, overwrite=False, cols_to_drop=["other"])
Expand Down
10 changes: 10 additions & 0 deletions tests/test_wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def test_wcs_to_dict(self):
self.assertTrue(key in new_dict)
self.assertAlmostEqual(new_dict[key], self.header_dict[key])

def test_serialization(self):
self.wcs.pixel_shape = (200, 250)
wcs_str = serialize_wcs(self.wcs)
self.assertTrue(isinstance(wcs_str, str))

wcs2 = deserialize_wcs(wcs_str)
self.assertTrue(isinstance(wcs2, WCS))
self.assertEqual(self.wcs.pixel_shape, wcs2.pixel_shape)
self.assertTrue(wcs_fits_equal(self.wcs, wcs2))

def test_append_wcs_to_hdu_header(self):
for use_dictionary in [True, False]:
if use_dictionary:
Expand Down

0 comments on commit f8e979b

Please sign in to comment.