Skip to content

Commit

Permalink
Merge pull request #374 from keflavich/fix_beam_masking
Browse files Browse the repository at this point in the history
Beam masking bugfixes
  • Loading branch information
keflavich authored Mar 21, 2017
2 parents 8630fb7 + 2abbc65 commit dfa99ed
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
45 changes: 32 additions & 13 deletions spectral_cube/spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2895,8 +2895,8 @@ def identify_bad_beams(self, threshold, reference_beam=None,
A beam to use as the reference. If unspecified, ``mid_value`` will
be used to select a middle beam
criteria : list
A list of criteria to compare. Can include 'sr','major','minor' or
any subset of those.
A list of criteria to compare. Can include
'sr','major','minor','pa' or any subset of those.
mid_value : function
The function used to determine the 'mid' value to compare to. This
will identify the middle-valued beam area.
Expand All @@ -2906,15 +2906,26 @@ def identify_bad_beams(self, threshold, reference_beam=None,
includemask : np.array
A boolean array where ``True`` indicates the good beams
"""
from radio_beam import Beam

includemask = np.ones(len(self.beams), dtype='bool')

sr = u.Quantity([getattr(beam, 'sr') for beam in self.beams])
all_criteria = ['sr','major','minor','pa']
if not set.issubset(set(criteria), set(all_criteria)):
raise ValueError("Criteria must be one of the allowed options: "
"{0}".format(all_criteria))

props = {prop: u.Quantity([getattr(beam, prop) for beam in self.beams])
for prop in all_criteria}

if reference_beam is None:
reference_beam = mid_value(sr)
reference_beam = Beam(major=mid_value(props['major']),
minor=mid_value(props['minor']),
pa=mid_value(props['pa'])
)

for prop in criteria:
val = u.Quantity([getattr(beam, prop) for beam in self.beams])
val = props[prop]
mid = getattr(reference_beam, prop)

diff = np.abs((val-mid)/mid)
Expand Down Expand Up @@ -2950,7 +2961,7 @@ def mask_out_bad_beams(self, threshold, reference_beam=None,
beam_threshold=threshold)


def average_beams(self, threshold, mask='compute'):
def average_beams(self, threshold, mask='compute', warn=False):
"""
Average the beams. Note that this operation only makes sense in
limited contexts! Generally one would want to convolve all the beams
Expand All @@ -2966,6 +2977,13 @@ def average_beams(self, threshold, mask='compute'):
mask : 'compute', None, or boolean array
The mask to apply to the beams. Useful for excluding bad channels
and edge beams.
warn : bool
Warn if successful?
Returns
-------
new_beam : radio_beam.Beam
A new radio beam object that is the average of the unmasked beams
"""
if mask == 'compute':
beam_mask = np.any(self.mask.include(), axis=(1,2))
Expand All @@ -2975,12 +2993,13 @@ def average_beams(self, threshold, mask='compute'):
new_beam = cube_utils.average_beams(self.beams, includemask=beam_mask)
assert not np.isnan(new_beam)
self._check_beam_areas(threshold, mean_beam=new_beam, mask=beam_mask)
warnings.warn("Arithmetic beam averaging is being performed. This is "
"not a mathematically robust operation, but is being "
"permitted because the beams differ by "
"<{0}".format(threshold),
BeamAverageWarning
)
if warn:
warnings.warn("Arithmetic beam averaging is being performed. This is "
"not a mathematically robust operation, but is being "
"permitted because the beams differ by "
"<{0}".format(threshold),
BeamAverageWarning
)
return new_beam


Expand Down Expand Up @@ -3018,7 +3037,7 @@ def newfunc(*args, **kwargs):
function.__name__))

if need_to_handle_beams:
avg_beam = self.average_beams(beam_threshold)
avg_beam = self.average_beams(beam_threshold, warn=True)
result.meta['beam'] = avg_beam

return result
Expand Down
6 changes: 6 additions & 0 deletions spectral_cube/tests/test_spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,3 +1291,9 @@ def test_mask_bad_beams():

mean = masked_cube.mean(axis=0)
assert np.all(mean == cube[2,:,:])


masked_cube2 = cube.mask_out_bad_beams(0.5,)

mean2 = masked_cube2.mean(axis=0)
assert np.all(mean2 == (cube[2,:,:]+cube[1,:,:])/2)

0 comments on commit dfa99ed

Please sign in to comment.