Skip to content

Commit

Permalink
add ability to set rng seed for fit_barycentric_wcs + add a consisten…
Browse files Browse the repository at this point in the history
…cy unit test (#554)
  • Loading branch information
maxwest-uw authored Apr 9, 2024
1 parent 5adc40a commit 9682119
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 25 deletions.
12 changes: 9 additions & 3 deletions src/kbmod/reprojection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def correct_parallax(coord, obstime, point_on_earth, guess_distance):
return answer


def fit_barycentric_wcs(original_wcs, width, height, distance, obstime, point_on_earth, npoints=10):
def fit_barycentric_wcs(
original_wcs, width, height, distance, obstime, point_on_earth, npoints=10, seed=None
):
"""Given a ICRS WCS and an object's distance from the Sun,
return a new WCS that has been corrected for parallax motion.
Expand All @@ -81,17 +83,21 @@ def fit_barycentric_wcs(original_wcs, width, height, distance, obstime, point_on
Typically, the more points the higher the accuracy. The four corners
of the image will always be included, so setting npoints = 0 will mean
just using the corners.
seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator}
the seed that `numpy.random.default_rng` will use.
Returns
----------
An `astropy.wcs.WCS` representing the original image in "Explicity Barycentric Distance" (EBD)
space, i.e. where the points have been corrected for parallax.
"""
rng = np.random.default_rng(seed)

sampled_x_points = np.array([0, 0, width, width])
sampled_y_points = np.array([0, height, height, 0])
if npoints > 0:
sampled_x_points = np.append(sampled_x_points, np.random.rand(npoints) * width)
sampled_y_points = np.append(sampled_y_points, np.random.rand(npoints) * height)
sampled_x_points = np.append(sampled_x_points, rng.random(npoints) * width)
sampled_y_points = np.append(sampled_y_points, rng.random(npoints) * height)

sampled_ra, sampled_dec = original_wcs.all_pix2world(sampled_x_points, sampled_y_points, 0)

Expand Down
63 changes: 41 additions & 22 deletions tests/test_reprojection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@


class test_reprojection_utils(unittest.TestCase):
def setUp(self):
self.nx = 2046
self.ny = 4094
self.test_wcs = WCS(naxis=2)
self.test_wcs.pixel_shape = (self.ny, self.nx)
self.test_wcs.wcs.crpix = [self.nx / 2, self.ny / 2]
self.test_wcs.wcs.cdelt = np.array([-0.000055555555556, 0.000055555555556])
self.test_wcs.wcs.crval = [346.9681342111, -6.482196848597]
self.test_wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]

self.time = "2021-08-24T20:59:06"
self.site = "ctio"
self.loc = EarthLocation.of_site(self.site)
self.distance = 41.1592725489203

def test_parallax_equinox(self):
icrs_ra1 = 88.74513571
icrs_dec1 = 23.43426475
Expand Down Expand Up @@ -49,17 +64,7 @@ def test_parallax_equinox(self):
npt.assert_almost_equal(corrected_coord2.dec.value, expected_dec)

def test_fit_barycentric_wcs(self):
nx = 2046
ny = 4094
test_wcs = WCS(naxis=2)
test_wcs.pixel_shape = (ny, nx)
test_wcs.wcs.crpix = [nx / 2, ny / 2]
test_wcs.wcs.cdelt = np.array([-0.000055555555556, 0.000055555555556])
test_wcs.wcs.crval = [346.9681342111, -6.482196848597]
test_wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]

x_points = np.array([247, 1252, 1052, 980, 420, 1954, 730, 1409, 1491, 803])

y_points = np.array([1530, 713, 3414, 3955, 1975, 123, 1456, 2008, 1413, 1756])

expected_ra = np.array(
Expand Down Expand Up @@ -94,18 +99,13 @@ def test_fit_barycentric_wcs(self):

expected_sc = SkyCoord(ra=expected_ra, dec=expected_dec, unit="deg")

time = "2021-08-24T20:59:06"
site = "ctio"
loc = EarthLocation.of_site(site)
distance = 41.1592725489203

corrected_wcs = fit_barycentric_wcs(
test_wcs,
nx,
ny,
distance,
time,
loc,
self.test_wcs,
self.nx,
self.ny,
self.distance,
self.time,
self.loc,
)

corrected_ra, corrected_dec = corrected_wcs.all_pix2world(x_points, y_points, 0)
Expand All @@ -114,4 +114,23 @@ def test_fit_barycentric_wcs(self):

# assert we have sub-milliarcsecond precision
assert np.all(seps < 0.001)
assert corrected_wcs.array_shape == (ny, nx)
assert corrected_wcs.array_shape == (self.ny, self.nx)

def test_fit_barycentric_wcs_consistency(self):
corrected_wcs = fit_barycentric_wcs(
self.test_wcs, self.nx, self.ny, self.distance, self.time, self.loc, seed=24601
)

# crval consistency
npt.assert_almost_equal(corrected_wcs.wcs.crval[0], 346.6498731934591)
npt.assert_almost_equal(corrected_wcs.wcs.crval[1], -6.593449653602658)

# crpix consistency
npt.assert_almost_equal(corrected_wcs.wcs.crpix[0], 1024.4630013095195)
npt.assert_almost_equal(corrected_wcs.wcs.crpix[1], 2047.9912979360922)

# cd consistency
npt.assert_almost_equal(corrected_wcs.wcs.cd[0][0], -5.424296904025753e-05)
npt.assert_almost_equal(corrected_wcs.wcs.cd[0][1], 3.459611876675614e-08)
npt.assert_almost_equal(corrected_wcs.wcs.cd[1][0], 3.401472764249802e-08)
npt.assert_almost_equal(corrected_wcs.wcs.cd[1][1], 5.4242245855217796e-05)

0 comments on commit 9682119

Please sign in to comment.