Skip to content

Commit

Permalink
Clarify documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox committed Dec 20, 2023
1 parent 7042c41 commit 3c02e14
Showing 1 changed file with 45 additions and 27 deletions.
72 changes: 45 additions & 27 deletions hera_cal/nucal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,12 +1068,18 @@ def _foreground_model(model_parameters, spectral_filters, spatial_filters):
@jax.jit
def _mean_squared_error(model_parameters, data_r, data_i, wgts, fg_model_r, fg_model_i, idealized_blvecs):
"""
Computes the mean squared error between the data and foreground model multiplied by the degenerate parameters
Computes the mean squared error between the data and foreground model multiplied by the
redundant calibration degenerate parameters. Used as the loss function in the gradient descent
in SpectralRedundantCalibrator.post_redcal to solve for the redundant calibration degrees of freedom
Parameters:
----------
model_parameters : dictionary
Parameters for fitting
Parameters used to fit the DPSS-based foreground model and redundant calibration degeneracies.
Keys are "fg_r", "fg_i", "amplitude", and "tip_tilt". Parameters "fg_r" and "fg_i" are the
real and imaginary components of the DPSS foreground model but are not used in this function.
"amplitude" is redundant calibration amplitude degeneracy and "tip_tilt" are the redundant
calibration phase gradient degeneracies.
data_r : np.ndarray
Array of real component of data with shape (Ntimes, Nbls)
data_i : np.ndarray
Expand Down Expand Up @@ -1155,17 +1161,22 @@ def _nucal_post_redcal(
wgts : np.ndarray
Array of weights with shape (Ntimes, Nbls)
model_parameters : dictionary
Parameters for fitting
Parameters used to fit the DPSS-based foreground model and redundant calibration degeneracies.
Keys are "fg_r", "fg_i", "amplitude", and "tip_tilt". Parameters "fg_r" and "fg_i" are the
real and imaginary components of the DPSS foreground model but are not used in this function.
"amplitude" is redundant calibration amplitude degeneracy and "tip_tilt" are the redundant
calibration phase gradient degeneracies.
optimizer : optax optimizer
Optimizer to use for gradient descent
Optimizer to use for gradient descent.
spectral_filters : np.ndarray
Array of spectral filters with shape (Nfreqs, Nfilters)
spatial_filters : List
List of spatial filters for each baseline in the group
idealized_blvecs : np.ndarray
Array of idealized baseline vectors with shape (Nbls, Ndims)
maxiter : int, optional, default=100
Maximum number of iterations to perform
Maximum number of iterations to perform in the gradient descent loop. If convergence_criteria is not met
after maxiter iterations, the optimization will stop.
tol : float, optional, default=1e-10
Tolerance for stopping criterion. If the difference of the loss between two iterations is less than tol,
the optimization will stop.
Expand Down Expand Up @@ -1250,23 +1261,29 @@ def __init__(self, radial_reds):
RadialRedundancy object containing a list of list baseline tuples of radially redundant
groups. Can be generated using nucal.RadialRedundancy.
"""
# Store the radial redundancy object and antpos
self.radial_reds = radial_reds
self.antpos = radial_reds.antpos

# Initialize variables for tracking whether filters have been computed
self._filters_computed = False
self._most_recent_filter_params = {}

def _compute_filters(self, freqs, spectral_filter_half_width, spatial_filter_half_width=1, eigenval_cutoff=1e-12, umin=None, umax=None):
"""
"""
# Get all parameter names and local variables
sig = inspect.signature(self._compute_filters)
local_vars = locals()
local_vars.pop("self")

if self._filters_computed:
for key in sig.parameters:
if not np.array_equal(local_vars[key], getattr(self, key)):
# Loop over all parameters and check if they have changed
for key in local_vars:
if not np.array_equal(local_vars[key], self._most_recent_filter_params[key]):
recompute_filters = True
break
else:
# if for loop reaches completion, parameters haven't changed -- return
return

if recompute_filters:
Expand All @@ -1275,9 +1292,9 @@ def _compute_filters(self, freqs, spectral_filter_half_width, spatial_filter_hal
self.radial_reds, freqs, spatial_filter_half_width, eigenval_cutoff=eigenval_cutoff, umin=umin, umax=umax
)

# Set most recent parameters
for key in sig.parameters:
setattr(self, key, local_vars[key])
# Set most recent set of filter parameters
for key in local_vars:
self._most_recent_filter_params[key] = local_vars[key]

self._filters_computed = True

Expand All @@ -1288,8 +1305,8 @@ def _compute_filters(self, freqs, spectral_filter_half_width, spatial_filter_hal
)

# Set most recent parameters
for key in sig.parameters:
setattr(self, key, local_vars[key])
for key in local_vars:
self._most_recent_filter_params[key] = local_vars[key]

self._filters_computed = True

Expand Down Expand Up @@ -1323,7 +1340,8 @@ def _estimate_degeneracies(self, data, model, wgts):
)

