Skip to content

Commit

Permalink
Lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonbb committed Nov 12, 2024
1 parent 77aeeb3 commit fd37556
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 114 deletions.
45 changes: 25 additions & 20 deletions src/kbmod/filters/known_object_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, table, mjd_col="mjd_mid", ra_col="RA", dec_col="DEC", name_co
------
ValueError
If the required columns are not present in the table.
Returns
-------
KnownObjs
Expand All @@ -64,25 +64,25 @@ def get_mjd(self, ko_idx):
Returns the MJD of the known object at a given index.
"""
return self.data[ko_idx][self.mjd_col]

def get_ra(self, ko_idx):
"""
Returns the RA of the known object at a given index.
"""
return self.data[ko_idx][self.ra_col]

def get_dec(self, ko_idx):
"""
Returns the DEC of the known object at a given index.
"""
return self.data[ko_idx][self.dec_col]

def get_name(self, ko_idx):
"""
Returns the name of the known object at a given index.
"""
return self.data[ko_idx][self.name_col]

def to_skycoords(self):
"""
Returns a SkyCoord representation of the known objects.
Expand All @@ -106,11 +106,10 @@ def filter_rows_by_time(self, start_mjd, end_mjd):
A new KnownObjs object filtered down withing the given time range.
"""
new_data = self.data[(self.data[self.mjd_col] >= start_mjd) & (self.data[self.mjd_col] <= end_mjd)]
return KnownObjs(new_data,
mjd_col=self.mjd_col,
ra_col=self.ra_col,
dec_col=self.dec_col,
name_col=self.name_col)
return KnownObjs(
new_data, mjd_col=self.mjd_col, ra_col=self.ra_col, dec_col=self.dec_col, name_col=self.name_col
)


def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_params, remove_match_obs=True):
"""This function takes a list of results and matches them to known objects.
Expand All @@ -131,7 +130,7 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param
filter_type, filter_params, and filter_obs.
remove_match_obs : bool
If True, remove observations that match to known objects from the results
Returns
-------
list(dict)
Expand All @@ -155,18 +154,22 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param
if len(result_data) == 0 or len(obstimes) == 0:
logger.info(f"{filter_type} : skipping, no results.")
return all_matches

# Skip matching known objects if there are none
if len(known_objs) == 0:
logger.info("Known Object Filtering : skipping, no objects to match agains.")
return all_matches

known_obj_thresh = filter_params['known_obj_thresh']
known_obj_thresh = filter_params["known_obj_thresh"]

sep_thresh = filter_params['known_obj_sep_thresh'] if 'known_obj_sep_thresh' in filter_params else 1.0
sep_thresh = filter_params["known_obj_sep_thresh"] if "known_obj_sep_thresh" in filter_params else 1.0
sep_thresh = sep_thresh * u.arcsec

time_sep_thresh_s = filter_params['known_obj_sep_time_thresh_s'] if "known_obj_sep_time_thresh_s" in filter_params else 1200.0
time_sep_thresh_s = (
filter_params["known_obj_sep_time_thresh_s"]
if "known_obj_sep_time_thresh_s" in filter_params
else 1200.0
)

# First filter down our cached data to a range of possible obstimes to speed up the search
start_mjd = max(0, min(obstimes) - 2 - time_sep_thresh_s)
Expand All @@ -182,7 +185,7 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param
# Becauase we're only using the valid obstimes, we can user this below to map back to
# the original observation index.
trj_idx_to_obs_idx = np.where(result_data[result_idx]["obs_valid"])[0]

# Now we can compare the SkyCoords of the known objects to the SkyCoords of the trajectories using search_around_sky
# This will return a list of indices of known objects that are within sep_thresh of a trajectory
trjs_idx, known_objs_idx, _, _ = search_around_sky(trj_skycoords, known_objs_ra_dec, sep_thresh)
Expand All @@ -191,7 +194,7 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param
matched_known_objs = {}
for t_idx, ko_idx in zip(trjs_idx, known_objs_idx):
# Check the time separation is witihin our threshold
if abs(known_objs.get_mjd(ko_idx) - valid_obstimes[t_idx])*3600 <= time_sep_thresh_s:
if abs(known_objs.get_mjd(ko_idx) - valid_obstimes[t_idx]) * 3600 <= time_sep_thresh_s:
# The name of the object that matched to this observation
obj_name = known_objs.get_name(ko_idx)
# Create an array of dimension trj_skycoords where each value is false
Expand All @@ -202,7 +205,9 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param
# want for results filtering.
obs_idx = trj_idx_to_obs_idx[t_idx]
if obs_idx >= len(matched_known_objs[obj_name]):
raise ValueError(f"obs_idx: {obs_idx}, \n t_idx: {t_idx}, \n trj_idx_to_obs_idx: {trj_idx_to_obs_idx}, \nvalid_obstimes: {valid_obstimes}\n,trjs_idx: {trjs_idx},\n known_objs_idx: {known_objs_idx}")
raise ValueError(
f"obs_idx: {obs_idx}, \n t_idx: {t_idx}, \n trj_idx_to_obs_idx: {trj_idx_to_obs_idx}, \nvalid_obstimes: {valid_obstimes}\n,trjs_idx: {trjs_idx},\n known_objs_idx: {known_objs_idx}"
)

matched_known_objs[obj_name][obs_idx] = True
all_matches.append(matched_known_objs)
Expand All @@ -211,8 +216,8 @@ def apply_known_obj_filters(result_data, known_objs, obstimes, wcs, filter_param

# Add matches as a result column
result_data.table[filter_params["filter_type"]] = all_matches

if remove_match_obs:
result_data.update_obs_valid(new_obs_valid)

return all_matches
Loading

0 comments on commit fd37556

Please sign in to comment.