Skip to content

Commit

Permalink
Add subscan operations to preprocess (#1028)
Browse files Browse the repository at this point in the history
Add basic statistics, PSDs, and glitches on subscans.

* New `subscan_info` aman added in get_turnaround_flags.
* Added arguments "subscan=True" for get_glitch_flags, calc_psd, fit_noise_model.
* New function flags.get_stats for computing basic TOD/PSD statistics on subscans/full obs.
* New preprocess function GetStats with option to plot the TOD.
* Added plot function to preprocess.PSDCalc to plot the PSD.
  • Loading branch information
earosenberg authored Dec 4, 2024
1 parent 5b66cae commit cfc268f
Show file tree
Hide file tree
Showing 6 changed files with 560 additions and 96 deletions.
1 change: 1 addition & 0 deletions docs/preprocess.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ Flagging and Products
.. autoclass:: sotodlib.preprocess.processes.FlagTurnarounds
.. autoclass:: sotodlib.preprocess.processes.DarkDets
.. autoclass:: sotodlib.preprocess.processes.SourceFlags
.. autoclass:: sotodlib.preprocess.processes.GetStats

HWP Related
:::::::::::
Expand Down
19 changes: 16 additions & 3 deletions sotodlib/preprocess/pcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .. import core
from so3g.proj import Ranges, RangesMatrix
from scipy.sparse import csr_array
from matplotlib import pyplot as plt

class _Preprocess(object):
"""The base class for Preprocessing modules which defines the required
Expand Down Expand Up @@ -270,16 +271,27 @@ def _expand(new, full, wrap_valid=True):
continue
out.wrap_new( k, new._assignments[k], cls=_zeros_cls(v))
oidx=[]; nidx=[]
for a in new._assignments[k]:
for ii, a in enumerate(new._assignments[k]):
if a == 'dets':
oidx.append(fs_dets)
nidx.append(ns_dets)
elif a == 'samps':
oidx.append(fs_samps)
nidx.append(ns_samps)
else:
oidx.append(slice(None))
nidx.append(slice(None))
if (ii == 0) and isinstance(out[k], RangesMatrix): # Treat like dets
# _ranges_matrix_match expects oidx[0] and nidx[0] to be list(inds), not slice.
# Unknown axes treated as dets if first entry, else like samps. Added to support (subscans, samps) RangesMatrix.
if a in full._axes:
_, fs, ns = full[a].intersection(new[a], return_slices=True)
else:
fs = range(new[a].count)
ns = range(new[a].count)
oidx.append(fs)
nidx.append(ns)
else: # Treat like samps
oidx.append(slice(None))
nidx.append(slice(None))
oidx = tuple(oidx)
nidx = tuple(nidx)
if isinstance(out[k], RangesMatrix):
Expand Down Expand Up @@ -456,6 +468,7 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
update_full_aman( proc_aman, full, self.wrap_valid)
if update_plot:
process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png'))
plt.close()
if select:
process.select(aman, proc_aman)
proc_aman.restrict('dets', aman.dets.vals)
Expand Down
38 changes: 38 additions & 0 deletions sotodlib/preprocess/preprocess_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,44 @@ def plot_trending_flags(aman, trend_aman, filename='./trending_flags.png'):
os.makedirs(head_tail[0], exist_ok=True)
plt.savefig(filename)

def plot_signal(aman, signal=None, xx=None, signal_name="signal", x_name="timestamps", plot_ds_factor=50, plot_ds_factor_dets=None, xlim=None, alpha=0.2, yscale='linear', y_unit=None, filename="./signal.png"):
from operator import attrgetter
if plot_ds_factor_dets is None:
plot_ds_factor_dets = plot_ds_factor
if signal is None:
signal = attrgetter(signal_name)(aman)
if xx is None:
xx = attrgetter(x_name)(aman)
yy = signal[::plot_ds_factor_dets, 1::plot_ds_factor].copy() # (dets, samps); (dets, nusamps); (dets, nusamps, subscans)
xx = xx[1::plot_ds_factor].copy() # (samps); (nusamps)
if x_name == "timestamps":
xx -= xx[0]
if yy.ndim > 2: # Flatten subscan axis into dets
yy = yy.swapaxes(1,2).reshape(-1, yy.shape[1])

if xlim is not None:
xinds = np.logical_and(xx >= xlim[0], xx <= xlim[1])
xx = xx[xinds]
yy = yy[:,xinds]

fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
ax.plot(xx, yy.T, color='k', alpha=0.2)
ax.set_yscale(yscale)
if "freqs" in x_name:
ax.set_xlabel("freq [Hz]")
else:
ax.set_xlabel(f"{x_name} [s]")
y_unit = "" if y_unit is None else f" [{y_unit}]"
ax.set_ylabel(f"{signal_name.replace('.Pxx', '')}{y_unit}")
plt.suptitle(f"{aman.obs_info.obs_id}, dT = {np.ptp(aman.timestamps)/60:.1f} min")
plt.tight_layout()
head_tail = os.path.split(filename)
os.makedirs(head_tail[0], exist_ok=True)
plt.savefig(filename)

def plot_psd(aman, signal=None, xx=None, signal_name="psd.Pxx", x_name="psd.freqs", plot_ds_factor=4, plot_ds_factor_dets=20, xlim=None, alpha=0.2, yscale='log', y_unit=None, filename="./psd.png"):
return plot_signal(aman, signal, xx, signal_name, x_name, plot_ds_factor, plot_ds_factor_dets, xlim, alpha, yscale, y_unit, filename)

def plot_signal_diff(aman, flag_aman, flag_type="glitches", flag_threshold=10, plot_ds_factor=50, filename="./glitch_signal_diff.png"):
"""
Function for plotting the difference in signal before and after cuts from either glitches or jumps.
Expand Down
142 changes: 124 additions & 18 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class GlitchDetection(_FracFlaggedMixIn, _Preprocess):
buffer: 10
hp_fc: 1
n_sig: 10
subscan: False
save: True
plot:
plot_ds_factor: 50
Expand Down Expand Up @@ -340,6 +341,7 @@ def plot(self, aman, proc_aman, filename):
plot_ds_factor=self.plot_cfgs.get("plot_ds_factor", 50), filename=filename.replace('{name}', f'{ufm}_jump_signal_diff'))
plot_flag_stats(aman, proc_aman[name], flag_type='jumps', filename=filename.replace('{name}', f'{ufm}_jumps_stats'))


