Skip to content

Commit

Permalink
feat: add plotting functions for latin hypercube sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Nov 19, 2024
1 parent ef5239e commit 2607284
Showing 1 changed file with 99 additions and 0 deletions.
99 changes: 99 additions & 0 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import verde as vd
import xarray as xr
from IPython.display import clear_output
from numpy.typing import NDArray
from polartoolkit import maps
from polartoolkit import utils as polar_utils

Expand Down Expand Up @@ -1437,3 +1438,101 @@ def remove_df_from_hoverdata(
text = tuple(text) # type: ignore[assignment]

return plot.update_traces(text=text)


def plot_latin_hypercube(
params_dict: dict[str, dict[str, typing.Any]],
plot_individual_dists: bool = True,
plot_2d_projections: bool = True,
) -> None:
"""
With a dictionary of parameters and their sampled values, plot the individual
distributions and or the 2D projections of the parameter pairs.
Parameters
----------
params_dict : dict[str, dict[str, typing.Any]]
dictionary of sampled parameter values, can be created manually or from the
output of func:`.uncertainty.create_lhc`
plot_individual_dists : bool, optional
choose to plot distribution of each parameter, by default True
plot_2d_projections : bool, optional
choose to plot the 2D projection of each parameter pair, by default True
"""
df = pd.DataFrame(
[params_dict[x]["sampled_values"] for x in params_dict],
).transpose()

df.columns = params_dict.keys()

# plot individual variables
if plot_individual_dists is True:
_, axes = plt.subplots(
1,
len(df.columns),
figsize=(3 * len(df.columns), 1.8),
)

for i, j in enumerate(df.columns):
sns.kdeplot(
ax=axes[i],
data=df,
x=j,
)
sns.rugplot(ax=axes[i], data=df, x=j, linewidth=2.5, height=0.07)
axes[i].set_xlabel(j.replace("_", " ").capitalize())
axes[i].ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
axes[i].set_ylabel(None)

plt.show()

dim = np.shape(df)[1]

param_values = df.values

problem = {
"num_vars": dim,
"names": [i.replace("_", " ") for i in df.columns],
"bounds": [[-1, 1]] * dim,
}

# Rescale to the unit hypercube for the analysis
sample = utils.scale_normalized(param_values, problem["bounds"])

# 2D projection
if plot_2d_projections:
projection_2d(sample, problem["names"])


def projection_2d(
sample: NDArray,
var_names: list[str],
) -> None:
"""
Plots the sample projected on each 2D plane
Parameters
----------
sample : NDArray
The sampled values
var_names : list[str]
The names of the variables
"""
dim = sample.shape[1]

for i in range(dim):
for j in range(dim):
plt.subplot(dim, dim, i * dim + j + 1)
plt.scatter(
sample[:, j],
sample[:, i],
s=2,
)
if j == 0:
plt.ylabel(var_names[i], rotation=0, ha="right")
if i == dim - 1:
plt.xlabel(var_names[j], rotation=20, ha="right")

plt.xticks([])
plt.yticks([])
plt.show()

0 comments on commit 2607284

Please sign in to comment.