Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/fix/act flags #947

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sotodlib/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import logging
import numpy as np

from typing import Union, Dict, Tuple, List
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to context.py are just noise; please revert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will. I put them when I was trying to understand what was happening.


from . import metadata
from .util import tag_substr
from .axisman import AxisManager, OffsetAxis, AxisInterface

logger = logging.getLogger(__name__)


class Context(odict):
# Sets of special handlers may be registered in this class variable, then
# requested by name in the context.yaml key "context_hooks".
Expand Down Expand Up @@ -322,7 +325,8 @@ def get_meta(self,
check=False,
ignore_missing=False,
on_missing=None,
det_info_scan=False):
det_info_scan=False
):
"""Load supporting metadata for an observation and return it in an
AxisManager.

Expand Down
2 changes: 1 addition & 1 deletion sotodlib/core/g3_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

"""

from spt3g import core
from so3g.spt3g import core
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an import bug.



class DataG3Module(object):
Expand Down
66 changes: 42 additions & 24 deletions sotodlib/mapmaking/ml_mapmaker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import numpy as np
from pixell import enmap, utils, tilemap, bunch
import h5py
import so3g
from typing import Optional
from pixell import bunch, enmap, tilemap
from pixell import utils as putils

from .. import coords
from .utilities import *
from .pointing_matrix import *
from .pointing_matrix import PmatCut
from .utilities import (MultiZipper, get_flags_from_path, recentering_to_quat_lonlat,
evaluate_recentering, TileMapZipper, MapZipper,
safe_invert_div, unarr, ArrayZipper)
from .noise_model import NmatUncorr


class MLMapmaker:
def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False):
def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False,
glitch_flags:str = "flags.glitch_flags"):
"""Initialize a Maximum Likelihood Mapmaker.
Arguments:
* signals: List of Signal-objects representing the models that will be solved
Expand All @@ -26,6 +35,7 @@ def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False
self.data = []
self.dof = MultiZipper()
self.ready = False
self.glitch_flags_path = glitch_flags

def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None):
# Prepare our tod
Expand All @@ -36,7 +46,7 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None)
# the noise model, if available
if signal_estimate is not None: tod -= signal_estimate
if deslope:
utils.deslope(tod, w=5, inplace=True)
putils.deslope(tod, w=5, inplace=True)
# Allow the user to override the noise model on a per-obs level
if noise_model is None: noise_model = self.noise_model
# Build the noise model from the obs unless a fully
Expand All @@ -55,12 +65,12 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None)
# The signal estimate might not be desloped, so
# adding it back can reintroduce a slope. Fix that here.
if deslope:
utils.deslope(tod, w=5, inplace=True)
putils.deslope(tod, w=5, inplace=True)
# And apply it to the tod
tod = nmat.apply(tod)
# Add the observation to each of our signals
for signal in self.signals:
signal.add_obs(id, obs, nmat, tod)
signal.add_obs(id, obs, nmat, tod, glitch_flags=self.glitch_flags_path)
# Save what we need about this observation
self.data.append(bunch.Bunch(id=id, ndet=obs.dets.count, nsamp=len(ctime),
dets=obs.dets.vals, nmat=nmat))
Expand Down Expand Up @@ -119,7 +129,7 @@ def solve(self, maxiter=500, maxerr=1e-6, x0=None):
self.prepare()
rhs = self.dof.zip(*[signal.rhs for signal in self.signals])
if x0 is not None: x0 = self.dof.zip(*x0)
solver = utils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0)
solver = putils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0)
while solver.i < maxiter and solver.err > maxerr:
solver.step()
yield bunch.Bunch(i=solver.i, err=solver.err, x=self.dof.unzip(solver.x))
Expand All @@ -146,22 +156,25 @@ def transeval(self, id, obs, other, x, tod=None):

class Signal:
"""This class represents a thing we want to solve for, e.g. the sky, ground, cut samples, etc."""
def __init__(self, name, ofmt, output, ext):
def __init__(self, name, ofmt, output, ext, **kwargs):
"""Initialize a Signal. It probably doesn't make sense to construct a generic signal
directly, though. Use one of the subclasses.
Arguments:
* name: The name of this signal, e.g. "sky", "cut", etc.
* ofmt: The format used when constructing output file prefix
* output: Whether this signal should be part of the output or not.
* ext: The extension used for the files.
* **kwargs: additional keyword based parameters, accessible as class parameters
"""
self.name = name
self.ofmt = ofmt
self.output = output
self.ext = ext
self.dof = None
self.ready = False
def add_obs(self, id, obs, nmat, Nd): pass
self.__dict__.update(kwargs)

def add_obs(self, id, obs, nmat, Nd, **kwargs): pass
def prepare(self): self.ready = True
def forward (self, id, tod, x): pass
def backward(self, id, tod, x): pass
Expand All @@ -176,12 +189,12 @@ class SignalMap(Signal):
"""Signal describing a non-distributed sky map."""
def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", output=True,
ext="fits", dtype=np.float32, sys=None, recenter=None, tile_shape=(500,500), tiled=False,
interpol=None):
interpol=None, glitch_flags: str = "flags.glitch_flags"):
"""Signal describing a sky map in the coordinate system given by "sys", which defaults
to equatorial coordinates. If tiled==True, then this will be a distributed map with
the given tile_shape, otherwise it will be a plain enmap. interpol controls the
pointing matrix interpolation mode. See so3g's Projectionist docstring for details."""
Signal.__init__(self, name, ofmt, output, ext)
Signal.__init__(self, name, ofmt, output, ext, glitch_flags=glitch_flags)
self.comm = comm
self.comps = comps
self.sys = sys
Expand All @@ -192,6 +205,7 @@ def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", out
self.data = {}
ncomp = len(comps)
shape = tuple(shape[-2:])

