Skip to content

Commit

Permalink
Merge pull request #397 from HERA-Team/future-array-shapes-bug
Browse files Browse the repository at this point in the history
Added tests re. future array shapes and ensured compatibility in utils
  • Loading branch information
JianrongTan authored Mar 17, 2024
2 parents a5f458f + 01857a8 commit cd12a55
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 136 deletions.
308 changes: 173 additions & 135 deletions hera_pspec/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import os, sys, copy
from hera_pspec.data import DATA_PATH
from .. import utils, testing
from .. import utils, testing, pspecbeam
from collections import OrderedDict as odict
from pyuvdata import UVData
from hera_cal import redcal
Expand Down Expand Up @@ -35,6 +35,7 @@ def test_cov():
w1 *= -1.0
pytest.raises(ValueError, utils.cov, d1, w1)


def test_load_config():
"""
Check YAML config file handling.
Expand Down Expand Up @@ -71,9 +72,16 @@ def setUp(self):
self.uvd.read_miriad(os.path.join(DATA_PATH,
"zen.2458042.17772.xx.HH.uvXA"),
use_future_array_shapes=True)

# without future array shapes
self.uvd2 = UVData()
self.uvd2.read_miriad(os.path.join(DATA_PATH,
"zen.2458042.17772.xx.HH.uvXA"),
use_future_array_shapes=False)
# Load PSpecBeam object
beamfile = os.path.join(DATA_PATH, 'HERA_NF_dipole_power.beamfits')
self.beam = pspecbeam.PSpecBeamUV(beamfile)
# Create UVPSpec object
self.uvp, cosmo = testing.build_vanilla_uvpspec()
self.uvp, _ = testing.build_vanilla_uvpspec()

def tearDown(self):
pass
Expand All @@ -89,7 +97,7 @@ def test_spw_range_from_freqs(self):
# Check that type errors and bounds errors are raised
pytest.raises(AttributeError, utils.spw_range_from_freqs, np.arange(3),
freq_range=(100e6, 110e6))
for obj in [self.uvd, self.uvp]:
for obj in [self.uvd, self.uvd2, self.uvp]:
pytest.raises(ValueError, utils.spw_range_from_freqs, obj,
freq_range=(98e6, 110e6)) # lower bound
pytest.raises(ValueError, utils.spw_range_from_freqs, obj,
Expand All @@ -98,20 +106,22 @@ def test_spw_range_from_freqs(self):
freq_range=(190e6, 180e6)) # wrong order

# Check that valid frequency ranges are returned
# with and without future array shapes
freq_list = [(100e6, 120e6), (120e6, 140e6), (140e6, 160e6)]
spw1 = utils.spw_range_from_freqs(self.uvd, freq_range=(110e6, 130e6))
spw2 = utils.spw_range_from_freqs(self.uvd, freq_range=freq_list)
spw3 = utils.spw_range_from_freqs(self.uvd, freq_range=(98e6, 120e6),
bounds_error=False)
spw4 = utils.spw_range_from_freqs(self.uvd, freq_range=(100e6, 120e6))
for obj in [self.uvd, self.uvd2]:
spw1 = utils.spw_range_from_freqs(obj, freq_range=(110e6, 130e6))
spw2 = utils.spw_range_from_freqs(obj, freq_range=freq_list)
spw3 = utils.spw_range_from_freqs(obj, freq_range=(98e6, 120e6),
bounds_error=False)
spw4 = utils.spw_range_from_freqs(obj, freq_range=(100e6, 120e6))

# Make sure tuple vs. list arguments were handled correctly
assert( isinstance(spw1, tuple) )
assert( isinstance(spw2, list) )
assert( len(spw2) == len(freq_list) )
# Make sure tuple vs. list arguments were handled correctly
assert( isinstance(spw1, tuple) )
assert( isinstance(spw2, list) )
assert( len(spw2) == len(freq_list) )

# Make sure that bounds_error=False works
assert( spw3 == spw4 )
# Make sure that bounds_error=False works
assert( spw3 == spw4 )

# Make sure that this also works for UVPSpec objects
spw5 = utils.spw_range_from_freqs(self.uvp, freq_range=(100e6, 104e6))
Expand All @@ -126,7 +136,7 @@ def test_spw_range_from_redshifts(self):
# Check that type errors and bounds errors are raised
pytest.raises(AttributeError, utils.spw_range_from_redshifts,
np.arange(3), z_range=(9.7, 12.1))
for obj in [self.uvd, self.uvp]:
for obj in [self.uvd, self.uvd2, self.uvp]:
pytest.raises(ValueError, utils.spw_range_from_redshifts, obj,
z_range=(5., 8.)) # lower bound
pytest.raises(ValueError, utils.spw_range_from_redshifts, obj,
Expand All @@ -135,17 +145,19 @@ def test_spw_range_from_redshifts(self):
z_range=(11., 10.)) # wrong order

# Check that valid frequency ranges are returned
# with and without future array shapes
z_list = [(6.5, 7.5), (7.5, 8.5), (8.5, 9.5)]
spw1 = utils.spw_range_from_redshifts(self.uvd, z_range=(7., 8.))
spw2 = utils.spw_range_from_redshifts(self.uvd, z_range=z_list)
spw3 = utils.spw_range_from_redshifts(self.uvd, z_range=(12., 14.),
bounds_error=False)
spw4 = utils.spw_range_from_redshifts(self.uvd, z_range=(6.2, 7.2))

# Make sure tuple vs. list arguments were handled correctly
assert( isinstance(spw1, tuple) )
assert( isinstance(spw2, list) )
assert( len(spw2) == len(z_list) )
for obj in [self.uvd, self.uvd2]:
spw1 = utils.spw_range_from_redshifts(obj, z_range=(7., 8.))
spw2 = utils.spw_range_from_redshifts(obj, z_range=z_list)
spw3 = utils.spw_range_from_redshifts(obj, z_range=(12., 14.),
bounds_error=False)
spw4 = utils.spw_range_from_redshifts(obj, z_range=(6.2, 7.2))

# Make sure tuple vs. list arguments were handled correctly
assert( isinstance(spw1, tuple) )
assert( isinstance(spw2, list) )
assert( len(spw2) == len(z_list) )

# Make sure that this also works for UVPSpec objects
spw5 = utils.spw_range_from_redshifts(self.uvp, z_range=(13.1, 13.2))
Expand All @@ -154,124 +166,127 @@ def test_spw_range_from_redshifts(self):

def test_calc_blpair_reds(self):
fname = os.path.join(DATA_PATH, 'zen.all.xx.LST.1.06964.uvA')
uvd = UVData()
uvd.read_miriad(fname, use_future_array_shapes=True)

# basic execution
(bls1, bls2, blps, xants1, xants2, rgrps, lens,
angs) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, extra_info=True,
exclude_auto_bls=False, exclude_permutations=True)
assert len(bls1) == len(bls2) == 15
assert blps == list(zip(bls1, bls2))
assert xants1 == xants2
assert len(xants1) == 42
assert len(rgrps) == len(bls1) # assert rgrps matches bls1 shape
assert np.max(rgrps) == len(lens) - 1 # assert rgrps indexes lens / angs

# test xant_flag_thresh
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=True, exclude_permutations=True,
xant_flag_thresh=0.0)
assert len(bls1) == len(bls2) == 0

# test bl_len_range
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=False, exclude_permutations=True,
bl_len_range=(0, 15.0))
assert len(bls1) == len(bls2) == 12
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=True, exclude_permutations=True,
bl_len_range=(0, 15.0))
assert len(bls1) == len(bls2) == 5
assert np.all([bls1[i] != bls2[i] for i in range(len(blps))])

# test grouping
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=False, exclude_permutations=True,
Nblps_per_group=2)
assert len(blps) == 10
assert isinstance(blps[0], list)
assert blps[0] == [((24, 37), (25, 38)), ((24, 37), (24, 37))]

# test baseline select on uvd
uvd2 = copy.deepcopy(uvd)
uvd2.select(bls=[(24, 25), (37, 38), (24, 39)])
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd2, uvd2, filter_blpairs=True, exclude_auto_bls=True, exclude_permutations=True,
bl_len_range=(10.0, 20.0))
assert blps == [((24, 25), (37, 38))]

# test exclude_cross_bls
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_cross_bls=True)
for bl1, bl2 in blps:
assert bl1 == bl2

# test exceptions
uvd2 = copy.deepcopy(uvd)
uvd2.antenna_positions[0] += 2
pytest.raises(AssertionError, utils.calc_blpair_reds, uvd, uvd2)
pytest.raises(AssertionError, utils.calc_blpair_reds, uvd, uvd, exclude_auto_bls=True, exclude_cross_bls=True)

for future_array_shapes in [True, False]:
uvd = UVData()
uvd.read_miriad(fname, use_future_array_shapes=future_array_shapes)

# basic execution
(bls1, bls2, blps, xants1, xants2, rgrps, lens,
angs) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, extra_info=True,
exclude_auto_bls=False, exclude_permutations=True)
assert len(bls1) == len(bls2) == 15
assert blps == list(zip(bls1, bls2))
assert xants1 == xants2
assert len(xants1) == 42
assert len(rgrps) == len(bls1) # assert rgrps matches bls1 shape
assert np.max(rgrps) == len(lens) - 1 # assert rgrps indexes lens / angs

# test xant_flag_thresh
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=True, exclude_permutations=True,
xant_flag_thresh=0.0)
assert len(bls1) == len(bls2) == 0

# test bl_len_range
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=False, exclude_permutations=True,
bl_len_range=(0, 15.0))
assert len(bls1) == len(bls2) == 12
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=True, exclude_permutations=True,
bl_len_range=(0, 15.0))
assert len(bls1) == len(bls2) == 5
assert np.all([bls1[i] != bls2[i] for i in range(len(blps))])

# test grouping
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_auto_bls=False, exclude_permutations=True,
Nblps_per_group=2)
assert len(blps) == 10
assert isinstance(blps[0], list)
assert blps[0] == [((24, 37), (25, 38)), ((24, 37), (24, 37))]

# test baseline select on uvd
uvd2 = copy.deepcopy(uvd)
uvd2.select(bls=[(24, 25), (37, 38), (24, 39)])
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd2, uvd2, filter_blpairs=True, exclude_auto_bls=True, exclude_permutations=True,
bl_len_range=(10.0, 20.0))
assert blps == [((24, 25), (37, 38))]

# test exclude_cross_bls
(bls1, bls2, blps, xants1,
xants2) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, exclude_cross_bls=True)
for bl1, bl2 in blps:
assert bl1 == bl2

# test exceptions
uvd2 = copy.deepcopy(uvd)
uvd2.antenna_positions[0] += 2
pytest.raises(AssertionError, utils.calc_blpair_reds, uvd, uvd2)
pytest.raises(AssertionError, utils.calc_blpair_reds, uvd, uvd, exclude_auto_bls=True, exclude_cross_bls=True)

def test_calc_blpair_reds_autos_only(self):
# test include_crosscorrs selection option being set to false.
fname = os.path.join(DATA_PATH, 'zen.all.xx.LST.1.06964.uvA')
uvd = UVData()
uvd.read_miriad(fname, use_future_array_shapes=True)
# basic execution
(bls1, bls2, blps, xants1, xants2, rgrps, lens,
angs) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, extra_info=True,
exclude_auto_bls=False, exclude_permutations=True, include_crosscorrs=False,
include_autocorrs=True)
assert len(bls1) > 0
for bl1, bl2 in zip(bls1, bls2):
assert bl1[0] == bl1[1]
assert bl2[0] == bl2[1]

for future_array_shapes in [True, False]:
uvd = UVData()
uvd.read_miriad(fname, use_future_array_shapes=future_array_shapes)
# basic execution
(bls1, bls2, blps, xants1, xants2, rgrps, lens,
angs) = utils.calc_blpair_reds(uvd, uvd, filter_blpairs=True, extra_info=True,
exclude_auto_bls=False, exclude_permutations=True, include_crosscorrs=False,
include_autocorrs=True)
assert len(bls1) > 0
for bl1, bl2 in zip(bls1, bls2):
assert bl1[0] == bl1[1]
assert bl2[0] == bl2[1]

def test_get_delays(self):
utils.get_delays(np.linspace(100., 200., 50)*1e6)

def test_get_reds(self):
fname = os.path.join(DATA_PATH, 'zen.all.xx.LST.1.06964.uvA')
uvd = UVData()
uvd.read_miriad(fname, read_data=False, use_future_array_shapes=True)
antpos, ants = uvd.get_ENU_antpos()
antpos_d = dict(list(zip(ants, antpos)))

# test basic execution
xants = [0, 1, 2]
r, l, a = utils.get_reds(fname, xants=xants)
assert np.all([np.all([bl[0] not in xants and bl[1] not in xants for bl in _r]) for _r in r])
assert len(r) == len(a) == len(l)
assert len(r) == 104

r2, l2, a2 = utils.get_reds(uvd, xants=xants)
_ = [np.testing.assert_array_equal(_r1, _r2) for _r1, _r2 in zip(r, r2)]

r2, l2, a2 = utils.get_reds(antpos_d, xants=xants)
_ = [np.testing.assert_array_equal(_r1, _r2) for _r1, _r2 in zip(r, r2)]

# restrict
bl_len_range = (14, 16)
bl_deg_range = (55, 65)
r, l, a = utils.get_reds(uvd, bl_len_range=bl_len_range, bl_deg_range=bl_deg_range)
assert (np.all([_l > bl_len_range[0] and _l < bl_len_range[1] for _l in l]))
assert (np.all([_a > bl_deg_range[0] and _a < bl_deg_range[1] for _a in a]))

# min EW cut
r, l, a = utils.get_reds(uvd, bl_len_range=(14, 16), min_EW_cut=14)
assert len(l) == len(a) == 1
assert np.isclose(a[0] % 180, 0, atol=1)

# autos
r, l, a = utils.get_reds(fname, xants=xants, add_autos=True)
np.testing.assert_almost_equal(l[0], 0)
np.testing.assert_almost_equal(a[0], 0)
assert len(r) == 105

# Check errors when wrong types input
pytest.raises(TypeError, utils.get_reds, [1., 2.])
for future_array_shapes in [True, False]:
uvd = UVData()
uvd.read_miriad(fname, read_data=False, use_future_array_shapes=future_array_shapes)
antpos, ants = uvd.get_ENU_antpos()
antpos_d = dict(list(zip(ants, antpos)))

# test basic execution
xants = [0, 1, 2]
r, l, a = utils.get_reds(fname, xants=xants)
assert np.all([np.all([bl[0] not in xants and bl[1] not in xants for bl in _r]) for _r in r])
assert len(r) == len(a) == len(l)
assert len(r) == 104

r2, l2, a2 = utils.get_reds(uvd, xants=xants)
_ = [np.testing.assert_array_equal(_r1, _r2) for _r1, _r2 in zip(r, r2)]

r2, l2, a2 = utils.get_reds(antpos_d, xants=xants)
_ = [np.testing.assert_array_equal(_r1, _r2) for _r1, _r2 in zip(r, r2)]

# restrict
bl_len_range = (14, 16)
bl_deg_range = (55, 65)
r, l, a = utils.get_reds(uvd, bl_len_range=bl_len_range, bl_deg_range=bl_deg_range)
assert (np.all([_l > bl_len_range[0] and _l < bl_len_range[1] for _l in l]))
assert (np.all([_a > bl_deg_range[0] and _a < bl_deg_range[1] for _a in a]))

# min EW cut
r, l, a = utils.get_reds(uvd, bl_len_range=(14, 16), min_EW_cut=14)
assert len(l) == len(a) == 1
assert np.isclose(a[0] % 180, 0, atol=1)

# autos
r, l, a = utils.get_reds(fname, xants=xants, add_autos=True)
np.testing.assert_almost_equal(l[0], 0)
np.testing.assert_almost_equal(a[0], 0)
assert len(r) == 105

# Check errors when wrong types input
pytest.raises(TypeError, utils.get_reds, [1., 2.])

def test_get_reds_autos_only(self):
fname = os.path.join(DATA_PATH, 'zen.all.xx.LST.1.06964.uvA')
Expand Down Expand Up @@ -313,7 +328,30 @@ def test_config_pspec_blpairs(self):
# test exceptions
pytest.raises(AssertionError, utils.config_pspec_blpairs, uv_template, [('xx', 'xx'), ('xx', 'xx')], [('even', 'odd')], verbose=False)


def test_uvd_to_Tsys(self):

# PROPER USAGE
# check different ways to call method work and are equivalent
# with or without future array shapes
for obj in [self.uvd, self.uvd2]:
tsys_estimate = utils.uvd_to_Tsys(obj, self.beam)
tsys_estimate2 = utils.uvd_to_Tsys(
obj,
os.path.join(DATA_PATH, 'HERA_NF_dipole_power.beamfits')
)
assert np.allclose(tsys_estimate.data_array, tsys_estimate2.data_array)
uvp2, _ = testing.build_vanilla_uvpspec(beam=self.beam)
tsys_estimate3 = utils.uvd_to_Tsys(obj, uvp2)
assert np.allclose(tsys_estimate.data_array, tsys_estimate3.data_array)

# CHECK ERROR CALLS
# uvp called for beam has no beam information
pytest.raises(ValueError, utils.uvd_to_Tsys, obj, self.uvp)
# beam has wrong format
pytest.raises(ValueError, utils.uvd_to_Tsys, obj, 12.)



def test_log():
"""
Test that log() prints output.
Expand Down
2 changes: 1 addition & 1 deletion hera_pspec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def uvd_to_Tsys(uvd, beam, Tsys_outfile=None):
if pol.upper() in STOKPOLS:
pol = 'pI'
pind = pols.index(pol)
uvd.data_array[tinds, :, pind] *= J2K[pol]
uvd.data_array[tinds, ..., pind] *= J2K[pol]

if Tsys_outfile is not None:
uvd.write_uvh5(Tsys_outfile, clobber=True)
Expand Down

0 comments on commit cd12a55

Please sign in to comment.