Skip to content

Commit

Permalink
cleaned up naming errors
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed May 17, 2024
1 parent 0172248 commit 685c2ef
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 51 deletions.
20 changes: 11 additions & 9 deletions src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(
path: str,
simulator_name: str,
simulator_kwargs: dict = None,
prior: str = "data",
prior: str = None,
prior_kwargs: dict = None,
):
self.rng = np.random.default_rng(
Expand All @@ -21,9 +21,9 @@ def __init__(
self.data = self._load(path)
self.simulator = self._load_simulator(simulator_name, simulator_kwargs)
self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.theta_true().shape[1]
self.n_dims = self.get_theta_true().shape[1]

def _load_simulator(self, name):
def _load_simulator(self, name, simulator_kwargs):
try:
sim_location = get_item("common", "sim_location", raise_exception=False)
simulator_path = os.environ[f"{sim_location}:{name}"]
Expand All @@ -41,7 +41,7 @@ def _load_simulator(self, name):

simulator = getattr(m, name)

simulator_kwargs = get_item("data", "simulator_kwargs", raise_exception=False)
simulator_kwargs = simulator_kwargs if simulator_kwargs is not None else get_item("data", "simulator_kwargs", raise_exception=False)
simulator_kwargs = {} if simulator_kwargs is None else simulator_kwargs
simulator_instance = simulator(**simulator_kwargs)

Expand All @@ -65,7 +65,7 @@ def true_context(self):
raise NotImplementedError

def true_simulator_outcome(self):
return self.simulator(self.theta_true(), self.true_context())
return self.simulator(self.get_theta_true(), self.true_context())

def sample_prior(self, n_samples: int):
return self.prior_dist(size=(n_samples, self.n_dims))
Expand All @@ -83,17 +83,17 @@ def simulator_outcome(self, theta, condition_context=None, n_samples=None):
def simulated_context(self, n_samples):
return self.simulator.generate_context(n_samples)

def theta_true(self):
def get_theta_true(self):
if hasattr(self, "theta_true"):
return self.theta_true
else:
return get_item("data", "theta_true")
return get_item("data", "theta_true", raise_exception=True)

def sigma_true(self):
def get_sigma_true(self):
if hasattr(self, "sigma_true"):
return self.sigma_true
else:
return get_item("data", "sigma_true")
return get_item("data", "sigma_true", raise_exception=True)

def save(self, data, path: str):
raise NotImplementedError
Expand All @@ -102,6 +102,8 @@ def read_prior(self):
raise NotImplementedError

def load_prior(self, prior, prior_kwargs):
if prior is None:
prior = get_item("data", "prior", raise_exception=False)
try:
prior = self.read_prior()
except NotImplementedError:
Expand Down
7 changes: 2 additions & 5 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,14 @@ def true_context(self):
# From Data
return self.data["xs"] # TODO change name

def true_simulator_outcome(self):
return self.simulator(self.theta_true(), self.x_true())

def prior(self):
# From Data
raise NotImplementedError

def theta_true(self):
def get_theta_true(self):
return self.data["thetas"]

def sigma_true(self):
def get_sigma_true(self):
try:
return super().sigma_true()
except (AssertionError, KeyError):
Expand Down
6 changes: 3 additions & 3 deletions src/metrics/all_sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def __init__(
self.samples_per_inference = samples_per_inference

def _collect_data_params(self):
self.thetas = tensor(self.data.theta_true())
self.y_true = tensor(self.data.x_true())
self.thetas = tensor(self.data.get_theta_true())
self.context = tensor(self.data.true_context())

def calculate(self):
ranks, dap_samples = run_sbc(
self.thetas,
self.y_true,
self.context,
self.model.posterior,
num_posterior_samples=self.samples_per_inference,
)
Expand Down
10 changes: 5 additions & 5 deletions src/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ def __init__(
)

def _collect_data_params(self):
self.thetas = self.data.theta_true()
self.y_true = self.data.x_true()
self.thetas = self.data.get_theta_true()
self.context = self.data.true_context()

def _run_model_inference(self, samples_per_inference, y_inference):
samples = self.model.sample_posterior(samples_per_inference, y_inference)
return samples

def calculate(self):
all_samples = np.empty(
(len(self.y_true), self.samples_per_inference, np.shape(self.thetas)[1])
(len(self.context), self.samples_per_inference, np.shape(self.thetas)[1])
)
count_array = []
iterator = enumerate(self.y_true)
iterator = enumerate(self.context)
if self.progress_bar:
iterator = tqdm(
iterator,
Expand Down Expand Up @@ -95,7 +95,7 @@ def calculate(self):

count_sum_array = np.sum(count_array, axis=0)
frac_lens_within_vol = np.array(count_sum_array)
coverage = frac_lens_within_vol / len(self.y_true)
coverage = frac_lens_within_vol / len(self.context)

self.output = coverage

Expand Down
9 changes: 3 additions & 6 deletions src/plots/cdf_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@ def _plot_name(self):
return "cdf_ranks.png"

def _data_setup(self):
thetas = tensor(self.data.theta_true())
y_true = tensor(self.data.x_true())
self.num_samples = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
thetas = tensor(self.data.get_theta_true())
context = tensor(self.data.true_context())

ranks, _ = run_sbc(
thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples
thetas, context, self.model.posterior, num_posterior_samples=self.num_samples
)
self.ranks = ranks

Expand Down
8 changes: 7 additions & 1 deletion src/plots/coverage_fraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
parameter_labels=None,
figure_size=None,
line_styles=None,
parameter_colors=None
):
super().__init__(model, data, save, show, out_dir)

Expand All @@ -26,6 +27,11 @@ def __init__(
if parameter_labels is not None
else get_item("plots_common", "parameter_labels", raise_exception=False)
)
self.colors = (
parameter_colors
if parameter_colors is not None
else get_item("plots_common", "parameter_colors", raise_exception=False)
)
self.n_parameters = len(self.labels)
self.figure_size = (
figure_size
Expand Down Expand Up @@ -65,7 +71,7 @@ def _plot(
):
n_steps = self.coverage_fractions.shape[0]
percentile_array = np.linspace(0, 1, n_steps)
color_cycler = iter(plt.cycler("color", cm.get_cmap(self.colorway).colors))
color_cycler = iter(plt.cycler("color", self.colors))
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))

# Plotting
Expand Down
6 changes: 3 additions & 3 deletions src/plots/ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def _plot_name(self):
return "ranks.png"

def _data_setup(self):
thetas = tensor(self.data.theta_true())
y_true = tensor(self.data.x_true())
thetas = tensor(self.data.get_theta_true())
context = tensor(self.data.true_context())
self.num_samples = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)

ranks, _ = run_sbc(
thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples
thetas, context, self.model.posterior, num_posterior_samples=self.num_samples
)
self.ranks = ranks

Expand Down
15 changes: 7 additions & 8 deletions src/plots/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,25 @@ def _plot_name(self):
return "tarp.png"

def _data_setup(self):
self.rng = np.random.default_rng(
get_item("common", "random_seed", raise_exception=False)
)
self.theta_true = self.data.get_theta_true()

samples_per_inference = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
num_simulations = get_item(
"metrics_common", "number_simulations", raise_exception=False
)

n_dims = self.data.theta_true().shape[1]
n_dims = self.theta_true.shape[1]
self.posterior_samples = np.zeros(
(num_simulations, samples_per_inference, n_dims)
)
self.thetas = np.zeros((num_simulations, n_dims))
for n in range(num_simulations):
sample_index = self.rng.integers(0, len(self.data.theta_true()))
sample_index = self.data.rng.integers(0, len(self.theta_true))

theta = self.data.theta_true()[sample_index, :]
x = self.data.x_true()[sample_index, :]
theta = self.theta_true[sample_index, :]
x = self.data.true_context()[sample_index, :]
self.posterior_samples[n] = self.model.sample_posterior(
samples_per_inference, x
)
Expand All @@ -56,7 +55,7 @@ def _get_hex_sigma_colors(self, n_colors, colorway=None):
"plots_common", "default_colorway", raise_exception=False
)

