Skip to content

Commit

Permalink
subscan preprocess fixes and upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
earosenberg committed Nov 20, 2024
1 parent e2c8b10 commit 67b59b0
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 52 deletions.
37 changes: 9 additions & 28 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,32 +395,6 @@ def save(self, proc_aman, fft_aman):
proc_aman.wrap(self.wrap, fft_aman)


class IdentifySubscans(_Preprocess):
""" Get subscan info and wrap it into aman.subscans.
Example config block::
- "name : "subscans"
"calc": True
"save": True
"""

name = "subscans"

def __init__(self, step_cfgs):
super().__init__(step_cfgs)

def process(self, aman, proc_aman):
tod_ops.flags.get_subscans(aman, merge=True)

def calc_and_save(self, aman, proc_aman):
self.save(proc_aman, aman.subscans)

def save(self, proc_aman, subscan_aman):
if not(self.save_cfgs is None):
proc_aman.wrap('subscans', subscan_aman)

class TODStats(_Preprocess):
""" Get basic statistics from a TOD or its power spectrum.
Expand Down Expand Up @@ -868,13 +842,20 @@ def calc_and_save(self, aman, proc_aman):
calc_aman = core.AxisManager(aman.dets, aman.samps)
calc_aman.wrap('turnarounds', ta, [(0, 'dets'), (1, 'samps')])

self.save(proc_aman, calc_aman)
if ('merge_subscans' not in self.calc_cfgs) or (self.calc_cfgs['merge_subscans']):
subscan_aman = aman.subscans
else:
subscan_aman = None

self.save(proc_aman, calc_aman, subscan_aman)

def save(self, proc_aman, turn_aman):
def save(self, proc_aman, turn_aman, subscan_aman):
if self.save_cfgs is None:
return
if self.save_cfgs:
proc_aman.wrap("turnaround_flags", turn_aman)
if subscan_aman is not None:
proc_aman.wrap("subscans", subscan_aman)

def process(self, aman, proc_aman):
tod_ops.flags.get_turnaround_flags(aman, **self.process_cfgs)
Expand Down
114 changes: 99 additions & 15 deletions sotodlib/tod_ops/fft_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from sotodlib import core

from . import detrend_tod
from .flags import get_subscan_signal

def _get_num_threads():
# Guess how many threads we should be using in FFT ops...
Expand Down Expand Up @@ -282,6 +281,13 @@ def calc_psd(
return freqs, Pxx

def calc_psd_subscan(aman, signal=None, freq_spacing=None, merge=False, wrap=None, **kwargs):
"""
Calculate the power spectrum density of subscans using signal.welch().
Data defaults to aman.signal. aman.timestamps is used for times.
aman.subscans is used to identify subscans.
See calc_psd for arguments.
"""
from .flags import get_subscan_signal
if signal is None:
signal = aman.signal

Expand All @@ -290,15 +296,20 @@ def calc_psd_subscan(aman, signal=None, freq_spacing=None, merge=False, wrap=Non
if freq_spacing is not None:
nperseg = int(2 ** (np.around(np.log2(fs / freq_spacing))))
else:
duration_samps = np.around((aman.subscans.stop_time - aman.subscans.start_time) * fs)
duration_samps = np.sum(aman.subscans.subscan_flags.mask(), axis=1)
duration_samps = duration_samps[duration_samps > 0]
nperseg = int(2 ** (np.around(np.log2(np.median(duration_samps) / 4))))
kwargs["nperseg"] = nperseg

Pxx = []
for iss in range(aman.subscans.subscans.count):
signal_ss = get_subscan_signal(aman, signal, iss)
freqs, pxx_sub = welch(signal_ss, fs, **kwargs)
Pxx.append(pxx_sub)
axis = -1 if "axis" not in kwargs else kwargs["axis"]
if signal_ss.shape[axis] >= kwargs["nperseg"]:
freqs, pxx_sub = welch(signal_ss, fs, **kwargs)
Pxx.append(pxx_sub)
else:
Pxx.append(np.full((signal.shape[0], kwargs["nperseg"]//2+1), np.nan)) # Add nans if subscan is too short
Pxx = np.array(Pxx)
Pxx = Pxx.transpose(1, 2, 0) # Dets, nusamps, subscans
if merge:
Expand Down Expand Up @@ -379,6 +390,8 @@ def fit_noise_model(
f_max=100,
merge_name="noise_fit_stats",
merge_psd=True,
freq_spacing=None,
approx_fit=False,
):
"""
Fits noise model with white and 1/f noise to the PSD of signal.
Expand Down Expand Up @@ -418,6 +431,10 @@ def fit_noise_model(
If ``merge_fit`` is True then addes into axis manager with merge_name.
merge_psd : bool
If ``merg_psd`` is True then adds fres and Pxx to the axis manager.
freq_spacing : float
The approximate desired frequency spacing of the PSD. Passed to calc_psd.
approx_fit : bool
Get a rough fit instead of minimizing loglike.
Returns
-------
noise_fit_stats : AxisManager
Expand All @@ -431,13 +448,14 @@ def fit_noise_model(
if f is None or pxx is None:
if psdargs is None:
f, pxx = calc_psd(
aman, signal=signal, timestamps=aman.timestamps, merge=merge_psd
aman, signal=signal, timestamps=aman.timestamps, freq_spacing=freq_spacing, merge=merge_psd
)
else:
f, pxx = calc_psd(
aman,
signal=signal,
timestamps=aman.timestamps,
freq_spacing=freq_spacing,
merge=merge_psd,
**psdargs,
)
Expand All @@ -454,17 +472,21 @@ def fit_noise_model(
pfit = np.polyfit(np.log10(f[f < lowf]), np.log10(p[f < lowf]), 1)
fidx = np.argmin(np.abs(10 ** np.polyval(pfit, np.log10(f)) - wnest))
p0 = [f[fidx], wnest, -pfit[0]]
bounds = [(0, None), (sys.float_info.min, None), (None, None)]
res = minimize(neglnlike, p0, args=(f, p), bounds=bounds, method="Nelder-Mead")
try:
Hfun = ndt.Hessian(lambda params: neglnlike(params, f, p), full_output=True)
hessian_ndt, _ = Hfun(res["x"])
# Inverse of the hessian is an estimator of the covariance matrix
# sqrt of the diagonals gives you the standard errors.
covout[i] = np.linalg.inv(hessian_ndt)
except np.linalg.LinAlgError:
if approx_fit:
covout[i] = np.full((3, 3), np.nan)
fitout[i] = res.x
fitout[i] = p0
else:
bounds = [(0, None), (sys.float_info.min, None), (None, None)]
res = minimize(neglnlike, p0, args=(f, p), bounds=bounds, method="Nelder-Mead")
try:
Hfun = ndt.Hessian(lambda params: neglnlike(params, f, p), full_output=True)
hessian_ndt, _ = Hfun(res["x"])
# Inverse of the hessian is an estimator of the covariance matrix
# sqrt of the diagonals gives you the standard errors.
covout[i] = np.linalg.inv(hessian_ndt)
except np.linalg.LinAlgError:
covout[i] = np.full((3, 3), np.nan)
fitout[i] = res.x

noise_model_coeffs = ["fknee", "white_noise", "alpha"]
noise_fit_stats = core.AxisManager(
Expand All @@ -485,6 +507,68 @@ def fit_noise_model(
return noise_fit_stats


def fit_noise_model_subscan(
aman,
signal=None,
f=None,
pxx=None,
psdargs={},
fwhite=(10, 100),
lowf=1,
merge_fit=False,
f_max=100,
merge_name="noise_fit_stats_ss",
merge_psd=True,
freq_spacing=None,
approx_fit=False,
):
"""
Fits noise model with white and 1/f noise to the PSD of signal subscans.
Args are as for fit_noise_model.
"""
fitout = np.empty((aman.dets.count, 3, aman.subscans.subscans.count))
covout = np.empty((aman.dets.count, 3, 3, aman.subscans.subscans.count))

if signal is None:
signal = aman.signal

if f is None or pxx is None:
f, pxx = calc_psd_subscan(
aman,
signal=signal,
freq_spacing=freq_spacing,
merge=merge_psd,
**psdargs,
)

for isub in range(aman.subscans.subscans.count):
if np.all(np.isnan(pxx[...,isub])): # Subscan has been fully cut
fitout[..., isub] = np.full((aman.dets.count, 3), np.nan)
covout[..., isub] = np.full((aman.dets.count, 3, 3), np.nan)
else:
noise_model = fit_noise_model(aman, f=f, pxx=pxx[...,isub], fwhite=fwhite, lowf=lowf,
merge_fit=False, f_max=f_max, merge_psd=False, approx_fit=approx_fit)

fitout[..., isub] = noise_model.fit
covout[..., isub] = noise_model.cov

noise_fit_stats = core.AxisManager(
aman.dets,
noise_model.noise_model_coeffs,
aman.subscans.subscans
)
noise_fit_stats.wrap("fit", fitout, [(0, "dets"), (1, "noise_model_coeffs"), (2, "subscans")])
noise_fit_stats.wrap(
"cov",
covout,
[(0, "dets"), (1, "noise_model_coeffs"), (2, "noise_model_coeffs"), (3, "subscans")],
)

if merge_fit:
aman.wrap(merge_name, noise_fit_stats)
return noise_fit_stats


def build_hpf_params_dict(
filter_name,
noise_fit=None,
Expand Down
52 changes: 43 additions & 9 deletions sotodlib/tod_ops/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_det_bias_flags(aman, detcal=None, rfrac_range=(0.1, 0.7),
def get_turnaround_flags(aman, az=None, method='scanspeed', name='turnarounds',
merge=True, merge_lr=True, overwrite=True,
t_buffer=2., kernel_size=400, peak_threshold=0.1, rel_distance_peaks=0.3,
truncate=False, qlim=1):
truncate=False, qlim=1, merge_subscans=True, turnarounds_in_subscan=False):
"""
Compute turnaround flags for a dataset.
Expand Down Expand Up @@ -172,6 +172,10 @@ def get_turnaround_flags(aman, az=None, method='scanspeed', name='turnarounds',
(Optional). Truncate unstable scan segments if True in ``scanspeed`` method.
qlim : float
(Optional). Azimuth threshold percentile for ``az`` method turnaround detection.
merge_subscans : bool
(Optional). Also merge an AxisManager with subscan information.
turnarounds_in_subscan : bool
(Optional). Turnarounds are included as part of a subscan.
Returns
-------
Expand Down Expand Up @@ -299,6 +303,10 @@ def get_turnaround_flags(aman, az=None, method='scanspeed', name='turnarounds',
aman.flags[name] = ta_flag
else:
aman.flags.wrap(name, ta_flag)

if merge_subscans:
get_subscans(aman, merge=True, include_turnarounds=turnarounds_in_subscan)

if method == 'az':
ta_exp = RangesMatrix([ta_flag for i in range(aman.dets.count)])
return ta_exp
Expand Down Expand Up @@ -685,13 +693,22 @@ def get_inv_var_flags(aman, signal_name='signal', nsigma=5,

return mskinvar

def get_subscans(aman, merge=True):
def get_subscans(aman, merge=True, include_turnarounds=False):
"""
Returns an axis manager with information about subscans.
This includes direction, start time, stop time, and a ranges matrix (subscans samps)
True inside each subscan. Subscans are defined excluding turnarounds.
"""
ss_ind = (~aman.flags.turnarounds).ranges() # sliceable indices (first inclusive, last exclusive) for subscans
if not include_turnarounds:
ss_ind = (~aman.flags.turnarounds).ranges() # sliceable indices (first inclusive, last exclusive) for subscans
else:
left = aman.flags.left_scan.ranges()
right = aman.flags.right_scan.ranges()
start_left = 0 if (left[0,0] < right[0,0]) else 1
ss_ind = np.empty((left.shape[0] + right.shape[0], 2), dtype=left.dtype)
ss_ind[start_left::2] = left
ss_ind[(start_left-1)%2::2] = right

start_inds, end_inds = ss_ind.T
n_subscan = ss_ind.shape[0]
tt = aman.timestamps
Expand All @@ -710,7 +727,7 @@ def get_subscans(aman, merge=True):
aman.wrap('subscans', subscan_aman)
return subscan_aman

def get_subscan_signal(aman, arr, isub=None):
def get_subscan_signal(aman, arr, isub=None, trim=False):
"""
Split an array into subscans.
Expand All @@ -724,12 +741,21 @@ def get_subscan_signal(aman, arr, isub=None):
(Optional). Index of the desired subscan. May also be a list of indices.
If None, all are used.
"""
if isinstance(arr, str):
arr = aman[arr]
if np.isscalar(isub):
return apply_rng(arr, aman.subscans.subscan_flags[isub])
out = apply_rng(arr, aman.subscans.subscan_flags[isub])
if trim and out.size == 0:
out = None
else:
if isub is None:
isub = range(len(aman.subscans.subscan_flags))
return [apply_rng(arr, aman.subscans.subscan_flags[ii]) for ii in isub]
out = [apply_rng(arr, aman.subscans.subscan_flags[ii]) for ii in isub]
if trim:
out = [x for x in out if x.size > 0]

return out


def apply_rng(arr, rng):
"""
Expand All @@ -742,7 +768,10 @@ def apply_rng(arr, rng):
rng : Ranges
Ranges object of len (samps) selecting the desired range
"""
slc = slice(*rng.ranges()[0])
if rng.ranges().size == 0:
slc = slice(0,0) # Return an empty array if rng is empty
else:
slc = slice(*np.squeeze(rng.ranges()))
isamps = np.where(np.array(arr.shape) == rng.count)[0][0]
ndslice = tuple((slice(None) if ii != isamps else slc for ii in range(arr.ndim)))
return arr[ndslice]
Expand Down Expand Up @@ -788,18 +817,23 @@ def get_stats(aman, signal, stat_names, split_subscans=False, mask=None, name="s
fn_dict = {'mean': np.mean, 'median': np.median, 'ptp': np.ptp, 'std': np.std,
'kurtosis': stats.kurtosis, 'skew': stats.skew}

if isinstance(signal, str):
signal = aman[signal]
if split_subscans:
if mask is not None:
raise ValueError("Cannot mask samples and split subscans")
stats_arr = []
for iss in range(aman.subscans.subscans.count):
data = get_subscan_signal(aman, signal, iss)
stats_arr.append([fn_dict[name](data, axis=1) for name in stat_names]) # Samps axis assumed to be 1
if data.size > 0:
stats_arr.append([fn_dict[name](data, axis=1) for name in stat_names]) # Samps axis assumed to be 1
else:
stats_arr.append(np.full((len(stat_names), signal.shape[0]), np.nan)) # Add nans if subscan has been entirely cut
stats_arr = np.array(stats_arr).transpose(1, 2, 0) # stat, dets, subscan
else:
if mask is None:
mask = slice(None)
stats_arr = np.array([fn_dict[name](data[:, mask], axis=1) for name in stat_names]) # Samps axis assumed to be 1
stats_arr = np.array([fn_dict[name](signal[:, mask], axis=1) for name in stat_names]) # Samps axis assumed to be 1

info_aman = wrap_info(aman, name, stats_arr, stat_names, merge)
return info_aman

0 comments on commit 67b59b0

Please sign in to comment.