From c7c10a325b85484fb48e1deb7cee0440c45fdb7d Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 7 Feb 2024 08:44:46 -0700 Subject: [PATCH] flake8 --- src/scripts/evaluate.py | 59 +++++++++++++---------------------------- src/scripts/io.py | 19 +++++-------- src/scripts/plot.py | 47 ++++++++++++++++++-------------- 3 files changed, 52 insertions(+), 73 deletions(-) diff --git a/src/scripts/evaluate.py b/src/scripts/evaluate.py index 97ea8f0..5275656 100644 --- a/src/scripts/evaluate.py +++ b/src/scripts/evaluate.py @@ -5,10 +5,10 @@ Includes utilities for posterior diagnostics as well as some inference functions. """ -from scripts.io import ModelLoader +from scripts.io import ModelLoader import argparse -from sbi.analysis import run_sbc, sbc_rank_plot, check_sbc, pairplot +from sbi.analysis import run_sbc, sbc_rank_plot, check_sbc import numpy as np from tqdm import tqdm @@ -24,12 +24,9 @@ class Diagnose_generative: - def posterior_predictive(self, - theta_true, - x_true, - simulator, - posterior_samples, - true_sigma): + def posterior_predictive( + self, theta_true, x_true, simulator, posterior_samples, true_sigma + ): # not sure how or where to define the simulator # could require that people input posterior predictive samples, # already drawn from the simulator @@ -98,9 +95,7 @@ def sbc_statistics(self, if these values are close to 0.5, dap is like the prior distribution. """ check_stats = check_sbc( - ranks, - thetas, - dap_samples, + ranks, thetas, dap_samples, num_posterior_samples=num_posterior_samples ) return check_stats @@ -195,11 +190,7 @@ def plot_cdf_1d_ranks( plt.show() def calculate_coverage_fraction( - self, - posterior, - thetas, - ys, - percentile_list, + self, posterior, thetas, ys, percentile_list, samples_per_inference=1_000 ): """ @@ -209,7 +200,8 @@ def calculate_coverage_fraction( """ # this holds all posterior samples for each inference run - all_samples = np.empty((len(ys), samples_per_inference, + all_samples = np.empty((len(ys), + samples_per_inference, np.shape(thetas)[1])) count_array = [] # make this for loop into a progress bar: @@ -321,8 +313,8 @@ def plot_coverage_fraction( ) ax.plot( - [0, 0.5, 1], [0, 0.5, 1], - "k--", lw=3, zorder=1000, label="Reference Line" + [0, 0.5, 1], [0, 0.5, 1], "k--", lw=3, zorder=1000, + label="Reference Line" ) ax.set_xlim([-0.05, 1.05]) ax.set_ylim([-0.05, 1.05]) @@ -512,10 +504,7 @@ def generate_sbc_samples( ) return thetas, ys, ranks, dap_samples - def sbc_statistics(self, - ranks, - thetas, - dap_samples, + def sbc_statistics(self, ranks, thetas, dap_samples, num_posterior_samples): """ The ks pvalues are vanishingly small here, @@ -533,9 +522,7 @@ def sbc_statistics(self, if these values are close to 0.5, dap is like the prior distribution. """ check_stats = check_sbc( - ranks, - thetas, - dap_samples, + ranks, thetas, dap_samples, num_posterior_samples=num_posterior_samples ) return check_stats @@ -590,7 +577,6 @@ def plot_1d_ranks( if plot: plt.show() - def plot_cdf_1d_ranks( self, ranks, @@ -631,11 +617,7 @@ def plot_cdf_1d_ranks( plt.show() def calculate_coverage_fraction( - self, - posterior, - thetas, - ys, - percentile_list, + self, posterior, thetas, ys, percentile_list, samples_per_inference=1_000 ): """ @@ -650,8 +632,7 @@ def calculate_coverage_fraction( count_array = [] # make this for loop into a progress bar: for i in tqdm( - range(len(ys)), - desc="Sampling from the posterior for each obs", + range(len(ys)), desc="Sampling from the posterior for each obs", unit="obs" ): # for i in range(len(ys)): @@ -684,11 +665,9 @@ def calculate_coverage_fraction( # find the percentile for the posterior for this observation # this is n_params dimensional # the units are in parameter space - confidence_l = np.percentile(samples.cpu(), - percentile_l, + confidence_l = np.percentile(samples.cpu(), percentile_l, axis=0) - confidence_u = np.percentile(samples.cpu(), - percentile_u, + confidence_u = np.percentile(samples.cpu(), percentile_u, axis=0) # this is asking if the true parameter value # is contained between the @@ -757,8 +736,8 @@ def plot_coverage_fraction( ) ax.plot( - [0, 0.5, 1], [0, 0.5, 1], - "k--", lw=3, zorder=1000, label="Reference Line" + [0, 0.5, 1], [0, 0.5, 1], "k--", lw=3, zorder=1000, + label="Reference Line" ) ax.set_xlim([-0.05, 1.05]) ax.set_ylim([-0.05, 1.05]) diff --git a/src/scripts/io.py b/src/scripts/io.py index 9cb70f4..135842a 100644 --- a/src/scripts/io.py +++ b/src/scripts/io.py @@ -3,6 +3,7 @@ import numpy as np import torch + class ModelLoader: def save_model_pkl(self, path, model_name, posterior): """ @@ -43,10 +44,7 @@ def predict(input, model): class DataLoader: - def save_data_pkl(self, - data_name, - data, - path='../saveddata/'): + def save_data_pkl(self, data_name, data, path="../saveddata/"): """ Save and load the pkl'ed training/test set @@ -58,9 +56,7 @@ def save_data_pkl(self, with open(file_name, "wb") as file: pickle.dump(data, file) - def load_data_pkl(self, - data_name, - path='../saveddata/'): + def load_data_pkl(self, data_name, path="../saveddata/"): """ Load the pkl'ed saved posterior model @@ -73,10 +69,7 @@ def load_data_pkl(self, data = pickle.load(file) return data - def save_data_h5(self, - data_name, - data, - path='../saveddata/'): + def save_data_h5(self, data_name, data, path="../saveddata/"): """ Save data to an h5 file. @@ -92,7 +85,7 @@ def save_data_h5(self, for key, value in data_arrays.items(): file.create_dataset(key, data=value) - def load_data_h5(self, data_name, path='../saveddata/'): + def load_data_h5(self, data_name, path="../saveddata/"): """ Load data from an h5 file. @@ -105,4 +98,4 @@ def load_data_h5(self, data_name, path='../saveddata/'): with h5py.File(file_name, "r") as file: for key in file.keys(): loaded_data[key] = torch.Tensor(file[key][...]) - return loaded_data \ No newline at end of file + return loaded_data diff --git a/src/scripts/plot.py b/src/scripts/plot.py index 4dae205..7f9b3f9 100644 --- a/src/scripts/plot.py +++ b/src/scripts/plot.py @@ -5,7 +5,7 @@ # plotting style things: import matplotlib import matplotlib.pyplot as plt -from cycler import cycler +# from cycler import cycler from typing import List, Union @@ -21,11 +21,10 @@ def mackelab_corner_plot( labels_list=None, limit_list=None, truth_list=None, - truth_color='red', + truth_color="red", plot=False, save=True, - path='plots/', - + path="plots/", ): """ Uses existing pairplot from mackelab analysis @@ -48,8 +47,7 @@ def mackelab_corner_plot( truths=truth_list, figsize=(5, 5), ) - axes[0, 1].plot([truth_list[1]], [truth_list[0]], - marker="o", + axes[0, 1].plot([truth_list[1]], [truth_list[0]], marker="o", color=truth_color) axes[0, 0].axvline(x=truth_list[0], color=truth_color) axes[1, 1].axvline(x=truth_list[1], color=truth_color) @@ -63,12 +61,14 @@ def getdist_corner_plot( self, posterior_samples: Union[List[np.ndarray], np.ndarray], labels_list: List[str] = None, - limit_list: List[List[float]] = None, # Each inner list contains [lower_limit, upper_limit] + limit_list: List[ + List[float] + ] = None, # Each inner list contains [lower_limit, upper_limit] truth_list: List[float] = None, - truth_color: str = 'orange', + truth_color: str = "orange", plot: bool = False, save: bool = True, - path: str = 'plots/', + path: str = "plots/", ): """ Uses existing getdist @@ -87,10 +87,12 @@ def getdist_corner_plot( # Handle the case where 'posterior_samples' is a list of samples # You may want to customize this part based on your requirements samples_list = [ - MCSamples(samples=samps, - names=labels_list, - labels=labels_list, - ranges=limit_list) + MCSamples( + samples=samps, + names=labels_list, + labels=labels_list, + ranges=limit_list, + ) for samps in posterior_samples ] @@ -101,7 +103,12 @@ def getdist_corner_plot( g.triangle_plot(samples_list, filled=True) else: # Assume 'posterior_samples' is a 2D numpy array or similar - samples = MCSamples(samples=posterior_samples, names=labels_list, labels=labels_list, ranges=limit_list) + samples = MCSamples( + samples=posterior_samples, + names=labels_list, + labels=labels_list, + ranges=limit_list, + ) # Create a getdist Plotter g = plots.get_subplot_plotter() @@ -118,22 +125,22 @@ def getdist_corner_plot( # which is on the diagnoal g.subplots[i, j].axvline(x=truth_list[i], color=truth_color) - + try: # plot as a point for the posteriors - g.subplots[int(1 + i), int(0 + j)].scatter(truth_list[0+i], - truth_list[1+i], - color=truth_color) + g.subplots[int(1 + i), int(0 + j)].scatter( + truth_list[0 + i], truth_list[1 + i], + color=truth_color + ) except IndexError: continue - + # Save or show the plot if save: plt.savefig(path + "getdist_cornerplot.pdf") if plot: plt.show() - def improved_corner_plot(self, posterior): """