Skip to content

Commit

Permalink
Update existing plots and metrics to work with 2d #72
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 14, 2024
1 parent 2f5e5a6 commit c05ebdc
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 70 deletions.
68 changes: 50 additions & 18 deletions src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,43 @@ def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.number_simulations)
self.q = np.zeros_like(self.p)

context_size = self.data.true_context().shape[-1]
self.outcome_given_p = np.zeros(
(self.number_simulations, context_size)
remove_first_dim = False

)
if self.data.simulator_dimensions == 1:
self.outcome_given_p = np.zeros((self.number_simulations, context_size))
elif self.data.simulator_dimensions == 2:
sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

sim_out_shape = np.product(sim_out_shape)
self.outcome_given_p = np.zeros((self.number_simulations, sim_out_shape))
else:
raise NotImplementedError("LC2ST only implemented for 1 or two dimensions.")

self.outcome_given_q = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros((self.number_simulations, context_size))

for index, p in enumerate(self.p):
context = self.data.simulator.generate_context(context_size)
self.outcome_given_p[index] = self.data.simulator.simulate(p, context)
# Q is the approximate posterior amortized in x
q = self.model.sample_posterior(1, context).ravel()
self.q[index] = q
self.outcome_given_q[index] = self.data.simulator.simulate(q, context)
self.evaluation_context[index] = context

p_outcome = self.data.simulator.simulate(p, context)
q_outcome = self.data.simulator.simulate(q, context)

if remove_first_dim:
p_outcome = p_outcome[0]
q_outcome = q_outcome[0]

self.outcome_given_p[index] = p_outcome.ravel()
self.outcome_given_q[index] = q_outcome.ravel() # Q is the approximate posterior amortized in x


self.evaluation_context = np.array(
[
self.data.simulator.generate_context(context_size)
for _ in range(self.num_simulations)
]
)

def train_linear_classifier(
self, p, q, x_p, x_q, classifier: str, classifier_kwargs: dict = {}
Expand Down Expand Up @@ -127,9 +141,21 @@ def _cross_eval_score(
cv_splits = kf.split(p)
# train classifiers over cv-folds
probabilities = []
self.evaluation_data = np.zeros(
(n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])
)

remove_first_dim = False
if self.data.simulator_dimensions == 1:
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1]))

elif self.data.simulator_dimensions == 2:
sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

sim_out_shape = np.product(sim_out_shape)
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), sim_out_shape))

self.prior_evaluation = np.zeros_like(p)

kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42)
Expand All @@ -142,10 +168,16 @@ def _cross_eval_score(
p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs
)
p_evaluate = p[val_index]

for index, p_validation in enumerate(p_evaluate):
self.evaluation_data[cross_trial][index] = self.data.simulator.simulate(
sim_output = self.data.simulator.simulate(
p_validation, self.evaluation_context[val_index][index]
)

if remove_first_dim:
sim_output = sim_output[0]
self.evaluation_data[cross_trial][index] = sim_output.ravel()

self.prior_evaluation[index] = p_validation
probabilities.append(
self._eval_model(
Expand Down
145 changes: 102 additions & 43 deletions src/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,41 @@ def __init__(
def _plot_name(self):
return "predictive_posterior_check.png"

def get_posterior(self, n_simulator_draws):
def get_posterior_2d(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference,context_shape[-1]))
sim_out_shape = self.data.get_simulator_output_shape()
remove_first_dim = False
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

self.posterior_predictive_samples = np.zeros((n_simulator_draws, *sim_out_shape))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)

random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws)
for index, sample in enumerate(random_context_indices):
context_sample = self.data.true_context()[sample, :]
posterior_sample = self.model.sample_posterior(1, context_sample)

# get the posterior samples for that context
sim_out_posterior = self.data.simulator.simulate(
theta=posterior_sample, context_samples = context_sample
)
sim_out_true = self.data.simulator.simulate(
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)
if remove_first_dim:
sim_out_posterior = sim_out_posterior[0]
sim_out_true = sim_out_true[0]

self.posterior_predictive_samples[index] = sim_out_posterior
self.posterior_true_samples[index] = sim_out_true


def get_posterior_1d(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, context_shape[-1]))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)
self.context = np.zeros((n_simulator_draws, context_shape[-1]))

Expand All @@ -48,70 +80,97 @@ def get_posterior(self, n_simulator_draws):
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)

def _plot_1d(self,
subplots: np.ndarray,
subplot_index: int,
n_coverage_sigma: Optional[int] = 3,
theta_true_marker: Optional[str] = '^'
):

