Skip to content

Commit

Permalink
refactor: cv score functions to return inversion results as well as s…
Browse files Browse the repository at this point in the history
…cores
  • Loading branch information
mdtanker committed Aug 2, 2024
1 parent 32a8bdf commit 15e4369
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
23 changes: 12 additions & 11 deletions src/invert4geom/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def grav_cv_score(
rmse_as_median: bool = False,
plot: bool = False,
**kwargs: typing.Any,
) -> float:
) -> tuple[float, tuple[pd.DataFrame, pd.DataFrame, dict[str, typing.Any], float]]:
"""
Find the score, represented by the root mean (or median) squared error (RMSE),
between the testing gravity data, and the predict gravity data after and
Expand All @@ -134,9 +134,9 @@ def grav_cv_score(
Returns
-------
float
tuple[float, tuple[pd.DataFrame, pd.DataFrame, dict[str, typing.Any], float]]
a score, represented by the root mean squared error, between the testing gravity
data and the predicted gravity data.
data and the predicted gravity data, and a tuple of the inversion results.
References
----------
Expand Down Expand Up @@ -222,7 +222,7 @@ def grav_cv_score(
rmse_in_title=False,
)

return score
return score, results


@deprecation.deprecated( # type: ignore[misc]
Expand Down Expand Up @@ -316,7 +316,7 @@ def grav_optimal_parameter(
# update parameter value in kwargs
kwargs[param_name] = value
# run cross validation
score = grav_cv_score(
score, _ = grav_cv_score(
training_data=train,
testing_data=test,
rmse_as_median=rmse_as_median,
Expand Down Expand Up @@ -401,7 +401,7 @@ def constraints_cv_score(
constraints_df: pd.DataFrame,
rmse_as_median: bool = False,
**kwargs: typing.Any,
) -> float:
) -> tuple[float, tuple[pd.DataFrame, pd.DataFrame, dict[str, typing.Any], float]]:
"""
Find the score, represented by the root mean squared error (RMSE), between the
constraint point elevation, and the inverted topography at the constraint points.
Expand All @@ -419,9 +419,10 @@ def constraints_cv_score(
False
Returns
-------
float
a score, represented by the root mean squared error, between the testing gravity
data and the predicted gravity data.
tuple[float, tuple[pd.DataFrame, pd.DataFrame, dict[str, typing.Any], float]]
a score, represented by the root mean squared error, between the constraint
point elevation and the inverted topography at the constraint points, and a
tuple of the inversion results.
References
----------
Expand Down Expand Up @@ -459,7 +460,7 @@ def constraints_cv_score(

dif = constraints_df.upward - constraints_df.inverted_topo

return utils.rmse(dif, as_median=rmse_as_median)
return utils.rmse(dif, as_median=rmse_as_median), results


# pylint: disable=duplicate-code
Expand Down Expand Up @@ -693,7 +694,7 @@ def zref_density_optimal_parameter(
]
}
# run cross validation
score = constraints_cv_score(
score, _ = constraints_cv_score(
grav_df=grav_df,
constraints_df=constraints_df,
results_fname=f"{results_fname}_trial_{i}",
Expand Down
9 changes: 7 additions & 2 deletions src/invert4geom/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,14 +521,18 @@ def __call__(self, trial: optuna.trial) -> float:

trial.set_user_attr("fname", f"{self.fname}_trial_{trial.number}")

return cross_validation.grav_cv_score(
score, results = cross_validation.grav_cv_score(
solver_damping=damping,
progressbar=False,
results_fname=trial.user_attrs.get("fname"),
plot=self.plot_grids,
**new_kwargs,
)

trial.set_user_attr("results", results)

return score


def optimize_inversion_damping(
training_df: pd.DataFrame,
Expand Down Expand Up @@ -901,12 +905,13 @@ def __call__(self, trial: optuna.trial) -> float:
trial.set_user_attr("fname", f"{self.fname}_trial_{trial.number}")

# run cross validation
return cross_validation.constraints_cv_score(
score, _ = cross_validation.constraints_cv_score(
grav_df=grav_df,
constraints_df=self.constraints_df,
results_fname=trial.user_attrs.get("fname"),
**new_kwargs,
)
return score


def optimize_inversion_zref_density_contrast(
Expand Down

0 comments on commit 15e4369

Please sign in to comment.