Skip to content

Commit

Permalink
subscan preprocess fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
earosenberg committed Nov 21, 2024
1 parent 67b59b0 commit 475c87d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 45 deletions.
15 changes: 12 additions & 3 deletions sotodlib/preprocess/pcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,25 @@ def _expand(new, full, wrap_valid=True):
continue
out.wrap_new( k, new._assignments[k], cls=_zeros_cls(v))
oidx=[]; nidx=[]
for a in new._assignments[k]:
for ii, a in enumerate(new._assignments[k]):
if a == 'dets':
oidx.append(fs_dets)
nidx.append(ns_dets)
elif a == 'samps':
oidx.append(fs_samps)
nidx.append(ns_samps)
else:
oidx.append(slice(None))
nidx.append(slice(None))
if ii == 0: # Treat like dets
if a in full._axes:
_, fs, ns = full[a].intersection(new[a], return_slices=True)
else:
fs = range(new[a].count)
ns = range(new[a].count)
oidx.append(fs)
nidx.append(ns)
else: # Treat like samps
oidx.append(slice(None))
nidx.append(slice(None))
oidx = tuple(oidx)
nidx = tuple(nidx)
if isinstance(out[k], RangesMatrix):
Expand Down
77 changes: 48 additions & 29 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,22 @@ def __init__(self, step_cfgs):


def process(self, aman, proc_aman):
if not self.subscan:
calc_psd = tod_ops.fft_ops.calc_psd
else:
if self.subscan:
calc_psd = tod_ops.fft_ops.calc_psd_subscan
else:
calc_psd = tod_ops.fft_ops.calc_psd

freqs, Pxx = calc_psd(aman, signal=aman[self.signal],
**self.process_cfgs)
fft_aman = core.AxisManager(
aman.dets,
core.OffsetAxis("nusamps",len(freqs))
)
axis_list = [aman.dets, core.OffsetAxis("nusamps", len(freqs))]
pxx_axis_map = [(0, "dets"), (1, "nusamps")]
if self.subscan:
axis_list.append(aman.subscan_info.subscans)
pxx_axis_map.append((2, "subscans"))

fft_aman = core.AxisManager(*axis_list)
fft_aman.wrap("freqs", freqs, [(0,"nusamps")])
fft_aman.wrap("Pxx", Pxx, [(0,"dets"), (1,"nusamps")])
fft_aman.wrap("Pxx", Pxx, pxx_axis_map)
aman.wrap(self.wrap, fft_aman)

