diff --git a/sotodlib/preprocess/processes.py b/sotodlib/preprocess/processes.py index a4b738369..0d869f3d6 100644 --- a/sotodlib/preprocess/processes.py +++ b/sotodlib/preprocess/processes.py @@ -244,6 +244,35 @@ class Demodulate(_Preprocess): def process(self, aman, proc_aman): hwp.demod_tod(aman, **self.process_cfgs) + +class GlitchFill(_Preprocess): + """Fill glitches. All process configs go to `fill_glitches`. + + .. autofunction:: sotodlib.tod_ops.gapfill.fill_glitches + """ + name = "glitchfill" + + def process(self, aman): + pcfgs = np.fromiter(self.process_cfgs.keys(), dtype='U16') + if 'glitch_flags' in pcfgs: + flags = aman.flags[self.process_cfgs["glitch_flags"]] + pcfgs = np.delete(pcfgs, np.where(pcfgs == 'glitch_flags')) + else: + flags = None + + if 'signal' in pcfgs: + signal = aman[self.process_cfgs["signal"]] + pcfgs = np.delete(pcfgs, np.where(pcfgs == 'signal')) + else: + signal = None + + args = {} + for pcfg in pcfgs: + args[pcfg] = self.process_cfgs[pcfg] + + tod_ops.gapfill.fill_glitches(aman, signal=signal, glitch_flags=flags, **args) + + _Preprocess.register(Trends.name, Trends) _Preprocess.register(FFTTrim.name, FFTTrim) _Preprocess.register(Detrend.name, Detrend) @@ -255,3 +284,4 @@ def process(self, aman, proc_aman): _Preprocess.register(SubtractHWPSS.name, SubtractHWPSS) _Preprocess.register(Apodize.name, Apodize) _Preprocess.register(Demodulate.name, Demodulate) +_Preprocess.register(GlitchFill.name, GlitchFill) diff --git a/sotodlib/tod_ops/gapfill.py b/sotodlib/tod_ops/gapfill.py index 87369eb23..ce863eb9f 100644 --- a/sotodlib/tod_ops/gapfill.py +++ b/sotodlib/tod_ops/gapfill.py @@ -1,6 +1,9 @@ import numpy as np import so3g +from . import pca +import logging +logger = logging.getLogger(__name__) class Extract: """Container for storage of sparse sub-segments of a vector. This @@ -380,3 +383,80 @@ def get_contaminated_ranges(good_flags, bad_flags): rs.add_interval(int(i0), int(i1)) return contam + +def fill_glitches(aman, nbuf=10, use_pca=False, modes=3, signal=None, + glitch_flags=None, wrap=True): + """ + This function fills pre-computed glitches provided by the caller in + time-ordered data using either a polynomial (default) or PCA-based + approach. Wraps the other functions in the ``tod_ops.gapfill`` module. + + Args + ----- + aman : AxisManager + AxisManager to fill glitches in + nbuf : int + Number of buffer samples to use in polynomial gap filling. + use_pca : bool + Whether or not to fill glitches using pca model. Default is False + modes : int + Number of modes in the pca to use if pca=True. Default is 3. + signal : ndarray or None + Array of data to fill glitches in. If None then uses ``aman.signal``. + Default is None. + glitch_flags : RangesMatrix or None + RangesMatrix containing flags to use for gap filling. If None then + uses ``aman.flags.glitches``. + wrap : bool or str + If True wraps new field called ``gap_filled``, if False returns the + gap filled array, if a string wraps new field with provided name. + + Returns + ------- + signal : ndarray + Returns ndarray with gaps filled from input signal. + """ + # Process Args + if signal is None: + sig = np.copy(aman.signal) + else: + sig = np.copy(signal) + + if glitch_flags is None: + glitch_flags = aman.flags.glitches + + # Polyfill + gaps = get_gap_fill(aman, nbuf=nbuf, flags=glitch_flags, + signal=np.float32(sig)) + sig = gaps.swap(aman, signal=sig) + + #PCA Fill + if use_pca: + if modes > aman.dets.count: + logger.warning(f'modes = {modes} > number of detectors = ' + + f'{aman.dets.count}, setting modes = number of ' + + 'detectors') + modes = aman.dets.count + # fill with poly fill before PCA + gaps = get_gap_fill(aman, nbuf=nbuf, flags=glitch_flags, + signal=np.float32(sig)) + sig = gaps.swap(aman, signal=sig) + # PCA fill + mod = pca.get_pca_model(tod=aman, n_modes=modes, + signal=sig) + gfill = get_gap_model(tod=aman, model=mod, flags=glitch_flags) + sig = gfill.swap(aman, signal=sig) + + # Wrap and Return + if isinstance(wrap, str): + if wrap in aman._assignments: + aman.move(wrap, None) + aman.wrap(wrap, sig, [(0, 'dets'), (1, 'samps')]) + return sig + elif wrap: + if 'gap_filled' in aman._assignments: + aman.move('gap_filled', None) + aman.wrap('gap_filled', sig, [(0, 'dets'), (1, 'samps')]) + return sig + else: + return sig diff --git a/sotodlib/tod_ops/pca.py b/sotodlib/tod_ops/pca.py index bec3bc47e..a8efb201d 100644 --- a/sotodlib/tod_ops/pca.py +++ b/sotodlib/tod_ops/pca.py @@ -109,8 +109,11 @@ def get_pca(tod=None, cov=None, signal=None, wrap=None): output = core.AxisManager(dets, mode_axis) output.wrap('cov', cov, [(0, dets.name), (1, dets.name)]) + # Note eig will sometimes return complex eigenvalues. E, R = np.linalg.eig(cov) # eigh nans sometimes... E[np.isnan(E)] = 0. + E, R = E.real, R.real + idx = np.argsort(-E) output.wrap('E', E[idx], [(0, mode_axis.name)]) output.wrap('R', R[:, idx], [(0, dets.name), (1, mode_axis.name)]) diff --git a/tests/test_tod_ops.py b/tests/test_tod_ops.py index ef64cff93..b362be585 100644 --- a/tests/test_tod_ops.py +++ b/tests/test_tod_ops.py @@ -12,7 +12,7 @@ from numpy.testing import assert_array_equal, assert_allclose -from sotodlib import core, tod_ops +from sotodlib import core, tod_ops, sim_flags import so3g from ._helpers import mpi_multi @@ -40,6 +40,38 @@ def get_tod(sig_type='trendy'): raise RuntimeError(f'sig_type={sig_type}?') return tod + +def get_glitchy_tod(ts, noise_amp=0, ndets=2, npoly=3, poly_coeffs=None): + """Returns axis manager to test fill_glitches""" + fake_signal = np.zeros((ndets, len(ts))) + input_sig = np.zeros((ndets, len(ts))) + if poly_coeffs is None: + poly_coeffs = np.random.uniform(0.5, 1.6, npoly)*1e-1 + poly_sig = np.polyval(poly_coeffs, ts-np.mean(ts)) + for nd in range(ndets): + input_sig[nd] = poly_sig + fake_signal[nd] = poly_sig + noise = np.random.normal(0, noise_amp, size=len(ts)) + fake_signal[nd] += noise + + dets = ['det%i' % i for i in range(ndets)] + + tod_fake = core.AxisManager(core.LabelAxis('dets', vals=dets), + core.OffsetAxis('samps', count=len(ts))) + tod_fake.wrap('timestamps', ts, axis_map=[(0, 'samps')]) + tod_fake.wrap('signal', np.atleast_2d(fake_signal), + axis_map=[(0, 'dets'), (1, 'samps')]) + tod_fake.wrap('inputsignal', np.atleast_2d(input_sig), + axis_map=[(0, 'dets'), (1, 'samps')]) + flgs = core.AxisManager() + tod_fake.wrap('flags', flgs) + params = {'n_glitches': 10, 'sig_n_glitch': 10, 'h_glitc h': 10, + 'sig_h_glitch': 2} + sim_flags.add_random_glitches(tod_fake, params=params, signal='signal', + flag='glitches', overwrite=False) + return tod_fake + + class FactorsTest(unittest.TestCase): def test_inf(self): f = tod_ops.fft_ops.find_inferior_integer @@ -121,6 +153,28 @@ def test_basic(self): assert_allclose(tod.signal[1][gap_mask], sentinel) # ... check "extraction" has model values. assert_allclose(ex[1].data, sig[gap_mask], atol=atol) + + def test_fillglitches(self): + """Tests fill glitches wrapper function""" + ts = np.arange(0, 1*60, 1/200) + aman = get_glitchy_tod(ts, ndets=100) + # test poly fill + up, mg = False, False + glitch_filled = tod_ops.gapfill.fill_glitches(aman, use_pca=up, + wrap=mg) + self.assertTrue(np.max(np.abs(glitch_filled-aman.inputsignal)) < 1e-3) + + # test pca fill + up, mg = True, False + glitch_filled = tod_ops.gapfill.fill_glitches(aman, use_pca=up, + wrap=mg) + print(np.max(np.abs(glitch_filled-aman.inputsignal))) + + # test wrap new field + up, mg = False, True + glitch_filled = tod_ops.gapfill.fill_glitches(aman, use_pca=up, + wrap=mg) + self.assertTrue('gap_filled' in aman._assignments) class FilterTest(unittest.TestCase): def test_basic(self):