From 7ff6148c313296f01296cf360762606e21599536 Mon Sep 17 00:00:00 2001 From: Max West <110124344+maxwest-uw@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:01:26 -0700 Subject: [PATCH 1/2] adds `invert_correct_parallax` function (#559) * invert_correct_parallax implementation * make correct_parallax output SkyCoord * black formatting * fix docstrings for reproject_utils * address pr comments --- src/kbmod/reprojection_utils.py | 90 +++++++++++++++++++++++++------ tests/test_reprojection_utils.py | 91 ++++++++++++++++++++++++-------- 2 files changed, 142 insertions(+), 39 deletions(-) diff --git a/src/kbmod/reprojection_utils.py b/src/kbmod/reprojection_utils.py index b843987f7..c324152bb 100644 --- a/src/kbmod/reprojection_utils.py +++ b/src/kbmod/reprojection_utils.py @@ -6,7 +6,7 @@ from scipy.optimize import minimize -def correct_parallax(coord, obstime, point_on_earth, guess_distance): +def correct_parallax(coord, obstime, point_on_earth, heliocentric_distance): """Calculate the parallax corrected postions for a given object at a given time and distance from Earth. Attributes @@ -17,12 +17,12 @@ def correct_parallax(coord, obstime, point_on_earth, guess_distance): The observation time. point_on_earth : `astropy.coordinate.EarthLocation` The location on Earth of the observation. - guess_distance : `float` - The guess distance to the object from Earth. + heliocentric_distance : `float` + The guess distance to the object from the Sun. Returns ---------- - An `astropy.coordinate.SkyCoord` containing the ra and dec of the pointin ICRS. + An `astropy.coordinate.SkyCoord` containing the ra and dec of the point in ICRS, and the best fit geocentric distance (float). References ---------- @@ -38,9 +38,15 @@ def correct_parallax(coord, obstime, point_on_earth, guess_distance): # the object has an unknown distance from earth los_earth_obj = coord.transform_to(GCRS(obstime=obstime, obsgeoloc=loc)) - cost = lambda d: np.abs( - guess_distance - - GCRS(ra=los_earth_obj.ra, dec=los_earth_obj.dec, distance=d * u.AU, obstime=obstime, obsgeoloc=loc) + cost = lambda geocentric_distance: np.abs( + heliocentric_distance + - GCRS( + ra=los_earth_obj.ra, + dec=los_earth_obj.dec, + distance=geocentric_distance * u.AU, + obstime=obstime, + obsgeoloc=loc, + ) .transform_to(ICRS()) .distance.to(u.AU) .value @@ -48,18 +54,64 @@ def correct_parallax(coord, obstime, point_on_earth, guess_distance): fit = minimize( cost, - (guess_distance,), + (heliocentric_distance,), ) - answer = GCRS( - ra=los_earth_obj.ra, dec=los_earth_obj.dec, distance=fit.x[0] * u.AU, obstime=obstime, obsgeoloc=loc + answer = SkyCoord( + ra=los_earth_obj.ra, + dec=los_earth_obj.dec, + distance=fit.x[0] * u.AU, + obstime=obstime, + obsgeoloc=loc, + frame="gcrs", ).transform_to(ICRS()) - return answer + return answer, fit.x[0] + + +def invert_correct_parallax(coord, obstime, point_on_earth, geocentric_distance, heliocentric_distance): + """Calculate the original ICRS coordinates of a point in EBD space, i.e. a result from `correct_parallax`. + + Attributes + ---------- + coord : `astropy.coordinate.SkyCoord` + The EBD coordinate that we want to find the original position of in non parallax corrected space of. + obstime : `astropy.time.Time` or `string` + The observation time. + point_on_earth : `astropy.coordinate.EarthLocation` + The location on Earth of the observation. + geocentric_distance : `float` + The distance from Earth to the object (generally a result from `correct_parallax`). + heliocentric_distance : `float` + The distance from the solar system barycenter to the object (generally an input for `correct_parallax`). + + Returns + ---------- + An `astropy.coordinate.SkyCoord` containing the ra and dec of the point in ICRS. corresponding to the + position in the original observation (before `correct_parallax`). + + References + ---------- + .. [1] `Jupyter Notebook `_ + """ + loc = ( + point_on_earth.x, + point_on_earth.y, + point_on_earth.z, + ) * u.m + icrs_with_dist = ICRS(ra=coord.ra, dec=coord.dec, distance=heliocentric_distance * u.au) + + gcrs_no_dist = icrs_with_dist.transform_to(GCRS(obsgeoloc=loc, obstime=obstime)) + gcrs_with_dist = GCRS( + ra=gcrs_no_dist.ra, dec=gcrs_no_dist.dec, distance=geocentric_distance, obsgeoloc=loc, obstime=obstime + ) + + original_icrs = gcrs_with_dist.transform_to(ICRS()) + return original_icrs def fit_barycentric_wcs( - original_wcs, width, height, distance, obstime, point_on_earth, npoints=10, seed=None + original_wcs, width, height, heliocentric_distance, obstime, point_on_earth, npoints=10, seed=None ): """Given a ICRS WCS and an object's distance from the Sun, return a new WCS that has been corrected for parallax motion. @@ -72,7 +124,7 @@ def fit_barycentric_wcs( The image's width (typically NAXIS1). height : `int` The image's height (typically NAXIS2). - distance : `float` + heliocentric_distance : `float` The distance of the object from the sun, in AU. obstime : `astropy.time.Time` or `string` The observation time. @@ -89,7 +141,8 @@ def fit_barycentric_wcs( Returns ---------- An `astropy.wcs.WCS` representing the original image in "Explicity Barycentric Distance" (EBD) - space, i.e. where the points have been corrected for parallax. + space, i.e. where the points have been corrected for parallax, as well as the average best fit + geocentric distance of the object. """ rng = np.random.default_rng(seed) @@ -104,12 +157,17 @@ def fit_barycentric_wcs( sampled_coordinates = SkyCoord(sampled_ra, sampled_dec, unit="deg") ebd_corrected_points = [] + geocentric_distances = [] for coord in sampled_coordinates: - ebd_corrected_points.append(correct_parallax(coord, obstime, point_on_earth, distance)) + coord, geo_dist = correct_parallax(coord, obstime, point_on_earth, heliocentric_distance) + ebd_corrected_points.append(coord) + geocentric_distances.append(geo_dist) ebd_corrected_points = SkyCoord(ebd_corrected_points) xy = (sampled_x_points, sampled_y_points) ebd_wcs = fit_wcs_from_points( xy, ebd_corrected_points, proj_point="center", projection="TAN", sip_degree=3 ) - return ebd_wcs + geocentric_distance = np.average(geocentric_distances) + + return ebd_wcs, geocentric_distance diff --git a/tests/test_reprojection_utils.py b/tests/test_reprojection_utils.py index 12c1a1357..130a2dab6 100644 --- a/tests/test_reprojection_utils.py +++ b/tests/test_reprojection_utils.py @@ -6,7 +6,7 @@ from astropy.time import Time from astropy.wcs import WCS -from kbmod.reprojection_utils import correct_parallax, fit_barycentric_wcs +from kbmod.reprojection_utils import correct_parallax, invert_correct_parallax, fit_barycentric_wcs class test_reprojection_utils(unittest.TestCase): @@ -25,26 +25,26 @@ def setUp(self): self.loc = EarthLocation.of_site(self.site) self.distance = 41.1592725489203 - def test_parallax_equinox(self): - icrs_ra1 = 88.74513571 - icrs_dec1 = 23.43426475 - time1 = Time("2023-03-20T16:00:00", format="isot", scale="utc") + self.icrs_ra1 = 88.74513571 + self.icrs_dec1 = 23.43426475 + self.icrs_time1 = Time("2023-03-20T16:00:00", format="isot", scale="utc") - icrs_ra2 = 91.24261107 - icrs_dec2 = 23.43437467 - time2 = Time("2023-09-24T04:00:00", format="isot", scale="utc") + self.icrs_ra2 = 91.24261107 + self.icrs_dec2 = 23.43437467 + self.icrs_time2 = Time("2023-09-24T04:00:00", format="isot", scale="utc") - sc1 = SkyCoord(ra=icrs_ra1, dec=icrs_dec1, unit="deg") - sc2 = SkyCoord(ra=icrs_ra2, dec=icrs_dec2, unit="deg") + self.sc1 = SkyCoord(ra=self.icrs_ra1, dec=self.icrs_dec1, unit="deg") + self.sc2 = SkyCoord(ra=self.icrs_ra2, dec=self.icrs_dec2, unit="deg") with solar_system_ephemeris.set("de432s"): - loc = EarthLocation.of_site("ctio") + self.eq_loc = EarthLocation.of_site("ctio") - corrected_coord1 = correct_parallax( - coord=sc1, - obstime=time1, - point_on_earth=loc, - guess_distance=50.0, + def test_parallax_equinox(self): + corrected_coord1, _ = correct_parallax( + coord=self.sc1, + obstime=self.icrs_time1, + point_on_earth=self.eq_loc, + heliocentric_distance=50.0, ) expected_ra = 90.0 @@ -53,16 +53,58 @@ def test_parallax_equinox(self): npt.assert_almost_equal(corrected_coord1.ra.value, expected_ra) npt.assert_almost_equal(corrected_coord1.dec.value, expected_dec) - corrected_coord2 = correct_parallax( - coord=sc2, - obstime=time2, - point_on_earth=loc, - guess_distance=50.0, + corrected_coord2, _ = correct_parallax( + coord=self.sc2, + obstime=self.icrs_time2, + point_on_earth=self.eq_loc, + heliocentric_distance=50.0, ) npt.assert_almost_equal(corrected_coord2.ra.value, expected_ra) npt.assert_almost_equal(corrected_coord2.dec.value, expected_dec) + assert type(corrected_coord1) is SkyCoord + assert type(corrected_coord2) is SkyCoord + + def test_invert_correct_parallax(self): + corrected_coord1, geo_dist1 = correct_parallax( + coord=self.sc1, + obstime=self.icrs_time1, + point_on_earth=self.eq_loc, + heliocentric_distance=50.0, + ) + + fresh_sc1 = SkyCoord(ra=corrected_coord1.ra.degree, dec=corrected_coord1.dec.degree, unit="deg") + + uncorrected_coord1 = invert_correct_parallax( + coord=fresh_sc1, + obstime=self.icrs_time1, + point_on_earth=self.eq_loc, + geocentric_distance=geo_dist1, + heliocentric_distance=50.0, + ) + + assert self.sc1.separation(uncorrected_coord1).arcsecond < 0.001 + + corrected_coord2, geo_dist2 = correct_parallax( + coord=self.sc2, + obstime=self.icrs_time2, + point_on_earth=self.eq_loc, + heliocentric_distance=50.0, + ) + + fresh_sc2 = SkyCoord(ra=corrected_coord2.ra.degree, dec=corrected_coord2.dec.degree, unit="deg") + + uncorrected_coord2 = invert_correct_parallax( + coord=fresh_sc2, + obstime=self.icrs_time2, + point_on_earth=self.eq_loc, + geocentric_distance=geo_dist2, + heliocentric_distance=50.0, + ) + + assert self.sc2.separation(uncorrected_coord2).arcsecond < 0.001 + def test_fit_barycentric_wcs(self): x_points = np.array([247, 1252, 1052, 980, 420, 1954, 730, 1409, 1491, 803]) y_points = np.array([1530, 713, 3414, 3955, 1975, 123, 1456, 2008, 1413, 1756]) @@ -99,7 +141,7 @@ def test_fit_barycentric_wcs(self): expected_sc = SkyCoord(ra=expected_ra, dec=expected_dec, unit="deg") - corrected_wcs = fit_barycentric_wcs( + corrected_wcs, geo_dist = fit_barycentric_wcs( self.test_wcs, self.nx, self.ny, @@ -115,9 +157,10 @@ def test_fit_barycentric_wcs(self): # assert we have sub-milliarcsecond precision assert np.all(seps < 0.001) assert corrected_wcs.array_shape == (self.ny, self.nx) + npt.assert_almost_equal(geo_dist, 40.18622, decimal=4) def test_fit_barycentric_wcs_consistency(self): - corrected_wcs = fit_barycentric_wcs( + corrected_wcs, geo_dist = fit_barycentric_wcs( self.test_wcs, self.nx, self.ny, self.distance, self.time, self.loc, seed=24601 ) @@ -134,3 +177,5 @@ def test_fit_barycentric_wcs_consistency(self): npt.assert_almost_equal(corrected_wcs.wcs.cd[0][1], 3.459611876675614e-08) npt.assert_almost_equal(corrected_wcs.wcs.cd[1][0], 3.401472764249802e-08) npt.assert_almost_equal(corrected_wcs.wcs.cd[1][1], 5.4242245855217796e-05) + + npt.assert_almost_equal(geo_dist, 40.18622524245729) From 84484cd86f70cfb2f28b31ba89c41714484eacb2 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:29:23 -0400 Subject: [PATCH 2/2] Address PR comments --- src/kbmod/filters/clustering_filters.py | 2 +- src/kbmod/filters/sigma_g_filter.py | 6 +- src/kbmod/filters/stamp_filters.py | 2 +- src/kbmod/results.py | 142 ++++++++++++------------ src/kbmod/trajectory_utils.py | 6 +- tests/test_results.py | 36 +++--- tests/test_sigma_g_filter.py | 2 +- 7 files changed, 100 insertions(+), 96 deletions(-) diff --git a/src/kbmod/filters/clustering_filters.py b/src/kbmod/filters/clustering_filters.py index a97535a46..6aade20bb 100644 --- a/src/kbmod/filters/clustering_filters.py +++ b/src/kbmod/filters/clustering_filters.py @@ -303,6 +303,6 @@ def apply_clustering(result_data, cluster_params): if type(result_data) is ResultList: result_data.filter_results(indices_to_keep, filt.get_filter_name()) elif type(result_data) is Results: - result_data.filter_by_index(indices_to_keep, filt.get_filter_name()) + result_data.filter_rows(indices_to_keep, filt.get_filter_name()) else: raise TypeError("Unknown data type for clustering.") diff --git a/src/kbmod/filters/sigma_g_filter.py b/src/kbmod/filters/sigma_g_filter.py index 2b599ec0b..fd516ebe2 100644 --- a/src/kbmod/filters/sigma_g_filter.py +++ b/src/kbmod/filters/sigma_g_filter.py @@ -190,9 +190,9 @@ def apply_clipped_sigma_g(clipper, result_data, num_threads=1): The number of threads to use. """ if type(result_data) is Results: - lh = result_data.compute_likelihood_curves(filter_indices=True, mask_value=np.NAN) - index_valid = clipper.compute_clipped_sigma_g_matrix(lh) - result_data.update_index_valid(index_valid) + lh = result_data.compute_likelihood_curves(filter_obs=True, mask_value=np.NAN) + obs_valid = clipper.compute_clipped_sigma_g_matrix(lh) + result_data.update_obs_valid(obs_valid) return # TODO: Remove this logic once we have switched over to Results. diff --git a/src/kbmod/filters/stamp_filters.py b/src/kbmod/filters/stamp_filters.py index b295f9de5..7347be19a 100644 --- a/src/kbmod/filters/stamp_filters.py +++ b/src/kbmod/filters/stamp_filters.py @@ -463,7 +463,7 @@ def get_coadds_and_filter_results(result_data, im_stack, stamp_params, chunk_siz start_idx += chunk_size # Do the actual filtering of results - result_data.filter_mask(keep_row, label="stamp_filter") + result_data.filter_rows(keep_row, label="stamp_filter") # Append the coadded stamps to the results. We do this after the filtering # so we are not adding a jagged array. diff --git a/src/kbmod/results.py b/src/kbmod/results.py index 3db822bd4..a7fe5d6c9 100644 --- a/src/kbmod/results.py +++ b/src/kbmod/results.py @@ -34,14 +34,14 @@ class Results: # The required columns list gives a list of tuples containing # (column name, dype, default value) for each required column. - _required_cols = [ - ("x", "int64", 0), - ("y", "int64", 0), - ("vx", "float64", 0.0), - ("vy", "float64", 0.0), - ("likelihood", "float64", 0.0), - ("flux", "float64", 0.0), - ("obs_count", "int64", 0), + required_cols = [ + ("x", int, 0), + ("y", int, 0), + ("vx", float, 0.0), + ("vy", float, 0.0), + ("likelihood", float, 0.0), + ("flux", float, 0.0), + ("obs_count", int, 0), ] def __init__(self, data=None, track_filtered=False): @@ -61,18 +61,18 @@ def __init__(self, data=None, track_filtered=False): if data is None: # Set up the basic table meta data. self.table = Table( - names=[col[0] for col in self._required_cols], - dtype=[col[1] for col in self._required_cols], + names=[col[0] for col in self.required_cols], + dtype=[col[1] for col in self.required_cols], ) - elif type(data) is dict: + elif isinstance(data, dict): self.table = Table(data) - elif type(data) is Table: + elif isinstance(data, Table): self.table = data.copy() else: raise TypeError(f"Incompatible data type {type(data)}") # Check that we have the correct columns. - for col in self._required_cols: + for col in self.required_cols: if col[0] not in self.table.colnames: raise KeyError(f"Column {col[0]} missing from input data.") @@ -109,7 +109,7 @@ def from_trajectories(cls, trajectories, track_filtered=False): # Create dictionaries for the required columns. input_d = {} invalid_d = {} - for col in cls._required_cols: + for col in cls.required_cols: input_d[col[0]] = [] invalid_d[col[0]] = [] num_valid = 0 @@ -128,6 +128,8 @@ def from_trajectories(cls, trajectories, track_filtered=False): input_d["obs_count"].append(trj.obs_count) num_valid += 1 elif track_filtered: + # Only fill in the invalid_d dictionary if we are going + # to use it (we are tracking the filtered values). invalid_d["x"].append(trj.x) invalid_d["y"].append(trj.y) invalid_d["vx"].append(trj.vx) @@ -138,7 +140,7 @@ def from_trajectories(cls, trajectories, track_filtered=False): num_invalid += 1 # Check for any missing columns and fill in the default value. - for col in cls._required_cols: + for col in cls.required_cols: if col[0] not in input_d: input_d[col[0]] = [col[2]] * num_valid invalid_d[col[0]] = [col[2]] * num_invalid @@ -225,14 +227,14 @@ def make_trajectory_list(self): ] return trajectories - def compute_likelihood_curves(self, filter_indices=True, mask_value=0.0): + def compute_likelihood_curves(self, filter_obs=True, mask_value=0.0): """Create a matrix of likelihood curves where each row has a likelihood curve for a single trajectory. Parameters ---------- - filter_indices : `bool` - Filter any indices marked as invalid in the 'index_valid' column. + filter_obs : `bool` + Filter any indices marked as invalid in the 'obs_valid' column. Substitutes the value of ``mask_value`` in their place. mask_value : `float` A floating point value to substitute into the masked entries. @@ -258,11 +260,11 @@ def compute_likelihood_curves(self, filter_indices=True, mask_value=0.0): # Create a mask of valid data. valid = (phi != 0) & np.isfinite(psi) & np.isfinite(phi) - if filter_indices and "index_valid" in self.table.colnames: - valid = valid & self.table["index_valid"] + if filter_obs and "obs_valid" in self.table.colnames: + valid = valid & self.table["obs_valid"] lh_matrix = np.full(psi.shape, mask_value) - lh_matrix[valid] = np.divide(psi[valid], np.sqrt(phi[valid])) + lh_matrix[valid] = psi[valid] / np.sqrt(phi[valid]) return lh_matrix def _update_likelihood(self): @@ -272,7 +274,7 @@ def _update_likelihood(self): Uses the (optional) 'valid_indices' if it exists. This should be called any time that the psi_curve, phi_curve, or - index_valid columns are modified. + obs_valid columns are modified. Raises ------ @@ -285,10 +287,10 @@ def _update_likelihood(self): num_rows = len(self.table) num_times = len(self.table["phi_curve"][0]) - if "index_valid" in self.table.colnames: - phi_sum = (self.table["phi_curve"] * self.table["index_valid"]).sum(axis=1) - psi_sum = (self.table["psi_curve"] * self.table["index_valid"]).sum(axis=1) - num_obs = self.table["index_valid"].sum(axis=1) + if "obs_valid" in self.table.colnames: + phi_sum = (self.table["phi_curve"] * self.table["obs_valid"]).sum(axis=1) + psi_sum = (self.table["psi_curve"] * self.table["obs_valid"]).sum(axis=1) + num_obs = self.table["obs_valid"].sum(axis=1) else: phi_sum = self.table["phi_curve"].sum(axis=1) psi_sum = self.table["psi_curve"].sum(axis=1) @@ -301,7 +303,7 @@ def _update_likelihood(self): self.table["flux"][non_zero] = psi_sum[non_zero] / phi_sum[non_zero] self.table["obs_count"] = num_obs - def add_psi_phi_data(self, psi_array, phi_array, index_valid=None): + def add_psi_phi_data(self, psi_array, phi_array, obs_valid=None): """Append columns for the psi and phi data and use this to update the relevant trajectory information. @@ -311,8 +313,8 @@ def add_psi_phi_data(self, psi_array, phi_array, index_valid=None): An array of psi_curves with one for each row. phi_array : `numpy.ndarray` An array of psi_curves with one for each row. - index_valid : `numpy.ndarray`, optional - An optional array of index_valid arrays with one for each row. + obs_valid : `numpy.ndarray`, optional + An optional array of obs_valid arrays with one for each row. Returns ------- @@ -325,29 +327,38 @@ def add_psi_phi_data(self, psi_array, phi_array, index_valid=None): or a given pair of rows in the arrays are not the same length. """ if len(psi_array) != len(self.table): - raise ValueError("Wrong number of psi curves provided.") + raise ValueError( + f"Wrong number of psi curves provided. Expected {len(self.table)} rows." + f" Found {len(psi_array)} rows." + ) if len(phi_array) != len(self.table): - raise ValueError("Wrong number of phi curves provided.") + raise ValueError( + f"Wrong number of phi curves provided. Expected {len(self.table)} rows." + f" Found {len(phi_array)} rows." + ) self.table["psi_curve"] = psi_array self.table["phi_curve"] = phi_array - if index_valid is not None: + if obs_valid is not None: # Make the data to match. - if len(index_valid) != len(self.table): - raise ValueError("Wrong number of index_valid lists provided.") - self.table["index_valid"] = index_valid + if len(obs_valid) != len(self.table): + raise ValueError( + f"Wrong number of obs_valid provided. Expected {len(self.table)} rows." + f" Found {len(obs_valid)} rows." + ) + self.table["obs_valid"] = obs_valid # Update the track likelihoods given this new information. self._update_likelihood() return self - def update_index_valid(self, index_valid): - """Updates or appends the 'index_valid' column. + def update_obs_valid(self, obs_valid): + """Updates or appends the 'obs_valid' column. Parameters ---------- - index_valid : `numpy.ndarray` + obs_valid : `numpy.ndarray` An array with one row per results and one column per timestamp with Booleans indicating whether the corresponding observation is valid. @@ -362,9 +373,12 @@ def update_index_valid(self, index_valid): Raises a ValueError if the input array is not the same size as the table or a given pair of rows in the arrays are not the same length. """ - if len(index_valid) != len(self.table): - raise ValueError("Wrong number of index_valid lists provided.") - self.table["index_valid"] = index_valid + if len(obs_valid) != len(self.table): + raise ValueError( + f"Wrong number of obs_valid lists provided. Expected {len(self.table)} rows" + f" Found {len(obs_valid)} rows" + ) + self.table["obs_valid"] = obs_valid # Update the track likelihoods given this new information. self._update_likelihood() @@ -394,15 +408,16 @@ def _append_filtered(self, table, label=None): else: self.filtered[label] = table - def filter_mask(self, mask, label=None): - """Filter the rows in the ResultTable to only include those indices - that are marked True in the mask. + def filter_rows(self, rows, label=None): + """Filter the rows in the `Results` to only include those indices + that are provided in a list of row indices (integers) or marked + ``True`` in a mask. Parameters ---------- - mask : `list` or `numpy.ndarray` - A list the same length as the table with True/False indicating - which row to keep. + rows : `numpy.ndarray` + Either a Boolean array of the same length as the table + or list of integer row indices to keep. label : `str` The label of the filtering stage to use. Only used if we keep filtered trajectories. @@ -412,6 +427,17 @@ def filter_mask(self, mask, label=None): self : `Results` Returns a reference to itself to allow chaining. """ + rows = np.array(rows) + if rows.dtype == bool: + if len(rows) != len(self.table): + raise ValueError( + f"Mask length mismatch. Expected {len(self.table)} rows, but found {len(rows)}." + ) + mask = rows + else: + mask = np.full((len(self.table),), False) + mask[rows] = True + if self.track_filtered: self._append_filtered(self.table[~mask], label) @@ -421,28 +447,6 @@ def filter_mask(self, mask, label=None): # Return a reference to the current object to allow chaining. return self - def filter_by_index(self, rows_to_keep, label=None): - """Filter the rows in the ResultTable to only include those indices - in the list indices_to_keep. - - Parameters - ---------- - rows_to_keep : `list[int]` - The indices of the rows to keep. - label : `str` - The label of the filtering stage to use. Only used if - we keep filtered trajectories. - - Returns - ------- - self : `Results` - Returns a reference to itself to allow chaining. - """ - row_set = set(rows_to_keep) - mask = np.array([i in row_set for i in range(len(self.table))]) - self.filter_mask(mask, label) - return self - def get_filtered(self, label=None): """Get the results filtered at a given stage or all stages. diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index aeb34d98a..76a9b7b0d 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -46,13 +46,13 @@ def make_trajectory(x=0, y=0, vx=0.0, vy=0.0, flux=0.0, lh=0.0, obs_count=0): The resulting Trajectory object. """ trj = Trajectory() - trj.x = int(x) - trj.y = int(y) + trj.x = x + trj.y = y trj.vx = vx trj.vy = vy trj.flux = flux trj.lh = lh - trj.obs_count = int(obs_count) + trj.obs_count = obs_count trj.valid = True return trj diff --git a/tests/test_results.py b/tests/test_results.py index 3b93fb8e6..09140ab37 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -142,7 +142,7 @@ def test_add_psi_phi(self): table = Results.from_trajectories(self.trj_list[0:num_to_use]) psi_array = np.array([[1.0, 1.1, 1.2, 1.3] for i in range(num_to_use)]) phi_array = np.array([[1.0, 1.0, 0.0, 2.0] for i in range(num_to_use)]) - index_valid = np.array( + obs_valid = np.array( [ [True, True, True, True], [True, False, True, True], @@ -155,17 +155,17 @@ def test_add_psi_phi(self): exp_obs = [4, 3, 0] # Check the the data has been inserted and the statistics have been updated. - table.add_psi_phi_data(psi_array, phi_array, index_valid) + table.add_psi_phi_data(psi_array, phi_array, obs_valid) for i in range(num_to_use): self.assertEqual(len(table["psi_curve"][i]), 4) self.assertEqual(len(table["phi_curve"][i]), 4) - self.assertEqual(len(table["index_valid"][i]), 4) + self.assertEqual(len(table["obs_valid"][i]), 4) self.assertAlmostEqual(table["likelihood"][i], exp_lh[i], delta=1e-5) self.assertAlmostEqual(table["flux"][i], exp_flux[i], delta=1e-5) self.assertEqual(table["obs_count"][i], exp_obs[i]) - def test_update_index_valid(self): + def test_update_obs_valid(self): num_to_use = 3 table = Results.from_trajectories(self.trj_list[0:num_to_use]) psi_array = np.array([[1.0, 1.1, 1.2, 1.3] for i in range(num_to_use)]) @@ -176,21 +176,21 @@ def test_update_index_valid(self): self.assertAlmostEqual(table["flux"][i], 1.15, delta=1e-5) self.assertEqual(table["obs_count"][i], 4) - # Add the index_valid column later to simulate sigmaG clipping. - index_valid = np.array( + # Add the obs_valid column later to simulate sigmaG clipping. + obs_valid = np.array( [ [True, True, True, True], [True, False, True, True], [False, False, False, False], ] ) - table.update_index_valid(index_valid) + table.update_obs_valid(obs_valid) exp_lh = [2.3, 2.020725, 0.0] exp_flux = [1.15, 1.1666667, 0.0] exp_obs = [4, 3, 0] for i in range(num_to_use): - self.assertEqual(len(table["index_valid"][i]), 4) + self.assertEqual(len(table["obs_valid"][i]), 4) self.assertAlmostEqual(table["likelihood"][i], exp_lh[i], delta=1e-5) self.assertAlmostEqual(table["flux"][i], exp_flux[i], delta=1e-5) self.assertEqual(table["obs_count"][i], exp_obs[i]) @@ -213,25 +213,25 @@ def test_compute_likelihood_curves(self): [25.0, 16.0, 4.0, 16.0], ] ) - index_valid = np.array( + obs_valid = np.array( [ [True, True, True, True], [True, True, True, True], [True, True, False, True], ] ) - table.add_psi_phi_data(psi_array, phi_array, index_valid) + table.add_psi_phi_data(psi_array, phi_array, obs_valid) expected1 = np.array([[1.0, 1.1, 0.5, 0.0], [1.0, 0.0, 0.0, 0.0], [0.2, 1.0, 5.0, 0.25]]) - lh_mat1 = table.compute_likelihood_curves(filter_indices=False) + lh_mat1 = table.compute_likelihood_curves(filter_obs=False) self.assertTrue(np.allclose(lh_mat1, expected1)) expected2 = np.array([[1.0, 1.1, 0.5, 0.0], [1.0, 0.0, 0.0, 0.0], [0.2, 1.0, 0.0, 0.25]]) - lh_mat2 = table.compute_likelihood_curves(filter_indices=True) + lh_mat2 = table.compute_likelihood_curves(filter_obs=True) self.assertTrue(np.allclose(lh_mat2, expected2)) # Try masking with NAN. This replaces ALL the invalid cells. - lh_mat3 = table.compute_likelihood_curves(filter_indices=True, mask_value=np.NAN) + lh_mat3 = table.compute_likelihood_curves(filter_obs=True, mask_value=np.NAN) expected = np.array( [ [True, True, True, False], @@ -247,7 +247,7 @@ def test_filter_by_index(self): # Do the filtering and check we have the correct ones. inds = [0, 2, 6, 7] - table.filter_by_index(inds) + table.filter_rows(inds) self.assertEqual(len(table), len(inds)) for i in range(len(inds)): self.assertEqual(table["x"][i], self.trj_list[inds[i]].x) @@ -268,9 +268,9 @@ def test_filter_by_index_tracked(self): # Do the filtering. First remove elements 0 and 2. Then remove elements # 0, 5, and 6 from the resulting list (1, 7, 8 in the original list). - table.filter_by_index([1, 3, 4, 5, 6, 7, 8, 9], label="filter1") + table.filter_rows([1, 3, 4, 5, 6, 7, 8, 9], label="filter1") self.assertEqual(len(table), 8) - table.filter_by_index([1, 2, 3, 4, 7], label="filter2") + table.filter_rows([1, 2, 3, 4, 7], label="filter2") self.assertEqual(len(table), 5) self.assertEqual(table["x"][0], 3) self.assertEqual(table["x"][1], 4) @@ -303,8 +303,8 @@ def test_filter_by_index_tracked(self): # Check that we can revert the filtering and add a 'filtered_reason' column. table = Results.from_trajectories(self.trj_list[0:10], track_filtered=True) - table.filter_by_index([1, 3, 4, 5, 6, 7, 8, 9], label="filter1") - table.filter_by_index([1, 2, 3, 4, 7], label="filter2") + table.filter_rows([1, 3, 4, 5, 6, 7, 8, 9], label="filter1") + table.filter_rows([1, 2, 3, 4, 7], label="filter2") table.revert_filter(add_column="reason") self.assertEqual(len(table), 10) expected_order = [3, 4, 5, 6, 9, 0, 2, 1, 7, 8] diff --git a/tests/test_sigma_g_filter.py b/tests/test_sigma_g_filter.py index 1d8b29511..27e3efc67 100644 --- a/tests/test_sigma_g_filter.py +++ b/tests/test_sigma_g_filter.py @@ -158,7 +158,7 @@ def test_apply_clipped_sigma_g_results(self): # Confirm that the ResultRows were modified in place. for i in range(num_results): - valid = table["index_valid"][i] + valid = table["obs_valid"][i] for j in range(i): self.assertFalse(valid[j]) for j in range(i, num_times):