cmap = plt.cm.get_cmap(colorway)
cmap = plt.get_cmap(colorway)
hex_colors = []
arr = np.linspace(0, 1, n_colors)
for hit in arr:
Expand Down
7 changes: 6 additions & 1 deletion src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
"random_seed": 42,
},
"model": {"model_engine": "SBIModel"},
"data": {"data_engine": "H5Data"},
"data": {
"data_engine": "H5Data",
"prior":"normal",
"prior_kwargs": None,
"simulator_kwargs": None,
},
"plots_common": {
"axis_spines": False,
"tight_layout": True,
Expand Down
18 changes: 8 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@


class MockSimulator(Simulator):
def __init__(self):
pass

def __call__(self, thetas, samples):
thetas = np.atleast_2d(thetas)
def generate_context(self, n_samples: int) -> np.ndarray:
return np.linspace(0, 100, n_samples)
def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray:
thetas = np.atleast_2d(theta)
if thetas.shape[1] != 2:
raise ValueError("Input tensor must have shape (n, 2) where n is the number of parameter sets.")

Expand All @@ -23,19 +23,17 @@ def __call__(self, thetas, samples):
else:
# If there are multiple sets of parameters, extract them for each row
m, b = thetas[:, 0], thetas[:, 1]
x = np.linspace(0, 100, samples)
rs = np.random.RandomState()
sigma = 1
epsilon = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0]))
epsilon = rs.normal(loc=0, scale=sigma, size=(len(context_samples), thetas.shape[0]))

# Initialize an empty array to store the results for each set of parameters
y = np.zeros((len(x), thetas.shape[0]))
y = np.zeros((len(context_samples), thetas.shape[0]))
for i in range(thetas.shape[0]):
m, b = thetas[i, 0], thetas[i, 1]
y[:, i] = m * x + b + epsilon[:, i]
y[:, i] = m * context_samples + b + epsilon[:, i]
return y.T


@pytest.fixture
def model_path():
return "resources/savedmodels/sbi/sbi_linear_from_data.pkl"
Expand Down
1 change: 1 addition & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def metric_config(config_factory):
metrics_settings={"use_progress_bar":False, "samples_per_inference":10, "percentiles":[95]}
config = config_factory(metrics_settings=metrics_settings)
Config(config)
return config

def test_all_metrics_catalogued():
Expand Down

0 comments on commit 685c2ef

Please sign in to comment.