Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a table-based result data structure #549

Merged
merged 24 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bdc2968
Basic skeleton
jeremykubica Apr 5, 2024
f494eaa
Merge branch 'main' into result_table
jeremykubica Apr 5, 2024
032e1b9
Framework of the psi/phi information
jeremykubica Apr 5, 2024
673bd7e
Finish psi/phi updates
jeremykubica Apr 5, 2024
c444c5b
Fix typo
jeremykubica Apr 8, 2024
0776084
Add ability to save and load
jeremykubica Apr 8, 2024
1849512
Merge branch 'main' into result_table
jeremykubica Apr 9, 2024
2a46779
Address some PR comments
jeremykubica Apr 15, 2024
d6ec08f
Update comment
jeremykubica Apr 15, 2024
32659f3
Add helper to compute the likelihood curves
jeremykubica Apr 15, 2024
f8938f5
Create a vectorized version of sigmaG
jeremykubica Apr 15, 2024
ba7dc44
Add ability to save in a results file format
jeremykubica Apr 16, 2024
c77b95c
Generalize clustering code so it takes Results objects
jeremykubica Apr 17, 2024
fd3fe77
Fix comment
jeremykubica Apr 17, 2024
26b6755
Merge branch 'main' into result_table
jeremykubica Apr 18, 2024
933cf9e
Fix bad merge
jeremykubica Apr 18, 2024
5fa1334
Improve flow of sigmaG filtering for Results object
jeremykubica Apr 19, 2024
72ed621
Add ability to fetch all stamps for a Results object
jeremykubica Apr 19, 2024
5fe9e62
Extend stamp filtering to use Results object
jeremykubica Apr 19, 2024
1b964d4
Fix negative_clipping on sigmaG filtering with a Results object
jeremykubica Apr 19, 2024
04e8926
Allow mistmatched joins with an empty table
jeremykubica Apr 22, 2024
deb3c9f
Update constructor
jeremykubica Apr 22, 2024
9e790cf
Update results.py
jeremykubica Apr 22, 2024
84484cd
Address PR comments
jeremykubica Apr 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 364 additions & 0 deletions src/kbmod/result_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
"""ResultTable is a column-based data structure for tracking results with additional global data
and helper functions for filtering and maintaining consistency between different attributes in each row.
"""
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np
from pathlib import Path

from astropy.table import Table, vstack

from kbmod.trajectory_utils import make_trajectory, update_trajectory_from_psi_phi
from kbmod.search import Trajectory


class ResultTable:
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
"""This class stores a collection of related data from all of the kbmod results.

At a minimum it contains columns for the trajectory information:
(x, y, vx, vy, likelihood, flux, obs_count)
but additional columns can be added as needed.
"""

def __init__(self, trj_list, track_filtered=False):
"""Create a ResultTable class.

Parameters
----------
trj_list : `list[Trajectory]`
A list of trajectories to include in these results.
track_filtered : bool
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
Whether to track (save) the filtered trajectories. This will use
more memory and is recommended only for analysis.
"""
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
valid_inds = [i for i in range(len(trj_list)) if trj_list[i].valid]
input_dict = {
"x": [trj_list[i].x for i in valid_inds],
"y": [trj_list[i].y for i in valid_inds],
"vx": [trj_list[i].vx for i in valid_inds],
"vy": [trj_list[i].vy for i in valid_inds],
"likelihood": [trj_list[i].lh for i in valid_inds],
"flux": [trj_list[i].flux for i in valid_inds],
"obs_count": [trj_list[i].obs_count for i in valid_inds],
"trajectory": [trj_list[i] for i in valid_inds],
}
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
self.results = Table(input_dict)
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

# Set up information to track which row is filtered at which round.
self.track_filtered = track_filtered
self.filtered = {}

def __len__(self):
"""Return the number of results in the list."""
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
return len(self.results)
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

@property
def colnames(self):
return self.results.colnames

@classmethod
def from_table(cls, data, track_filtered=False):
"""Extract the ResultList from an astropy Table with the minimum
trajectory information. Fills in missing columns (such as the trajectory
object) if they are not present.

Parameters
----------
data : `astropy.table.Table`
The input data.
track_filtered : `bool`
Indicates whether the ResultList should track future filtered points.

Raises
------
KeyError if any required columns are missing.
"""
# Check that the minimum information is present.
required_cols = ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count"]
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
for col in required_cols:
if col not in data.colnames:
raise KeyError(f"Column {col} missing from input data.")

