Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel committed Nov 15, 2023
1 parent eb0a780 commit 58b84c4
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 76 deletions.
46 changes: 12 additions & 34 deletions ema_workbench/analysis/regional_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,45 +82,15 @@ def plot_discrete_cdf(ax, unc, x, y, xticklabels_on, ccdf):
x_plot = [j * 1, j * 1 + 1]
y_plot = [freq] * 2

ax.plot(x_plot, y_plot, c=cp[i + 1], label=i == 1)
ax.scatter(
x_plot[0],
y_plot[0],
edgecolors=cp[i + 1],
facecolors=cp[i + 1],
linewidths=1,
zorder=2,
)
ax.scatter(
x_plot[1],
y_plot[0],
edgecolors=cp[i + 1],
facecolors="white",
linewidths=1,
zorder=2,
)
ax.plot(x_plot, y_plot, c=cp[i + 1], label=i == 1, marker="o")

# misnomer
cum_freq_un = (j + 1) / n_cat
if ccdf:
cum_freq_un = (len(freqs) - j - 1) / n_cat

ax.plot(x_plot, [cum_freq_un] * 2, lw=1, c="darkgrey", zorder=1, label="x==y")
ax.scatter(
x_plot[0],
cum_freq_un,
edgecolors="darkgrey",
facecolors="darkgrey",
linewidths=1,
zorder=1,
)
ax.scatter(
x_plot[1],
cum_freq_un,
edgecolors="darkgrey",
facecolors="white",
linewidths=1,
zorder=1,
ax.plot(
x_plot, [cum_freq_un] * 2, lw=1, c="darkgrey", zorder=1, label="x==y", marker="o"
)

ax.set_xticklabels([])
Expand Down Expand Up @@ -256,10 +226,18 @@ def plot_cdfs(x, y, ccdf=False):
x = x.copy()

try:
x = x.drop("scenario", axis=1)
x = x.drop("scenario_id", axis=1)
except KeyError:
pass

for entry in ["model", "policy"]:
if x.loc[:, entry].unique().shape != (1,):
continue
try:
x = x.drop(entry, axis=1)
except KeyError:
pass

uncs = x.columns.tolist()
cp = sns.color_palette()

Expand Down
49 changes: 27 additions & 22 deletions ema_workbench/analysis/scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def plot_pair_wise_scatter(
restricted_dims,
diag="kde",
upper="scatter",
lower="contour",
lower="hist",
fill_subplots=True,
):
"""helper function for pair wise scatter plotting
Expand Down Expand Up @@ -399,7 +399,9 @@ def plot_pair_wise_scatter(
elif upper == "none":
None
else:
raise NotImplementedError(f"upper = {upper} not implemented.")
raise NotImplementedError(
f"upper = {upper} not implemented. Use either 'scatter', 'contour', 'hist' (bivariate histogram) or None plots for upper triangle."
)

# lower triangle
if lower == "contour":
Expand All @@ -417,42 +419,45 @@ def plot_pair_wise_scatter(
elif lower == "none":
raise ValueError(f"Lower triangle cannot be none.")
else:
raise NotImplementedError(f"lower = {lower} not implemented.")
raise NotImplementedError(
f"lower = {lower} not implemented. Use either 'scatter', 'contour' or 'hist' (bivariate histogram) plots for lower triangle."
)

# diagonal
if diag == "cdf":
grid.map_diag(sns.ecdfplot)
elif diag == "kde":
grid.map_diag(sns.kdeplot, fill=False, common_norm=False, cut=0)
else:
raise NotImplementedError(f"diag = {diag} not implemented.")
raise NotImplementedError(
f"diag = {diag} not implemented. Use either 'kde' (kernel density estimate) or 'cdf' (cumulative density function)."
)

# draw box
pad = 0.1

cats = set(categorical_columns)
for row, ylabel in zip(grid.axes, grid.y_vars):
ylim = boxlim[ylabel]

if ylabel in cats:
height = (len(ylim[0]) - 1) + pad
y = -pad / 2
# y = -0.2
# height = len(ylim[0]) - 0.6 # 2 * 0.2
else:
y = ylim[0]
height = ylim[1] - ylim[0]

for ax, xlabel in zip(row, grid.x_vars):
if ylabel == xlabel:
continue

xrange = ax.get_xlim()[1] - ax.get_xlim()[0]
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]

ylim = boxlim[ylabel]

if ylabel in cats:
height = (len(ylim[0]) - 1) + pad * yrange
y = -yrange * pad / 2
else:
y = ylim[0]
height = ylim[1] - ylim[0]

if xlabel in cats:
xlim = boxlim.at[0, xlabel]
width = (len(xlim) - 1) + pad
x = -pad / 2
# x = -0.2
# width = len(xlim) - 0.6 # 2 * 0.2
width = (len(xlim) - 1) + pad * xrange
x = -xrange * pad / 2
else:
xlim = boxlim[xlabel]
x = xlim[0]
Expand All @@ -462,7 +467,7 @@ def plot_pair_wise_scatter(
box = patches.Rectangle(
xy, width, height, edgecolor="red", facecolor="none", lw=3, zorder=100
)
if ax.has_data() == True:
if ax.has_data(): # keeps box from being drawn in upper triangle if empty
ax.add_patch(box)
else:
ax.set_axis_off()
Expand Down Expand Up @@ -508,15 +513,15 @@ def plot_pair_wise_scatter(
upper = data[subplot.get_xlabel()].max()
lower = data[subplot.get_xlabel()].min()

pad_rel = (upper - lower) * pad # padding relative to range of data points
pad_rel = (upper - lower) * 0.1 # padding relative to range of data points

subplot.set_xlim(lower - pad_rel, upper + pad_rel)

if subplot.get_ylabel() != "":
upper = data[subplot.get_ylabel()].max()
lower = data[subplot.get_ylabel()].min()

pad_rel = (upper - lower) * pad # padding relative to range of data points
pad_rel = (upper - lower) * 0.1 # padding relative to range of data points

subplot.set_ylim(lower - pad_rel, upper + pad_rel)

Expand Down
10 changes: 10 additions & 0 deletions ema_workbench/em_framework/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ def __init__(self, name, model_name, policy, scenario, experiment_id):
self.model_name = model_name
self.scenario = scenario

def __repr__(self):
return (
f"Experiment(name={self.name!r}, model_name={self.model_name!r}, "
f"policy={self.policy!r}, scenario={self.scenario!r}, "
f"experiment_id={self.experiment_id!r})"
)

def __str__(self):
return f"Experiment {self.experiment_id} (model: {self.model_name}, policy: {self.policy.name}, scenario: {self.scenario.name})"


class ExperimentReplication(NamedDict):
"""helper class that combines scenario, policy, any constants, and
Expand Down
50 changes: 30 additions & 20 deletions ema_workbench/examples/example_lake_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,47 @@ def lake_problem(
decisions = [kwargs[str(i)] for i in range(100)]
except KeyError:
decisions = [0] * 100
print("No valid decisions found, using 0 water release every year as default")

Pcrit = brentq(lambda x: x**q / (1 + x**q) - b * x, 0.01, 1.5)
nvars = len(decisions)
X = np.zeros((nvars,))
average_daily_P = np.zeros((nvars,))
decisions = np.array(decisions)
reliability = 0.0

for _ in range(nsamples):
X[0] = 0.0
# Calculate the critical pollution level (Pcrit)
Pcrit = brentq(lambda x: x**q / (1 + x**q) - b * x, 0.01, 1.5)

natural_inflows = np.random.lognormal(
math.log(mean**2 / math.sqrt(stdev**2 + mean**2)),
math.sqrt(math.log(1.0 + stdev**2 / mean**2)),
size=nvars,
# Generate natural inflows using lognormal distribution
natural_inflows = np.random.lognormal(
mean=math.log(mean**2 / math.sqrt(stdev**2 + mean**2)),
sigma=math.sqrt(math.log(1.0 + stdev**2 / mean**2)),
size=(nsamples, nvars),
)

# Initialize the pollution level matrix X
X = np.zeros((nsamples, nvars))

# Loop through time to compute the pollution levels
for t in range(1, nvars):
X[:, t] = (
(1 - b) * X[:, t - 1]
+ (X[:, t - 1] ** q / (1 + X[:, t - 1] ** q))
+ decisions[t - 1]
+ natural_inflows[:, t - 1]
)

for t in range(1, nvars):
X[t] = (
(1 - b) * X[t - 1]
+ X[t - 1] ** q / (1 + X[t - 1] ** q)
+ decisions[t - 1]
+ natural_inflows[t - 1]
)
average_daily_P[t] += X[t] / float(nsamples)
# Calculate the average daily pollution for each time step
average_daily_P = np.mean(X, axis=0)

reliability += np.sum(X < Pcrit) / float(nsamples * nvars)
# Calculate the reliability (probability of the pollution level being below Pcrit)
reliability = np.sum(X < Pcrit) / float(nsamples * nvars)

# Calculate the maximum pollution level (max_P)
max_P = np.max(average_daily_P)

# Calculate the utility by discounting the decisions using the discount factor (delta)
utility = np.sum(alpha * decisions * np.power(delta, np.arange(nvars)))
inertia = np.sum(np.absolute(np.diff(decisions)) < 0.02) / float(nvars - 1)

# Calculate the inertia (the fraction of time steps with changes larger than 0.02)
inertia = np.sum(np.abs(np.diff(decisions)) > 0.02) / float(nvars - 1)

return max_P, utility, inertia, reliability

Expand Down
18 changes: 18 additions & 0 deletions ema_workbench/examples/regional_sa_flu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
""" A simple example of performing regional sensitivity analysis
"""
import matplotlib.pyplot as plt

from ema_workbench.analysis import regional_sa
from ema_workbench import ema_logging, load_results

fn = "./data/1000 flu cases with policies.tar.gz"
results = load_results(fn)
x, outcomes = results

y = outcomes["deceased population region 1"][:, -1] > 1000000

fig = regional_sa.plot_cdfs(x, y)

plt.show()

0 comments on commit 58b84c4

Please sign in to comment.