def calc_and_save(self, aman, proc_aman):
Expand All @@ -400,17 +403,17 @@ class TODStats(_Preprocess):
Example config block:
- "name : "stats"
"signal: "signal" # optional
"wrap": "stats" # optional
"calc":
"stat_names": ["median", "std"]
"split_subscans": False # optional
"mask": # optional, for cutting a power spectrum in frequency
"psd_aman": "psd"
"low_f": 1
"high_f": 10
"save": True
- name : "stats"
signal: "signal" # optional
wrap: "stats" # optional
calc:
stat_names: ["median", "std"]
split_subscans: False # optional
mask: # optional, for cutting a power spectrum in frequency
freqs: "psd.freqs"
low_f: 1
high_f: 10
save: True
"""
name = "stats"
Expand All @@ -421,14 +424,18 @@ def __init__(self, step_cfgs):
super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
def get_sub(aman, name): # Helper fn to access nested aman entries eg aman.psd.Pxx
cmd = "aman"+"".join([str([xx]) for xx in name.split(".")])
return eval(cmd)

if self.calc_cfgs.get('mask') is not None:
mask_dict = self.calc_cfgs.get('mask')
freqs = aman[mask_dict['psd_aman']]['freqs']
freqs = get_sub(aman, mask_dict['freqs'])
low_f, high_f = mask_dict['low_f'], mask_dict['high_f']
fmask = np.all([freqs >= low_f, freqs <= high_f], axis=0)
self.calc_cfgs['mask'] = fmask

stats_aman = tod_ops.flags.get_stats(aman, aman[self.signal], **self.calc_cfgs)
stats_aman = tod_ops.flags.get_stats(aman, get_sub(aman, self.signal), **self.calc_cfgs)
self.save(proc_aman, stats_aman)

def save(self, proc_aman, stats_aman):
Expand Down Expand Up @@ -464,6 +471,7 @@ class Noise(_Preprocess):
def __init__(self, step_cfgs):
self.psd = step_cfgs.get('psd', 'psd')
self.fit = step_cfgs.get('fit', False)
self.subscan = step_cfgs.get('subscan', False)

super().__init__(step_cfgs)

Expand All @@ -476,16 +484,24 @@ def calc_and_save(self, aman, proc_aman):
self.calc_cfgs = {}

if self.fit:
calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=psd.Pxx,
if not self.subscan:
fit_noise_model = tod_ops.fft_ops.fit_noise_model
else:
fit_noise_model = tod_ops.fft_ops.fit_noise_model_subscan
calc_aman = fit_noise_model(aman, pxx=psd.Pxx,
f=psd.freqs,
merge_fit=True,
**self.calc_cfgs)
else:
wn = tod_ops.fft_ops.calc_wn(aman, pxx=psd.Pxx,
freqs=psd.freqs,
**self.calc_cfgs)
calc_aman = core.AxisManager(aman.dets)
calc_aman.wrap("white_noise", wn, [(0,"dets")])
if not self.subscan:
calc_aman = core.AxisManager(aman.dets)
calc_aman.wrap("white_noise", wn, [(0,"dets")])
else:
calc_aman = core.AxisManager(aman.dets, aman.subscan_info.subscans)
calc_aman.wrap("white_noise", wn, [(0,"dets"), (1,"subscans")])

self.save(proc_aman, calc_aman)

Expand Down Expand Up @@ -513,10 +529,12 @@ def select(self, meta, proc_aman=None):
self.select_cfgs['name'] = self.select_cfgs.get('name','noise')

if self.fit:
keep = proc_aman[self.select_cfgs['name']].fit[:,1] <= self.select_cfgs["max_noise"]
wn = proc_aman[self.select_cfgs['name']].fit[:,1]
else:
keep = proc_aman[self.select_cfgs['name']].white_noise <= self.select_cfgs["max_noise"]

wn = proc_aman[self.select_cfgs['name']].white_noise
if self.subscan:
wn = np.nanmean(wn, axis=-1) # Mean over subscans
keep = wn <= np.float64(self.select_cfgs["max_noise"])
meta.restrict("dets", meta.dets.vals[keep])
return meta

Expand Down Expand Up @@ -843,7 +861,7 @@ def calc_and_save(self, aman, proc_aman):
calc_aman.wrap('turnarounds', ta, [(0, 'dets'), (1, 'samps')])

if ('merge_subscans' not in self.calc_cfgs) or (self.calc_cfgs['merge_subscans']):
subscan_aman = aman.subscans
subscan_aman = aman.subscan_info
else:
subscan_aman = None

Expand All @@ -855,7 +873,7 @@ def save(self, proc_aman, turn_aman, subscan_aman):
if self.save_cfgs:
proc_aman.wrap("turnaround_flags", turn_aman)
if subscan_aman is not None:
proc_aman.wrap("subscans", subscan_aman)
proc_aman.wrap("subscan_info", subscan_aman)

def process(self, aman, proc_aman):
tod_ops.flags.get_turnaround_flags(aman, **self.process_cfgs)
Expand Down Expand Up @@ -1385,3 +1403,4 @@ def process(self, aman, proc_aman):
_Preprocess.register(DarkDets)
_Preprocess.register(SourceFlags)
_Preprocess.register(HWPAngleModel)
_Preprocess.register(TODStats)
12 changes: 6 additions & 6 deletions sotodlib/tod_ops/fft_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ def calc_psd_subscan(aman, signal=None, freq_spacing=None, merge=False, wrap=Non
if freq_spacing is not None:
nperseg = int(2 ** (np.around(np.log2(fs / freq_spacing))))
else:
duration_samps = np.sum(aman.subscans.subscan_flags.mask(), axis=1)
duration_samps = np.sum(aman.subscan_info.subscan_flags.mask(), axis=1)
duration_samps = duration_samps[duration_samps > 0]
nperseg = int(2 ** (np.around(np.log2(np.median(duration_samps) / 4))))
kwargs["nperseg"] = nperseg

Pxx = []
for iss in range(aman.subscans.subscans.count):
for iss in range(aman.subscan_info.subscans.count):
signal_ss = get_subscan_signal(aman, signal, iss)
axis = -1 if "axis" not in kwargs else kwargs["axis"]
if signal_ss.shape[axis] >= kwargs["nperseg"]:
Expand Down Expand Up @@ -526,8 +526,8 @@ def fit_noise_model_subscan(
Fits noise model with white and 1/f noise to the PSD of signal subscans.
Args are as for fit_noise_model.
"""
fitout = np.empty((aman.dets.count, 3, aman.subscans.subscans.count))
covout = np.empty((aman.dets.count, 3, 3, aman.subscans.subscans.count))
fitout = np.empty((aman.dets.count, 3, aman.subscan_info.subscans.count))
covout = np.empty((aman.dets.count, 3, 3, aman.subscan_info.subscans.count))