# Create an empty ResultTable and append the data table.
table = ResultTable([], track_filtered=track_filtered)
table.results = data
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

# If the data did not have a column for Trajectory object, add it with
# an expensive linear scan.
if "trajectory" not in data.colnames:
trjs = [
make_trajectory(
x=row["x"],
y=row["y"],
vx=row["vx"],
vy=row["vy"],
flux=row["flux"],
lh=row["likelihood"],
obs_count=row["obs_count"],
)
for row in data
]
table.results["trajectory"] = trjs
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

return table

@classmethod
def read_table(self, filename, track_filtered=False):
"""Read the ResultList from a table file.

Parameters
----------
filename : `str`
The name of the file to load.
track_filtered : `bool`
Indicates whether the ResultList should track future filtered points.

Raises
------
FileNotFoundError if the file is not found.
KeyError if any of the columns are missing.
"""
if not Path(filename).is_file():
raise FileNotFoundError
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
data = Table.read(filename)
return ResultTable.from_table(data, track_filtered=track_filtered)

def extend(self, table2):
"""Append the results in a second ResultTable to the current one.
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
table2 : `ResultTable`
The data structure containing additional `ResultTable` elements to add.
"""
self.results = vstack([self.results, table2.results])
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

# When merging the filtered results extend lists with the
# same key and create new lists for new keys.
for key in table2.filtered.keys():
if key in self.filtered:
self.filtered[key] = vstack([self.filtered[key], table2.filtered[key]])
else:
self.filtered[key] = table2.filtered[key]

def _update_likelihood(self):
"""Update the likelihood related trajectory information from the
psi and phi information. Requires the existence of the columns
'psi_curve' and 'phi_curve' which can be set with add_psi_phi_data().
Uses the (optional) 'valid_indices' if it exists.

Raises
------
Raises an IndexError if the necessary columns are missing.
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
"""
if "psi_curve" not in self.results.colnames:
raise IndexError("Missing column 'phi_curve'. Use add_psi_phi_data()")
if "phi_curve" not in self.results.colnames:
raise IndexError("Missing column 'phi_curve'. Use add_psi_phi_data()")
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
use_valid_indices = "index_valid" in self.results.colnames
inds = None

# Go through each row to update.
for row in self.results:
if use_valid_indices:
inds = row["index_valid"]
trj = update_trajectory_from_psi_phi(
row["trajectory"], row["psi_curve"], row["phi_curve"], index_valid=inds, in_place=True
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not always very clear what's index_valid, inds, valid_indices and they keep shifting names but I have no better naming suggestions though.

add_psi_phi calls update_lh calls update_trj_from_psi_phi, it's all for loops and can be replaced by (if it's ok that there are no jagged arrays in the table like we talked about) with this:

phisum = (test["phi"] * test["valid_idxs"]).sum(axis=1)
psisum = (test["psi"] * test["valid_idxs"]).sum(axis=1)
test["lh"] = phisum/np.sqrt(psisum)
test["flux"] = psi_sum / phi_sum
test["n_obs"] = test["valid_idxs"].sum(axis=1)

and this is all vectorized so pretty fast and most of the time in-place so no extra memory allocations. You seemed to be somewhat concerned about that, judging by the comments.

# Update the exploded columns.
row["likelihood"] = trj.lh
row["flux"] = trj.flux
row["obs_count"] = trj.obs_count

def add_psi_phi_data(self, psi_array, phi_array, index_valid=None):
"""Append columns for the psi and phi data and use this to update the
relevant trajectory information.

Parameters
----------
psi_array : `numpy.ndarray`
An array of psi_curves with one for each row.
phi_array : `numpy.ndarray`
An array of psi_curves with one for each row.
index_valid : `numpy.ndarray`, optional
An optional array of index_valid arrays with one for each row.

Raises
------
Raises a ValueError if the input arrays are not the same size as the table
or a given pair of rows in the arrays are not the same length.
"""
if len(psi_array) != len(self.results):
raise ValueError("Wrong number of psi curves provided.")
if len(phi_array) != len(self.results):
raise ValueError("Wrong number of phi curves provided.")
self.results["psi_curve"] = psi_array
self.results["phi_curve"] = phi_array

if index_valid is not None:
# Make the data to match.
if len(index_valid) != len(self.results):
raise ValueError("Wrong number of index_valid lists provided.")
self.results["index_valid"] = index_valid

# Update the track likelihoods given this new information.
self._update_likelihood()

