diff --git a/sotodlib/preprocess/processes.py b/sotodlib/preprocess/processes.py index 488b18a57..63ee0a659 100644 --- a/sotodlib/preprocess/processes.py +++ b/sotodlib/preprocess/processes.py @@ -804,7 +804,7 @@ def process(self, aman, proc_aman): aman.samps.offset + aman.samps.count - trim)) -class EstimateAzSS(_Preprocess): +class AzSS(_Preprocess): """Estimates Azimuth Synchronous Signal (AzSS) by binning signal by azimuth of boresight. All process confgis go to `get_azss`. If `method` is 'interpolate', no fitting applied and binned signal is directly used as AzSS model. If `method` is 'fit', Legendre polynominal @@ -812,29 +812,66 @@ class EstimateAzSS(_Preprocess): Example configuration block:: - - name: "estimate_azss" + - name: "azss" + azss_stats_name: 'azss_statsQ' + proc_aman_turnaround_info: 'turnaround_flags' calc: signal: 'demodQ' - azss_stats_name: 'azss_statsQ' - range: [-1.57079, 7.85398] + frange: [-1.57079, 7.85398] bins: 1080 - merge_stats: False - merge_model: False + left_right: True save: True + process: + subtract: True .. autofunction:: sotodlib.tod_ops.azss.get_azss """ - name = "estimate_azss" + name = "azss" + def __init__(self, step_cfgs): + self.azss_stats_name = step_cfgs.get('azss_stats_name', 'azss_stats') + self.proc_aman_turnaround_info = step_cfgs.get('proc_aman_turnaround_info', None) + + super().__init__(step_cfgs) def calc_and_save(self, aman, proc_aman): - calc_aman, _ = tod_ops.azss.get_azss(aman, **self.calc_cfgs) - self.save(proc_aman, calc_aman) + # If process is run then just wrap info from process step + if self.process_cfgs: + self.save(proc_aman, aman[self.azss_stats_name]) + else: + if self.proc_aman_turnaround_info: + _f = attergetter(self.proc_aman_turnaround_info) + turnaround_info = _f(proc_aman) + else: + turnaround_info = None + azss_stats, _ = tod_ops.azss.get_azss(aman, turnaround_info=turnaround_info, + azss_stats_name=self.azss_stats_name, + merge_stats = False, merge_model=False, + **self.calc_cfgs) + self.save(proc_aman, azss_stats) def save(self, proc_aman, azss_stats): if self.save_cfgs is None: return if self.save_cfgs: - proc_aman.wrap(self.calc_cfgs["azss_stats_name"], azss_stats) + proc_aman.wrap(self.azss_stats_name, azss_stats) + + def process(self, aman, proc_aman): + subtract = self.process_cfgs.get('subtract', False) + if self.proc_aman_turnaround_info: + _f = attergetter(self.proc_aman_turnaround_info) + turnaround_info = _f(proc_aman) + else: + turnaround_info = None + if self.azss_stats_name in proc_aman: + if subtract: + tod_ops.azss.subtract_azss(aman, proc_aman[self.azss_stats_name], + signal = self.calc_cfgs.get('signal'), + in_place=True) + else: + tod_ops.azss.get_azss(aman, azss_stats_name=self.azss_stats_name, + turnaround_info=turnaround_info, + merge_stats = True, merge_model=False, + subtract_in_place=subtract, **self.calc_cfgs) class GlitchFill(_Preprocess): """Fill glitches. All process configs go to `fill_glitches`. @@ -1641,7 +1678,7 @@ def process(self, aman, proc_aman): _Preprocess.register(SubtractHWPSS) _Preprocess.register(Apodize) _Preprocess.register(Demodulate) -_Preprocess.register(EstimateAzSS) +_Preprocess.register(AzSS) _Preprocess.register(GlitchFill) _Preprocess.register(FlagTurnarounds) _Preprocess.register(SubPolyf) diff --git a/sotodlib/tod_ops/azss.py b/sotodlib/tod_ops/azss.py index e95f3ad43..6ef844f6b 100644 --- a/sotodlib/tod_ops/azss.py +++ b/sotodlib/tod_ops/azss.py @@ -5,11 +5,12 @@ from scipy.interpolate import interp1d from sotodlib import core, tod_ops from sotodlib.tod_ops import bin_signal, apodize, filters +from so3g.proj import Ranges import logging logger = logging.getLogger(__name__) -def bin_by_az(aman, signal=None, az=None, range=None, bins=100, flags=None, +def bin_by_az(aman, signal=None, az=None, frange=None, bins=100, flags=None, apodize_edges=True, apodize_edges_samps=1600, apodize_flags=True, apodize_flags_samps=200): """ @@ -23,13 +24,13 @@ def bin_by_az(aman, signal=None, az=None, range=None, bins=100, flags=None, numpy array of signal to be binned. If None, the signal is taken from aman.signal. az: array-like, optional A 1D numpy array representing the azimuth angles. If not provided, the azimuth angles are taken from aman.boresight.az attribute. - range: array-like, optional + frange: array-like, optional A list specifying the range of azimuth angles to consider for binning. Defaults to None. If None, [min(az), max(az)] will be used for binning. bins: int or sequence of scalars If bins is an int, it defines the number of equal-width bins in the given range (100, by default). If bins is a sequence, it defines the bin edges, including the rightmost edge, allowing for non-uniform bin widths. - If ``bins`` is a sequence, ``bins`` overwrite ``range``. + If ``bins`` is a sequence, ``bins`` overwrite ``frange``. flags: RangesMatrix, optional Flag indicating whether to exclude flagged samples when binning the signal. Default is no mask applied. @@ -80,7 +81,7 @@ def bin_by_az(aman, signal=None, az=None, range=None, bins=100, flags=None, else: weight_for_signal = None binning_dict = bin_signal(aman, bin_by=az, signal=signal, - range=range, bins=bins, flags=flags, weight_for_signal=weight_for_signal) + range=frange, bins=bins, flags=flags, weight_for_signal=weight_for_signal) return binning_dict def fit_azss(az, azss_stats, max_mode, fit_range=None): @@ -141,13 +142,41 @@ def fit_azss(az, azss_stats, max_mode, fit_range=None): return azss_stats, L.legval(x_legendre, coeffs.T) - -def get_azss(aman, signal='signal', az=None, range=None, bins=100, flags=None, - apodize_edges=True, apodize_edges_samps=40000, apodize_flags=True, apodize_flags_samps=200, +def _prepare_azss_stats(aman, signal, az, frange=None, bins=100, flags=None, + apodize_edges=True, apodize_edges_samps=40000, apodize_flags=True, + apodize_flags_samps=200, method='interpolate', max_mode=None): + """ + Helper function to collect initial info for azss_stats AxisManager. + """ + # do binning + binning_dict = bin_by_az(aman, signal=signal, az=az, frange=frange, bins=bins, flags=flags, + apodize_edges=apodize_edges, apodize_edges_samps=apodize_edges_samps, + apodize_flags=apodize_flags, apodize_flags_samps=apodize_flags_samps,) + bin_centers = binning_dict['bin_centers'] + bin_counts = binning_dict['bin_counts'] + binned_signal = binning_dict['binned_signal'] + binned_signal_sigma = binning_dict['binned_signal_sigma'] + uniform_binned_signal_sigma = np.nanmedian(binned_signal_sigma, axis=-1) + + azss_stats = core.AxisManager(aman.dets) + azss_stats.wrap('binned_az', bin_centers, [(0, core.IndexAxis('bin_az_samps', count=bins))]) + azss_stats.wrap('bin_counts', bin_counts, [(0, 'dets'), (1, 'bin_az_samps')]) + azss_stats.wrap('binned_signal', binned_signal, [(0, 'dets'), (1, 'bin_az_samps')]) + azss_stats.wrap('binned_signal_sigma', binned_signal_sigma, [(0, 'dets'), (1, 'bin_az_samps')]) + azss_stats.wrap('uniform_binned_signal_sigma', uniform_binned_signal_sigma, [(0, 'dets')]) + azss_stats.wrap('method', method) + azss_stats.wrap('frange_min', frange[0]) + azss_stats.wrap('frange_max', frange[1]) + if max_mode: + azss_stats.wrap('max_mode', max_mode) + return azss_stats + +def get_azss(aman, signal='signal', az=None, frange=None, bins=100, flags=None, + apodize_edges=True, apodize_edges_samps=1600, apodize_flags=True, apodize_flags_samps=200, apply_prefilt=True, prefilt_cfg=None, prefilt_detrend='linear', method='interpolate', max_mode=None, subtract_in_place=False, - merge_stats=True, azss_stats_name='azss_stats', - merge_model=True, azss_model_name='azss_model'): + merge_stats=True, azss_stats_name='azss_stats', turnaround_info=None, + merge_model=True, azss_model_name='azss_model', left_right=False): """ Derive azss (Azimuth Synchronous Signal) statistics and model from the given axismanager data. **NOTE:** This function does not modify the ``signal`` unless ``subtract_in_place = True``. @@ -160,14 +189,14 @@ def get_azss(aman, signal='signal', az=None, range=None, bins=100, flags=None, A numpy array representing the signal to be used for azss extraction. If not provided, the signal is taken from aman.signal. az: array-like, optional A 1D numpy array representing the azimuth angles. If not provided, the azimuth angles are taken from aman.boresight.az. - range: list, optional + frange: list, optional A list specifying the range of azimuth angles to consider for binning. Defaults to [-np.pi, np.pi]. If None, [min(az), max(az)] will be used for binning. bins: int or sequence of scalars If bins is an int, it defines the number of equal-width bins in the given range (100, by default). If bins is a sequence, it defines the bin edges, including the rightmost edge, allowing for non-uniform bin widths. - If `bins` is a sequence, `bins` overwrite `range`. - flags : RangesMatrix, optinal + If `bins` is a sequence, `bins` overwrite `frange`. + flags : RangesMatrix, optional Flag indicating whether to exclude flagged samples when binning the signal. Default is no mask applied. apodize_edges : bool, optional @@ -201,6 +230,11 @@ def get_azss(aman, signal='signal', az=None, range=None, bins=100, flags=None, Boolean flag indicating whether to merge the azss model with the aman. Defaults to True. azss_model_name: string, optional The name to assign to the merged azss model. Defaults to 'azss_model'. + left_right: bool + Default False. If True estimate (and subtract) the AzSS template for left and right subscans + separately. + turnaround_flags: FlagManager or AxisManager + Optional, default is aman.flags. Returns ------- @@ -235,43 +269,119 @@ def get_azss(aman, signal='signal', az=None, range=None, bins=100, flags=None, if az is None: az = aman.boresight.az - - # do binning - binning_dict = bin_by_az(aman, signal=signal, az=az, range=range, bins=bins, flags=flags, - apodize_edges=apodize_edges, apodize_edges_samps=apodize_edges_samps, - apodize_flags=apodize_flags, apodize_flags_samps=apodize_flags_samps,) - bin_centers = binning_dict['bin_centers'] - bin_counts = binning_dict['bin_counts'] - binned_signal = binning_dict['binned_signal'] - binned_signal_sigma = binning_dict['binned_signal_sigma'] - uniform_binned_signal_sigma = np.nanmedian(binned_signal_sigma, axis=-1) - - azss_stats = core.AxisManager(aman.dets) - azss_stats.wrap('binned_az', bin_centers, [(0, core.IndexAxis('bin_az_samps', count=bins))]) - azss_stats.wrap('bin_counts', bin_counts, [(0, 'dets'), (1, 'bin_az_samps')]) - azss_stats.wrap('binned_signal', binned_signal, [(0, 'dets'), (1, 'bin_az_samps')]) - azss_stats.wrap('binned_signal_sigma', binned_signal_sigma, [(0, 'dets'), (1, 'bin_az_samps')]) - azss_stats.wrap('uniform_binned_signal_sigma', uniform_binned_signal_sigma, [(0, 'dets')]) - - if method == 'fit': - if type(max_mode) is not int: - raise ValueError('max_mode is not provided as integer') - azss_stats, model_sig_tod = fit_azss(az=az, azss_stats=azss_stats, max_mode=max_mode, fit_range=range) - - if method == 'interpolate': - f_template = interp1d(bin_centers, binned_signal, fill_value='extrapolate') - model_sig_tod = f_template(aman.boresight.az) + + if turnaround_info is None: + turnaround_info = aman.flags + if isinstance(turnaround_info, str): + _f = attrgetter(turnaround_info) + turnaround_info = _f(aman) + if (not isinstance(turnaround_info, (core.AxisManager, core.FlagManager))) and left_right: + raise TypeError('turnaround_info must be AxisManager or FlagManager') + if flags is None: + flags = Ranges.from_mask(np.zeros(aman.samps.count).astype(bool)) + + if left_right: + if "valid_right_scans" not in turnaround_info: + left_mask = turnaround_info.left_scan + right_mask = turnaround_info.right_scan + else: + left_mask = turnaround_info.valid_left_scans + right_mask = turnaround_info.valid_right_scans + + azss_left = _prepare_azss_stats(aman, signal, az, frange, bins, flags+left_mask, apodize_edges, + apodize_edges_samps, apodize_flags, apodize_flags_samps, + method=method, max_mode=max_mode) + azss_left.add_axis(aman.samps) + azss_left.wrap('mask', left_mask, [(0, 'samps')]) + azss_right = _prepare_azss_stats(aman, signal, az, frange, bins, flags+right_mask, apodize_edges, + apodize_edges_samps, apodize_flags, apodize_flags_samps, + method=method, max_mode=max_mode) + azss_right.add_axis(aman.samps) + azss_right.wrap('mask', right_mask, [(0, 'samps')]) + azss_stats = core.AxisManager(aman.dets) + azss_stats.wrap('azss_stats_left', azss_left) + azss_stats.wrap('azss_stats_right', azss_right) + azss_stats.wrap('left_right', left_right) + else: + azss_stats = _prepare_azss_stats(aman, signal, az, frange, bins, flags, apodize_edges, + apodize_edges_samps, apodize_flags, apodize_flags_samps, + method=method, max_mode=max_mode) + azss_stats.wrap('left_right', left_right) if merge_stats: aman.wrap(azss_stats_name, azss_stats) - if merge_model: - aman.wrap(azss_model_name, model_sig_tod, [(0, 'dets'), (1, 'samps')]) - if subtract_in_place: - aman[signal_name] = np.subtract(signal, model_sig_tod, dtype='float32') + model_sig_tod = None + + if merge_model or subtract_in_place: + if left_right: + azss_stats, model_left, model_right = get_model_sig_tod(aman, azss_stats, az) + if merge_model: + aman.wrap(azss_model_name+'_left', model_left, [(0, 'dets'), (1, 'samps')]) + aman.wrap(azss_model_name+'_right', model_right, [(0, 'dets'), (1, 'samps')]) + if subtract_in_place: + if signal_name is None: + lmask = left_mask.mask() + signal[:,lmask] -= model_left[:,lmask].astype(signal.dtype) + rmask = right_mask.mask() + signal[:,rmask] -= model_right[:,rmask].astype(signal.dtype) + else: + lmask = left_mask.mask() + aman[signal_name][:,lmask] -= model_left[:,lmask].astype(aman[signal_name].dtype) + rmask = right_mask.mask() + aman[signal_name][:,rmask] -= model_right[:,rmask].astype(aman[signal_name].dtype) + else: + azss_stats, model = get_model_sig_tod(aman, azss_stats, az) + if merge_model: + aman.wrap(azss_model_name, model, [(0, 'dets'), (1, 'samps')]) + if subtract_in_place: + if signal_name is None: + signal -= model.astype(signal.dtype) + else: + aman[signal_name] -= model.astype(aman[signal_name].dtype) + return azss_stats, model_sig_tod -def subtract_azss(aman, signal='signal', azss_template_name='azss_model', - subtract_name='azss_remove', in_place=False, remove_template=True): +def get_model_sig_tod(aman, azss_stats, az=None): + """ + Function to return the azss template for subtraction given the azss_stats AxisManager + """ + # Need to handle left_right field in here. + if az is None: + az = aman.boresight.az + + if azss_stats.left_right: + model = [] + for fld in ['azss_stats_left', 'azss_stats_right']: + _azss_stats = azss_stats[fld] + if _azss_stats.method == 'fit': + if type(_azss_stats.max_mode) is not int: + raise ValueError('max_mode is not provided as integer') + _azss_stats, _model = fit_azss(az=az, azss_stats=_azss_stats, + max_mode=_azss_stats.max_mode, + fit_range=[_azss_stats.frange_min, _azss_stats.frange_max]) + azss_stats.wrap(fld, _azss_stats, overwrite=True) + model.append(_model) + + if _azss_stats.method == 'interpolate': + f_template = interp1d(_azss_stats.binned_az, + _azss_stats.binned_signal, fill_value='extrapolate') + _model = f_template(az) + model.append(_model) + return azss_stats, model[0], model[1] + + else: + if type(azss_stats.max_mode) is not int: + raise ValueError('max_mode is not provided as integer') + azss_stats, model = fit_azss(az=az, azss_stats=azss_stats, + max_mode=azss_stats.max_mode, + fit_range=[azss_stats.frange_min, azss_stats.frange_max]) + if azss_stats.method == 'interpolate': + f_template = interp1d(azss_stats.binned_az, azss_stats.binned_signal, fill_value='extrapolate') + model = f_template(az) + return azss_stats, model, None + +def subtract_azss(aman, azss_stats, signal='signal', subtract_name='azss_remove', + in_place=False): """ Subtract the scan synchronous signal (azss) template from the signal in the given axis manager. @@ -280,21 +390,17 @@ def subtract_azss(aman, signal='signal', azss_template_name='azss_model', ---------- aman : AxisManager The axis manager containing the signal and the azss template. + azss_stats: AxisManager + Contains AxisManager from get_azss. signal : str, optional The name of the field in the axis manager containing the signal to be processed. Defaults to 'signal'. - azss_template_name : str, optional - The name of the field in the axis manager containing the azss template. - Defaults to 'azss_model'. subtract_name : str, optional The name of the field in the axis manager that will store the azss-subtracted signal. Only used if in_place is False. Defaults to 'azss_remove'. in_place : bool, optional If True, the subtraction is done in place, modifying the original signal in the axis manager. If False, the result is stored in a new field specified by subtract_name. Defaults to False. - remove_template : bool, optional - If True, the azss template field is removed from the axis manager after subtraction. - Defaults to True. Returns ------- @@ -313,15 +419,31 @@ def subtract_azss(aman, signal='signal', azss_template_name='azss_model', else: raise TypeError("Signal must be None, str, or ndarray") + azss_stats, model_left, model_right = get_model_sig_tod(aman, azss_stats, az=None) + if in_place: if signal_name is None: - signal -= aman[azss_template_name].astype(signal.dtype) + if azss_stats.left_right: + for model, azss_fld in zip([model_left, model_right], ['azss_stats_left', 'azss_stats_right']): + mask = azss_stats[azss_fld]['mask'].mask() + signal[:, mask] -= model[:, mask].astype(signal.dtype) + else: + signal -= model_left.astype(signal.dtype) else: - aman[signal_name] -= aman[azss_template_name].astype(aman[signal_name].dtype) + if azss_stats.left_right: + for model, azss_fld in zip([model_left, model_right], ['azss_stats_left', 'azss_stats_right']): + mask = azss_stats[azss_fld]['mask'].mask() + aman[signal_name][:, mask] -= model[:, mask].astype(aman[signal_name].dtype) + else: + aman[signal_name] -= model_left.astype(aman[signal_name].dtype) else: - aman.wrap(subtract_name, - np.subtract(aman[signal_name], aman[azss_template_name], dtype='float32'), - [(0, 'dets'), (1, 'samps')]) - - if remove_template: - aman.move(azss_template_name, None) + if azss_stats.left_right: + wrap_sig = np.copy(signal) + for model, azss_fld in zip([model_left, model_right], ['azss_stats_left', 'azss_stats_right']): + mask = azss_stats[azss_fld]['mask'].mask() + wrap_sig[:, mask] -= model[:, mask].astype(signal.dtype) + aman.wrap(subtract_name, wrap_sig, [(0, 'dets'), (1, 'samps')]) + else: + aman.wrap(subtract_name, + np.subtract(aman[signal_name], model_left, dtype='float32'), + [(0, 'dets'), (1, 'samps')])