# Unpack solution into dictionary
# Degeneracy as written in gradient descent is exp(2 * eta)
# Degeneracy as written in gradient descent is exp(2 * eta) because the amplitude degeneracy
# in nucal is written as the square of the amplitude degeneracy in the abscal solution
amplitude = {
pol: np.exp(2 * amp_sol[f"eta_J{pol}"]) for pol in data.pols()
}
Expand All @@ -1344,7 +1362,7 @@ def _estimate_degeneracies(self, data, model, wgts):

return amplitude, tip_tilt

def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only=True, linear_solver="lu_solve", linear_tol=1e-12, share_fg_model=False,
def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only=False, linear_solver="lu_solve", linear_tol=1e-12, share_fg_model=False,
spectral_filter_half_width=30e-9, spatial_filter_half_width=1, eigenval_cutoff=1e-12, umin=None, umax=None, estimate_degeneracies=False,
optimizer_name='adabelief', learning_rate=1e-3, maxiter=100, minor_cycle_maxiter=0, convergence_criteria=1e-10, return_model=False):
"""
Expand All @@ -1357,14 +1375,14 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
Parameters:
----------
data : DataContainer
data : DataContainer (or RedDataContainer)
Data to be calibrated. Data is assumed to be redundantly averaged. DataContainer is of the form {(ant1, ant2, pol): np.array([Ntimes, Nfreqs])}
data_wgts : DataContainer
data_wgts : DataContainer (or RedDataContainer)
Weights associated with data. DataContainer is of the form {(ant1, ant2, pol): np.array([Ntimes, Nfreqs])}
cal_flags : dictionary, default={}
Dictionary containing flags for each antenna. Keys are antenna numbers and values are boolean arrays of shape (Ntimes,).
Dictionary containing flags for each antenna. Keys are antenna numbers and values are boolean arrays of shape (Ntimes, Nfreqs).
This dictionary is primarily used for computing the idealized antenna positions.
spatial_estimate_only : bool, default="False"
spatial_estimate_only : bool, default=False
If True, the initial estimate of the foreground model will be computed from the data assuming that the evolution foreground model
is entirely restricted to the spatial axis. This estimate will then be projected onto the eigenmodes of the spectral DPSS modes
for refinement in the gradient descent step. If False, the initial estimate of the foreground model will be computed from the
Expand All @@ -1380,7 +1398,11 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
Regularization parameter for linear least-squares fit when computing the initial estimate of the foreground model.
share_fg_model : bool, default=False
If True, the foreground model for each radially-redundant group is shared across the time axis for both the least-squares and
gradient descent steps.
gradient descent steps. One useful application of this option is when performing calibration of data across multiple nights at
the same LST where the data have shape (N_nights, Nfreqs). In this case, the foreground model is expected to be the same across nights, so sharing the foreground model
across nights can greatly reduce the number of parameters to fit. This parameter could also be used for subsequent times to share a
sky model assuming the sky doesn't evolve much in the subsequent integrations. If False, a nucal foreground will be solved for independently
for each time integration.
spectral_filter_half_width : float, default=20e-9
Fourier half-width of the spectral axis DPSS filters in units of seconds.
spatial_filter_half_width : float, default=1
Expand All @@ -1402,7 +1424,7 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
initialized to 1 and tip-tilt degeneracies will be initialized to 0. If the data are well-calibrated,
setting this option to False can improve the runtime of the calibration.
optimizer_name : str, default="adabelief"
Name of the optimizer to use for gradient descent. Options are listed in the "OPTIMIZERS" variable.
Name of the optimizer to use for gradient descent. Options are keys in nucal.OPTIMIZERS.
learning_rate : float, default=1e-3
Learning rate for the gradient descent optimizer
maxiter : int, default=100
Expand Down Expand Up @@ -1504,22 +1526,18 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
]

# Compute idealized baseline vectors
idealized_blvecs = np.array([
idealized_blvecs = jnp.array([
idealized_antpos[blkey[1]] - idealized_antpos[blkey[0]]
for rdgrp in self.radial_reds.get_pol(pol) for blkey in rdgrp
])

# Run optimization
_model_parameters, _metadata = _nucal_post_redcal(
model_parameters[pol], metadata[pol] = _nucal_post_redcal(
data_real, data_imag, wgts, init_model_parameters, optimizer, spectral_filters=self.spectral_filters,
spatial_filters=spatial_filters, idealized_blvecs=idealized_blvecs, maxiter=maxiter, convergence_criteria=convergence_criteria,
minor_cycle_maxiter=minor_cycle_maxiter
)

# Pack model parameters and metadata
metadata[pol] = _metadata
model_parameters[pol] = _model_parameters

if return_model:
# Compute the foreground model from the model parameters
fg_model_comps = {
Expand Down

0 comments on commit 3c02e14

Please sign in to comment.