Skip to content

Commit

Permalink
Merge pull request #408 from HERA-Team/recursive_combine_uvpspec
Browse files Browse the repository at this point in the history
Add recursive_combine_uvpspec()
  • Loading branch information
jsdillon authored Oct 28, 2024
2 parents 416f8e8 + 4712b4e commit 2fb76f0
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
49 changes: 49 additions & 0 deletions hera_pspec/tests/test_uvpspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,10 @@ def test_combine_uvpspec(self):
beam=beam)
uvp1 = self._add_optionals(uvp1)

# test single UVPSpec
out = uvpspec.combine_uvpspec([uvp1], verbose=False)
assert id(out) == id(uvp1)

# test concat across pol
uvp2 = copy.deepcopy(uvp1)
uvp2.polpair_array[0] = 1414
Expand Down Expand Up @@ -1030,6 +1034,51 @@ def test_combine_uvpspec_std(self):
out = uvp1 + uvp2 + uvp3
assert out.Npols == 3

def assert_uvpspec_equal(self, uvp1, uvp2):
"""Helper function to compare two UVPSpec objects."""
assert np.all(uvp1.spw_array == uvp2.spw_array)
assert np.all(uvp1.polpair_array == uvp2.polpair_array)
for k in uvp1.data_array:
assert np.allclose(uvp1.data_array[k], uvp2.data_array[k])
assert np.allclose(uvp1.nsample_array[k], uvp2.nsample_array[k])
assert np.allclose(uvp1.integration_array[k], uvp2.integration_array[k])

def test_recursive_combine_uvpspec_single(self):
"""Test recursive_combine_uvpspec with a single UVPSpec object."""
uvps_list = [copy.deepcopy(self.uvp)]
combined_recursive = uvpspec.recursive_combine_uvpspec(uvps_list)
self.assert_uvpspec_equal(combined_recursive, self.uvp)

def test_recursive_combine_uvpspec_pair(self):
"""Test recursive_combine_uvpspec with a pair of UVPSpec objects."""
uvp_copy = copy.deepcopy(self.uvp)
uvp_copy.polpair_array[0] = 1414 # Slight modification for differentiation
uvps_list = [self.uvp, uvp_copy]

combined_recursive = uvpspec.recursive_combine_uvpspec(uvps_list)
combined_standard = uvpspec.combine_uvpspec(uvps_list, merge_history=False, verbose=False)

self.assert_uvpspec_equal(combined_recursive,combined_standard)

def test_recursive_combine_uvpspec_multiple(self):
"""Test recursive_combine_uvpspec with multiple UVPSpec objects."""
uvp1 = copy.deepcopy(self.uvp)
uvp2 = copy.deepcopy(self.uvp)
uvp2.polpair_array[0] = 1414
uvp3 = copy.deepcopy(self.uvp)
uvp3.polpair_array[0] = 1313

uvps_list = [uvp1, uvp2, uvp3]
combined_recursive = uvpspec.recursive_combine_uvpspec(uvps_list)
combined_standard = uvpspec.combine_uvpspec(uvps_list, merge_history=False, verbose=False)

self.assert_uvpspec_equal(combined_recursive, combined_standard)

def test_recursive_combine_uvpspec_empty(self):
"""Test recursive_combine_uvpspec with an empty list."""
with pytest.raises(ValueError, match="Cannot run recursive_combine_uvpspec on length-0 objects."):
uvpspec.recursive_combine_uvpspec([])

def test_conj_blpair_int():
conj_blpair = uvputils._conj_blpair_int(101102103104)
assert conj_blpair == 103104101102
Expand Down
39 changes: 39 additions & 0 deletions hera_pspec/uvpspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,6 +2351,10 @@ def combine_uvpspec(uvps, merge_history=True, verbose=True):
u : UVPSpec object
A UVPSpec object with the data of all the inputs combined.
"""
# Check if only one UVPSpec object is given
if (len(uvps) == 1) and issubclass(type(uvps[0]), UVPSpec):
return uvps[0]

# Perform type checks and get concatenation axis
(uvps, concat_ax, new_spws, new_blpts, new_polpairs,
static_meta) = get_uvp_overlap(uvps, just_meta=False, verbose=verbose)
Expand Down Expand Up @@ -2712,6 +2716,41 @@ def combine_uvpspec(uvps, merge_history=True, verbose=True):
return u


def recursive_combine_uvpspec(uvps):
"""
Method for faster combination of UVPSpec objects by combining them recursively.
This is faster than combine_uvpspec for long lists of files---e.g. if you have
one uvpspec object for every unique baseline and hundreds of baselines. Note:
Histories are not merged, so this is the equivalent of running combine_uvpspec
with merge_history=False.
Parameters
----------
uvps : list
A list of UVPSpec objects to combine.
Returns
-------
u : UVPSpec object
A UVPSpec object with the data of all the inputs combined.
"""
if len(uvps) == 0:
raise ValueError('Cannot run recursive_combine_uvpspec on length-0 objects.')
if len(uvps) == 1:
# Base case: only one object left, return it
return uvps[0]
elif len(uvps) == 2:
# Base case: two uvp objects, add them together
return combine_uvpspec(uvps, merge_history=False, verbose=False) # prevents exponential profileration of copied histories
else:
# Recursive case: split the list in half and add each half
midpoint = len(uvps) // 2
left_sum = recursive_combine_uvpspec(uvps[:midpoint])
right_sum = recursive_combine_uvpspec(uvps[midpoint:])
return combine_uvpspec([left_sum, right_sum], merge_history=False, verbose=False)


def get_uvp_overlap(uvps, just_meta=True, verbose=True):
"""
Given a list of UVPSpec objects or a list of paths to UVPSpec objects,
Expand Down

0 comments on commit 2fb76f0

Please sign in to comment.