Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Feb 7, 2024
1 parent 947d591 commit c7c10a3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 73 deletions.
59 changes: 19 additions & 40 deletions src/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
Includes utilities for posterior diagnostics as well as some
inference functions.
"""
from scripts.io import ModelLoader

from scripts.io import ModelLoader
import argparse
from sbi.analysis import run_sbc, sbc_rank_plot, check_sbc, pairplot
from sbi.analysis import run_sbc, sbc_rank_plot, check_sbc
import numpy as np
from tqdm import tqdm

Expand All @@ -24,12 +24,9 @@


class Diagnose_generative:
def posterior_predictive(self,
theta_true,
x_true,
simulator,
posterior_samples,
true_sigma):
def posterior_predictive(
self, theta_true, x_true, simulator, posterior_samples, true_sigma
):
# not sure how or where to define the simulator
# could require that people input posterior predictive samples,
# already drawn from the simulator
Expand Down Expand Up @@ -98,9 +95,7 @@ def sbc_statistics(self,
if these values are close to 0.5, dap is like the prior distribution.
"""
check_stats = check_sbc(
ranks,
thetas,
dap_samples,
ranks, thetas, dap_samples,
num_posterior_samples=num_posterior_samples
)
return check_stats
Expand Down Expand Up @@ -195,11 +190,7 @@ def plot_cdf_1d_ranks(
plt.show()

def calculate_coverage_fraction(
self,
posterior,
thetas,
ys,
percentile_list,
self, posterior, thetas, ys, percentile_list,
samples_per_inference=1_000
):
"""
Expand All @@ -209,7 +200,8 @@ def calculate_coverage_fraction(
"""
# this holds all posterior samples for each inference run
all_samples = np.empty((len(ys), samples_per_inference,
all_samples = np.empty((len(ys),
samples_per_inference,
np.shape(thetas)[1]))
count_array = []
# make this for loop into a progress bar:
Expand Down Expand Up @@ -321,8 +313,8 @@ def plot_coverage_fraction(
)

ax.plot(
[0, 0.5, 1], [0, 0.5, 1],
"k--", lw=3, zorder=1000, label="Reference Line"
[0, 0.5, 1], [0, 0.5, 1], "k--", lw=3, zorder=1000,
label="Reference Line"
)
ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])
Expand Down Expand Up @@ -512,10 +504,7 @@ def generate_sbc_samples(
)
return thetas, ys, ranks, dap_samples

def sbc_statistics(self,
ranks,
thetas,
dap_samples,
def sbc_statistics(self, ranks, thetas, dap_samples,
num_posterior_samples):
"""
The ks pvalues are vanishingly small here,
Expand All @@ -533,9 +522,7 @@ def sbc_statistics(self,
if these values are close to 0.5, dap is like the prior distribution.
"""
check_stats = check_sbc(
ranks,
thetas,
dap_samples,
ranks, thetas, dap_samples,
num_posterior_samples=num_posterior_samples
)
return check_stats
Expand Down Expand Up @@ -590,7 +577,6 @@ def plot_1d_ranks(
if plot:
plt.show()


def plot_cdf_1d_ranks(
self,
ranks,
Expand Down Expand Up @@ -631,11 +617,7 @@ def plot_cdf_1d_ranks(
plt.show()

def calculate_coverage_fraction(
self,
posterior,
thetas,
ys,
percentile_list,
self, posterior, thetas, ys, percentile_list,
samples_per_inference=1_000
):
"""
Expand All @@ -650,8 +632,7 @@ def calculate_coverage_fraction(
count_array = []
# make this for loop into a progress bar:
for i in tqdm(
range(len(ys)),
desc="Sampling from the posterior for each obs",
range(len(ys)), desc="Sampling from the posterior for each obs",
unit="obs"
):
# for i in range(len(ys)):
Expand Down Expand Up @@ -684,11 +665,9 @@ def calculate_coverage_fraction(
# find the percentile for the posterior for this observation
# this is n_params dimensional
# the units are in parameter space
confidence_l = np.percentile(samples.cpu(),
percentile_l,
confidence_l = np.percentile(samples.cpu(), percentile_l,
axis=0)
confidence_u = np.percentile(samples.cpu(),
percentile_u,
confidence_u = np.percentile(samples.cpu(), percentile_u,
axis=0)
# this is asking if the true parameter value
# is contained between the
Expand Down Expand Up @@ -757,8 +736,8 @@ def plot_coverage_fraction(
)

ax.plot(
[0, 0.5, 1], [0, 0.5, 1],
"k--", lw=3, zorder=1000, label="Reference Line"
[0, 0.5, 1], [0, 0.5, 1], "k--", lw=3, zorder=1000,
label="Reference Line"
)
ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])
Expand Down
19 changes: 6 additions & 13 deletions src/scripts/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch


class ModelLoader:
def save_model_pkl(self, path, model_name, posterior):
"""
Expand Down Expand Up @@ -43,10 +44,7 @@ def predict(input, model):


class DataLoader:
def save_data_pkl(self,
data_name,
data,
path='../saveddata/'):
def save_data_pkl(self, data_name, data, path="../saveddata/"):
"""
Save and load the pkl'ed training/test set
Expand All @@ -58,9 +56,7 @@ def save_data_pkl(self,
with open(file_name, "wb") as file:
pickle.dump(data, file)

def load_data_pkl(self,
data_name,
path='../saveddata/'):
def load_data_pkl(self, data_name, path="../saveddata/"):
"""
Load the pkl'ed saved posterior model
Expand All @@ -73,10 +69,7 @@ def load_data_pkl(self,
data = pickle.load(file)
return data

def save_data_h5(self,
data_name,
data,
path='../saveddata/'):
def save_data_h5(self, data_name, data, path="../saveddata/"):
"""
Save data to an h5 file.
Expand All @@ -92,7 +85,7 @@ def save_data_h5(self,
for key, value in data_arrays.items():
file.create_dataset(key, data=value)

def load_data_h5(self, data_name, path='../saveddata/'):
def load_data_h5(self, data_name, path="../saveddata/"):
"""
Load data from an h5 file.
Expand All @@ -105,4 +98,4 @@ def load_data_h5(self, data_name, path='../saveddata/'):
with h5py.File(file_name, "r") as file:
for key in file.keys():
loaded_data[key] = torch.Tensor(file[key][...])
return loaded_data
return loaded_data
47 changes: 27 additions & 20 deletions src/scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# plotting style things:
import matplotlib
import matplotlib.pyplot as plt
from cycler import cycler
# from cycler import cycler

from typing import List, Union

Expand All @@ -21,11 +21,10 @@ def mackelab_corner_plot(
labels_list=None,
limit_list=None,
truth_list=None,
truth_color='red',
truth_color="red",
plot=False,
save=True,
path='plots/',

path="plots/",
):
"""
Uses existing pairplot from mackelab analysis
Expand All @@ -48,8 +47,7 @@ def mackelab_corner_plot(
truths=truth_list,
figsize=(5, 5),
)
axes[0, 1].plot([truth_list[1]], [truth_list[0]],
marker="o",
axes[0, 1].plot([truth_list[1]], [truth_list[0]], marker="o",
color=truth_color)
axes[0, 0].axvline(x=truth_list[0], color=truth_color)
axes[1, 1].axvline(x=truth_list[1], color=truth_color)
Expand All @@ -63,12 +61,14 @@ def getdist_corner_plot(
self,
posterior_samples: Union[List[np.ndarray], np.ndarray],
labels_list: List[str] = None,
limit_list: List[List[float]] = None, # Each inner list contains [lower_limit, upper_limit]
limit_list: List[
List[float]
] = None, # Each inner list contains [lower_limit, upper_limit]
truth_list: List[float] = None,
truth_color: str = 'orange',
truth_color: str = "orange",
plot: bool = False,
save: bool = True,
path: str = 'plots/',
path: str = "plots/",
):
"""
Uses existing getdist
Expand All @@ -87,10 +87,12 @@ def getdist_corner_plot(
# Handle the case where 'posterior_samples' is a list of samples
# You may want to customize this part based on your requirements
samples_list = [
MCSamples(samples=samps,
names=labels_list,
labels=labels_list,
ranges=limit_list)
MCSamples(
samples=samps,
names=labels_list,
labels=labels_list,
ranges=limit_list,
)
for samps in posterior_samples
]

Expand All @@ -101,7 +103,12 @@ def getdist_corner_plot(
g.triangle_plot(samples_list, filled=True)
else:
# Assume 'posterior_samples' is a 2D numpy array or similar
samples = MCSamples(samples=posterior_samples, names=labels_list, labels=labels_list, ranges=limit_list)
samples = MCSamples(
samples=posterior_samples,
names=labels_list,
labels=labels_list,
ranges=limit_list,
)

# Create a getdist Plotter
g = plots.get_subplot_plotter()
Expand All @@ -118,22 +125,22 @@ def getdist_corner_plot(
# which is on the diagnoal
g.subplots[i, j].axvline(x=truth_list[i],
color=truth_color)

try:
# plot as a point for the posteriors
g.subplots[int(1 + i), int(0 + j)].scatter(truth_list[0+i],
truth_list[1+i],
color=truth_color)
g.subplots[int(1 + i), int(0 + j)].scatter(
truth_list[0 + i], truth_list[1 + i],
color=truth_color
)
except IndexError:
continue

# Save or show the plot
if save:
plt.savefig(path + "getdist_cornerplot.pdf")

if plot:
plt.show()


def improved_corner_plot(self, posterior):
"""
Expand Down

0 comments on commit c7c10a3

Please sign in to comment.