From fd375567c66597b3ddcf6d5c6c95e874dc9b4db1 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Tue, 12 Nov 2024 11:43:26 -0800 Subject: [PATCH] Lint fixes --- src/kbmod/filters/known_object_filters.py | 45 ++--- tests/test_known_object_filters.py | 210 ++++++++++++---------- 2 files changed, 141 insertions(+), 114 deletions(-) diff --git a/src/kbmod/filters/known_object_filters.py b/src/kbmod/filters/known_object_filters.py index 96c84a0ad..ceef4f1c5 100644 --- a/src/kbmod/filters/known_object_filters.py +++ b/src/kbmod/filters/known_object_filters.py @@ -37,7 +37,7 @@ def __init__(self, table, mjd_col="mjd_mid", ra_col="RA", dec_col="DEC", name_co ------ ValueError If the required columns are not present in the table. - + Returns ------- KnownObjs @@ -64,25 +64,25 @@ def get_mjd(self, ko_idx): Returns the MJD of the known object at a given index. """ return self.data[ko_idx][self.mjd_col] - + def get_ra(self, ko_idx): """ Returns the RA of the known object at a given index. """ return self.data[ko_idx][self.ra_col] - + def get_dec(self, ko_idx): """ Returns the DEC of the known object at a given index. """ return self.data[ko_idx][self.dec_col] - + def get_name(self, ko_idx): """ Returns the name of the known object at a given index. """ return self.data[ko_idx][self.name_col] - + def to_skycoords(self): """ Returns a SkyCoord representation of the known objects. @@ -106,11 +106,10 @@ def filter_rows_by_time(self, start_mjd, end_mjd): A new KnownObjs object filtered down withing the given time range. """ new_data = self.data[(self.data[self.mjd_col] >= start_mjd) & (self.data[self.mjd_col] <= end_mjd)] - return KnownObjs(new_data, - mjd_col=self.mjd_col, - ra_col=self.ra_col, - dec_col=self.dec_col, - name_col=self.name_col) + return KnownObjs( + new_data, mjd_col=self.mjd_col, ra_col=self.ra_col, dec_col=self.dec_col, name_col=self.name_col + ) + def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_params, remove_match_obs=True): """This function takes a list of results and matches them to known objects. @@ -131,7 +130,7 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param filter_type, filter_params, and filter_obs. remove_match_obs : bool If True, remove observations that match to known objects from the results - + Returns ------- list(dict) @@ -155,18 +154,22 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param if len(result_data) == 0 or len(obstimes) == 0: logger.info(f"{filter_type} : skipping, no results.") return all_matches - + # Skip matching known objects if there are none if len(known_objs) == 0: logger.info("Known Object Filtering : skipping, no objects to match agains.") return all_matches - known_obj_thresh = filter_params['known_obj_thresh'] + known_obj_thresh = filter_params["known_obj_thresh"] - sep_thresh = filter_params['known_obj_sep_thresh'] if 'known_obj_sep_thresh' in filter_params else 1.0 + sep_thresh = filter_params["known_obj_sep_thresh"] if "known_obj_sep_thresh" in filter_params else 1.0 sep_thresh = sep_thresh * u.arcsec - time_sep_thresh_s = filter_params['known_obj_sep_time_thresh_s'] if "known_obj_sep_time_thresh_s" in filter_params else 1200.0 + time_sep_thresh_s = ( + filter_params["known_obj_sep_time_thresh_s"] + if "known_obj_sep_time_thresh_s" in filter_params + else 1200.0 + ) # First filter down our cached data to a range of possible obstimes to speed up the search start_mjd = max(0, min(obstimes) - 2 - time_sep_thresh_s) @@ -182,7 +185,7 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param # Becauase we're only using the valid obstimes, we can user this below to map back to # the original observation index. trj_idx_to_obs_idx = np.where(result_data[result_idx]["obs_valid"])[0] - + # Now we can compare the SkyCoords of the known objects to the SkyCoords of the trajectories using search_around_sky # This will return a list of indices of known objects that are within sep_thresh of a trajectory trjs_idx, known_objs_idx, _, _ = search_around_sky(trj_skycoords, known_objs_ra_dec, sep_thresh) @@ -191,7 +194,7 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param matched_known_objs = {} for t_idx, ko_idx in zip(trjs_idx, known_objs_idx): # Check the time separation is witihin our threshold - if abs(known_objs.get_mjd(ko_idx) - valid_obstimes[t_idx])*3600 <= time_sep_thresh_s: + if abs(known_objs.get_mjd(ko_idx) - valid_obstimes[t_idx]) * 3600 <= time_sep_thresh_s: # The name of the object that matched to this observation obj_name = known_objs.get_name(ko_idx) # Create an array of dimension trj_skycoords where each value is false @@ -202,7 +205,9 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param # want for results filtering. obs_idx = trj_idx_to_obs_idx[t_idx] if obs_idx >= len(matched_known_objs[obj_name]): - raise ValueError(f"obs_idx: {obs_idx}, \n t_idx: {t_idx}, \n trj_idx_to_obs_idx: {trj_idx_to_obs_idx}, \nvalid_obstimes: {valid_obstimes}\n,trjs_idx: {trjs_idx},\n known_objs_idx: {known_objs_idx}") + raise ValueError( + f"obs_idx: {obs_idx}, \n t_idx: {t_idx}, \n trj_idx_to_obs_idx: {trj_idx_to_obs_idx}, \nvalid_obstimes: {valid_obstimes}\n,trjs_idx: {trjs_idx},\n known_objs_idx: {known_objs_idx}" + ) matched_known_objs[obj_name][obs_idx] = True all_matches.append(matched_known_objs) @@ -211,8 +216,8 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param # Add matches as a result column result_data.table[filter_params["filter_type"]] = all_matches - + if remove_match_obs: result_data.update_obs_valid(new_obs_valid) - + return all_matches diff --git a/tests/test_known_object_filters.py b/tests/test_known_object_filters.py index 284fb2a57..2f7514678 100644 --- a/tests/test_known_object_filters.py +++ b/tests/test_known_object_filters.py @@ -22,9 +22,10 @@ from kbmod.search import * from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal + class TestKnownObjFilters(unittest.TestCase): def setUp(self): - self.seed = 500 # Seed for reproducibility + self.seed = 500 # Seed for reproducibility np.random.seed(self.seed) random.seed(self.seed) @@ -56,62 +57,68 @@ def setUp(self): invalid_obs = np.random.choice(num_images, 5, replace=False) self.obs_valid[i][invalid_obs] = False self.res.update_obs_valid(self.obs_valid) - assert set(self.res.table.columns) == set(['x','y','vx','vy','likelihood','flux','obs_count', 'obs_valid']) + assert set(self.res.table.columns) == set( + ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count", "obs_valid"] + ) # Use the results' trajectories to generate a set of known objects that we can use to test the filter # Now we want to create a data set of known objects that interset our generated results in various - # ways. + # ways. known_obj_table = Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []}) # Case 1: Near in space (<1") and near in time (>1 s) and near in time to result 1 self.generate_known_obj_from_result( known_obj_table, - 1, # Base off result 1 - self.obstimes, # Use all possible obstimes + 1, # Base off result 1 + self.obstimes, # Use all possible obstimes "spatial_close_time_close_1", spatial_offset=0.00001, - time_offset=0.00025) + time_offset=0.00025, + ) # Case 2 near in space to result 3, but farther in time. self.generate_known_obj_from_result( known_obj_table, - 3, # Base off result 3 - self.obstimes, # Use all possible obstimes + 3, # Base off result 3 + self.obstimes, # Use all possible obstimes "spatial_close_time_far_3", spatial_offset=0.0001, - time_offset=0.3) + time_offset=0.3, + ) # Case 3: A similar trajectory to result 5, but farther in space with similar timestamps. self.generate_known_obj_from_result( known_obj_table, - 5, # Base off result 5 - self.obstimes, # Use all possible obstimes + 5, # Base off result 5 + self.obstimes, # Use all possible obstimes "spatial_far_time_close_5", spatial_offset=5, - time_offset=0.00025) + time_offset=0.00025, + ) # Case 4: A similar trajectory to result 7, but far off spatially and temporally self.generate_known_obj_from_result( known_obj_table, - 7, # Base off result 7 - self.obstimes, # Use all possible obstimes + 7, # Base off result 7 + self.obstimes, # Use all possible obstimes "spatial_far_time_far_7", spatial_offset=5, - time_offset=0.3) - + time_offset=0.3, + ) # Case 5: a trajectory matching result 8 but with only a few observations. self.generate_known_obj_from_result( known_obj_table, - 8, # Base off result 8 - self.obstimes[::10], # Samples down to every 5th observation + 8, # Base off result 8 + self.obstimes[::10], # Samples down to every 5th observation "sparse_8", spatial_offset=0.0001, - time_offset=0.00025) - + time_offset=0.00025, + ) + self.known_objs = KnownObjs(known_obj_table) - def test_known_obj_init(self): # Test a table with no columns specified raises a ValueError + def test_known_obj_init(self): # Test a table with no columns specified raises a ValueError with self.assertRaises(ValueError): KnownObjs(Table()) @@ -122,7 +129,7 @@ def test_known_obj_init(self): # Test a table with no columns specified raises a # Test a table with no RA column raises a ValueError with self.assertRaises(ValueError): KnownObjs(Table({"Name": [], "DEC": [], "mjd_mid": []})) - + # Test a table with no DEC column raises a ValueError with self.assertRaises(ValueError): KnownObjs(Table({"Name": [], "RA": [], "mjd_mid": []})) @@ -135,54 +142,66 @@ def test_known_obj_init(self): # Test a table with no columns specified raises a self.assertEqual(0, len(KnownObjs(Table({"Name": [], "RA": [], "DEC": [], "mjd_mid": []})))) # Test a table where we override the names for each column - self.assertEqual(0, len(KnownObjs( - Table({"my_Name": [], "my_RA": [], "my_DEC": [], "my_mjd_mid": []}), - mjd_col="my_mjd_mid", - ra_col="my_RA", - dec_col="my_DEC", - name_col="my_Name", - ))) + self.assertEqual( + 0, + len( + KnownObjs( + Table({"my_Name": [], "my_RA": [], "my_DEC": [], "my_mjd_mid": []}), + mjd_col="my_mjd_mid", + ra_col="my_RA", + dec_col="my_DEC", + name_col="my_Name", + ) + ), + ) def generate_known_obj_from_result( - self, - known_obj_table, - res_idx, - obstimes, - name, - spatial_offset=0.0001, - time_offset=0.00025, - ): - """ Helper function to generate a known object based on existing result trajectory """ + self, + known_obj_table, + res_idx, + obstimes, + name, + spatial_offset=0.0001, + time_offset=0.00025, + ): + """Helper function to generate a known object based on existing result trajectory""" trj_skycoords = trajectory_predict_skypos( self.res.make_trajectory_list()[res_idx], self.wcs, - obstimes, + obstimes, ) for i in range(len(obstimes)): - known_obj_table.add_row({ - "Name": name, - "RA": trj_skycoords[i].ra.degree + spatial_offset, - "DEC": trj_skycoords[i].dec.degree + spatial_offset, - "mjd_mid": obstimes[i] + time_offset, - }) + known_obj_table.add_row( + { + "Name": name, + "RA": trj_skycoords[i].ra.degree + spatial_offset, + "DEC": trj_skycoords[i].dec.degree + spatial_offset, + "mjd_mid": obstimes[i] + time_offset, + } + ) def test_apply_known_obj_empty(self): # Here we test that the filter across various empty parameters # Test that the filter is not applied when no known objects were provided - empty_objs = KnownObjs( - Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []})) - matches = apply_known_obj_filters(self.res, empty_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) + empty_objs = KnownObjs(Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []})) + matches = apply_known_obj_filters( + self.res, empty_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) self.assertEqual(0, sum([len(m.keys()) for m in matches])) self.assertEqual(10, len(self.res)) # Test that the filter is not applied when there were no results. - matches = apply_known_obj_filters(Results(), self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) + matches = apply_known_obj_filters( + Results(), self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) self.assertEqual(0, sum([len(m.keys()) for m in matches])) self.assertEqual(10, len(self.res)) # Test that the filter is not applied when there were no obstimes - matches = apply_known_obj_filters(self.res, self.known_objs, obstimes=[], wcs=self.wcs, filter_params=self.filter_params) + matches = apply_known_obj_filters( + self.res, self.known_objs, obstimes=[], wcs=self.wcs, filter_params=self.filter_params + ) self.assertEqual(0, sum([len(m.keys()) for m in matches])) self.assertEqual(10, len(self.res)) @@ -190,8 +209,10 @@ def test_apply_known_obj_filtering(self): expected_matches = set(["spatial_close_time_close_1", "sparse_8"]) # Call the function under test - matches = apply_known_obj_filters(self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) - + matches = apply_known_obj_filters( + self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) + # Assert the expected result obs_matches = set() for m in matches: @@ -216,21 +237,20 @@ def test_apply_known_obj_excessive_spatial_filtering(self): # Here we only filter for exact spatial matches and should return no results self.filter_params["known_obj_sep_thresh"] = 0.0 matches = apply_known_obj_filters( - self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) + self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) self.assertEqual(0, sum([len(m.keys()) for m in matches])) self.assertEqual(10, len(self.res)) def test_apply_known_obj_spatial_filtering(self): # Here we use a filter that only matches spatially with an unreasonably generous time filter self.filter_params["known_obj_sep_time_thresh_s"] = 1000000 - expected_matches = set([ - "spatial_close_time_close_1", - "spatial_close_time_far_3", - "sparse_8"]) + expected_matches = set(["spatial_close_time_close_1", "spatial_close_time_far_3", "sparse_8"]) matches = apply_known_obj_filters( - self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) - + self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) + obs_matches = set() for m in matches: obs_matches.update(m.keys()) @@ -241,20 +261,24 @@ def test_apply_known_obj_spatial_filtering(self): # Check that the close known object we inserted near result 1 is present self.assertEqual(1, len(matches[1])) self.assertTrue("spatial_close_time_close_1" in matches[1]) - self.assertEqual(np.count_nonzero(self.obs_valid[1]), - np.count_nonzero(matches[1]["spatial_close_time_close_1"])) + self.assertEqual( + np.count_nonzero(self.obs_valid[1]), np.count_nonzero(matches[1]["spatial_close_time_close_1"]) + ) # Check that the close known object we inserted near result 3 is present self.assertEqual(1, len(matches[3])) self.assertTrue("spatial_close_time_far_3" in matches[3]) - self.assertEqual(np.count_nonzero(self.obs_valid[3]), - np.count_nonzero(matches[3]["spatial_close_time_far_3"])) - + self.assertEqual( + np.count_nonzero(self.obs_valid[3]), np.count_nonzero(matches[3]["spatial_close_time_far_3"]) + ) + # Check that the sparse known object we inserted near result 8 is present self.assertEqual(1, len(matches[8])) self.assertTrue("sparse_8" in matches[8]) - self.assertGreaterEqual(len(self.known_objs.data[self.known_objs.data["Name"] == "sparse_8"]), - np.count_nonzero(matches[8]["sparse_8"])) + self.assertGreaterEqual( + len(self.known_objs.data[self.known_objs.data["Name"] == "sparse_8"]), + np.count_nonzero(matches[8]["sparse_8"]), + ) # Check that no results other than results 1 and 3 are full matches # Since these are based off of random trajectories we can't guarantee there @@ -263,20 +287,18 @@ def test_apply_known_obj_spatial_filtering(self): if i not in [1, 3]: for obj_name in matches[i]: self.assertGreater( - np.count_nonzero(self.obs_valid[i]), - np.count_nonzero(matches[i][obj_name])) + np.count_nonzero(self.obs_valid[i]), np.count_nonzero(matches[i][obj_name]) + ) def test_apply_known_obj_temporal_filtering(self): # Here we use a filter that only matches temporally with an unreasonably generous spatial filter self.filter_params["known_obj_sep_thresh"] = 100000 - expected_matches = set([ - "spatial_close_time_close_1", - "spatial_far_time_close_5", - "sparse_8"]) - + expected_matches = set(["spatial_close_time_close_1", "spatial_far_time_close_5", "sparse_8"]) + matches = apply_known_obj_filters( - self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) - + self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) + obs_matches = set() for m in matches: obs_matches.update(m.keys()) @@ -291,27 +313,31 @@ def test_apply_known_obj_temporal_filtering(self): if obj_name == "sparse_8": self.assertGreaterEqual( len(self.known_objs.data[self.known_objs.data["Name"] == "sparse_8"]), - np.count_nonzero(matches[i]["sparse_8"])) + np.count_nonzero(matches[i]["sparse_8"]), + ) else: self.assertEqual( - np.count_nonzero(self.obs_valid[i]), - np.count_nonzero(matches[i][obj_name]) - ) + np.count_nonzero(self.obs_valid[i]), np.count_nonzero(matches[i][obj_name]) + ) def test_apply_known_obj_time_no_filtering(self): # Here we use generous temporal and spatial filters to uncover all objects self.filter_params["known_obj_sep_thresh"] = 100000 self.filter_params["known_obj_sep_time_thresh_s"] = 1000000 - expected_matches = set([ - "spatial_close_time_close_1", - "spatial_close_time_far_3", - "spatial_far_time_close_5", - "spatial_far_time_far_7", - "sparse_8"]) - + expected_matches = set( + [ + "spatial_close_time_close_1", + "spatial_close_time_far_3", + "spatial_far_time_close_5", + "spatial_far_time_far_7", + "sparse_8", + ] + ) + matches = apply_known_obj_filters( - self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params) - + self.res, self.known_objs, obstimes=self.obstimes, wcs=self.wcs, filter_params=self.filter_params + ) + obs_matches = set() for m in matches: obs_matches.update(m.keys()) @@ -319,14 +345,10 @@ def test_apply_known_obj_time_no_filtering(self): # Each result should have matched to every object self.assertEqual(0, len(self.res)) - + # Check that every result matches to all of expected known objects for i in range(len(matches)): self.assertEqual(expected_matches, set(matches[i].keys())) # Check that all observations were matched to the known objects for obj_name in matches[i]: - self.assertEqual( - np.count_nonzero(self.obs_valid[i]), - np.count_nonzero(matches[i][obj_name]) - ) - + self.assertEqual(np.count_nonzero(self.obs_valid[i]), np.count_nonzero(matches[i][obj_name]))