Skip to content

Commit

Permalink
Update constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Apr 22, 2024
1 parent 04e8926 commit deb3c9f
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 102 deletions.
175 changes: 100 additions & 75 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,24 @@ class Results:
of the results removed by that filter.
"""

_required_cols = ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count"]

def __init__(self, trajectories, track_filtered=False):
# The required columns list gives a list of tuples containing
# (column name, dype, default value) for each required column.
_required_cols = [
("x", "int64", 0),
("y", "int64", 0),
("vx", "float64", 0.0),
("vy", "float64", 0.0),
("likelihood", "float64", 0.0),
("flux", "float64", 0.0),
("obs_count", "int64", 0),
]

def __init__(self, data=None, track_filtered=False):
"""Create a ResultTable class.
Parameters
----------
trajectories : `list[Trajectory]`
A list of trajectories to include in these results.
data : `dict`, `astropy.table.Table`
track_filtered : `bool`
Whether to track (save) the filtered trajectories. This will use
more memory and is recommended only for analysis.
Expand All @@ -49,36 +58,23 @@ def __init__(self, trajectories, track_filtered=False):
self.track_filtered = track_filtered
self.filtered = {}

# Create dictionaries for the required columns.
input_d = {}
invalid_d = {}
for col in self._required_cols:
input_d[col] = []
invalid_d[col] = []

# Add the valid trajectories to the table. If we are tracking filtered
# data, add invalid trajectories to the invalid_d dictionary.
for trj in trajectories:
if trj.valid:
input_d["x"].append(trj.x)
input_d["y"].append(trj.y)
input_d["vx"].append(trj.vx)
input_d["vy"].append(trj.vy)
input_d["likelihood"].append(trj.lh)
input_d["flux"].append(trj.flux)
input_d["obs_count"].append(trj.obs_count)
elif track_filtered:
invalid_d["x"].append(trj.x)
invalid_d["y"].append(trj.y)
invalid_d["vx"].append(trj.vx)
invalid_d["vy"].append(trj.vy)
invalid_d["likelihood"].append(trj.lh)
invalid_d["flux"].append(trj.flux)
invalid_d["obs_count"].append(trj.obs_count)
if data is None:
# Set up the basic table meta data.
self.table = Table(
names=[col[0] for col in self._required_cols],
dtype=[col[1] for col in self._required_cols],
)
elif type(data) is dict:
self.table = Table(data)
elif type(data) is Table:
self.table = data.copy()
else:
raise TypeError(f"Incompatible data type {type(data)}")

self.table = Table(input_d)
if track_filtered:
self.filtered["invalid_trajectory"] = Table(invalid_d)
# Check that we have the correct columns.
for col in self._required_cols:
if col[0] not in self.table.colnames:
raise KeyError(f"Column {col[0]} missing from input data.")

def __len__(self):
return len(self.table)
Expand All @@ -100,48 +96,59 @@ def colnames(self):
return self.table.colnames

@classmethod
def from_table(cls, data, track_filtered=False):
"""Extract data from an astropy Table with the minimum trajectory information.
def from_trajectories(cls, trajectories, track_filtered=False):
"""Extract data from a list of Trajectory objects.
Parameters
----------
data : `astropy.table.Table`
The input data.
trajectories : `list[Trajectory]`
A list of trajectories to include in these results.
track_filtered : `bool`
Indicates whether to track future filtered points.
Raises
------
Raises a KeyError if any required columns are missing.
"""
# Check that the minimum information is present.
# Create dictionaries for the required columns.
input_d = {}
invalid_d = {}
for col in cls._required_cols:
if col not in data.colnames:
raise KeyError(f"Column {col} missing from input data.")
input_d[col[0]] = []
invalid_d[col[0]] = []
num_valid = 0
num_invalid = 0

# Create an empty Results object and append the data table.
results = Results([], track_filtered=track_filtered)
results.table = data
# Add the valid trajectories to the table. If we are tracking filtered
# data, add invalid trajectories to the invalid_d dictionary.
for trj in trajectories:
if trj.valid:
input_d["x"].append(trj.x)
input_d["y"].append(trj.y)
input_d["vx"].append(trj.vx)
input_d["vy"].append(trj.vy)
input_d["likelihood"].append(trj.lh)
input_d["flux"].append(trj.flux)
input_d["obs_count"].append(trj.obs_count)
num_valid += 1
elif track_filtered:
invalid_d["x"].append(trj.x)
invalid_d["y"].append(trj.y)
invalid_d["vx"].append(trj.vx)
invalid_d["vy"].append(trj.vy)
invalid_d["likelihood"].append(trj.lh)
invalid_d["flux"].append(trj.flux)
invalid_d["obs_count"].append(trj.obs_count)
num_invalid += 1

# Check for any missing columns and fill in the default value.
for col in cls._required_cols:
if col[0] not in input_d:
input_d[col[0]] = [col[2]] * num_valid
invalid_d[col[0]] = [col[2]] * num_invalid

# Create the table and add the unfiltered (and filtered) results.
results = Results(input_d, track_filtered=track_filtered)
if track_filtered and num_invalid > 0:
results.filtered["invalid_trajectory"] = Table(invalid_d)
return results

@classmethod
def from_dict(cls, input_dict, track_filtered=False):
"""Extract data from a dictionary with the minimum trajectory information.
Parameters
----------
input_dict : `dict`
The input data.
track_filtered : `bool`
Indicates whether to track future filtered points.
Raises
------
Raises a KeyError if any required columns are missing.
"""
return cls.from_table(Table(input_dict))

@classmethod
def read_table(cls, filename, track_filtered=False):
"""Read the ResultList from a table file.
Expand All @@ -161,7 +168,7 @@ def read_table(cls, filename, track_filtered=False):
if not Path(filename).is_file():
raise FileNotFoundError(f"File {filename} not found.")
data = Table.read(filename)
return Results.from_table(data, track_filtered=track_filtered)
return Results(data, track_filtered=track_filtered)