if signal is None:
signal = aman.signal
Expand All @@ -541,7 +541,7 @@ def fit_noise_model_subscan(
**psdargs,
)

for isub in range(aman.subscans.subscans.count):
for isub in range(aman.subscan_info.subscans.count):
if np.all(np.isnan(pxx[...,isub])): # Subscan has been fully cut
fitout[..., isub] = np.full((aman.dets.count, 3), np.nan)
covout[..., isub] = np.full((aman.dets.count, 3, 3), np.nan)
Expand All @@ -555,7 +555,7 @@ def fit_noise_model_subscan(
noise_fit_stats = core.AxisManager(
aman.dets,
noise_model.noise_model_coeffs,
aman.subscans.subscans
aman.subscan_info.subscans
)
noise_fit_stats.wrap("fit", fitout, [(0, "dets"), (1, "noise_model_coeffs"), (2, "subscans")])
noise_fit_stats.wrap(
Expand Down
17 changes: 10 additions & 7 deletions sotodlib/tod_ops/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def get_inv_var_flags(aman, signal_name='signal', nsigma=5,

return mskinvar

def get_subscans(aman, merge=True, include_turnarounds=False):
def get_subscans(aman, merge=True, include_turnarounds=False, overwrite=True):
"""
Returns an axis manager with information about subscans.
This includes direction, start time, stop time, and a ranges matrix (subscans samps)
Expand Down Expand Up @@ -724,7 +724,10 @@ def get_subscans(aman, merge=True, include_turnarounds=False):
rm = RangesMatrix([Ranges.from_array(np.atleast_2d(ss), tt.size) for ss in ss_ind])
subscan_aman.wrap('subscan_flags', rm, [(0, 'subscans'), (1, 'samps')]) # True in the subscan
if merge:
aman.wrap('subscans', subscan_aman)
name = 'subscan_info'
if overwrite and name in aman:
aman.move(name, None)
aman.wrap(name, subscan_aman)
return subscan_aman

def get_subscan_signal(aman, arr, isub=None, trim=False):
Expand All @@ -744,13 +747,13 @@ def get_subscan_signal(aman, arr, isub=None, trim=False):
if isinstance(arr, str):
arr = aman[arr]
if np.isscalar(isub):
out = apply_rng(arr, aman.subscans.subscan_flags[isub])
out = apply_rng(arr, aman.subscan_info.subscan_flags[isub])
if trim and out.size == 0:
out = None
else:
if isub is None:
isub = range(len(aman.subscans.subscan_flags))
out = [apply_rng(arr, aman.subscans.subscan_flags[ii]) for ii in isub]
isub = range(len(aman.subscan_info.subscan_flags))
out = [apply_rng(arr, aman.subscan_info.subscan_flags[ii]) for ii in isub]
if trim:
out = [x for x in out if x.size > 0]

Expand Down Expand Up @@ -782,7 +785,7 @@ def wrap_info(aman, info_aman_name, info, info_names, merge=True):
info_aman = core.AxisManager(aman.dets)
axmap = [(0, 'dets')]
elif info[0].ndim == 2:
info_aman = core.AxisManager(aman.dets, aman.subscans.subscans)
info_aman = core.AxisManager(aman.dets, aman.subscan_info.subscans)
axmap = [(0, 'dets'), (1, 'subscans')]

for ii in range(len(info_names)):
Expand Down Expand Up @@ -823,7 +826,7 @@ def get_stats(aman, signal, stat_names, split_subscans=False, mask=None, name="s
if mask is not None:
raise ValueError("Cannot mask samples and split subscans")
stats_arr = []
for iss in range(aman.subscans.subscans.count):
for iss in range(aman.subscan_info.subscans.count):
data = get_subscan_signal(aman, signal, iss)
if data.size > 0:
stats_arr.append([fn_dict[name](data, axis=1) for name in stat_names]) # Samps axis assumed to be 1
Expand Down

0 comments on commit 475c87d

Please sign in to comment.