Skip to content

Commit

Permalink
feat: add fname kwarg to some plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Nov 19, 2024
1 parent 2607284 commit 155f645
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def plot_cv_scores(
param_name: str = "Hyperparameter",
figsize: tuple[float, float] = (5, 3.5),
plot_title: str | None = None,
fname: str | None = None,
) -> None:
"""
plot a graph of cross-validation scores vs hyperparameter values
Expand All @@ -226,6 +227,8 @@ def plot_cv_scores(
size of the figure, by default (5, 3.5)
plot_title : str | None, optional
title of figure, by default None
fname : str | None, optional
filename to save figure, by default None
"""

sns.set_theme()
Expand Down Expand Up @@ -259,12 +262,16 @@ def plot_cv_scores(

plt.tight_layout()

if fname is not None:
plt.savefig(fname)


def plot_convergence(
results: pd.DataFrame,
params: dict[str, typing.Any],
inversion_region: tuple[float, float, float, float] | None = None,
figsize: tuple[float, float] = (5, 3.5),
fname: str | None = None,
) -> None:
"""
plot a graph of L2-norm and delta L2-norm vs iteration number.
Expand All @@ -279,6 +286,8 @@ def plot_convergence(
inside region of inversion, by default None
figsize : tuple[float, float], optional
width and height of figure, by default (5, 3.5)
fname : str | None, optional
filename to save figure, by default None
"""

sns.set_theme()
Expand Down Expand Up @@ -357,6 +366,10 @@ def plot_convergence(

plt.title("Inversion convergence")
plt.tight_layout()

if fname is not None:
plt.savefig(fname)

plt.show()


Expand Down

0 comments on commit 155f645

Please sign in to comment.