diff --git a/src/data/data.py b/src/data/data.py index 895226c..129877d 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -12,7 +12,7 @@ def __init__( path: str, simulator_name: str, simulator_kwargs: dict = None, - prior: str = "data", + prior: str = None, prior_kwargs: dict = None, ): self.rng = np.random.default_rng( @@ -21,9 +21,9 @@ def __init__( self.data = self._load(path) self.simulator = self._load_simulator(simulator_name, simulator_kwargs) self.prior_dist = self.load_prior(prior, prior_kwargs) - self.n_dims = self.theta_true().shape[1] + self.n_dims = self.get_theta_true().shape[1] - def _load_simulator(self, name): + def _load_simulator(self, name, simulator_kwargs): try: sim_location = get_item("common", "sim_location", raise_exception=False) simulator_path = os.environ[f"{sim_location}:{name}"] @@ -41,7 +41,7 @@ def _load_simulator(self, name): simulator = getattr(m, name) - simulator_kwargs = get_item("data", "simulator_kwargs", raise_exception=False) + simulator_kwargs = simulator_kwargs if simulator_kwargs is not None else get_item("data", "simulator_kwargs", raise_exception=False) simulator_kwargs = {} if simulator_kwargs is None else simulator_kwargs simulator_instance = simulator(**simulator_kwargs) @@ -65,7 +65,7 @@ def true_context(self): raise NotImplementedError def true_simulator_outcome(self): - return self.simulator(self.theta_true(), self.true_context()) + return self.simulator(self.get_theta_true(), self.true_context()) def sample_prior(self, n_samples: int): return self.prior_dist(size=(n_samples, self.n_dims)) @@ -83,17 +83,17 @@ def simulator_outcome(self, theta, condition_context=None, n_samples=None): def simulated_context(self, n_samples): return self.simulator.generate_context(n_samples) - def theta_true(self): + def get_theta_true(self): if hasattr(self, "theta_true"): return self.theta_true else: - return get_item("data", "theta_true") + return get_item("data", "theta_true", raise_exception=True) - def sigma_true(self): + def get_sigma_true(self): if hasattr(self, "sigma_true"): return self.sigma_true else: - return get_item("data", "sigma_true") + return get_item("data", "sigma_true", raise_exception=True) def save(self, data, path: str): raise NotImplementedError @@ -102,6 +102,8 @@ def read_prior(self): raise NotImplementedError def load_prior(self, prior, prior_kwargs): + if prior is None: + prior = get_item("data", "prior", raise_exception=False) try: prior = self.read_prior() except NotImplementedError: diff --git a/src/data/h5_data.py b/src/data/h5_data.py index 5c0e9bd..80ddac0 100644 --- a/src/data/h5_data.py +++ b/src/data/h5_data.py @@ -34,17 +34,14 @@ def true_context(self): # From Data return self.data["xs"] # TODO change name - def true_simulator_outcome(self): - return self.simulator(self.theta_true(), self.x_true()) - def prior(self): # From Data raise NotImplementedError - def theta_true(self): + def get_theta_true(self): return self.data["thetas"] - def sigma_true(self): + def get_sigma_true(self): try: return super().sigma_true() except (AssertionError, KeyError): diff --git a/src/metrics/all_sbc.py b/src/metrics/all_sbc.py index 09d66a1..8193f68 100644 --- a/src/metrics/all_sbc.py +++ b/src/metrics/all_sbc.py @@ -24,13 +24,13 @@ def __init__( self.samples_per_inference = samples_per_inference def _collect_data_params(self): - self.thetas = tensor(self.data.theta_true()) - self.y_true = tensor(self.data.x_true()) + self.thetas = tensor(self.data.get_theta_true()) + self.context = tensor(self.data.true_context()) def calculate(self): ranks, dap_samples = run_sbc( self.thetas, - self.y_true, + self.context, self.model.posterior, num_posterior_samples=self.samples_per_inference, ) diff --git a/src/metrics/coverage_fraction.py b/src/metrics/coverage_fraction.py index f7d6b59..72c1df2 100644 --- a/src/metrics/coverage_fraction.py +++ b/src/metrics/coverage_fraction.py @@ -41,8 +41,8 @@ def __init__( ) def _collect_data_params(self): - self.thetas = self.data.theta_true() - self.y_true = self.data.x_true() + self.thetas = self.data.get_theta_true() + self.context = self.data.true_context() def _run_model_inference(self, samples_per_inference, y_inference): samples = self.model.sample_posterior(samples_per_inference, y_inference) @@ -50,10 +50,10 @@ def _run_model_inference(self, samples_per_inference, y_inference): def calculate(self): all_samples = np.empty( - (len(self.y_true), self.samples_per_inference, np.shape(self.thetas)[1]) + (len(self.context), self.samples_per_inference, np.shape(self.thetas)[1]) ) count_array = [] - iterator = enumerate(self.y_true) + iterator = enumerate(self.context) if self.progress_bar: iterator = tqdm( iterator, @@ -95,7 +95,7 @@ def calculate(self): count_sum_array = np.sum(count_array, axis=0) frac_lens_within_vol = np.array(count_sum_array) - coverage = frac_lens_within_vol / len(self.y_true) + coverage = frac_lens_within_vol / len(self.context) self.output = coverage diff --git a/src/plots/cdf_ranks.py b/src/plots/cdf_ranks.py index 6c0f333..62b7a20 100644 --- a/src/plots/cdf_ranks.py +++ b/src/plots/cdf_ranks.py @@ -41,14 +41,11 @@ def _plot_name(self): return "cdf_ranks.png" def _data_setup(self): - thetas = tensor(self.data.theta_true()) - y_true = tensor(self.data.x_true()) - self.num_samples = get_item( - "metrics_common", "samples_per_inference", raise_exception=False - ) + thetas = tensor(self.data.get_theta_true()) + context = tensor(self.data.true_context()) ranks, _ = run_sbc( - thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples + thetas, context, self.model.posterior, num_posterior_samples=self.num_samples ) self.ranks = ranks diff --git a/src/plots/coverage_fraction.py b/src/plots/coverage_fraction.py index 31a6f4a..bbfe293 100644 --- a/src/plots/coverage_fraction.py +++ b/src/plots/coverage_fraction.py @@ -18,6 +18,7 @@ def __init__( parameter_labels=None, figure_size=None, line_styles=None, + parameter_colors=None ): super().__init__(model, data, save, show, out_dir) @@ -26,6 +27,11 @@ def __init__( if parameter_labels is not None else get_item("plots_common", "parameter_labels", raise_exception=False) ) + self.colors = ( + parameter_colors + if parameter_colors is not None + else get_item("plots_common", "parameter_colors", raise_exception=False) + ) self.n_parameters = len(self.labels) self.figure_size = ( figure_size @@ -65,7 +71,7 @@ def _plot( ): n_steps = self.coverage_fractions.shape[0] percentile_array = np.linspace(0, 1, n_steps) - color_cycler = iter(plt.cycler("color", cm.get_cmap(self.colorway).colors)) + color_cycler = iter(plt.cycler("color", self.colors)) line_style_cycler = iter(plt.cycler("line_style", self.line_cycle)) # Plotting diff --git a/src/plots/ranks.py b/src/plots/ranks.py index ebf26aa..050dbca 100644 --- a/src/plots/ranks.py +++ b/src/plots/ranks.py @@ -13,14 +13,14 @@ def _plot_name(self): return "ranks.png" def _data_setup(self): - thetas = tensor(self.data.theta_true()) - y_true = tensor(self.data.x_true()) + thetas = tensor(self.data.get_theta_true()) + context = tensor(self.data.true_context()) self.num_samples = get_item( "metrics_common", "samples_per_inference", raise_exception=False ) ranks, _ = run_sbc( - thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples + thetas, context, self.model.posterior, num_posterior_samples=self.num_samples ) self.ranks = ranks diff --git a/src/plots/tarp.py b/src/plots/tarp.py index 653c3bd..e11c54d 100644 --- a/src/plots/tarp.py +++ b/src/plots/tarp.py @@ -18,9 +18,8 @@ def _plot_name(self): return "tarp.png" def _data_setup(self): - self.rng = np.random.default_rng( - get_item("common", "random_seed", raise_exception=False) - ) + self.theta_true = self.data.get_theta_true() + samples_per_inference = get_item( "metrics_common", "samples_per_inference", raise_exception=False ) @@ -28,16 +27,16 @@ def _data_setup(self): "metrics_common", "number_simulations", raise_exception=False ) - n_dims = self.data.theta_true().shape[1] + n_dims = self.theta_true.shape[1] self.posterior_samples = np.zeros( (num_simulations, samples_per_inference, n_dims) ) self.thetas = np.zeros((num_simulations, n_dims)) for n in range(num_simulations): - sample_index = self.rng.integers(0, len(self.data.theta_true())) + sample_index = self.data.rng.integers(0, len(self.theta_true)) - theta = self.data.theta_true()[sample_index, :] - x = self.data.x_true()[sample_index, :] + theta = self.theta_true[sample_index, :] + x = self.data.true_context()[sample_index, :] self.posterior_samples[n] = self.model.sample_posterior( samples_per_inference, x ) @@ -56,7 +55,7 @@ def _get_hex_sigma_colors(self, n_colors, colorway=None): "plots_common", "default_colorway", raise_exception=False ) - cmap = plt.cm.get_cmap(colorway) + cmap = plt.get_cmap(colorway) hex_colors = [] arr = np.linspace(0, 1, n_colors) for hit in arr: diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 3d389a2..3e5a1ed 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -6,7 +6,12 @@ "random_seed": 42, }, "model": {"model_engine": "SBIModel"}, - "data": {"data_engine": "H5Data"}, + "data": { + "data_engine": "H5Data", + "prior":"normal", + "prior_kwargs": None, + "simulator_kwargs": None, + }, "plots_common": { "axis_spines": False, "tight_layout": True, diff --git a/tests/conftest.py b/tests/conftest.py index 3be2a31..094fbb6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,11 +9,11 @@ class MockSimulator(Simulator): - def __init__(self): - pass - - def __call__(self, thetas, samples): - thetas = np.atleast_2d(thetas) + def generate_context(self, n_samples: int) -> np.ndarray: + return np.linspace(0, 100, n_samples) + + def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: + thetas = np.atleast_2d(theta) if thetas.shape[1] != 2: raise ValueError("Input tensor must have shape (n, 2) where n is the number of parameter sets.") @@ -23,19 +23,17 @@ def __call__(self, thetas, samples): else: # If there are multiple sets of parameters, extract them for each row m, b = thetas[:, 0], thetas[:, 1] - x = np.linspace(0, 100, samples) rs = np.random.RandomState() sigma = 1 - epsilon = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) + epsilon = rs.normal(loc=0, scale=sigma, size=(len(context_samples), thetas.shape[0])) # Initialize an empty array to store the results for each set of parameters - y = np.zeros((len(x), thetas.shape[0])) + y = np.zeros((len(context_samples), thetas.shape[0])) for i in range(thetas.shape[0]): m, b = thetas[i, 0], thetas[i, 1] - y[:, i] = m * x + b + epsilon[:, i] + y[:, i] = m * context_samples + b + epsilon[:, i] return y.T - @pytest.fixture def model_path(): return "resources/savedmodels/sbi/sbi_linear_from_data.pkl" diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5d39c4e..1cec089 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -13,6 +13,7 @@ def metric_config(config_factory): metrics_settings={"use_progress_bar":False, "samples_per_inference":10, "percentiles":[95]} config = config_factory(metrics_settings=metrics_settings) + Config(config) return config def test_all_metrics_catalogued():