def extend(self, results2):
"""Append the results in a second `Results` object to the current one.
Expand All @@ -180,7 +187,7 @@ def extend(self, results2):
------
Raises a ValueError if the columns of the results do not match.
"""
if len(self) > 0 and set(self.colnames) != set(results2.colnames):
if set(self.colnames) != set(results2.colnames):
raise ValueError("Column mismatch when merging results")

self.table = vstack([self.table, results2.table])
Expand Down Expand Up @@ -362,6 +369,30 @@ def update_index_valid(self, index_valid):
self._update_likelihood()
return self

def _append_filtered(self, table, label=None):
"""Appended a filtered table onto the current tables for
tracking the filtered values.
Parameters
----------
mask : `list` or `numpy.ndarray`
A list the same length as the table with True/False indicating
which row to keep.
label : `str`
The label of the filtering stage to use. Only used if
we keep filtered trajectories.
"""
if not self.track_filtered:
return

if label is None:
label = ""

if label in self.filtered:
self.filtered[label] = vstack([self.filtered[label], table])
else:
self.filtered[label] = table

def filter_mask(self, mask, label=None):
"""Filter the rows in the ResultTable to only include those indices
that are marked True in the mask.
Expand All @@ -381,13 +412,7 @@ def filter_mask(self, mask, label=None):
Returns a reference to itself to allow chaining.
"""
if self.track_filtered:
if label is None:
label = ""

if label in self.filtered:
self.filtered[label] = vstack([self.filtered[label], self.table[~mask]])
else:
self.filtered[label] = self.table[~mask]
self._append_filtered(self.table[~mask], label)

# Do the actual filtering.
self.table = self.table[mask]
Expand Down Expand Up @@ -562,4 +587,4 @@ def from_trajectory_file(cls, filename, track_filtered=False):
raise FileNotFoundError(f"{filename} not found for load.")

trj_list = FileUtils.load_results_file_as_trajectories(filename)
return cls(trj_list, track_filtered)
return cls.from_trajectories(trj_list, track_filtered)
6 changes: 3 additions & 3 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def make_trajectory(x=0, y=0, vx=0.0, vy=0.0, flux=0.0, lh=0.0, obs_count=0):
The resulting Trajectory object.
"""
trj = Trajectory()
trj.x = x
trj.y = y
trj.x = int(x)
trj.y = int(y)
trj.vx = vx
trj.vy = vy
trj.flux = flux
trj.lh = lh
trj.obs_count = obs_count
trj.obs_count = int(obs_count)
trj.valid = True
return trj

Expand Down
2 changes: 1 addition & 1 deletion tests/test_clustering_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _make_result_data(self, objs):
`Results`
"""
trj_list = [make_trajectory(x[0], x[1], x[2], x[3], lh=100.0) for x in objs]
return Results(trj_list)
return Results.from_trajectories(trj_list)

def test_dbscan_position_result_list(self):
rs = self._make_result_list(
Expand Down
Loading

0 comments on commit deb3c9f

Please sign in to comment.