Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parity plots with difference, percentage, and residuals #29 #73

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from plots.tarp import TARP
from plots.local_two_sample import LocalTwoSampleTest
from plots.predictive_posterior_check import PPC
from plots.parity import Parity

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
"LC2ST": LocalTwoSampleTest,
PPC.__name__: PPC
PPC.__name__: PPC,
"Parity": Parity
}
134 changes: 134 additions & 0 deletions src/plots/parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Optional, Sequence
import matplotlib.pyplot as plt
import numpy as np

from plots.plot import Display

class Parity(Display):
def __init__(
self,
model,
data,
save:bool,
show:bool,
out_dir:Optional[str]=None,
percentiles: Optional[Sequence] = None,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
number_simulations: Optional[int] = None,
parameter_names: Optional[Sequence] = None,
parameter_colors: Optional[Sequence]= None,
colorway: Optional[str]=None
):
super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

def _plot_name(self):
return "parity.png"

def get_posterior(self, n_samples):
context_shape = self.data.true_context().shape
self.posterior_sample_mean = np.zeros((n_samples, self.data.n_dims))
self.posterior_sample_std = np.zeros_like(self.posterior_sample_mean)
self.true_samples = np.zeros_like(self.posterior_sample_mean)

random_context_indices = self.data.rng.integers(0, context_shape[0], n_samples)
for index, sample in enumerate(random_context_indices):

posterior_sample = self.model.sample_posterior(self.samples_per_inference, self.data.true_context()[sample, :]).numpy()
self.posterior_sample_mean[index] = np.mean(posterior_sample, axis=0)
self.posterior_sample_std[index] = np.std(posterior_sample, axis=0)

self.true_samples[index] = self.data.get_theta_true()[sample, :]


def _plot(
self,
n_samples: int = 80,
include_difference: bool = False,
include_residual: bool = False,
include_percentage: bool = False,
show_ideal: bool = True,
errorbar_color: str = 'black',
title:str="Parity",
y_label:str=r"$\theta_{predicted}$",
x_label:str=r"$\theta_{true}$"
):
self.get_posterior(n_samples)

# parity - predicted vs true
# parity difference plot = true - predicted vs. true (y-axis vs x-axis)
# residual: (true - predicted / true) vs. true
# percentage: (true - predicted / true)*100 vs. true

height_ratios = [3]
n_rows = 1
if include_difference:
n_rows += 1
height_ratios.append(1)
if include_residual:
n_rows += 1
height_ratios.append(1)
if include_percentage:
n_rows += 1
height_ratios.append(1)

figure, subplots = plt.subplots(
nrows=n_rows,
ncols=self.data.n_dims,
figsize=(int(self.figure_size[0]*self.data.n_dims*.8), int(self.figure_size[1]*n_rows*.6)),
height_ratios=height_ratios,
sharex="col",
sharey=False)

figure.suptitle(title)
figure.supxlabel(x_label)
figure.supylabel(y_label)

for theta_dimension in range(self.data.n_dims):

true = self.true_samples[:, theta_dimension]
posterior_sample = self.posterior_sample_mean[:, theta_dimension]
posterior_errorbar = self.posterior_sample_std[:, theta_dimension]

title = self.parameter_names[theta_dimension]

if n_rows != 1:
parity_plot = subplots[0, theta_dimension]
subplots[0, 0].set_ylabel("Parity")

else:
parity_plot = subplots[theta_dimension]
subplots[0].set_ylabel("Parity")


parity_plot.title.set_text(title)
parity_plot.errorbar(true, posterior_sample, yerr=posterior_errorbar, fmt="o", ecolor=errorbar_color)

if show_ideal:
parity_plot.plot([0, 1], [0, 1], transform=parity_plot.transAxes, color='black', linestyle="--")

row_index = 1
if include_difference:
subplots[row_index, 0].set_ylabel("Difference")
subplots[row_index, theta_dimension].scatter(true, true-posterior_sample)
if show_ideal:
subplots[row_index, theta_dimension].hlines(0, xmin = true.min(), xmax=true.max(), alpha=0.4, color='black', linestyle="--")

row_index += 1

if include_residual:
subplots[row_index, 0].set_ylabel("Residuals")
subplots[row_index, theta_dimension].scatter(true, (true-posterior_sample)/true)
if show_ideal:
subplots[row_index, theta_dimension].hlines(0, xmin = true.min(), xmax=true.max(), alpha=0.4, color='black', linestyle="--")

row_index += 1

if include_percentage:
subplots[row_index, 0].set_ylabel("Percentage")
subplots[row_index, theta_dimension].scatter(true, (true-posterior_sample)*100/true)
if show_ideal:
subplots[row_index, theta_dimension].hlines(0, xmin = true.min(), xmax=true.max(), alpha=0.4, color='black', linestyle="--")

row_index += 1

5 changes: 3 additions & 2 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
"Ranks": {"num_bins": None},
"CoverageFraction": {},
"TARP": {
"coverage_sigma": 3 # How many sigma to show coverage over
"coverage_sigma": 3
},
"LC2ST": {}
"LC2ST": {},
"Parity":{}
},
"metrics_common": {
"use_progress_bar": False,
Expand Down
29 changes: 27 additions & 2 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
CoverageFraction,
TARP,
LocalTwoSampleTest,
PPC
PPC,
Parity
)


Expand Down Expand Up @@ -68,4 +69,28 @@ def test_lc2st(plot_config, mock_model, mock_data):
def test_ppc(plot_config, mock_model, mock_data):
plot = PPC(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "PPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")


def test_parity(plot_config, mock_model, mock_data):
plot = Parity(mock_model, mock_data, save=True, show=False)

plot(include_difference= False,
include_residual = False,
include_percentage = False)

assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")
os.remove(f"{plot.out_dir}/{plot.plot_name}")

plot(include_difference= True,
include_residual = False,
include_percentage = True)

assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")
os.remove(f"{plot.out_dir}/{plot.plot_name}")

plot(include_difference= True,
include_residual = True,
include_percentage = True)

assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")