class PSDCalc(_Preprocess):
""" Calculate the PSD of the data and add it to the Preprocessing AxisManager under the
"psd" field.
Expand All @@ -353,6 +355,7 @@ class PSDCalc(_Preprocess):
"psd_cfgs": # optional, kwargs to scipy.welch
"nperseg": 1024
"wrap_name": "psd" # optional
"subscan": False
"save": True
.. autofunction:: sotodlib.tod_ops.fft_ops.calc_psd
Expand All @@ -368,21 +371,105 @@ def __init__(self, step_cfgs):
def calc_and_save(self, aman, proc_aman):
freqs, Pxx = tod_ops.fft_ops.calc_psd(aman, signal=aman[self.signal],
**self.calc_cfgs)
fft_aman = core.AxisManager(
aman.dets,
core.OffsetAxis("nusamps",len(freqs))
)

fft_aman = core.AxisManager(aman.dets,
core.OffsetAxis("nusamps", len(freqs)))
pxx_axis_map = [(0, "dets"), (1, "nusamps")]
if self.calc_cfgs.get('subscan', False):
fft_aman.wrap("Pxx_ss", Pxx, pxx_axis_map+[(2, aman.subscans)])
Pxx = np.nanmean(Pxx, axis=-1) # Mean of subscans

fft_aman.wrap("freqs", freqs, [(0,"nusamps")])
fft_aman.wrap("Pxx", Pxx, [(0,"dets"), (1,"nusamps")])
fft_aman.wrap("Pxx", Pxx, pxx_axis_map)

self.save(proc_aman, fft_aman)

def save(self, proc_aman, fft_aman):
if not(self.save_cfgs is None):
proc_aman.wrap(self.wrap, fft_aman)
def plot(self, aman, proc_aman, filename):
if self.plot_cfgs is None:
return
if self.plot_cfgs:
from .preprocess_plot import plot_psd

filename = filename.replace('{ctime}', f'{str(aman.timestamps[0])[:5]}')
filename = filename.replace('{obsid}', aman.obs_info.obs_id)
det = aman.dets.vals[0]
ufm = det.split('_')[2]
filename = filename.replace('{name}', f'{ufm}_{self.wrap}')

plot_psd(aman, signal=attrgetter(f"{self.wrap}.Pxx")(proc_aman),
xx=attrgetter(f"{self.wrap}.freqs")(proc_aman), filename=filename, **self.plot_cfgs)


class GetStats(_Preprocess):
""" Get basic statistics from a TOD or its power spectrum.
Example config block:
- name : "tod_stats"
signal: "signal" # optional
wrap: "tod_stats" # optional
calc:
stat_names: ["median", "std"]
split_subscans: False # optional
psd_mask: # optional, for cutting a power spectrum in frequency
freqs: "psd.freqs"
low_f: 1
high_f: 10
save: True
"""
name = "tod_stats"
def __init__(self, step_cfgs):
self.signal = step_cfgs.get('signal', 'signal')
self.wrap = step_cfgs.get('wrap', 'tod_stats')

super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
if self.calc_cfgs.get('psd_mask') is not None:
mask_dict = self.calc_cfgs.get('psd_mask')
_f = attrgetter(mask_dict['freqs'])
try:
freqs = _f(aman)
except KeyError:
freqs = _f(proc_aman)
low_f, high_f = mask_dict['low_f'], mask_dict['high_f']
fmask = np.all([freqs >= low_f, freqs <= high_f], axis=0)
self.calc_cfgs['mask'] = fmask
del self.calc_cfgs['psd_mask']

_f = attrgetter(self.signal)
try:
signal = _f(aman)
except KeyError:
signal = _f(proc_aman)
stats_aman = tod_ops.flags.get_stats(aman, signal, **self.calc_cfgs)
self.save(proc_aman, stats_aman)

def save(self, proc_aman, stats_aman):
if not(self.save_cfgs is None):
proc_aman.wrap(self.wrap, stats_aman)

def plot(self, aman, proc_aman, filename):
if self.plot_cfgs is None:
return
if self.plot_cfgs:
from .preprocess_plot import plot_signal

filename = filename.replace('{ctime}', f'{str(aman.timestamps[0])[:5]}')
filename = filename.replace('{obsid}', aman.obs_info.obs_id)
det = aman.dets.vals[0]
ufm = det.split('_')[2]
filename = filename.replace('{name}', f'{ufm}_{self.signal}')

plot_signal(aman, signal_name=self.signal, x_name="timestamps", filename=filename, **self.plot_cfgs)

class Noise(_Preprocess):
"""Estimate the white noise levels in the data. Assumes the PSD has been
wrapped into the preprocessing AxisManager. All calculation configs goes to `calc_wn`.
wrapped into the preprocessing AxisManager. All calculation configs goes to `calc_wn`.
Saves the results into the "noise" field of proc_aman.
Expand All @@ -391,6 +478,8 @@ class Noise(_Preprocess):
Example config block::
- name: "noise"
fit: False
subscan: False
calc:
low_f: 5
high_f: 10
Expand All @@ -408,28 +497,36 @@ class Noise(_Preprocess):
def __init__(self, step_cfgs):
self.psd = step_cfgs.get('psd', 'psd')
self.fit = step_cfgs.get('fit', False)
self.subscan = step_cfgs.get('subscan', False)

