diff --git a/pyproject.toml b/pyproject.toml index 095de07..eea7716 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ numpy = "^1.26.4" matplotlib = "^3.8.3" tarp = "^0.1.1" deprecation = "^2.1.0" -scipy = "1.12.0" [tool.poetry.group.dev.dependencies] diff --git a/src/client/client.py b/src/client/client.py index 0ffe68f..063b1d3 100644 --- a/src/client/client.py +++ b/src/client/client.py @@ -97,15 +97,9 @@ def main(): plots = config.get_section("plots", raise_exception=False) for metrics_name, metrics_args in metrics.items(): - try: - Metrics[metrics_name](model, data, **metrics_args)() - except (NotImplementedError, RuntimeError) as error: - print(f"WARNING - skipping metric {metrics_name} due to error: {error}") + Metrics[metrics_name](model, data, **metrics_args)() for plot_name, plot_args in plots.items(): - try: - Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( - **plot_args - ) - except (NotImplementedError, RuntimeError) as error: - print(f"WARNING - skipping plot {plot_name} due to error: {error}") + Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( + **plot_args + ) diff --git a/src/data/data.py b/src/data/data.py index a022d6f..129877d 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -125,4 +125,4 @@ def load_prior(self, prior, prior_kwargs): return lambda size: choices[prior](**prior_kwargs, size=size) except KeyError as e: - raise RuntimeError(f"Data missing a prior specification - {e}") \ No newline at end of file + raise RuntimeError(f"Data missing a prior specification - {e}") diff --git a/src/data/h5_data.py b/src/data/h5_data.py index c10b4a5..80ddac0 100644 --- a/src/data/h5_data.py +++ b/src/data/h5_data.py @@ -10,8 +10,7 @@ class H5Data(Data): def __init__(self, path: str, simulator: Callable): super().__init__(path, simulator) - self.theta_true = self.get_theta_true() - + def _load(self, path): assert path.split(".")[-1] == "h5", "File extension must be h5" loaded_data = {} diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 450669f..6d58c90 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -1,7 +1,4 @@ from metrics.all_sbc import AllSBC from metrics.coverage_fraction import CoverageFraction -from metrics.local_two_sample import LocalTwoSampleTest - -_all = [CoverageFraction, AllSBC, LocalTwoSampleTest] -Metrics = {m.__name__: m for m in _all} +Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC} diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py deleted file mode 100644 index e078670..0000000 --- a/src/metrics/local_two_sample.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import Any, Optional, Union -import numpy as np - -from sklearn.model_selection import KFold -from sklearn.neural_network import MLPClassifier -from sklearn.utils import shuffle - -from metrics.metric import Metric -from utils.config import get_item - -class LocalTwoSampleTest(Metric): - def __init__(self, model: Any, data: Any, out_dir: str | None = None, num_simulations: Optional[int] = None) -> None: - super().__init__(model, data, out_dir) - self.num_simulations = num_simulations if num_simulations is not None else get_item( - "metrics_common", "number_simulations", raise_exception=False - ) - def _collect_data_params(self): - - # P is the prior and x_P is generated via the simulator from the parameters P. - self.p = self.data.sample_prior(self.num_simulations) - self.q = np.zeros_like(self.p) - - self.outcome_given_p = np.zeros((self.num_simulations, self.data.simulator.generate_context().shape[-1])) - self.outcome_given_q = np.zeros_like(self.outcome_given_p) - self.evaluation_context = np.zeros_like(self.outcome_given_p) - - for index, p in enumerate(self.p): - context = self.data.simulator.generate_context() - self.outcome_given_p[index] = self.data.simulator.simulate(p, context) - # Q is the approximate posterior amortized in x - q = self.model.sample_posterior(1, context).ravel() - self.q[index] = q - self.outcome_given_q[index] = self.data.simulator.simulate(q, context) - - self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(self.num_simulations)]) - - def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}): - classifier_map = { - "MLP":MLPClassifier - } - try: - classifier = classifier_map[classifier](**classifier_kwargs) - except KeyError: - raise NotImplementedError( - f"{classifier} not implemented, choose from {list(classifier_map.keys())}.") - - joint_P_x = np.concatenate([p, x_p], axis=1) - joint_Q_x = np.concatenate([q, x_q], axis=1) - - features = np.concatenate([joint_P_x, joint_Q_x], axis=0) - labels = np.concatenate( - [np.array([0] * len(joint_P_x)), np.array([1] * len(joint_Q_x))] - ).ravel() - - # shuffle features and labels - features, labels = shuffle(features, labels) - - # train the classifier - classifier.fit(X=features, y=labels) - return classifier - - def _eval_model(self, P, evaluation_sample, classifier): - evaluation = np.concatenate([P, evaluation_sample], axis=1) - probability = classifier.predict_proba(evaluation)[:, 0] - return probability - - def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classifier_kwargs=None): - model_probabilities = [] - for model, model_args in zip(classifier, classifier_kwargs): - if cross_evaluate: - model_probabilities.append(self._cross_eval_score(p, q, x_p, x_q, model, model_args)) - else: - trained_model = self.train_linear_classifier(p, q, x_p, x_q, model, model_args) - model_probabilities.append(self._eval_model(P=p, classifier=trained_model)) - - return np.mean(model_probabilities, axis=0) - - def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cross_folds=5): - kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # Getting the shape - cv_splits = kf.split(p) - # train classifiers over cv-folds - probabilities = [] - self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])) - - kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) - cv_splits = kf.split(p) - for cross_trial, (train_index, val_index) in enumerate(cv_splits): - # get train split - p_train, x_p_train = p[train_index,:], x_p[train_index,:] - q_train, x_q_train = q[train_index,:], x_q[train_index,:] - trained_nth_classifier = self.train_linear_classifier(p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs) - p_evaluate = p[val_index] - for index, p_validation in enumerate(p_evaluate): - self.evaluation_data[cross_trial][index] = self.data.simulator.simulate( - p_validation, self.evaluation_context[val_index][index] - ) - probabilities.append(self._eval_model(p_evaluate, self.evaluation_data[cross_trial], trained_nth_classifier)) - return probabilities - - def permute_data(self, P, Q): - """Permute the concatenated data [P,Q] to create null-hyp samples. - - Args: - P (torch.Tensor): data of shape (n_samples, dim) - Q (torch.Tensor): data of shape (n_samples, dim) - """ - n_samples = P.shape[0] - X = np.concatenate([P, Q], axis=0) - X_perm = X[self.data.rng.permutation(np.arange(n_samples * 2))] - return X_perm[:n_samples], X_perm[n_samples:] - - def calculate( - self, - linear_classifier:Union[str, list[str]]='MLP', - cross_evaluate:bool=True, - n_null_hypothesis_trials=100, - classifier_kwargs:Union[dict, list[dict]]=None - ): - - if isinstance(linear_classifier, str): - linear_classifier = [linear_classifier] - - if classifier_kwargs is None: - classifier_kwargs = {} - if isinstance(classifier_kwargs, dict): - classifier_kwargs = [classifier_kwargs] - - probabilities = self._scores( - self.p, - self.q, - self.outcome_given_p, - self.outcome_given_q, - classifier=linear_classifier, - cross_evaluate=cross_evaluate, - classifier_kwargs=classifier_kwargs - ) - null_hypothesis_probabilities = [] - for _ in range(n_null_hypothesis_trials): - joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1) - joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1) - joint_P_x_perm, joint_Q_x_perm = self.permute_data( - joint_P_x, joint_Q_x, - ) - p_null = joint_P_x_perm[:, : self.p.shape[-1]] - p_given_x_null = joint_P_x_perm[:, self.p.shape[-1] :] - q_null = joint_Q_x_perm[:, : self.q.shape[-1]] - q_given_x_null = joint_Q_x_perm[:, self.q.shape[-1] :] - - null_result = self._scores( - p_null, - q_null, - p_given_x_null, - q_given_x_null, - classifier=linear_classifier, - cross_evaluate=cross_evaluate, - classifier_kwargs=classifier_kwargs - ) - - null_hypothesis_probabilities.append(null_result) - - null = np.array(null_hypothesis_probabilities) - self.output = { - "lc2st_probabilities": probabilities, - "lc2st_null_hypothesis_probabilities": null - } - return probabilities, null - - def __call__(self, **kwds: Any) -> Any: - try: - self._collect_data_params() - except NotImplementedError: - pass - - self.calculate(**kwds) - self._finish() \ No newline at end of file diff --git a/src/models/sbi_model.py b/src/models/sbi_model.py index a244da1..9085402 100644 --- a/src/models/sbi_model.py +++ b/src/models/sbi_model.py @@ -24,8 +24,6 @@ def sample_posterior(self, n_samples: int, y_true): # TODO typing def predict_posterior(self, data): posterior_samples = self.sample_posterior(data.y_true) posterior_predictive_samples = data.simulator( - data.get_theta_true(), posterior_samples + data.theta_true(), posterior_samples ) return posterior_predictive_samples - - \ No newline at end of file diff --git a/src/plots/__init__.py b/src/plots/__init__.py index b186bc2..f576bd7 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -1,8 +1,11 @@ from plots.cdf_ranks import CDFRanks from plots.coverage_fraction import CoverageFraction from plots.ranks import Ranks -from plots.local_two_sample import LocalTwoSampleTest from plots.tarp import TARP -_all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP] -Plots = {m.__name__: m for m in _all} \ No newline at end of file +Plots = { + CDFRanks.__name__: CDFRanks, + CoverageFraction.__name__: CoverageFraction, + Ranks.__name__: Ranks, + TARP.__name__: TARP, +} diff --git a/src/plots/local_two_sample.py b/src/plots/local_two_sample.py deleted file mode 100644 index 0923763..0000000 --- a/src/plots/local_two_sample.py +++ /dev/null @@ -1,212 +0,0 @@ -from typing import Optional, Sequence, Union -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.colors import Normalize -from matplotlib.patches import Rectangle - -from plots.plot import Display -from metrics.local_two_sample import LocalTwoSampleTest as l2st -from utils.config import get_item -from utils.plotting_utils import get_hex_colors - -class LocalTwoSampleTest(Display): - - # https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 - - def __init__(self, - model, - data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - figure_size: Optional[Sequence] = None, - num_simulations: Optional[int] = None, - colorway: Optional[str]=None): - super().__init__(model, data, save, show, out_dir) - self.percentiles = percentiles if percentiles is not None else get_item("metrics_common", item='percentiles', raise_exception=False) - - self.param_names = parameter_names if parameter_names is not None else get_item("plots_common", item="parameter_labels", raise_exception=False) - self.param_colors = parameter_colors if parameter_colors is not None else get_item("plots_common", item="parameter_colors", raise_exception=False) - self.figure_size = figure_size if figure_size is not None else get_item("plots_common", item="figure_size", raise_exception=False) - - colorway = colorway if colorway is not None else get_item( - "plots_common", "default_colorway", raise_exception=False - ) - self.region_colors = get_hex_colors(n_colors=len(self.percentiles), colorway=colorway) - - num_simulations = num_simulations if num_simulations is not None else get_item( - "metrics_common", "number_simulations", raise_exception=False - ) - self.l2st = l2st(model, data, out_dir, num_simulations) - - def _plot_name(self): - return "local_C2ST.png" - - def _make_pairplot_values(self, random_samples): - pp_vals = np.array([np.mean(random_samples <= alpha) for alpha in self.cdf_alphas]) - return pp_vals - - def lc2st_pairplot(self, subplot, confidence_region_alpha=0.2): - - null_cdf = self._make_pairplot_values([0.5] * len(self.probability)) - subplot.plot( - self.cdf_alphas, null_cdf, "--", color="black", label="Theoretical Null CDF" - ) - - null_hypothesis_pairplot = np.zeros((len(self.cdf_alphas), *null_cdf.shape)) - - for t in range(len(self.null_hypothesis_probability)): - null_hypothesis_pairplot[t] = self._make_pairplot_values(self.null_hypothesis_probability[t]) - - - for percentile, color in zip(self.percentiles, self.region_colors): - low_null = np.quantile(null_hypothesis_pairplot, percentile/100, axis=1) - up_null = np.quantile(null_hypothesis_pairplot, (100-percentile)/100, axis=1) - - subplot.fill_between( - self.cdf_alphas, - low_null, - up_null, - color=color, - alpha=confidence_region_alpha, - label=f"{percentile}% Conf. region", - ) - - for prob, label, color in zip(self.probability, self.param_names, self.param_colors): - pairplot_values = self._make_pairplot_values(prob) - subplot.plot(self.cdf_alphas, pairplot_values, label=label, color=color) - - def probability_intensity(self, subplot, plot_dims, features, n_bins=20): - evaluation_data = self.l2st.evaluation_data - - if len(evaluation_data.shape) >=3: # Used the kfold option - evaluation_data = evaluation_data.reshape(( - evaluation_data.shape[0]*evaluation_data.shape[1], - evaluation_data.shape[-1])) - self.probability = self.probability.ravel() - - if plot_dims==1: - - _, bins, patches = subplot.hist( - evaluation_data[:,features], n_bins, weights=self.probability, density=True, color=self.param_colors[features]) - - eval_bins = np.select( - [evaluation_data[:,features] <= i for i in bins[1:]], list(range(n_bins)) - ) - - # get mean predicted proba for each bin - weights = np.array([self.probability[eval_bins==i].mean() for i in np.unique(eval_bins)]) #df_probas.groupby(["bins"]).mean().probas - colors = plt.get_cmap(self.colorway) - - for w, p in zip(weights, patches): - p.set_facecolor(colors(w)) # color is mean predicted proba - - else: - - _, x_edges, y_edges, patches = subplot.hist2d( - evaluation_data[:,features[0]], - evaluation_data[:,features[1]], - n_bins, - density=True, color=self.param_colors[features[0]]) - - eval_bins_dim_1 = np.select( - [evaluation_data[:,features[0]] <= i for i in x_edges[1:]], list(range(n_bins)) - ) - eval_bins_dim_2 = np.select( - [evaluation_data[:,features[1]] <= i for i in y_edges[1:]], list(range(n_bins)) - ) - - colors = plt.get_cmap(self.colorway) - - weights = np.empty((n_bins, n_bins)) - for i in range(n_bins): - for j in range(n_bins): - try: - weights[i, j] = self.probability[np.logical_and(eval_bins_dim_1==i, eval_bins_dim_2==j)].mean() - except KeyError: - pass - - for i in range(len(x_edges) - 1): - for j in range(len(y_edges) - 1): - weight = weights[i,j] - facecolor = colors(weight) - # if no sample in bin, set color to white - if weight == np.nan: - facecolor = "white" - rect = Rectangle( - (x_edges[i], y_edges[j]), - x_edges[i + 1] - x_edges[i], - y_edges[j + 1] - y_edges[j], - facecolor=facecolor, - edgecolor="none", - ) - subplot.add_patch(rect) - - - def _plot(self, - use_intensity_plot:bool=True, - n_alpha_samples:int=100, - confidence_region_alpha:float=0.2, - n_intensity_bins:int=20, - intensity_dimension:int=2, - intensity_feature_index:Union[int, Sequence[int]]=[0,1], - linear_classifier:Union[str, list[str]]='MLP', - cross_evaluate:bool=True, - n_null_hypothesis_trials=100, - classifier_kwargs:Union[dict, list[dict]]=None, - y_label="Empirical CDF", - x_label="", - title="Local Classifier 2-Sample Test" - ): - - if use_intensity_plot: - if intensity_dimension not in (1, 2): - raise NotImplementedError("LC2ST Intensity Plot only implemented in 1D and 2D") - - if intensity_dimension == 1: - try: - int(intensity_feature_index) - except TypeError: - raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply an integer value index.") - - else: - try: - assert len(intensity_feature_index) == intensity_dimension - int(intensity_feature_index[0]) - int(intensity_feature_index[1]) - except (AssertionError, TypeError): - raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply a list of 2 integer value indices.") - - self.l2st(**{ - "linear_classifier":linear_classifier, - "cross_evaluate": cross_evaluate, - "n_null_hypothesis_trials": n_null_hypothesis_trials, - "classifier_kwargs": classifier_kwargs}) - - self.probability, self.null_hypothesis_probability = self.l2st.output["lc2st_probabilities"], self.l2st.output["lc2st_null_hypothesis_probabilities"] - - # Plots to make - - # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 - # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 - - n_plots = 1 if not use_intensity_plot else 2 - figure_size = self.figure_size if n_plots==1 else (int(self.figure_size[0]*1.8),self.figure_size[1]) - fig, subplots = plt.subplots(1, n_plots, figsize=figure_size) - self.cdf_alphas = np.linspace(0, 1, n_alpha_samples) - - self.lc2st_pairplot(subplots[0] if n_plots == 2 else subplots, confidence_region_alpha=confidence_region_alpha) - if use_intensity_plot: - self.probability_intensity( - subplots[1], - intensity_dimension, - n_bins=n_intensity_bins, - features=intensity_feature_index - ) - - fig.legend() - fig.supylabel(y_label) - fig.supxlabel(x_label) - fig.suptitle(title) \ No newline at end of file diff --git a/src/plots/plot.py b/src/plots/plot.py index e3ac508..0448800 100644 --- a/src/plots/plot.py +++ b/src/plots/plot.py @@ -27,6 +27,7 @@ def __init__( self.model = model self._common_settings() + self._plot_settings() self.plot_name = self._plot_name() def _plot_name(self): @@ -76,14 +77,6 @@ def _finish(self): plt.cla() def __call__(self, **plot_args) -> None: - try: - self._data_setup() - except NotImplementedError: - pass - try: - self._plot_settings() - except NotImplementedError: - pass - + self._data_setup() self._plot(**plot_args) self._finish() diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 3073bdd..3e5a1ed 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -10,10 +10,7 @@ "data_engine": "H5Data", "prior":"normal", "prior_kwargs": None, - "simulator_kwargs": None, - "prior": "normal", - "prior_kwargs":{} - + "simulator_kwargs": None, }, "plots_common": { "axis_spines": False, @@ -29,7 +26,6 @@ "CDFRanks": {}, "Ranks": {"num_bins": None}, "CoverageFraction": {}, - "LocalTwoSampleTest":{}, "TARP": { "coverage_sigma": 3 # How many sigma to show coverage over }, @@ -43,9 +39,5 @@ "metrics": { "AllSBC": {}, "CoverageFraction": {}, - "LocalTwoSampleTest":{ - "linear_classifier":"MLP", - "classifier_kwargs":{"alpha":0, "max_iter":2500} - } }, } diff --git a/src/utils/plotting_utils.py b/src/utils/plotting_utils.py deleted file mode 100644 index dc138d6..0000000 --- a/src/utils/plotting_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -import numpy as np -import matplotlib as mpl - -def get_hex_colors(n_colors:int, colorway:str): - cmap = mpl.pyplot.get_cmap(colorway) - hex_colors = [] - arr=np.linspace(0, 1, n_colors) - for hit in arr: - hex_colors.append(mpl.colors.rgb2hex(cmap(hit))) - - return hex_colors \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 26b8af2..094fbb6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,8 @@ class MockSimulator(Simulator): - def generate_context(self, n_samples=None) -> np.ndarray: - return np.linspace(0, 100, 101) + def generate_context(self, n_samples: int) -> np.ndarray: + return np.linspace(0, 100, n_samples) def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: thetas = np.atleast_2d(theta) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py new file mode 100644 index 0000000..3ed6116 --- /dev/null +++ b/tests/test_evaluate.py @@ -0,0 +1,185 @@ +import sys +import pytest +import torch +import numpy as np +import sbi +import os + +# flake8: noqa +#sys.path.append("..") +print(sys.path) +from scripts.evaluate import Diagnose_static, Diagnose_generative +from scripts.io import ModelLoader +#from src.scripts import evaluate + + +""" +""" + + +""" +Test the evaluate module +""" + + +@pytest.fixture +def diagnose_static_instance(): + return Diagnose_static() + +@pytest.fixture +def diagnose_generative_instance(): + return Diagnose_generative() + + +@pytest.fixture +def posterior_generative_sbi_model(): + # create a temporary directory for the saved model + #dir = "savedmodels/sbi/" + #os.makedirs(dir) + + # now save the model + low_bounds = torch.tensor([0, -10]) + high_bounds = torch.tensor([10, 10]) + + prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds) + + posterior = sbi.inference.base.infer(simulator, prior, "SNPE", num_simulations=10000) + + # Provide the posterior to the tests + yield prior, posterior + + # Teardown: Remove the temporary directory and its contents + #shutil.rmtree(dataset_dir) + +@pytest.fixture +def setup_plot_dir(): + # create a temporary directory for the saved model + dir = "tests/plots/" + os.makedirs(dir) + yield dir + +def simulator(thetas): # , percent_errors): + # convert to numpy array (if tensor): + thetas = np.atleast_2d(thetas) + # Check if the input has the correct shape + if thetas.shape[1] != 2: + raise ValueError( + "Input tensor must have shape (n, 2) \ + where n is the number of parameter sets." + ) + + # Unpack the parameters + if thetas.shape[0] == 1: + # If there's only one set of parameters, extract them directly + m, b = thetas[0, 0], thetas[0, 1] + else: + # If there are multiple sets of parameters, extract them for each row + m, b = thetas[:, 0], thetas[:, 1] + x = np.linspace(0, 100, 101) + rs = np.random.RandomState() # 2147483648)# + # I'm thinking sigma could actually be a function of x + # if we want to get fancy down the road + # Generate random noise (epsilon) based + # on a normal distribution with mean 0 and standard deviation sigma + sigma = 5 + ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) + + # Initialize an empty array to store the results for each set of parameters + y = np.zeros((len(x), thetas.shape[0])) + for i in range(thetas.shape[0]): + m, b = thetas[i, 0], thetas[i, 1] + y[:, i] = m * x + b + ε[:, i] + return torch.Tensor(y.T) + + +def test_generate_sbc_samples(diagnose_generative_instance, + posterior_generative_sbi_model): + # Mock data + #low_bounds = torch.tensor([0, -10]) + #high_bounds = torch.tensor([10, 10]) + + #prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) + prior, posterior = posterior_generative_sbi_model + #inference_instance # provide a mock posterior object + simulator_test = simulator # provide a mock simulator function + num_sbc_runs = 1000 + num_posterior_samples = 1000 + + # Generate SBC samples + thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples( + prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples + ) + + # Add assertions based on the expected behavior of the method + + +def test_run_all_sbc(diagnose_generative_instance, + posterior_generative_sbi_model, + setup_plot_dir): + labels_list = ["$m$", "$b$"] + colorlist = ["#9C92A3", "#0F5257"] + + prior, posterior = posterior_generative_sbi_model + simulator_test = simulator # provide a mock simulator function + + save_path = setup_plot_dir + + diagnose_generative_instance.run_all_sbc( + prior, + posterior, + simulator_test, + labels_list, + colorlist, + num_sbc_runs=1_000, + num_posterior_samples=1_000, + samples_per_inference=1_000, + plot=False, + save=True, + path=save_path, + ) + # Check if PDF files were saved + assert os.path.exists(save_path), f"No 'plots' folder found at {save_path}" + + # List all files in the directory + files_in_directory = os.listdir(save_path) + + # Check if at least one PDF file is present + pdf_files = [file for file in files_in_directory if file.endswith(".pdf")] + assert pdf_files, "No PDF files found in the 'plots' folder" + + # We expect the pdfs to exist in the directory + expected_pdf_files = ["sbc_ranks.pdf", "sbc_ranks_cdf.pdf", "coverage.pdf"] + for expected_file in expected_pdf_files: + assert ( + expected_file in pdf_files + ), f"Expected PDF file '{expected_file}' not found" + + +""" +def test_sbc_statistics(diagnose_instance): + # Mock data + ranks = # provide mock ranks + thetas = # provide mock thetas + dap_samples = # provide mock dap_samples + num_posterior_samples = 1000 + + # Calculate SBC statistics + check_stats = diagnose_instance.sbc_statistics( + ranks, thetas, dap_samples, num_posterior_samples + ) + + # Add assertions based on the expected behavior of the method + +def test_plot_1d_ranks(diagnose_instance): + # Mock data + ranks = # provide mock ranks + num_posterior_samples = 1000 + labels_list = # provide mock labels_list + colorlist = # provide mock colorlist + + # Plot 1D ranks + diagnose_instance.plot_1d_ranks( + ranks, num_posterior_samples, labels_list, + colorlist, plot=False, save=False + ) +""" diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 127ae28..1cec089 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -6,8 +6,7 @@ from metrics import ( Metrics, CoverageFraction, - AllSBC, - LocalTwoSampleTest + AllSBC ) @pytest.fixture @@ -32,6 +31,7 @@ def test_all_defaults(metric_config, mock_model, mock_data): Ensures each metric has a default set of parameters and is included in the defaults list Ensures each test can initialize, regardless of the veracity of the output """ + Config(metric_config) for metric_name, metric_obj in Metrics.items(): assert metric_name in Defaults['metrics'] @@ -39,6 +39,7 @@ def test_all_defaults(metric_config, mock_model, mock_data): def test_coverage_fraction(metric_config, mock_model, mock_data): + Config(metric_config) coverage_fraction = CoverageFraction(mock_model, mock_data) _, coverage = coverage_fraction.calculate() assert coverage_fraction.output.all() is not None @@ -47,11 +48,7 @@ def test_coverage_fraction(metric_config, mock_model, mock_data): assert coverage.shape def test_all_sbc(metric_config, mock_model, mock_data): + Config(metric_config) all_sbc = AllSBC(mock_model, mock_data) all_sbc() - # TODO What is this supposed to be - -def test_lc2st(metric_config, mock_model, mock_data): - lc2st = LocalTwoSampleTest(mock_model, mock_data) - lc2st() - assert lc2st.output is not None \ No newline at end of file + # TODO What is this supposed to be \ No newline at end of file diff --git a/tests/test_plots.py b/tests/test_plots.py index 4006ac9..253343b 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -8,8 +8,7 @@ CDFRanks, Ranks, CoverageFraction, - TARP, - LocalTwoSampleTest + TARP ) @pytest.fixture @@ -56,10 +55,4 @@ def test_plot_coverage(plot_config, mock_model, mock_data): def test_plot_tarp(plot_config, mock_model, mock_data): plot = TARP(mock_model, mock_data, save=True, show=False) - plot(**get_item("plots", "TARP", raise_exception=False)) - assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") - -def test_plot_lc2st(plot_config, mock_model, mock_data): - plot = LocalTwoSampleTest(mock_model, mock_data, save=True, show=False) - plot(**get_item("plots", "LocalTwoSampleTest", raise_exception=False)) - assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") \ No newline at end of file + plot(**get_item("plots", "TARP", raise_exception=False)) \ No newline at end of file