if tiled:
geo = tilemap.geometry(shape, wcs, tile_shape=tile_shape)
self.rhs = tilemap.zeros(geo.copy(pre=(ncomp,)), dtype=dtype)
Expand All @@ -202,15 +216,16 @@ def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", out
self.div = enmap.zeros((ncomp,ncomp)+shape, wcs, dtype=dtype)
self.hits= enmap.zeros( shape, wcs, dtype=dtype)

def add_obs(self, id, obs, nmat, Nd, pmap=None):
def add_obs(self, id, obs, nmat, Nd, pmap=None, glitch_flags: Optional[str] = None):
"""Add and process an observation, building the pointing matrix
and our part of the RHS. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data.
"""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
ctime = obs.timestamps
pcut = PmatCut(obs.flags.glitch_flags) # could pass this in, but fast to construct
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
pcut = PmatCut(get_flags_from_path(obs, gflags)) # could pass this in, but fast to construct
if pmap is None:
# Build the local geometry and pointing matrix for this observation
if self.recenter:
Expand Down Expand Up @@ -261,9 +276,9 @@ def prepare(self):
self.dof = TileMapZipper(self.rhs.geometry, dtype=self.dtype, comm=self.comm)
else:
if self.comm is not None:
self.rhs = utils.allreduce(self.rhs, self.comm)
self.div = utils.allreduce(self.div, self.comm)
self.hits = utils.allreduce(self.hits, self.comm)
self.rhs = putils.allreduce(self.rhs, self.comm)
self.div = putils.allreduce(self.div, self.comm)
self.hits = putils.allreduce(self.hits, self.comm)
self.dof = MapZipper(*self.rhs.geometry, dtype=self.dtype)
self.idiv = safe_invert_div(self.div)
self.ready = True
Expand Down Expand Up @@ -300,7 +315,7 @@ def from_work(self, map):
return tilemap.redistribute(map, self.comm, self.rhs.geometry.active)
else:
if self.comm is None: return map
else: return utils.allreduce(map, self.comm)
else: return putils.allreduce(map, self.comm)

def write(self, prefix, tag, m):
if not self.output: return
Expand Down Expand Up @@ -347,6 +362,7 @@ def transeval(self, id, obs, other, map, tod):
# Currently we don't support any actual translation, but could handle
# resolution changes in the future (probably not useful though)
self._checkcompat(other)
ctime = obs.timestamps
# Build the local geometry and pointing matrix for this observation
if self.recenter:
rot = recentering_to_quat_lonlat(*evaluate_recentering(self.recenter,
Expand All @@ -361,9 +377,9 @@ def transeval(self, id, obs, other, map, tod):

class SignalCut(Signal):
def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
output=False, cut_type=None):
output=False, cut_type=None, glitch_flags: str = "flags.glitch_flags"):
"""Signal for handling the ML solution for the values of the cut samples."""
Signal.__init__(self, name, ofmt, output, ext="hdf")
Signal.__init__(self, name, ofmt, output, ext="hdf", glitch_flags=glitch_flags)
self.comm = comm
self.data = {}
self.dtype = dtype
Expand All @@ -372,12 +388,13 @@ def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
self.rhs = []
self.div = []

def add_obs(self, id, obs, nmat, Nd):
def add_obs(self, id, obs, nmat, Nd, glitch_flags: Optional[str] = None):
"""Add and process an observation. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data."""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
pcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type)
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
pcut = PmatCut(get_flags_from_path(obs, gflags), model=self.cut_type)
# Build our RHS
obs_rhs = np.zeros(pcut.njunk, self.dtype)
pcut.backward(Nd, obs_rhs)
Expand Down Expand Up @@ -441,15 +458,16 @@ def translate(self, other, junk):
so3g.translate_cuts(odata.pcut.cuts, sdata.pcut.cuts, sdata.pcut.model, sdata.pcut.params, junk[odata.i1:odata.i2], res[sdata.i1:sdata.i2])
return res

def transeval(self, id, obs, other, junk, tod):
def transeval(self, id, obs, other, junk, tod, glitch_flags: Optional[str] = None):
"""Translate data junk from SignalCut other to the current SignalCut,
and then evaluate it for the given observation, returning a tod.
This is used when building a signal-free tod for the noise model
in multipass mapmaking."""
self._checkcompat(other)
# We have to make a pointing matrix from scratch because add_obs
# won't have been called yet at this point
spcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type)
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
spcut = PmatCut(get_flags_from_path(obs, gflags), model=self.cut_type)
# We do have one for other though, since that will be the output
# from the previous round of multiplass mapmaking.
odata = other.data[id]
Expand Down
Loading