Skip to content

Commit

Permalink
feat: add option to sigma-clip sub-bands homogeneously
Browse files Browse the repository at this point in the history
This adds several options to a new `sigma_clip()` function, including
``clip_type`` which specifies whether data should be directly clipped
or accumulated over some axis before assessing the threshold. A placeholder
for allowing the expected variance to be passed has also been added, but
not piped through (i.e. you can't specify this on the CLI yet).
  • Loading branch information
steven-murray committed Jan 9, 2024
1 parent d45df27 commit aea94b7
Showing 1 changed file with 167 additions and 9 deletions.
176 changes: 167 additions & 9 deletions hera_cal/lstbin_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from . import utils
import warnings
from pathlib import Path
from .lstbin import sigma_clip, make_lst_grid
from .lstbin import make_lst_grid
from . import abscal
import os
from . import io
Expand All @@ -35,6 +35,7 @@
import yaml
from .types import Antpair
from .datacontainer import DataContainer
from typing import Literal

try:
profile
Expand Down Expand Up @@ -501,6 +502,114 @@ def threshold_flags(
return flags


def sigma_clip(
array: np.ndarray | np.ma.MaskedArray,
expected_variance: np.ndarray | None = None,
threshold: float = 4.0,
min_N: int = 4,
median_axis: int = 0,
threshold_axis: int = 0,
clip_type: Literal['direct', 'mean', 'median'] = 'direct',
flag_bands: list[tuple[int, int]] | None = None
):
"""
One-iteration robust sigma clipping algorithm.
Parameters
----------
array
ndarray of *real* data.
expected_variance
Expected variance of the data over the median_axis.
If None, it is estimated from the data.
threshold
Threshold to cut above, in units of the standard deviation.
min_N
minimum length of array to sigma clip, below which no sigma
clipping is performed. Non-clipped values are *not* flagged.
median_axis
Axis along which to perform the median to determine the zscore of individual
data.
threshold_axis
Axis along which to perform the thresholding, if multiple data are to be
combined before thresholding. This is only applicable if ``clip_type`` is
``mean`` or ``median``. In this case, if for example a 2D array is passed in
and ``threshold_axis=1`` (but no ``flag_bands`` is passed), then the mean of
the absolute zscores is take along the final axis, and the output flags are
applied homogeneously across this axis based on this mean.
clip_type
The type of sigma clipping to perform. If ``direct``, each datum is flagged
individually. If ``mean`` or ``median``, an entire sub-band of the data is
flagged if its mean (absolute) zscore is beyond the threshold.
flag_bands
A list of tuples specifying the start and end indices of the threshold axis
to perform sigma clipping over. If None, the entire threshold axis is used at
once.
Output
------
clip_flags
A boolean array with same shape as input array,
with clipped values set to True.
"""
# ensure array is an array
if not isinstance(array, np.ndarray):
array = np.array(array)

# ensure array passes min_N criterion:
if array.shape[median_axis] < min_N:
return np.zeros_like(array, dtype=bool)

if not isinstance(array, np.ma.MaskedArray):
array = np.ma.MaskedArray(array, mask=np.nonzero(np.isnan(array)))

location = np.expand_dims(np.ma.median(array, axis=median_axis), axis=median_axis)

if expected_variance is None:
scale = np.expand_dims(np.ma.median(np.abs(array - location), axis=median_axis) * 1.482579, axis=median_axis)
else:
scale = np.expand_dims(np.sqrt(expected_variance), axis=median_axis)

if flag_bands is None:
# Use entire threshold axis together
flag_bands = [(0, array.shape[threshold_axis])]

zscore = np.abs(array - location) / scale

clip_flags = np.zeros_like(array, dtype=bool)

for band in flag_bands:
# build the slice index. Don't use np.take with axis= parameter because it
# creates a new array, instead of a view.
mask = [slice(None)] * array.ndim
mask[threshold_axis] = slice(*band)
mask = tuple(mask)

subz = zscore[mask]
subflags = clip_flags[mask]

if clip_type == 'direct':
# In this mode, each datum is flagged individually.
subflags[:] = subz > threshold
elif clip_type == 'mean':
# In this mode, an entire sub-band of the data is flagged if its mean
# (absolute) zscore is beyond the threshold.
mean_abs_dev = np.mean(subz, axis=threshold_axis)
subflags[:] = np.expand_dims(mean_abs_dev > threshold, axis=threshold_axis)
elif clip_type == 'median':
# In this mode, an entire sub-band of the data is flagged if its median
# (absolute) zscore is beyond the threshold.
mean_abs_dev = np.median(subz, axis=threshold_axis)
subflags[:] = np.expand_dims(mean_abs_dev > threshold, axis=threshold_axis)

else:
raise ValueError(
f"clip_type must be 'direct', 'mean' or 'median', got {clip_type}"
)

return clip_flags


@profile
def lst_average(
data: np.ndarray | np.ma.MaskedArray,
Expand All @@ -510,6 +619,9 @@ def lst_average(
sigma_clip_thresh: float | None = None,
sigma_clip_min_N: int = 4,
flag_below_min_N: bool = False,
sigma_clip_subbands: list[int] | None = None,
sigma_clip_type: Literal['direct', 'mean', 'median'] = 'direct',
sigma_clip_expected_variance: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute statistics of a set of data over its zeroth axis.
Expand Down Expand Up @@ -541,6 +653,17 @@ def lst_average(
The minimum number of unflagged samples required to perform sigma clipping.
flag_below_min_N
Whether to flag data that has fewer than ``sigma_clip_min_N`` unflagged samples.
sigma_clip_subbands
A list of integers specifying the start and end indices of the frequency axis
to perform sigma clipping over. If None, the entire frequency axis is used at
once.
sigma_clip_type
The type of sigma clipping to perform. If ``direct``, each datum is flagged
individually. If ``mean`` or ``median``, an entire sub-band of the data is
flagged if its mean (absolute) zscore is beyond the threshold.
sigma_clip_expected_variance
Expected variance of the data over the sigma_clip_subbands. If None, it is
estimated from the data.
Returns
-------
Expand All @@ -562,20 +685,25 @@ def lst_average(

# Now do sigma-clipping.
if sigma_clip_thresh is not None:
if inpainted_mode:
if inpainted_mode and sigma_clip_type == 'direct':
warnings.warn(
"Sigma-clipping in in-painted mode is a bad idea, because it creates "
"Direct-mode sigma-clipping in in-painted mode is a bad idea, because it creates "
"non-uniform flags over frequency, which can cause artificial spectral "
"structure. In-painted mode specifically attempts to avoid this."
)

nflags = np.sum(flags)
clip_flags = sigma_clip(
data.real, sigma=sigma_clip_thresh, min_N=sigma_clip_min_N
)
clip_flags |= sigma_clip(
data.imag, sigma=sigma_clip_thresh, min_N=sigma_clip_min_N
)
kw = {
'expected_variance': sigma_clip_expected_variance,
'threshold': sigma_clip_thresh,
'min_N': sigma_clip_min_N,
'clip_type': sigma_clip_type,
'median_axis': 0,
'threshold_axis': 0 if sigma_clip_type == 'direct' else -2,
'flag_bands': list(zip(sigma_clip_subbands[:-1], sigma_clip_subbands[1:])) if sigma_clip_subbands else None,
}
clip_flags = sigma_clip(data.real, **kw)
clip_flags |= sigma_clip(data.imag, **kw)

# Need to restore min_N condition properly here because it's not done properly in sigma_clip
sc_min_N = np.sum(~flags, axis=0) < sigma_clip_min_N
Expand Down Expand Up @@ -1140,6 +1268,9 @@ def lst_bin_files_single_outfile(
golden_lsts: tuple[float] = (),
sigma_clip_thresh: float | None = None,
sigma_clip_min_N: int = 4,
sigma_clip_subbands: list[int] | None = None,
sigma_clip_type: Literal['direct', 'mean', 'median'] = 'direct',
sigma_clip_expected_variance: np.ndarray | None = None,
flag_below_min_N: bool = False,
flag_thresh: float = 0.7,
freq_min: float | None = None,
Expand Down Expand Up @@ -1258,6 +1389,17 @@ def lst_bin_files_single_outfile(
within an LST-bin required to perform sigma clipping. If `flag_below_min_N`
is False, these (antpair,pol,channel) combinations are not flagged by
sigma-clipping (otherwise they are).
sigma_clip_subbands
A list of integers specifying the start and end indices of the frequency axis
to perform sigma clipping over. If None, the entire frequency axis is used at
once.
sigma_clip_type
The type of sigma clipping to perform. If ``direct``, each datum is flagged
individually. If ``mean`` or ``median``, an entire sub-band of the data is
flagged if its mean (absolute) zscore is beyond the threshold.
sigma_clip_expected_variance
Expected variance of the data over the sigma_clip_subbands. If None, it is
estimated from the data.
flag_below_min_N
If True, flag all (antpair, pol,channel) combinations for an LST-bin that
contiain fewer than `flag_below_min_N` unflagged integrations within the bin.
Expand Down Expand Up @@ -1630,6 +1772,9 @@ def lst_bin_files_single_outfile(
sigma_clip_min_N=sigma_clip_min_N,
flag_below_min_N=flag_below_min_N,
get_mad=write_med_mad,
sigma_clip_subbands=sigma_clip_subbands,
sigma_clip_type=sigma_clip_type,
sigma_clip_expected_variance=sigma_clip_expected_variance,
)

write_baseline_slc_to_file(
Expand Down Expand Up @@ -2371,6 +2516,19 @@ def lst_bin_arg_parser():
help="number of unflagged data points over time to require before considering sigma clipping",
default=4,
)
a.add_argument(
"--sigma-clip-subbands",
type=str,
help="Channels at which bands are separated for homogeneous sigma clipping. Separated by commas.",
default=None,
)
a.add_argument(
"--sigma-clip-type",
type='str',
default='direct',
choices=['direct', 'mean', 'median'],
help="How to threshold the absolute zscores for sigma clipping."
)
a.add_argument(
"--flag-below-min-N",
action="store_true",
Expand Down

0 comments on commit aea94b7

Please sign in to comment.