super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
if self.psd not in proc_aman:
raise ValueError("PSD is not saved in Preprocessing AxisManager")
psd = proc_aman[self.psd]

pxx = psd.Pxx_ss if self.subscan else psd.Pxx

if self.calc_cfgs is None:
self.calc_cfgs = {}

if self.fit:
calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=psd.Pxx,
if self.calc_cfgs.get('subscan') is None:
self.calc_cfgs['subscan'] = self.subscan
calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=pxx,
f=psd.freqs,
merge_fit=True,
**self.calc_cfgs)
else:
wn = tod_ops.fft_ops.calc_wn(aman, pxx=psd.Pxx,
wn = tod_ops.fft_ops.calc_wn(aman, pxx=pxx,
freqs=psd.freqs,
**self.calc_cfgs)
calc_aman = core.AxisManager(aman.dets)
calc_aman.wrap("white_noise", wn, [(0,"dets")])
if not self.subscan:
calc_aman = core.AxisManager(aman.dets)
calc_aman.wrap("white_noise", wn, [(0,"dets")])
else:
calc_aman = core.AxisManager(aman.dets, aman.subscan_info.subscans)
calc_aman.wrap("white_noise", wn, [(0,"dets"), (1,"subscans")])

self.save(proc_aman, calc_aman)

Expand Down Expand Up @@ -457,10 +554,12 @@ def select(self, meta, proc_aman=None):
self.select_cfgs['name'] = self.select_cfgs.get('name','noise')

if self.fit:
keep = proc_aman[self.select_cfgs['name']].fit[:,1] <= self.select_cfgs["max_noise"]
wn = proc_aman[self.select_cfgs['name']].fit[:,1]
else:
keep = proc_aman[self.select_cfgs['name']].white_noise <= self.select_cfgs["max_noise"]

wn = proc_aman[self.select_cfgs['name']].white_noise
if self.subscan:
wn = np.nanmean(wn, axis=-1) # Mean over subscans
keep = wn <= np.float64(self.select_cfgs["max_noise"])
meta.restrict("dets", meta.dets.vals[keep])
return meta

Expand Down Expand Up @@ -786,6 +885,9 @@ def calc_and_save(self, aman, proc_aman):
calc_aman = core.AxisManager(aman.dets, aman.samps)
calc_aman.wrap('turnarounds', ta, [(0, 'dets'), (1, 'samps')])

if ('merge_subscans' not in self.calc_cfgs) or (self.calc_cfgs['merge_subscans']):
calc_aman.wrap('subscan_info', aman.subscan_info)

self.save(proc_aman, calc_aman)

def save(self, proc_aman, turn_aman):
Expand Down Expand Up @@ -1083,9 +1185,9 @@ class PCARelCal(_Preprocess):
yfac: 1.5
calc_good_medianw: True
lpf:
type: "low_pass_sine2"
type: "sine2"
cutoff: 1
width: 0.1
trans_width: 0.1
trim_samps: 2000
save: True
plot:
Expand All @@ -1102,6 +1204,7 @@ def __init__(self, step_cfgs):
super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
self.plot_signal = self.signal
if self.calc_cfgs.get("lpf") is not None:
filt = tod_ops.filters.get_lpf(self.calc_cfgs.get("lpf"))
filt_tod = tod_ops.fourier_filter(aman, filt, signal_name='signal')
Expand All @@ -1117,6 +1220,8 @@ def calc_and_save(self, aman, proc_aman):
proc_aman.samps.offset + proc_aman.samps.count - trim))
filt_aman.restrict('samps', (filt_aman.samps.offset + trim,
filt_aman.samps.offset + filt_aman.samps.count - trim))
if self.plot_cfgs:
self.plot_signal = filt_aman[self.signal]