dimension_y_simulation = self.posterior_predictive_samples[subplot_index]
y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), self.colors):
subplots[0, subplot_index].fill_between(
self.context[subplot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, subplot_index].plot(
self.context[subplot_index],
y_simulation_mean - self.true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, subplot_index].plot(
self.context[subplot_index],
y_simulation_mean + self.true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[subplot_index, :, :], axis=0).ravel()
subplots[1, subplot_index].scatter(
self.context[subplot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)

def _plot_2d(self, subplots, subplot_index, include_axis_ticks):
subplots[1, subplot_index].imshow(self.posterior_predictive_samples[subplot_index])
subplots[0, subplot_index].imshow(self.posterior_true_samples[subplot_index])

if not include_axis_ticks:
subplots[1, subplot_index].set_xticks([])
subplots[1, subplot_index].set_yticks([])

subplots[0, subplot_index].set_xticks([])
subplots[0, subplot_index].set_yticks([])

def _plot(
self,
n_coverage_sigma: Optional[int] = 3,
true_sigma: Optional[float] = None,
theta_true_marker: Optional[str] = '^',
n_unique_plots: Optional[int] = 3,
include_axis_ticks: bool = False,
title:str="Predictive Posterior",
y_label:str="Simulation Output",
x_label:str="X"):

if self.data.simulator_dimensions == 1:
self.get_posterior_1d(n_unique_plots)
self.true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()
self.colors = get_hex_colors(n_coverage_sigma, self.colorway)

self.get_posterior(n_unique_plots)
true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()
elif self.data.simulator_dimensions == 2:
self.get_posterior_2d(n_unique_plots)

else:
raise NotImplementedError("Posterior Checks only implemented for 1 or two dimensions.")

figure, subplots = plt.subplots(
2,
n_unique_plots,
figsize=(int(self.figure_size[0]*n_unique_plots*.6), self.figure_size[1]),
sharex=False,
sharey=True
)
colors = get_hex_colors(n_coverage_sigma, self.colorway)

for plot_index in range(n_unique_plots):
if self.data.simulator_dimensions == 1:
self._plot_1d(subplots, plot_index, n_coverage_sigma, theta_true_marker)

dimension_y_simulation = self.posterior_predictive_samples[plot_index]

y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), colors):
subplots[0, plot_index].fill_between(
self.context[plot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean - true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean + true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[plot_index, :, :], axis=0).ravel()
subplots[1, plot_index].scatter(
self.context[plot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)
else:
self._plot_2d(subplots, plot_index, include_axis_ticks)

subplots[1, -1].legend()
subplots[0, -1].legend()

subplots[1, 0].set_ylabel("True Parameters")
subplots[0, 0].set_ylabel("Predicted Parameters")
Expand Down
35 changes: 26 additions & 9 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,72 @@

@pytest.fixture
def plot_config(config_factory):
out_dir = "./temp_results/"
metrics_settings = {
"use_progress_bar": False,
"samples_per_inference": 10,
"percentiles": [95],
}
config = config_factory(out_dir=out_dir, metrics_settings=metrics_settings)
Config(config)
config = config_factory(metrics_settings=metrics_settings)
return config


def test_all_defaults(plot_config, mock_model, mock_data):
"""
Ensures each metric has a default set of parameters and is included in the defaults list
Ensures each test can initialize, regardless of the veracity of the output
"""
Config(plot_config)
for plot_name, plot_obj in Plots.items():
assert plot_name in Defaults["plots"]
plot_obj(mock_model, mock_data, save=True, show=False)


def test_plot_cdf(plot_config, mock_model, mock_data):
Config(plot_config)
plot = CDFRanks(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "CDFRanks", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")


def test_plot_ranks(plot_config, mock_model, mock_data):
Config(plot_config)
plot = Ranks(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "Ranks", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")


def test_plot_coverage(plot_config, mock_model, mock_data):
Config(plot_config)
plot = CoverageFraction(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "CoverageFraction", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")


def test_plot_tarp(plot_config, mock_model, mock_data):
Config(plot_config)
plot = TARP(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "TARP", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

def test_lc2st(plot_config, mock_model, mock_data):
def test_lc2st(plot_config, mock_model, mock_data, mock_2d_data, result_output):
Config(plot_config)
plot = LocalTwoSampleTest(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "LC2ST", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

def test_ppc(plot_config, mock_model, mock_data):
plot = LocalTwoSampleTest(
mock_model, mock_2d_data, save=True, show=False,
out_dir=f"{result_output.strip('/')}/mock_2d/")
assert type(plot.data.simulator).__name__ == "Mock2DSimulator"
plot(**get_item("plots", "LC2ST", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output):
Config(plot_config)
plot = PPC(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "PPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

plot = PPC(
mock_model,
mock_2d_data, save=True, show=False,
out_dir=f"{result_output.strip('/')}/mock_2d/")
assert type(plot.data.simulator).__name__ == "Mock2DSimulator"
plot(**get_item("plots", "PPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

0 comments on commit c05ebdc

Please sign in to comment.