Skip to content

Commit

Permalink
changed default optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox committed Feb 5, 2024
1 parent b0e71db commit a509e52
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions hera_cal/nucal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
# Approved Optax Optimizers
OPTIMIZERS = {
'adabelief': optax.adabelief, 'adafactor': optax.adafactor, 'adagrad': optax.adagrad, 'adam': optax.adam,
'adamw': optax.adamw, 'fromage': optax.fromage, 'lamb': optax.lamb, 'lars': optax.lars,
'adamax': optax.adamax, 'adamaxw': optax.adamaxw, 'amsgrad': optax.amsgrad, 'adamw': optax.adamw,
'fromage': optax.fromage, 'lamb': optax.lamb, 'lars': optax.lars, 'lion': optax.lion, 'novograd': optax.novograd,
'noisy_sgd': optax.noisy_sgd, 'dpsgd': optax.dpsgd, 'radam': optax.radam, 'rmsprop': optax.rmsprop,
'sgd': optax.sgd, 'sm3': optax.sm3, 'yogi': optax.yogi
'sgd': optax.sgd, 'sm3': optax.sm3, 'yogi': optax.yogi, 'optimistic_gradient_descent': optax.optimistic_gradient_descent
}

# Constants
Expand Down Expand Up @@ -874,7 +875,7 @@ def evaluate_foreground_model(radial_reds, fg_model_comps, spatial_filters, spec
return RedDataContainer(model, radial_reds.reds)


def fit_nucal_foreground_model(data, data_wgts, radial_reds, spatial_filters, spectral_filters=None, tol=1e-15,
def fit_nucal_foreground_model(data, data_wgts, radial_reds, spatial_filters, spectral_filters=None, tol=1e-12,
share_fg_model=False, return_model_comps=False, solver="lu_solve"):
"""
Compute a foreground model for a set of radially redundant baselines. The model is computed by performing a linear
Expand Down Expand Up @@ -1110,7 +1111,7 @@ def _mean_squared_error(model_parameters, data_r, data_i, wgts, fg_model_r, fg_m
return jnp.sum((jnp.square(model_r - data_r) + jnp.square(model_i - data_i)) * wgts)

@jax.jit
def _calibration_loss_function(model_parameters, data_r, data_i, wgts, spectral_filters, spatial_filters, idealized_blvecs):
def _calibration_loss_function(model_parameters, data_r, data_i, wgts, spectral_filters, spatial_filters, idealized_blvecs, alpha=1e-12):
"""
Function which computes the value of the loss from the degenerate parameters, DPSS foreground components, and the data
Expand Down Expand Up @@ -1139,12 +1140,17 @@ def _calibration_loss_function(model_parameters, data_r, data_i, wgts, spectral_
# Compute foreground model from the model_parameters and DPSS filters
fg_model_r, fg_model_i = _foreground_model(model_parameters, spectral_filters, spatial_filters)

# Regularize the loss
param_loss = 0
for fgr, fgi in zip(model_parameters['fg_r'], model_parameters['fg_i']):
param_loss += (jnp.square(fgr).sum() + jnp.square(fgi).sum()) * alpha

# Compute loss
return _mean_squared_error(model_parameters, data_r, data_i, wgts, fg_model_r, fg_model_i, idealized_blvecs)
return _mean_squared_error(model_parameters, data_r, data_i, wgts, fg_model_r, fg_model_i, idealized_blvecs) + param_loss

def _nucal_post_redcal(
data_r, data_i, wgts, model_parameters, optimizer, spectral_filters, spatial_filters, idealized_blvecs,
major_cycle_maxiter=100, convergence_criteria=1e-10, minor_cycle_maxiter=10
major_cycle_maxiter=100, convergence_criteria=1e-10, minor_cycle_maxiter=10, alpha=1e-12
):
"""
Function to perform frequency redundant calibration using gradient descent. Calibrates the
Expand Down Expand Up @@ -1199,13 +1205,13 @@ def _nucal_post_redcal(

# Initialize variables used in calibration loop
losses = []
previous_loss = np.inf

# Start gradient descent
for step in range(major_cycle_maxiter):
# Compute loss and gradient
loss, gradient = jax.value_and_grad(_calibration_loss_function)(
model_parameters, data_r, data_i, wgts, spectral_filters=spectral_filters, spatial_filters=spatial_filters, idealized_blvecs=idealized_blvecs
model_parameters, data_r, data_i, wgts, spectral_filters=spectral_filters, spatial_filters=spatial_filters,
idealized_blvecs=idealized_blvecs, alpha=alpha
)
# Update optimizer state and parameters
updates, opt_state = optimizer.update(gradient, opt_state, model_parameters)
Expand All @@ -1220,7 +1226,7 @@ def _nucal_post_redcal(
# Since the foreground model is fixed, we can just use the _mean_square_error
# function as our loss function
minor_cycle_loss, gradient = jax.value_and_grad(_mean_squared_error)(
model_parameters, data_r, data_i, wgts, fg_model_r, fg_model_i, idealized_blvecs=idealized_blvecs
model_parameters, data_r, data_i, wgts, fg_model_r, fg_model_i, idealized_blvecs=idealized_blvecs,
)
# Update optimizer state and parameters
updates, opt_state = optimizer.update(gradient, opt_state, model_parameters)
Expand Down Expand Up @@ -1370,7 +1376,7 @@ def _estimate_degeneracies(self, data, model, wgts):

def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only=False, linear_solver="lu_solve", linear_tol=1e-12, share_fg_model=False,
spectral_filter_half_width=30e-9, spatial_filter_half_width=1, eigenval_cutoff=1e-12, umin=None, umax=None, estimate_degeneracies=False,
optimizer_name='adabelief', learning_rate=1e-3, major_cycle_maxiter=100, minor_cycle_maxiter=0, convergence_criteria=1e-10, return_model=False):
optimizer_name='novograd', learning_rate=1e-3, major_cycle_maxiter=100, minor_cycle_maxiter=0, convergence_criteria=1e-10, return_model=False):
"""
Estimates redundant calibration degeneracies by building a DPSS-based, sky-model and solving for the parameters which lead to the smoothest
calibrated visibilities. Function starts by estimating a sky-model by using DPSS filters (which can start with spatial dependence or spectral
Expand Down Expand Up @@ -1402,6 +1408,7 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
robust.
linear_tol : float, default=1e-12
Regularization parameter for linear least-squares fit when computing the initial estimate of the foreground model.
Regularization parameter is also used as a regularizer in the gradient descent step.
share_fg_model : bool, default=False
If True, the foreground model for each radially-redundant group is shared across the time axis for both the least-squares and
gradient descent steps. One useful application of this option is when performing calibration of data across multiple nights at
Expand Down Expand Up @@ -1429,7 +1436,7 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
abscal techniques and the initial nucal model as the sky model. If False, the amplitude degeneracies will be
initialized to 1 and tip-tilt degeneracies will be initialized to 0. If the data are well-calibrated,
setting this option to False can improve the runtime of the calibration.
optimizer_name : str, default="adabelief"
optimizer_name : str, default="novograd"
Name of the optimizer to use for gradient descent. Options are keys in nucal.OPTIMIZERS.
learning_rate : float, default=1e-3
Learning rate for the gradient descent optimizer
Expand Down Expand Up @@ -1544,7 +1551,7 @@ def post_redcal_nucal(self, data, data_wgts, cal_flags={}, spatial_estimate_only
model_parameters[pol], metadata[pol] = _nucal_post_redcal(
data_real, data_imag, wgts, init_model_parameters, optimizer, spectral_filters=self.spectral_filters,
spatial_filters=spatial_filters, idealized_blvecs=idealized_blvecs, major_cycle_maxiter=major_cycle_maxiter,
convergence_criteria=convergence_criteria, minor_cycle_maxiter=minor_cycle_maxiter
convergence_criteria=convergence_criteria, minor_cycle_maxiter=minor_cycle_maxiter, alpha=linear_tol
)

if return_model:
Expand Down

0 comments on commit a509e52

Please sign in to comment.