bands = np.unique(aman.det_info.wafer.bandpass)
bands = bands[bands != 'NC']
Expand Down Expand Up @@ -1184,7 +1289,7 @@ def plot(self, aman, proc_aman, filename):
for band in bands:
pca_aman = aman.restrict('dets', aman.dets.vals[proc_aman[self.run_name][f'{band}_idx']], in_place=False)
band_aman = proc_aman[self.run_name].restrict('dets', aman.dets.vals[proc_aman[self.run_name][f'{band}_idx']], in_place=False)
plot_pcabounds(pca_aman, band_aman, filename=filename.replace('{name}', f'{ufm}_{band}_pca'), signal=self.signal, band=band, plot_ds_factor=self.plot_cfgs.get('plot_ds_factor', 20))
plot_pcabounds(pca_aman, band_aman, filename=filename.replace('{name}', f'{ufm}_{band}_pca'), signal=self.plot_signal, band=band, plot_ds_factor=self.plot_cfgs.get('plot_ds_factor', 20))


class PTPFlags(_Preprocess):
Expand Down Expand Up @@ -1384,3 +1489,4 @@ def save(self, proc_aman, split_flg_aman):
_Preprocess.register(DarkDets)
_Preprocess.register(SourceFlags)
_Preprocess.register(HWPAngleModel)
_Preprocess.register(GetStats)
Loading

0 comments on commit cfc268f

Please sign in to comment.