From 1aa7354f57f9aa2a2f28ea94b5e2c566d6a58fb1 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Mon, 18 Nov 2024 19:55:58 -0800 Subject: [PATCH] Add a KBMOD results filter for matching "known objects" (#741) * Add filter for known objects * Modify init with known object filter * Clean up comments and tests * Lint fixes * Refactored to KnownObjsMatcher and added filters * More refactoring and renaming * Separate match vs obs_valid filters and clean up comments. * Update test names and documentaiton * Format * Revert local change * Remove blank line * Address small comments * Fix time filter unit conversion and testing * Make obs_ratio and min_obs function parameters * Remove unneeded obs_match_ratio field --- src/kbmod/filters/__init__.py | 1 + src/kbmod/filters/known_object_filters.py | 438 +++++++++++++++ tests/test_known_object_filters.py | 627 ++++++++++++++++++++++ 3 files changed, 1066 insertions(+) create mode 100644 src/kbmod/filters/known_object_filters.py create mode 100644 tests/test_known_object_filters.py diff --git a/src/kbmod/filters/__init__.py b/src/kbmod/filters/__init__.py index 79d271cd0..f9e607cae 100644 --- a/src/kbmod/filters/__init__.py +++ b/src/kbmod/filters/__init__.py @@ -1,5 +1,6 @@ from . import ( clustering_filters, + known_object_filters, sigma_g_filter, stamp_filters, ) diff --git a/src/kbmod/filters/known_object_filters.py b/src/kbmod/filters/known_object_filters.py new file mode 100644 index 000000000..6c6cdc6ec --- /dev/null +++ b/src/kbmod/filters/known_object_filters.py @@ -0,0 +1,438 @@ +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord, search_around_sky + +import kbmod.search as kb +from kbmod.trajectory_utils import trajectory_predict_skypos +from collections import Counter + +logger = kb.Logging.getLogger(__name__) + + +class KnownObjsMatcher: + """ + A class which ingests an astopy table of object data expected to be found in the dataset + searched by KBMOD (either real objects or inserted synthetic fakes) and provides methods for + matching to the observations in a given set of KBMOD Results. + + It allows for configuration of how the matching is done, including the maximum + separation in arcseconds between a known object and a result to be considered a match, + the maximum time separation in seconds between a known object and the observation + used in a KBMOD result. + + In addition to modifying a KBMOD `Results` table to include columns for matched known objects, + it also provides methods for filtering the results based on the matches. This includes + marking observations that matched to known objects as invalid, and filtering out results that matched to known objects by + either the minimum number of observations that matched to that known object or the proportion + of observations from the catalog for that known object that were matched to a given result. + """ + + def __init__( + self, + table, + obstimes, + matcher_name, + sep_thresh=1.0, + time_thresh_s=600.0, + mjd_col="mjd_mid", + ra_col="RA", + dec_col="DEC", + name_col="Name", + ): + """ + Parameters + ---------- + table : astropy.table.Table + A table containing our catalog of observations of known objects. + obstimes : list(float) + The MJD times of each observation within KBMOD results we want to match to + the known objects. + matcher_name : str + The name of the filter to apply to the results. This both determines + the name of the column of matched observations which may be added to + the `Results` table and how the filtering and matching phases are identified within KBMOD logs. + sep_thresh : float, optional + The maximum separation in arcseconds between a known object and a result + to be considered a match. Default is 1.0. + time_thresh_s : float, optional + The maximum time separation in seconds between a known object and the observation + used in a KBMOD result. Default is 600.0. + mjd_col : str, optional + The name of the catalog column containing the MJD of the known objects. Default is "mjd_mid". + ra_col : str, optional + The name of the catalog column containing the RA of the known objects. Default is "RA". + dec_col : str, optional + The name of the catalog column containing the DEC of the known objects. Default is "DEC". + name_col : str, optional + The name of the catalog column containing the name of the known objects. Default is "Name". + + Raises + ------ + ValueError + If the required columns are not present in the table. + + Returns + ------- + KnownObjsMatcher + A KnownObjsMatcher object. + """ + self.data = table + + # Map our required columns to any specified column names. + self.mjd_col = mjd_col + self.ra_col = ra_col + self.dec_col = dec_col + self.name_col = name_col + + # Check that the required columns are present + user_cols = set([self.mjd_col, self.ra_col, self.dec_col, self.name_col]) + invalid_cols = user_cols - set(self.data.colnames) + if invalid_cols: + raise ValueError(f"{invalid_cols} not found in KnownObjs data.") + + self.obstimes = obstimes + if len(self.obstimes) == 0: + raise ValueError("No obstimes provided") + + self.matcher_name = matcher_name + self.sep_thresh = sep_thresh * u.arcsec + self.time_thresh_s = time_thresh_s + + # Pre-filter down our data to window of temporally relevant observations to speed up matching. + time_thresh_days = self.time_thresh_s / (24 * 3600) # Convert seconds to days + start_mjd = max(0, min(self.obstimes) - time_thresh_days - 1e-6) + end_mjd = max(self.obstimes) + time_thresh_days + 1e-6 + + # Filter out known object observations outside of our time thresholds + self.data = self.data[(self.data[self.mjd_col] >= start_mjd) & (self.data[self.mjd_col] <= end_mjd)] + + def match_min_obs_col(self, min_obs): + """A colummn name for objects that matched results based on the minimum number of observations.""" + return f"recovered_{self.matcher_name}_min_obs_{min_obs}" + + def match_obs_ratio_col(self, obs_ratio): + # A column name for objects that matched results based on the proportion of observations that + # matched to the known observations for that object within the catalog. + return f"recovered_{self.matcher_name}_obs_ratio_{obs_ratio}" + + def __len__(self): + """Returns the number of observations known objects of interest in this matcher's catalog.""" + return len(self.data) + + def get_mjd(self, ko_idx): + """ + Returns the MJD of the known object at a given index. + """ + return self.data[ko_idx][self.mjd_col] + + def get_ra(self, ko_idx): + """ + Returns the RA of the known object at a given index. + """ + return self.data[ko_idx][self.ra_col] + + def get_dec(self, ko_idx): + """ + Returns the DEC of the known object at a given index. + """ + return self.data[ko_idx][self.dec_col] + + def get_name(self, ko_idx): + """ + Returns the name of the known object at a given index. + """ + return self.data[ko_idx][self.name_col] + + def to_skycoords(self): + """ + Returns a SkyCoord representation of the known objects. + """ + return SkyCoord(ra=self.data[self.ra_col], dec=self.data[self.dec_col], unit="deg") + + def match(self, result_data, wcs): + """This function takes a list of results and matches them to known objects. + + This modifies the `Results` table by adding a column with name `self.matcher_name` that provides for each result a dictionary mapping the names of known + objects (as defined by the catalog's `name_col`) to a boolean array indicating which observations + in the result matched to that known object. Note that depending on the matching parameters, a result + can match to multiple known objects from the catalog even at the same observation time. + + So for a dataset with 5 observations a result matching to 2 known objects, A and B, might have an entry in the column `self.matcher_name` like: + ```{ + "A": [True, True, False, False, False], + "B": [False, False, False, True, True], + }``` + + Parameters + ---------- + result_data: `Results` + The set of results to filter. This data gets modified directly by + the filtering. + wcs: `astropy.wcs.WCS` + The common WCS object for the stack of images. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + """ + logger.info(f"Matching known objects to {len(result_data)} results using {self.matcher_name} filter") + all_matches = [] + + # Get the RA and DEC of the known objects and the trajectories of the results for matching + known_objs_ra_dec = self.to_skycoords() + trj_list = result_data.make_trajectory_list() + + for result_idx in range(len(result_data)): + # Generate (RA, Dec) pairs for all of the valid observations for this result trajectory + valid_obstimes = self.obstimes[result_data[result_idx]["obs_valid"]] + trj_skycoords = trajectory_predict_skypos(trj_list[result_idx], wcs, valid_obstimes) + + # Because we're only matching using the subset of obstimes that were valid for this result, we + # can use this to later map back to the original index of all observations in the stack. + trj_idx_to_obs_idx = np.where(result_data[result_idx]["obs_valid"])[0] + + # Now we can compare the SkyCoords of the known objects to the SkyCoords of the result trajectories using search_around_sky + # This will return a list of indices of known objects that are within sep_thresh of a trajectory + # Note that subsequent calls by default will use the same underlying KD-Tree iin coords2.cache. + trjs_idx, known_objs_idx, _, _ = search_around_sky( + trj_skycoords, known_objs_ra_dec, self.sep_thresh + ) + + # Now we can count per-known object how many observations matched within this result + matched_known_objs = {} + for t_idx, ko_idx in zip(trjs_idx, known_objs_idx): + # The observation spatially matched but now check that the time separation is witihin our threshold + if abs(self.get_mjd(ko_idx) - valid_obstimes[t_idx]) * 24 * 3600 <= self.time_thresh_s: + # The name of the object that matched to this observation + obj_name = self.get_name(ko_idx) + if obj_name not in matched_known_objs: + # Create an array of which observations match to this object. + # Note that we need to use the length of all obstimes, not just the presently valid ones + matched_known_objs[obj_name] = np.full(len(self.obstimes), False) + # Map to the original set of all obstimes (valid or invalid) since that's what we + # want for results filtering. + obs_idx = trj_idx_to_obs_idx[t_idx] + matched_known_objs[obj_name][obs_idx] = True + all_matches.append(matched_known_objs) + + # Add matches as a result column + result_data.table[self.matcher_name] = all_matches + + logger.info(f"Matched known objects to {len(result_data)} results using {self.matcher_name} filter") + + return result_data + + def mark_matched_obs_invalid( + self, + result_data, + drop_empty_rows=True, + ): + """ + Mark observations that matched to known objects as invalid, by default dropping + results that no longer have any valid observations. + + Note that a given result can match to multiple objects, and that we expect the + `Results` table to have a column with name corresponding to `self.matcher_name` that + contains which observations were matched to each known object. + + Parameters + ---------- + result_data : `Results` + The results to filter. + drop_empty_rows : bool, optional + If True, drop rows that have no valid observations after filtering. Default is True. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + """ + # Skip filtering if there is nothing to filter. + if len(result_data) == 0 or len(self.obstimes) == 0 or len(self.data) == 0: + return result_data + + if self.matcher_name not in result_data.table.colnames: + raise ValueError( + f"Column {self.matcher_name} not found in results table. Please run match() first." + ) + + matched_known_objs = result_data.table[self.matcher_name] + new_obs_valid = result_data["obs_valid"] + for result_idx in range(len(result_data)): + # A result can match to multiple objects, so we want to logically OR + # against all matching objects with a logical OR using np.any. + # We can then use bitwise NOT and AND to mark any previously valid + # observations that matched to known objects as invalid. + new_obs_valid[result_idx] &= ~np.any( + np.array(list(matched_known_objs[result_idx].values())), axis=0 + ) + + return result_data.update_obs_valid(new_obs_valid, drop_empty_rows=drop_empty_rows) + + def match_on_min_obs( + self, + result_data, + min_obs, + ): + """ + Create a column corresponding to the known objects that were matched to a result + based on the minimum number of observations that matched to that known object. + Note that the ratio is calculated based on the total number of observations + that were within `time_sep_thresh_s` of the `obstimes` we are matching to. Observations + outside of that time range are not considered. + + Note that a given result can match to multiple objects. + + Parameters + ---------- + result_data : `Results` + The results to filter. + min_obs : int + The minimum number of observations within a KBMOD result that must match to a known + object for that result to be considered a match. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + """ + matched_objs = [] + for idx in range(len(result_data)): + matched_objs.append(set([])) + matches = result_data[self.matcher_name][idx] + for name in matches: + if np.count_nonzero(matches[name]) >= min_obs: + matched_objs[-1].add(name) + result_data.table[self.match_min_obs_col(min_obs)] = matched_objs + + return result_data + + def match_on_obs_ratio( + self, + result_data, + obs_ratio, + ): + """ + Create a column corresponding to the known objects that were matched to a result + based on the proportion of observations that matched to that known object within the catalog. + + Note that a given result can match to multiple objects. + + Parameters + ---------- + result_data : `Results` + The results to filter. + obs_ratio : float + The minimum ratio of observations within a KBMOD result that must match to the total + observations within our catalog of known objects for that result to be considered a match. + Must be within the range [0, 1]. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + + Raises + ------ + ValueError + If `obs_ratio` is not within the range [0, 1]. + """ + if obs_ratio < 0 or obs_ratio > 1: + raise ValueError("obs_ratio must be within the range [0, 1].") + + # Create a dictionary of how many observations we have for each known object + # in our catalog + known_obj_cnts = dict(Counter(self.data[self.name_col])) + matched_objs = [] + for idx in range(len(result_data)): + matched_objs.append(set([])) + matches = result_data[self.matcher_name][idx] + for name in matches: + if name not in known_obj_cnts: + raise ValueError(f"Unknown known object {name}") + + curr_obs_ratio = np.count_nonzero(matches[name]) / known_obj_cnts[name] + if curr_obs_ratio <= obs_ratio: + matched_objs[-1].add(name) + + result_data.table[self.match_obs_ratio_col(obs_ratio)] = matched_objs + + return result_data + + def get_recovered_objects(self, result_data, match_col): + """ + Get the set of objects that were recovered or missed in the results. + + For our purposes, a recovered object is one that was matched to a result based on the + matching column of choice in the results table and a missing object are objects in + the catalog that were not matched. Note that not all catalogs may be + constructed in a way where all objects could be spatially present and + recoverable in the results. + + Parameters + ---------- + result_data : `Results` + The results to filter. + match_col : str + The name of the column in the results table that contains the matched objects. + + Returns + ------- + set, set + A tuple of sets where the first set contains the names of objects that were recovered + and the second set contains the names objects that were missed + + Raises + ------ + ValueError + If the `match_col` is not present in the results table + """ + if match_col not in result_data.table.colnames: + raise ValueError(f"Column {match_col} not found in results table.") + + if len(result_data) == 0 or len(self.data) == 0: + return set(), set() + + expected_objects = set(self.data[self.name_col]) + matched_objects = set() + for idx in range(len(result_data)): + matched_objects.update(result_data[match_col][idx]) + recovered_objects = matched_objects.intersection(expected_objects) + missed_objects = expected_objects - recovered_objects + + return recovered_objects, missed_objects + + def filter_matches(self, result_data, match_col): + """ + Filter out the results table to only include results that did not match to any known objects. + + Parameters + ---------- + result_data : `Results` + The results to filter. + match_col : str + The name of the column in the results table that contains the matched objects. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + + Raises + ------ + ValueError + If the `match_col` is not present in the results table. + """ + if match_col not in result_data.table.colnames: + raise ValueError(f"Column {match_col} not found in results table.") + + if len(result_data) == 0: + return result_data + + # Only keep results that did not match to any known objects in our column + idx_to_keep = np.array([len(x) == 0 for x in result_data[match_col]]) + # Use the name of our matching column as the filter name + result_data = result_data.filter_rows(idx_to_keep, match_col) + + return result_data diff --git a/tests/test_known_object_filters.py b/tests/test_known_object_filters.py new file mode 100644 index 000000000..a6d2e7b11 --- /dev/null +++ b/tests/test_known_object_filters.py @@ -0,0 +1,627 @@ +import random +import unittest + +import numpy as np +from astropy.table import Table + +from kbmod.fake_data.fake_data_creator import FakeDataSet, create_fake_times +from kbmod.filters.known_object_filters import KnownObjsMatcher +from kbmod.results import Results +from kbmod.search import * +from kbmod.trajectory_utils import trajectory_predict_skypos +from kbmod.wcs_utils import make_fake_wcs + + +class TestKnownObjMatcher(unittest.TestCase): + def setUp(self): + # Seed for reproducibility of random generated trajectories + self.seed = 500 + np.random.seed(self.seed) + random.seed(self.seed) + + # Set up some default parameters for our matcher + self.matcher_name = "test_matches" + self.sep_thresh = 1.0 + self.time_thresh_s = 600.0 + + # Create a fake dataset with 15 x 10 images and 25 obstimes. + num_images = 25 + self.obstimes = np.array(create_fake_times(num_images)) + ds = FakeDataSet(15, 10, self.obstimes, use_seed=True) + self.wcs = make_fake_wcs(10.0, 15.0, 15, 10) + ds.set_wcs(self.wcs) + + # Randomly generate a Trajectory for each of our 10 results + num_results = 10 + for i in range(num_results): + ds.insert_random_object(self.seed) + self.res = Results.from_trajectories(ds.trajectories, track_filtered=True) + self.assertEqual(len(ds.trajectories), num_results) + + # Generate which observations are valid observations for each result + self.obs_valid = np.full((num_results, num_images), True) + for i in range(num_results): + # For each result include a random set of 5 invalid observations + invalid_obs = np.random.choice(num_images, 5, replace=False) + self.obs_valid[i][invalid_obs] = False + self.res.update_obs_valid(self.obs_valid) + assert set(self.res.table.columns) == set( + ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count", "obs_valid"] + ) + + # Use the results' trajectories to generate a set of known objects that intersect our generated results in various + # ways. + self.known_objs = Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []}) + + # Have the temporal offset for near and far objects be just below and above our time threshold + time_offset_mjd_close = (self.time_thresh_s - 1) / (24.0 * 3600) + time_offset_mjd_far = (self.time_thresh_s + 1) / (24.0 * 3600) + + # Case 1: Near in space and near in time just within the range of our filters to result 1 + self.generate_known_obj_from_result( + self.known_objs, + 1, # Base off result 1 + self.obstimes, # Use all possible obstimes + "spatial_close_time_close_1", + spatial_offset=0.00001, + time_offset=time_offset_mjd_close, + ) + + # Case 2 near in space to result 3, but farther in time. + self.generate_known_obj_from_result( + self.known_objs, + 3, # Base off result 3 + self.obstimes, # Use all possible obstimes + "spatial_close_time_far_3", + spatial_offset=0.0001, + time_offset=time_offset_mjd_far, + ) + + # Case 3: A similar trajectory to result 5, but farther in space with similar timestamps. + self.generate_known_obj_from_result( + self.known_objs, + 5, # Base off result 5 + self.obstimes, # Use all possible obstimes + "spatial_far_time_close_5", + spatial_offset=5, + time_offset=time_offset_mjd_close, + ) + + # Case 4: A similar trajectory to result 7, but far off spatially and temporally + self.generate_known_obj_from_result( + self.known_objs, + 7, # Base off result 7 + self.obstimes, # Use all possible obstimes + "spatial_far_time_far_7", + spatial_offset=5, + time_offset=time_offset_mjd_far, + ) + + # Case 5: a trajectory matching result 8 but with only a few observations. + self.generate_known_obj_from_result( + self.known_objs, + 8, # Base off result 8 + self.obstimes[::10], # Samples down to every 10th observation + "sparse_8", + spatial_offset=0.0001, + time_offset=time_offset_mjd_close, + ) + + def test_known_objs_matcher_init( + self, + ): # Test that a table with no columns specified raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table(), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no Name column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"RA": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no RA column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"Name": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no DEC column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"Name": [], "RA": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no mjd_mid column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"Name": [], "RA": [], "DEC": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with all columns specified does not raise an error + correct = KnownObjsMatcher( + Table({"Name": [], "RA": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + self.assertEqual(0, len(correct)) + + # Test a table where we override the names for each column + self.assertEqual( + 0, + len( + KnownObjsMatcher( + Table({"my_Name": [], "my_RA": [], "my_DEC": [], "my_mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + mjd_col="my_mjd_mid", + ra_col="my_RA", + dec_col="my_DEC", + name_col="my_Name", + ) + ), + ) + + def generate_known_obj_from_result( + self, + known_obj_table, + res_idx, + obstimes, + name, + spatial_offset=0.0001, + time_offset=0.00025, + ): + """Helper function to generate a known object based on an existing result trajectory""" + trj_skycoords = trajectory_predict_skypos( + self.res.make_trajectory_list()[res_idx], + self.wcs, + obstimes, + ) + for i in range(len(obstimes)): + known_obj_table.add_row( + { + "Name": name, + "RA": trj_skycoords[i].ra.degree + spatial_offset, + "DEC": trj_skycoords[i].dec.degree + spatial_offset, + "mjd_mid": obstimes[i] + time_offset, + } + ) + + def test_known_objs_match_empty(self): + # Here we test the filter across various empty parameters + + # Test that the filter is not applied when no known objects were provided + empty_objs = KnownObjsMatcher( + Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + self.res = empty_objs.match( + self.res, + self.wcs, + ) + # Though there were no known objects, check that the results table still has rows + self.assertEqual(10, len(self.res)) + # We should still apply the matching column to the results table even if empty + matches = self.res[empty_objs.matcher_name] + self.assertEqual(0, sum([len(m.keys()) for m in matches])) + + # Test that we can apply the filter even when there are known results + self.res = empty_objs.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(10, len(self.res)) + + # Test that the filter is not applied when there were no results. + empty_res = Results() + empty_res = empty_objs.match( + empty_res, + self.wcs, + ) + matches = empty_res[empty_objs.matcher_name] + self.assertEqual(0, sum([len(m.keys()) for m in matches])) + + empty_res = empty_objs.mark_matched_obs_invalid(empty_res, drop_empty_rows=True) + self.assertEqual(0, len(empty_res)) + + def test_match(self): + # We expect to find only the objects close in time and space to our results, + # including one object matching closely to a result across all observations + # and also a sparsely represented object with only a few observations. + expected_matches = set(["spatial_close_time_close_1", "sparse_8"]) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Generate matches for the results according to the known objects + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[self.matcher_name] + # Assert the expected result + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Check that the close known object we inserted near result 1 is dropped + # But the sparsely observed known object will not get filtered out. + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(9, len(self.res)) + + # Check that the close known object we inserted near result 1 is present + self.assertEqual(len(matches[1]), 1) + self.assertTrue("spatial_close_time_close_1" in matches[1]) + + self.assertEqual(len(matches[8]), 1) + self.assertTrue("sparse_8" in matches[8]) + + # Check that no results other than results 1 and 8 have a match + for i in range(len(self.res)): + if i != 1 and i != 8: + self.assertEqual(0, len(matches[i])) + + def test_match_excessive_spatial_filtering(self): + # Here we only filter for exact spatial matches and should return no results + self.sep_thresh = 0.0 + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[matcher.matcher_name] + self.assertEqual(0, sum([len(m.keys()) for m in matches])) + + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(10, len(self.res)) + + def test_match_spatial_filtering(self): + # Here we use a filter that only matches spatially with an unreasonably generous time filter + self.time_thresh_s += 2 + # Our expected matches now include all objects that are close in space to our results regardless + # of the time offset we generated. + expected_matches = set(["spatial_close_time_close_1", "spatial_close_time_far_3", "sparse_8"]) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Performing matching + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[matcher.matcher_name] + + # Confirm that the expected matches are present + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Check that the close known objects we inserted are removed by valid obs filtering + # while the sparse known object does not fully filter out that result. + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(8, len(self.res)) + + # Check that the close known object we inserted near result 1 is present + self.assertEqual(1, len(matches[1])) + self.assertTrue("spatial_close_time_close_1" in matches[1]) + self.assertEqual( + np.count_nonzero(self.obs_valid[1]), + np.count_nonzero(matches[1]["spatial_close_time_close_1"]), + ) + + # Check that the close known object we inserted near result 3 is present + self.assertEqual(1, len(matches[3])) + self.assertTrue("spatial_close_time_far_3" in matches[3]) + self.assertEqual( + np.count_nonzero(self.obs_valid[3]), + np.count_nonzero(matches[3]["spatial_close_time_far_3"]), + ) + + # Check that the sparse known object we inserted near result 8 is present + self.assertEqual(1, len(matches[8])) + self.assertTrue("sparse_8" in matches[8]) + self.assertGreaterEqual( + len(self.known_objs[self.known_objs["Name"] == "sparse_8"]), + np.count_nonzero(matches[8]["sparse_8"]), + ) + + # Check that no results other than results 1 and 3 are full matches + # Since these are based off of random trajectories we can't guarantee there + # won't some overlapping observations. + for i in range(len(self.res)): + if i not in [1, 3]: + for obj_name in matches[i]: + self.assertGreater( + np.count_nonzero(self.obs_valid[i]), + np.count_nonzero(matches[i][obj_name]), + ) + + def test_match_temporal_filtering(self): + # Here we use a filter that only matches temporally with an unreasonably generous spatial filter + self.sep_thresh = 100000 + expected_matches = set(["spatial_close_time_close_1", "spatial_far_time_close_5", "sparse_8"]) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Generate matches + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[matcher.matcher_name] + + # Confirm that the expected matches are present + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Because we have objects that match to each observation temporally, + # a generous spatial filter will filter out all valid observations. + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(0, len(self.res)) + + for i in range(len(matches)): + self.assertEqual(expected_matches, set(matches[i].keys())) + # Check that all observations were matched to the known objects + for obj_name in matches[i]: + if obj_name == "sparse_8": + # The sparse object only has a few observations to match + self.assertGreaterEqual( + len(self.known_objs[self.known_objs["Name"] == "sparse_8"]), + np.count_nonzero(matches[i]["sparse_8"]), + ) + else: + # The other objects have a full set of observations to match + self.assertEqual( + np.count_nonzero(self.obs_valid[i]), + np.count_nonzero(matches[i][obj_name]), + ) + + def test_match_all(self): + # Here we use generous temporal and spatial filters to recover all objects + self.sep_thresh = 100000 + self.time_thresh_s = 1000000 + expected_matches = set( + [ + "spatial_close_time_close_1", + "spatial_close_time_far_3", + "spatial_far_time_close_5", + "spatial_far_time_far_7", + "sparse_8", + ] + ) + # Perform the matching + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + self.res = matcher.match( + self.res, + self.wcs, + ) + + # Here we expect to recover all of our known objects. + matches = self.res[matcher.matcher_name] + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Each result should have matched to every object + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(0, len(self.res)) + + # Check that every result matches to all of expected known objects + for i in range(len(matches)): + self.assertEqual(expected_matches, set(matches[i].keys())) + # Check that all observations were matched to the known objects since + # ven the most sparse object should match to every observation with + # our time filter. + for obj_name in matches[i]: + self.assertEqual( + np.count_nonzero(self.obs_valid[i]), + np.count_nonzero(matches[i][obj_name]), + ) + + def test_match_obs_ratio_invalid(self): + # Here we test that we raise an error for observation ratios outside of the valid range + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + ) + self.res = matcher.match(self.res, self.wcs) + + # Test some inavlid ratios outside of the range [0, 1] + with self.assertRaises(ValueError): + matcher.match_on_obs_ratio(self.res, 1.1) + with self.assertRaises(ValueError): + matcher.match_on_obs_ratio(self.res, -0.1) + + def test_match_obs_ratio(self): + # Here we test considering a known object recovered based on the ratio of observations + # in the catalog that were temporally within + min_obs_ratios = [ + 0.0, + 1.0, + ] + # The expected matching objects for each min_obs_ratio parameter chosen. + expected_matches = [ + set([]), + set(["spatial_close_time_close_1", "sparse_8"]), + ] + orig_res = self.res.table.copy() + for obs_ratio, expected in zip(min_obs_ratios, expected_matches): + self.res = Results(data=orig_res.copy()) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + matcher_name=self.matcher_name, + sep_thresh=self.sep_thresh, + time_thresh_s=self.time_thresh_s, + ) + + # Perform the intial matching + self.res = matcher.match( + self.res, + self.wcs, + ) + + # Validate that we did not filter any results by obstimes + assert self.matcher_name in self.res.table.columns + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=False) + self.assertEqual(10, len(self.res)) + + # Generate the column of which objects were "recovered" + matcher.match_on_obs_ratio(self.res, obs_ratio) + match_col = f"recovered_test_matches_obs_ratio_{obs_ratio}" + assert match_col in self.res.table.columns + assert match_col == matcher.match_obs_ratio_col(obs_ratio) + + # Verify that we recovered the expected matches + recovered, missed = matcher.get_recovered_objects( + self.res, matcher.match_obs_ratio_col(obs_ratio) + ) + self.assertEqual(expected, recovered) + # The missed object are all other known objects in our catalog - the expected objects + expected_missed = set(self.known_objs["Name"]) - expected + self.assertEqual(expected_missed, missed) + + # Verify that we filter out our expected results + matcher.filter_matches(self.res, match_col) + self.assertEqual(10 - len(expected), len(self.res)) + + def test_match_min_obs(self): + # Here we test considering a known object recovered based on the ratio of observations + # in the catalog that were temporally within + min_obs_settings = [ + 100, # No objects should be recovered since our catalog objects have fewer observations + 1, + 5, # The sparse object will not have enough observations to be recovered. + ] + expected_matches = [ + set([]), + set(["spatial_close_time_close_1", "sparse_8"]), + set(["spatial_close_time_close_1"]), + ] + orig_res = self.res.table.copy() + for min_obs, expected in zip(min_obs_settings, expected_matches): + self.res = Results(data=orig_res.copy()) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + matcher_name=self.matcher_name, + sep_thresh=self.sep_thresh, + time_thresh_s=self.time_thresh_s, + ) + # Perform the initial matching + matcher.match( + self.res, + self.wcs, + ) + # Validate that we did not filter any results + assert self.matcher_name in self.res.table.columns + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=False) + self.assertEqual(10, len(self.res)) + + # Generate the recovered object column for a minimum number of observations + matcher.match_on_min_obs(self.res, min_obs) + match_col = f"recovered_test_matches_min_obs_{min_obs}" + assert match_col in self.res.table.columns + assert match_col == matcher.match_min_obs_col(min_obs) + + # Verify that we recovered the expected matches + recovered, missed = matcher.get_recovered_objects(self.res, matcher.match_min_obs_col(min_obs)) + self.assertEqual(expected, recovered) + # The missed object are all other known objects in our catalog - the expected objects + expected_missed = set(self.known_objs["Name"]) - expected + self.assertEqual(expected_missed, missed) + + # Verify that we filter out our expected results + matcher.filter_matches(self.res, match_col) + self.assertEqual(10 - len(expected), len(self.res)) + + def test_empty_filter_matches(self): + # Test that we can filter matches with an empty Results table + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + ) + # Adds a matching column to our empty table. + empty_res = matcher.match_on_obs_ratio(Results(), 0.5) + with self.assertRaises(ValueError): + # Test an inavlid matching column + matcher.filter_matches(empty_res, "empty") + + empty_res = matcher.filter_matches(empty_res, matcher.match_obs_ratio_col(0.5)) + self.assertEqual(0, len(empty_res)) + + def test_empty_get_recovered_objects(self): + # Test that we can get recovered objects with an empty Results table + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + ) + # Adds a matching column to our empty table. + empty_res = matcher.match_on_min_obs(Results(), 5) + with self.assertRaises(ValueError): + # Test an inavlid matching column + matcher.get_recovered_objects(empty_res, "empty") + + recovered, missed = matcher.get_recovered_objects(empty_res, matcher.match_min_obs_col(5)) + self.assertEqual(0, len(recovered)) + self.assertEqual(0, len(missed))