diff --git a/src/kbmod/results.py b/src/kbmod/results.py index c045fa7c7..798d75c57 100644 --- a/src/kbmod/results.py +++ b/src/kbmod/results.py @@ -32,15 +32,24 @@ class Results: of the results removed by that filter. """ - _required_cols = ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count"] - - def __init__(self, trajectories, track_filtered=False): + # The required columns list gives a list of tuples containing + # (column name, dype, default value) for each required column. + _required_cols = [ + ("x", "int64", 0), + ("y", "int64", 0), + ("vx", "float64", 0.0), + ("vy", "float64", 0.0), + ("likelihood", "float64", 0.0), + ("flux", "float64", 0.0), + ("obs_count", "int64", 0), + ] + + def __init__(self, data=None, track_filtered=False): """Create a ResultTable class. Parameters ---------- - trajectories : `list[Trajectory]` - A list of trajectories to include in these results. + data : `dict`, `astropy.table.Table` track_filtered : `bool` Whether to track (save) the filtered trajectories. This will use more memory and is recommended only for analysis. @@ -49,36 +58,23 @@ def __init__(self, trajectories, track_filtered=False): self.track_filtered = track_filtered self.filtered = {} - # Create dictionaries for the required columns. - input_d = {} - invalid_d = {} - for col in self._required_cols: - input_d[col] = [] - invalid_d[col] = [] - - # Add the valid trajectories to the table. If we are tracking filtered - # data, add invalid trajectories to the invalid_d dictionary. - for trj in trajectories: - if trj.valid: - input_d["x"].append(trj.x) - input_d["y"].append(trj.y) - input_d["vx"].append(trj.vx) - input_d["vy"].append(trj.vy) - input_d["likelihood"].append(trj.lh) - input_d["flux"].append(trj.flux) - input_d["obs_count"].append(trj.obs_count) - elif track_filtered: - invalid_d["x"].append(trj.x) - invalid_d["y"].append(trj.y) - invalid_d["vx"].append(trj.vx) - invalid_d["vy"].append(trj.vy) - invalid_d["likelihood"].append(trj.lh) - invalid_d["flux"].append(trj.flux) - invalid_d["obs_count"].append(trj.obs_count) + if data is None: + # Set up the basic table meta data. + self.table = Table( + names=[col[0] for col in self._required_cols], + dtype=[col[1] for col in self._required_cols], + ) + elif type(data) is dict: + self.table = Table(data) + elif type(data) is Table: + self.table = data.copy() + else: + raise TypeError(f"Incompatible data type {type(data)}") - self.table = Table(input_d) - if track_filtered: - self.filtered["invalid_trajectory"] = Table(invalid_d) + # Check that we have the correct columns. + for col in self._required_cols: + if col[0] not in self.table.colnames: + raise KeyError(f"Column {col[0]} missing from input data.") def __len__(self): return len(self.table) @@ -100,48 +96,59 @@ def colnames(self): return self.table.colnames @classmethod - def from_table(cls, data, track_filtered=False): - """Extract data from an astropy Table with the minimum trajectory information. + def from_trajectories(cls, trajectories, track_filtered=False): + """Extract data from a list of Trajectory objects. Parameters ---------- - data : `astropy.table.Table` - The input data. + trajectories : `list[Trajectory]` + A list of trajectories to include in these results. track_filtered : `bool` Indicates whether to track future filtered points. - - Raises - ------ - Raises a KeyError if any required columns are missing. """ - # Check that the minimum information is present. + # Create dictionaries for the required columns. + input_d = {} + invalid_d = {} for col in cls._required_cols: - if col not in data.colnames: - raise KeyError(f"Column {col} missing from input data.") + input_d[col[0]] = [] + invalid_d[col[0]] = [] + num_valid = 0 + num_invalid = 0 - # Create an empty Results object and append the data table. - results = Results([], track_filtered=track_filtered) - results.table = data + # Add the valid trajectories to the table. If we are tracking filtered + # data, add invalid trajectories to the invalid_d dictionary. + for trj in trajectories: + if trj.valid: + input_d["x"].append(trj.x) + input_d["y"].append(trj.y) + input_d["vx"].append(trj.vx) + input_d["vy"].append(trj.vy) + input_d["likelihood"].append(trj.lh) + input_d["flux"].append(trj.flux) + input_d["obs_count"].append(trj.obs_count) + num_valid += 1 + elif track_filtered: + invalid_d["x"].append(trj.x) + invalid_d["y"].append(trj.y) + invalid_d["vx"].append(trj.vx) + invalid_d["vy"].append(trj.vy) + invalid_d["likelihood"].append(trj.lh) + invalid_d["flux"].append(trj.flux) + invalid_d["obs_count"].append(trj.obs_count) + num_invalid += 1 + # Check for any missing columns and fill in the default value. + for col in cls._required_cols: + if col[0] not in input_d: + input_d[col[0]] = [col[2]] * num_valid + invalid_d[col[0]] = [col[2]] * num_invalid + + # Create the table and add the unfiltered (and filtered) results. + results = Results(input_d, track_filtered=track_filtered) + if track_filtered and num_invalid > 0: + results.filtered["invalid_trajectory"] = Table(invalid_d) return results - @classmethod - def from_dict(cls, input_dict, track_filtered=False): - """Extract data from a dictionary with the minimum trajectory information. - - Parameters - ---------- - input_dict : `dict` - The input data. - track_filtered : `bool` - Indicates whether to track future filtered points. - - Raises - ------ - Raises a KeyError if any required columns are missing. - """ - return cls.from_table(Table(input_dict)) - @classmethod def read_table(cls, filename, track_filtered=False): """Read the ResultList from a table file. @@ -161,7 +168,7 @@ def read_table(cls, filename, track_filtered=False): if not Path(filename).is_file(): raise FileNotFoundError(f"File {filename} not found.") data = Table.read(filename) - return Results.from_table(data, track_filtered=track_filtered) + return Results(data, track_filtered=track_filtered) def extend(self, results2): """Append the results in a second `Results` object to the current one. @@ -180,7 +187,7 @@ def extend(self, results2): ------ Raises a ValueError if the columns of the results do not match. """ - if len(self) > 0 and set(self.colnames) != set(results2.colnames): + if set(self.colnames) != set(results2.colnames): raise ValueError("Column mismatch when merging results") self.table = vstack([self.table, results2.table]) @@ -362,6 +369,30 @@ def update_index_valid(self, index_valid): self._update_likelihood() return self + def _append_filtered(self, table, label=None): + """Appended a filtered table onto the current tables for + tracking the filtered values. + + Parameters + ---------- + mask : `list` or `numpy.ndarray` + A list the same length as the table with True/False indicating + which row to keep. + label : `str` + The label of the filtering stage to use. Only used if + we keep filtered trajectories. + """ + if not self.track_filtered: + return + + if label is None: + label = "" + + if label in self.filtered: + self.filtered[label] = vstack([self.filtered[label], table]) + else: + self.filtered[label] = table + def filter_mask(self, mask, label=None): """Filter the rows in the ResultTable to only include those indices that are marked True in the mask. @@ -381,13 +412,7 @@ def filter_mask(self, mask, label=None): Returns a reference to itself to allow chaining. """ if self.track_filtered: - if label is None: - label = "" - - if label in self.filtered: - self.filtered[label] = vstack([self.filtered[label], self.table[~mask]]) - else: - self.filtered[label] = self.table[~mask] + self._append_filtered(self.table[~mask], label) # Do the actual filtering. self.table = self.table[mask] @@ -562,4 +587,4 @@ def from_trajectory_file(cls, filename, track_filtered=False): raise FileNotFoundError(f"{filename} not found for load.") trj_list = FileUtils.load_results_file_as_trajectories(filename) - return cls(trj_list, track_filtered) + return cls.from_trajectories(trj_list, track_filtered) diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index 76a9b7b0d..aeb34d98a 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -46,13 +46,13 @@ def make_trajectory(x=0, y=0, vx=0.0, vy=0.0, flux=0.0, lh=0.0, obs_count=0): The resulting Trajectory object. """ trj = Trajectory() - trj.x = x - trj.y = y + trj.x = int(x) + trj.y = int(y) trj.vx = vx trj.vy = vy trj.flux = flux trj.lh = lh - trj.obs_count = obs_count + trj.obs_count = int(obs_count) trj.valid = True return trj diff --git a/tests/test_clustering_filters.py b/tests/test_clustering_filters.py index a279d5992..7cd2c0c3f 100644 --- a/tests/test_clustering_filters.py +++ b/tests/test_clustering_filters.py @@ -46,7 +46,7 @@ def _make_result_data(self, objs): `Results` """ trj_list = [make_trajectory(x[0], x[1], x[2], x[3], lh=100.0) for x in objs] - return Results(trj_list) + return Results.from_trajectories(trj_list) def test_dbscan_position_result_list(self): rs = self._make_result_list( diff --git a/tests/test_results.py b/tests/test_results.py index 3bae93c9b..3b93fb8e6 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -48,8 +48,13 @@ def _assert_results_match_dict(self, results, test_dict): for i in range(len(results)): self.assertEqual(results[col][i], test_dict[col][i]) - def test_create(self): - table = Results(self.trj_list) + def test_empty(self): + table = Results() + self.assertEqual(len(table), 0) + self.assertEqual(len(table.colnames), 7) + + def test_from_trajectories(self): + table = Results.from_trajectories(self.trj_list) self.assertEqual(len(table), self.num_entries) self.assertEqual(len(table.colnames), 7) self._assert_results_match_dict(table, self.input_dict) @@ -57,7 +62,7 @@ def test_create(self): # Test that we ignore invalid results, but track them in the filtered table. self.trj_list[2].valid = False self.trj_list[7].valid = False - table2 = Results(self.trj_list, track_filtered=True) + table2 = Results.from_trajectories(self.trj_list, track_filtered=True) self.assertEqual(len(table2), self.num_entries - 2) for i in range(self.num_entries - 2): self.assertFalse(table2["x"][i] == 2 or table2["x"][i] == 7) @@ -73,11 +78,11 @@ def test_from_dict(self): # Missing 'x' column del self.input_dict["x"] with self.assertRaises(KeyError): - _ = Results.from_dict(self.input_dict) + _ = Results(self.input_dict) # Add back the 'x' column. self.input_dict["x"] = [trj.x for trj in self.trj_list] - table = Results.from_dict(self.input_dict) + table = Results(self.input_dict) self._assert_results_match_dict(table, self.input_dict) def test_from_table(self): @@ -86,16 +91,16 @@ def test_from_table(self): # Missing 'x' column del self.input_dict["x"] with self.assertRaises(KeyError): - _ = Results.from_table(Table(self.input_dict)) + _ = Results(Table(self.input_dict)) # Add back the 'x' column. self.input_dict["x"] = [trj.x for trj in self.trj_list] - table = Results.from_table(Table(self.input_dict)) + table = Results(Table(self.input_dict)) self._assert_results_match_dict(table, self.input_dict) def test_make_trajectory_list(self): self.input_dict["something_added"] = [i for i in range(self.num_entries)] - table = Results.from_dict(self.input_dict) + table = Results(self.input_dict) trajectories = table.make_trajectory_list() self.assertEqual(len(trajectories), self.num_entries) @@ -109,10 +114,10 @@ def test_make_trajectory_list(self): self.assertEqual(trj.lh, table["likelihood"][i]) def test_extend(self): - table1 = Results(self.trj_list) + table1 = Results.from_trajectories(self.trj_list) for i in range(self.num_entries): self.trj_list[i].x += self.num_entries - table2 = Results(self.trj_list) + table2 = Results.from_trajectories(self.trj_list) table1.extend(table2) self.assertEqual(len(table1), 2 * self.num_entries) @@ -121,12 +126,12 @@ def test_extend(self): # Fail with a mismatched table. self.input_dict["something_added"] = [i for i in range(self.num_entries)] - table3 = Results.from_dict(self.input_dict) + table3 = Results(self.input_dict) with self.assertRaises(ValueError): table1.extend(table3) # Test starting from an empty table. - table4 = Results([]) + table4 = Results() table4.extend(table1) self.assertEqual(len(table1), len(table4)) for i in range(self.num_entries): @@ -134,7 +139,7 @@ def test_extend(self): def test_add_psi_phi(self): num_to_use = 3 - table = Results(self.trj_list[0:num_to_use]) + table = Results.from_trajectories(self.trj_list[0:num_to_use]) psi_array = np.array([[1.0, 1.1, 1.2, 1.3] for i in range(num_to_use)]) phi_array = np.array([[1.0, 1.0, 0.0, 2.0] for i in range(num_to_use)]) index_valid = np.array( @@ -162,7 +167,7 @@ def test_add_psi_phi(self): def test_update_index_valid(self): num_to_use = 3 - table = Results(self.trj_list[0:num_to_use]) + table = Results.from_trajectories(self.trj_list[0:num_to_use]) psi_array = np.array([[1.0, 1.1, 1.2, 1.3] for i in range(num_to_use)]) phi_array = np.array([[1.0, 1.0, 0.0, 2.0] for i in range(num_to_use)]) table.add_psi_phi_data(psi_array, phi_array) @@ -192,7 +197,7 @@ def test_update_index_valid(self): def test_compute_likelihood_curves(self): num_to_use = 3 - table = Results(self.trj_list[0:num_to_use]) + table = Results.from_trajectories(self.trj_list[0:num_to_use]) psi_array = np.array( [ @@ -237,7 +242,7 @@ def test_compute_likelihood_curves(self): self.assertTrue(np.array_equal(np.isfinite(lh_mat3), expected)) def test_filter_by_index(self): - table = Results(self.trj_list) + table = Results.from_trajectories(self.trj_list) self.assertEqual(len(table), self.num_entries) # Do the filtering and check we have the correct ones. @@ -258,7 +263,7 @@ def test_filter_by_index(self): table.revert_filter() def test_filter_by_index_tracked(self): - table = Results(self.trj_list[0:10], track_filtered=True) + table = Results.from_trajectories(self.trj_list[0:10], track_filtered=True) self.assertEqual(len(table), 10) # Do the filtering. First remove elements 0 and 2. Then remove elements @@ -297,7 +302,7 @@ def test_filter_by_index_tracked(self): self.assertEqual(table["x"][i], value) # Check that we can revert the filtering and add a 'filtered_reason' column. - table = Results(self.trj_list[0:10], track_filtered=True) + table = Results.from_trajectories(self.trj_list[0:10], track_filtered=True) table.filter_by_index([1, 3, 4, 5, 6, 7, 8, 9], label="filter1") table.filter_by_index([1, 2, 3, 4, 7], label="filter2") table.revert_filter(add_column="reason") @@ -310,7 +315,7 @@ def test_filter_by_index_tracked(self): def test_to_from_table_file(self): max_save = 5 - table = Results(self.trj_list[0:max_save], track_filtered=True) + table = Results.from_trajectories(self.trj_list[0:max_save], track_filtered=True) table.table["other"] = [i for i in range(max_save)] self.assertEqual(len(table), max_save) @@ -343,7 +348,7 @@ def test_to_from_table_file(self): self.assertTrue("other" in table.colnames) def test_save_and_load_trajectories(self): - table = Results(self.trj_list) + table = Results.from_trajectories(self.trj_list) # Try outputting the ResultList with tempfile.TemporaryDirectory() as dir_name: diff --git a/tests/test_sigma_g_filter.py b/tests/test_sigma_g_filter.py index a8673aab6..1d8b29511 100644 --- a/tests/test_sigma_g_filter.py +++ b/tests/test_sigma_g_filter.py @@ -143,7 +143,7 @@ def test_apply_clipped_sigma_g_results(self): num_times = 20 num_results = 5 trj_all = [Trajectory() for _ in range(num_results)] - table = Results(trj_all) + table = Results.from_trajectories(trj_all) phi_all = np.full((num_results, num_times), 0.1) psi_all = np.full((num_results, num_times), 1.0) diff --git a/tests/test_stamp_filters.py b/tests/test_stamp_filters.py index 64b35d4fe..dbb8265a6 100644 --- a/tests/test_stamp_filters.py +++ b/tests/test_stamp_filters.py @@ -263,7 +263,7 @@ def test_get_coadds_and_filter_results(self): trj4 = make_trajectory(trj.x + 1, trj.y + 1, trj.vx, trj.vy) # Create the Results. - keep = Results([trj, trj2, trj3, trj4]) + keep = Results.from_trajectories([trj, trj2, trj3, trj4]) self.assertFalse("stamp" in keep.colnames) # Create the stamp parameters we need. @@ -332,7 +332,7 @@ def test_append_all_stamps_results(self): make_trajectory(10, 22, -2.0, -1.0), make_trajectory(8, 7, -2.0, -1.0), ] - keep = Results(trj_list) + keep = Results.from_trajectories(trj_list) self.assertFalse("all_stamps" in keep.colnames) append_all_stamps(keep, ds.stack, 5)