Skip to content

Commit

Permalink
fix: blt slicing for new uvdata 3.0 functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-murray committed Jan 31, 2024
1 parent bf8e867 commit 54c41e5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
30 changes: 18 additions & 12 deletions hera_cal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,11 @@ def get_blt_slices(uvo, tried_to_reorder=False):
Returns:
blt_slices: dictionary mapping anntenna pair tuples to baseline-time slice objects
'''
if hasattr(uvo, 'blts_are_rectangular') and uvo.blts_are_rectangular is None:
uvo.set_rectangularity()
if uvo.blts_are_rectangular is None:
uvo.set_rectangularity(force=True)

blt_slices = {}
if getattr(uvo, 'blts_are_rectangular', False):
if uvo.blts_are_rectangular:
if uvo.time_axis_faster_than_bls:
for i in range(uvo.Nbls):
start = i * uvo.Ntimes
Expand All @@ -474,19 +474,22 @@ def get_blt_slices(uvo, tried_to_reorder=False):
else:
for ant1, ant2 in uvo.get_antpairs():
indices = uvo.antpair2ind(ant1, ant2)
if len(indices) == 1: # only one blt matches
if isinstance(indices, slice):
blt_slices[(ant1, ant2)] = indices
elif indices is None:
raise ValueError(f"Antpair ({ant1}, {ant2}) does not exist in the data.")
elif len(indices) == 1: # only one blt matches
blt_slices[(ant1, ant2)] = slice(indices[0], indices[0] + 1, uvo.Nblts)
elif not (len(set(np.ediff1d(indices))) == 1): # checks if the consecutive differences are all the same
elif len(set(np.ediff1d(indices))) != 1: # checks if the consecutive differences are all the same
if not tried_to_reorder:
uvo.reorder_blts(order='time')
return get_blt_slices(uvo, tried_to_reorder=True)
else:
raise NotImplementedError(
'UVData objects with non-regular spacing of '
'baselines in its baseline-times are not supported.'
f'Got indices {indices} for baseline {ant1}, {ant2}.'
)
else:
blt_slices[(ant1, ant2)] = slice(indices[0], indices[-1] + 1, indices[1] - indices[0])
return blt_slices


Expand Down Expand Up @@ -626,6 +629,7 @@ def get_metadata_dict(self):
times_by_bl = {antpair: np.array(self.time_array[self._blt_slices[antpair]])
for antpair in antpairs}
times_by_bl.update({(ant1, ant0): times_here for (ant0, ant1), times_here in times_by_bl.items()})

lsts_by_bl = {antpair: np.array(self.lst_array[self._blt_slices[antpair]])
for antpair in antpairs}
lsts_by_bl.update({(ant1, ant0): lsts_here for (ant0, ant1), lsts_here in lsts_by_bl.items()})
Expand All @@ -635,6 +639,7 @@ def get_metadata_dict(self):

def _determine_blt_slicing(self):
'''Determine the mapping between antenna pairs and slices of the blt axis of the data_array.'''

self._blt_slices = get_blt_slices(self)

def get_polstr_index(self, pol: str) -> int:
Expand Down Expand Up @@ -853,9 +858,10 @@ def read(self, bls=None, polarizations=None, times=None, time_range=None, lsts=N
self.set_rectangularity(force=True)

# process data into DataContainers
if read_data or self.filetype in ['uvh5', 'uvfits']:
self._determine_blt_slicing()
self._determine_pol_indexing()
self._clear_antpair2ind_cache(self) # required because we over-wrote read()
self._determine_blt_slicing()
self._determine_pol_indexing()

if read_data and return_data:
return self.build_datacontainers()

Expand Down Expand Up @@ -1443,8 +1449,8 @@ def __init__(self, input_data, read_metadata=True, check=False, skip_lsts=False)
def _adapt_metadata(self, info_dict, skip_lsts=False):
'''Updates metadata from read_hera_hdf5 to better match HERAData. Updates info_dict in place.'''
info_dict['data_ants'] = sorted(info_dict['data_ants'])
info_dict['antpairs'] = sorted(info_dict['bls'])
info_dict['bls'] = sorted(set([ap + (pol, ) for ap in info_dict['antpairs'] for pol in info_dict['pols']]))
info_dict['antpairs'] = info_dict['bls']
info_dict['bls'] = list(set([ap + (pol, ) for ap in info_dict['antpairs'] for pol in info_dict['pols']]))
XYZ = XYZ_from_LatLonAlt(info_dict['latitude'] * np.pi / 180, info_dict['longitude'] * np.pi / 180, info_dict['altitude'])
enu_antpos = ENU_from_ECEF(
np.array([antpos for ant, antpos in info_dict['antpos'].items()]) + XYZ,
Expand Down
13 changes: 7 additions & 6 deletions hera_cal/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def compare_datacontainers(self, dc1, dc2, allow_close: bool = False):
np.testing.assert_array_equal(dc1.times_by_bl[ap], dc2.times_by_bl[ap])
for ap in dc1.lsts_by_bl:
np.testing.assert_allclose(dc1.lsts_by_bl[ap], dc2.lsts_by_bl[ap])

@pytest.mark.parametrize(
'infile', (['uvh5_1'], ['uvh5_1', 'uvh5_2'], 'uvh5_h4c')
)
Expand Down Expand Up @@ -952,7 +952,7 @@ def test_comp_to_HERAData_dc(self, infile, bls, pols):
d2, f2, n2 = hd2.read(bls=bls, polarizations=pols)
# compare all data and metadata
for dc1, dc2 in zip([d, f, n], [d2, f2, n2]):
self.compare_datacontainers(dc1, dc2, allow_close=infile != 'uvh5_h4c')
self.compare_datacontainers(dc1, dc2, allow_close=infile != 'uvh5_h4c')

@pytest.mark.parametrize(
'infile', (['uvh5_1'], 'uvh5_h4c')
Expand All @@ -975,8 +975,9 @@ def test_comp_to_HERAData_hd(self, infile):
np.testing.assert_array_equal(hd1.ants, hd2.ants)
np.testing.assert_array_equal(hd1.data_ants, hd2.data_ants)
np.testing.assert_array_equal(hd1.pols, hd2.pols)
np.testing.assert_array_equal(hd1.antpairs, hd2.antpairs)
np.testing.assert_array_equal(hd1.bls, hd2.bls)
assert set(hd1.antpairs) == set(hd2.antpairs)
assert set(hd1.bls) == set(hd2.bls)

for ant in hd1.antpos:
np.testing.assert_array_almost_equal(hd1.antpos[ant] - hd2.antpos[ant], 0)
for ant in hd1.data_antpos:
Expand Down Expand Up @@ -1788,7 +1789,7 @@ def test_default(self):
np.testing.assert_equal(self.uvd_default.ant_2_array, uvd.ant_2_array)
np.testing.assert_equal(self.uvd_default.time_array, uvd.time_array)
np.testing.assert_equal(self.uvd_default.lst_array, uvd.lst_array)

def test_blfirst(self):
uvd = io.uvdata_from_fastuvh5(self.meta_blfirst)

Expand All @@ -1798,7 +1799,7 @@ def test_blfirst(self):
np.testing.assert_equal(self.uvd_blfirst.ant_2_array, uvd.ant_2_array)
np.testing.assert_equal(self.uvd_blfirst.time_array, uvd.time_array)
np.testing.assert_equal(self.uvd_blfirst.lst_array, uvd.lst_array)

def test_lsts_without_start_jd(self):
with pytest.raises(AttributeError, match='if times is not given, start_jd must be given'):
io.uvdata_from_fastuvh5(self.meta_default, times=None, start_jd=None, lsts=np.array([0.1, 0.2]))
Expand Down
6 changes: 2 additions & 4 deletions hera_cal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,9 +1410,7 @@ def red_average(data, reds=None, bl_tol=1.0, inplace=False,
n = np.asarray([data.get_nsamples(bl + (pol,)) for bl in blg])
tint = []
for bl in blg:
blinds = data.antpair2ind(bl + (pol,))
if len(blinds) == 0:
blinds = data.antpair2ind(reverse_bl(bl + (pol,)))
blinds = data.antpair2ind(bl + (pol,), ordered=False)
tint.append(data.integration_time[blinds])
tint = np.asarray(tint)[:, :, None]
w = np.asarray([wgts[bl + (pol,)] for bl in blg]) * tint
Expand Down Expand Up @@ -1445,7 +1443,7 @@ def red_average(data, reds=None, bl_tol=1.0, inplace=False,
else:
blinds = data.antpair2ind(blk)
polind = pols.index(pol)
if len(blinds) == 0:
if blinds is None:
blinds = data.antpair2ind(reverse_bl(blk))
davg = np.conj(davg)
data.data_array[blinds, :, polind] = davg
Expand Down

0 comments on commit 54c41e5

Please sign in to comment.