From 6712b7d7b05eb7c7ca730fa3cc05233b04b552b6 Mon Sep 17 00:00:00 2001 From: voetberg Date: Fri, 14 Jun 2024 09:34:26 -0500 Subject: [PATCH] Parity plots with difference, percentage, and residuals #29 --- src/plots/__init__.py | 4 +- src/plots/parity.py | 134 ++++++++++++++++++++++++++++++++++++++++++ src/utils/defaults.py | 5 +- tests/test_plots.py | 29 ++++++++- 4 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 src/plots/parity.py diff --git a/src/plots/__init__.py b/src/plots/__init__.py index b037003..45e4d4a 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -4,6 +4,7 @@ from plots.tarp import TARP from plots.local_two_sample import LocalTwoSampleTest from plots.predictive_posterior_check import PPC +from plots.parity import Parity Plots = { CDFRanks.__name__: CDFRanks, @@ -11,5 +12,6 @@ Ranks.__name__: Ranks, TARP.__name__: TARP, "LC2ST": LocalTwoSampleTest, - PPC.__name__: PPC + PPC.__name__: PPC, + "Parity": Parity } diff --git a/src/plots/parity.py b/src/plots/parity.py new file mode 100644 index 0000000..b504756 --- /dev/null +++ b/src/plots/parity.py @@ -0,0 +1,134 @@ +from typing import Optional, Sequence +import matplotlib.pyplot as plt +import numpy as np + +from plots.plot import Display + +class Parity(Display): + def __init__( + self, + model, + data, + save:bool, + show:bool, + out_dir:Optional[str]=None, + percentiles: Optional[Sequence] = None, + use_progress_bar: Optional[bool] = None, + samples_per_inference: Optional[int] = None, + number_simulations: Optional[int] = None, + parameter_names: Optional[Sequence] = None, + parameter_colors: Optional[Sequence]= None, + colorway: Optional[str]=None + ): + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) + + def _plot_name(self): + return "parity.png" + + def get_posterior(self, n_samples): + context_shape = self.data.true_context().shape + self.posterior_sample_mean = np.zeros((n_samples, self.data.n_dims)) + self.posterior_sample_std = np.zeros_like(self.posterior_sample_mean) + self.true_samples = np.zeros_like(self.posterior_sample_mean) + + random_context_indices = self.data.rng.integers(0, context_shape[0], n_samples) + for index, sample in enumerate(random_context_indices): + + posterior_sample = self.model.sample_posterior(self.samples_per_inference, self.data.true_context()[sample, :]).numpy() + self.posterior_sample_mean[index] = np.mean(posterior_sample, axis=0) + self.posterior_sample_std[index] = np.std(posterior_sample, axis=0) + + self.true_samples[index] = self.data.get_theta_true()[sample, :] + + + def _plot( + self, + n_samples: int = 80, + include_difference: bool = False, + include_residual: bool = False, + include_percentage: bool = False, + show_ideal: bool = True, + errorbar_color: str = 'black', + title:str="Parity", + y_label:str=r"$\theta_{predicted}$", + x_label:str=r"$\theta_{true}$" + ): + self.get_posterior(n_samples) + + # parity - predicted vs true + # parity difference plot = true - predicted vs. true (y-axis vs x-axis) + # residual: (true - predicted / true) vs. true + # percentage: (true - predicted / true)*100 vs. true + + height_ratios = [3] + n_rows = 1 + if include_difference: + n_rows += 1 + height_ratios.append(1) + if include_residual: + n_rows += 1 + height_ratios.append(1) + if include_percentage: + n_rows += 1 + height_ratios.append(1) + + figure, subplots = plt.subplots( + nrows=n_rows, + ncols=self.data.n_dims, + figsize=(int(self.figure_size[0]*self.data.n_dims*.8), int(self.figure_size[1]*n_rows*.6)), + height_ratios=height_ratios, + sharex="col", + sharey=False) + + figure.suptitle(title) + figure.supxlabel(x_label) + figure.supylabel(y_label) + + for theta_dimension in range(self.data.n_dims): + + true = self.true_samples[:, theta_dimension] + posterior_sample = self.posterior_sample_mean[:, theta_dimension] + posterior_errorbar = self.posterior_sample_std[:, theta_dimension] + + title = self.parameter_names[theta_dimension] + + if n_rows != 1: + parity_plot = subplots[0, theta_dimension] + subplots[0, 0].set_ylabel("Parity") + + else: + parity_plot = subplots[theta_dimension] + subplots[0].set_ylabel("Parity") + + + parity_plot.title.set_text(title) + parity_plot.errorbar(true, posterior_sample, yerr=posterior_errorbar, fmt="o", ecolor=errorbar_color) + + if show_ideal: + parity_plot.plot([0, 1], [0, 1], transform=parity_plot.transAxes, color='black', linestyle="--") + + row_index = 1 + if include_difference: + subplots[row_index, 0].set_ylabel("Difference") + subplots[row_index, theta_dimension].scatter(true, true-posterior_sample) + if show_ideal: + subplots[row_index, theta_dimension].hlines(0, xmin = true.min(), xmax=true.max(), alpha=0.4, color='black', linestyle="--") + + row_index += 1 + + if include_residual: + subplots[row_index, 0].set_ylabel("Residuals") + subplots[row_index, theta_dimension].scatter(true, (true-posterior_sample)/true) + if show_ideal: + subplots[row_index, theta_dimension].hlines(0, xmin = true.min(), xmax=true.max(), alpha=0.4, color='black', linestyle="--") + + row_index += 1 + + if include_percentage: + subplots[row_index, 0].set_ylabel("Percentage") + subplots[row_index, theta_dimension].scatter(true, (true-posterior_sample)*100/true) + if show_ideal: + subplots[row_index, theta_dimension].hlines(0, xmin = true.min(), xmax=true.max(), alpha=0.4, color='black', linestyle="--") + + row_index += 1 + \ No newline at end of file diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 3956bb2..89c6c5b 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -27,9 +27,10 @@ "Ranks": {"num_bins": None}, "CoverageFraction": {}, "TARP": { - "coverage_sigma": 3 # How many sigma to show coverage over + "coverage_sigma": 3 }, - "LC2ST": {} + "LC2ST": {}, + "Parity":{} }, "metrics_common": { "use_progress_bar": False, diff --git a/tests/test_plots.py b/tests/test_plots.py index f32e546..a600630 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -11,7 +11,8 @@ CoverageFraction, TARP, LocalTwoSampleTest, - PPC + PPC, + Parity ) @@ -68,4 +69,28 @@ def test_lc2st(plot_config, mock_model, mock_data): def test_ppc(plot_config, mock_model, mock_data): plot = PPC(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "PPC", raise_exception=False)) - assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") \ No newline at end of file + assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") + + +def test_parity(plot_config, mock_model, mock_data): + plot = Parity(mock_model, mock_data, save=True, show=False) + + plot(include_difference= False, + include_residual = False, + include_percentage = False) + + assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") + os.remove(f"{plot.out_dir}/{plot.plot_name}") + + plot(include_difference= True, + include_residual = False, + include_percentage = True) + + assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") + os.remove(f"{plot.out_dir}/{plot.plot_name}") + + plot(include_difference= True, + include_residual = True, + include_percentage = True) + + assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")