Skip to content

Commit

Permalink
feat: add function regional_constraints_cv to automatically perform…
Browse files Browse the repository at this point in the history
… a K-Folds cv for finding optimal parameter values for constraint point minimization
  • Loading branch information
mdtanker committed Aug 2, 2024
1 parent 4982cc3 commit 06ebe79
Showing 1 changed file with 56 additions and 1 deletion.
57 changes: 56 additions & 1 deletion src/invert4geom/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,54 @@ def regional_constraints(
return grav_df


def regional_constraints_cv(
constraints_df: pd.DataFrame,
n_trials: int,
remove_starting_grav_mean: bool = False,
split_method: str = "KFold",
split_spacing: float | tuple[float, float] | None = None,
split_shape: tuple[float, float] | None = None,
n_splits: int = 5,
random_state: int = 10,
**kwargs: typing.Any,
) -> pd.DataFrame:
"""
separate the regional field by sampling and re-gridding at the constraint points
using cross-validation to find the best parameters
"""
df = constraints_df.copy()
# print("DF1\n", df)
df = df[df.columns.drop(list(df.filter(regex="fold_")))]

testing_training_df = cross_validation.split_test_train(
df,
method=split_method,
spacing=split_spacing,
shape=split_shape,
n_splits=n_splits,
random_state=random_state,
plot=False,
)
# print("DF2\n", testing_training_df)

log.addFilter(log_filter)

_, grav_df, _ = optimization.optimize_regional_constraint_point_minimization_kfolds(
testing_training_df=testing_training_df,
n_trials=n_trials,
plot=False,
plot_grid=False,
fold_progressbar=False,
separate_metrics=True,
remove_starting_grav_mean=remove_starting_grav_mean,
progressbar=False,
**kwargs,
)
log.removeFilter(log_filter)

return grav_df


def regional_separation(
method: str,
grav_df: pd.DataFrame,
Expand All @@ -571,7 +619,7 @@ def regional_separation(
----------
method : str
choose method to apply; one of "constant", "filter", "trend",
"eq_sources", "constraints".
"eq_sources", "constraints" or "constraints_cv".
grav_df : pd.DataFrame
gravity data with columns "easting", "northing", "gravity_anomaly", and
"starting_gravity".
Expand Down Expand Up @@ -634,5 +682,12 @@ def regional_separation(
regional_shift=regional_shift,
**kwargs,
)
if method == "constraints_cv":
return regional_constraints_cv(
grav_df=grav_df,
remove_starting_grav_mean=remove_starting_grav_mean,
**kwargs,
)

msg = "invalid string for regional method"
raise ValueError(msg)

0 comments on commit 06ebe79

Please sign in to comment.