def filter_mask(self, mask, label=None):
"""Filter the rows in the ResultTable to only include those indices
that are marked True in the mask.

Parameters
----------
mask : list
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
A list the same length as the table with True/False indicating
which row to keep.
label : string
The label of the filtering stage to use. Only used if
we keep filtered trajectories.

Returns
-------
self : ResultTable
Returns a reference to itself to allow chaining.
"""
if self.track_filtered:
if label is None:
label = ""

if label in self.filtered:
self.filtered[label] = vstack([self.filtered[label], self.results[~mask]])
else:
self.filtered[label] = self.results[~mask]

# Do the actual filtering.
self.results = self.results[mask]

# Return a reference to the current object to allow chaining.
return self
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved

def filter_by_index(self, indices_to_keep, label=None):
"""Filter the rows in the ResultTable to only include those indices
in the list indices_to_keep.

Parameters
----------
indices_to_keep : `list[int]`
The indices of the rows to keep.
label : `str`
The label of the filtering stage to use. Only used if
we keep filtered trajectories.

Returns
-------
self : `ResultTable`
Returns a reference to itself to allow chaining.
"""
indices_set = set(indices_to_keep)
mask = np.array([i in indices_set for i in range(len(self.results))])
self.filter_mask(mask, label)
return self

def get_filtered(self, label=None):
"""Get the results filtered at a given stage or all stages.

Parameters
----------
label : `str`
The filtering stage to use. If no label is provided,
return all filtered rows.

Returns
-------
results : `astropy.table.Table` or None
A table with the filtered rows or None if there are no entries.
"""
if not self.track_filtered:
raise ValueError("ResultTable filter tracking not enabled.")

result = None
if label is not None:
# Check if anything was filtered at this stage.
if label in self.filtered:
result = self.filtered[label]
else:
result = vstack([x for x in self.filtered.values()])

return result

def revert_filter(self, label=None, add_column=None):
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
"""Revert the filtering by re-adding filtered ResultRows.

Note
----
Filtered rows are appended to the end of the list. Does not return
the results to the original ordering.

Parameters
----------
label : `str`
The filtering stage to use. If no label is provided,
revert all filtered rows.
add_column : `str`
If not ``None``, add a tracking column with the given name
that includes the original filtering reason.

Returns
-------
self : `ResultTable`
Returns a reference to itself to allow chaining.

Raises
------
ValueError if filtering is not enabled.
KeyError if label is unknown.
"""
if not self.track_filtered:
raise ValueError("ResultTable filter tracking not enabled.")

# Make a list of labels to revert
if label is not None:
if label not in self.filtered:
raise KeyError(f"Unknown filtered label {label}")
to_revert = [label]
else:
to_revert = list(self.filtered.keys())

# If we don't have the tracking column yet, add it.
if add_column is not None and add_column not in self.results.colnames:
self.results[add_column] = [""] * len(self.results)

# Make a list of tables to merge.
table_list = [self.results]
for key in to_revert:
filtered_table = self.filtered[key]
if add_column is not None:
filtered_table[add_column] = [key] * len(filtered_table)
table_list.append(filtered_table)
del self.filtered[key]
self.results = vstack(table_list)

return self

def write_table(self, filename, overwrite=True, cols_to_drop=[]):
"""Write the unfiltered results to a single (ecsv) file.

Parameter
---------
filename : `str`
The name of the result file.
overwrite : `bool`
Overwrite the file if it already exists. [default: True]
cols_to_drop : `list`
A list of columns to drop (to save space). [default: []]
"""
# Make a copy so we can modify the table (drop the Trajectory objects)
write_table = self.results.copy()

all_cols_to_drop = ["trajectory"] + cols_to_drop
for col in all_cols_to_drop:
if col in write_table.colnames:
write_table.remove_column(col)

# Write out the table.
write_table.write(filename, overwrite=overwrite)
3 changes: 3 additions & 0 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def update_trajectory_from_psi_phi(trj, psi_curve, phi_curve, index_valid=None,
"""Update the trajectory's statistic information from a psi_curve and
phi_curve. Uses an optional index_valid mask (True/False) to mask out
pixels.

Parameters
----------
trj : `Trajectory`
Expand All @@ -239,10 +240,12 @@ def update_trajectory_from_psi_phi(trj, psi_curve, phi_curve, index_valid=None,
An array of Booleans indicating whether the time step is valid.
in_place : `bool`
Update the input trajectory in-place.

Returns
-------
result : `Trajectory`
The updated trajectory. May be the same as trj if in_place=True.

Raises
------
Raises a ValueError if the input arrays are not the same size.
Expand Down
Loading
Loading