Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo notebook #68

Merged
merged 6 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
782 changes: 782 additions & 0 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def parser():
# List of metrics (cannot supply specific kwargs)
parser.add_argument(
"--metrics",
nargs="+",
nargs="?",
default=list(Defaults["metrics"].keys()),
choices=Metrics.keys(),
)

# List of plots
parser.add_argument(
"--plots",
nargs="+",
nargs="?",
default=list(Defaults["plots"].keys()),
choices=Plots.keys(),
)
Expand Down
38 changes: 2 additions & 36 deletions src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from utils.config import get_item

from utils.register import load_simulator

class Data:
def __init__(
Expand All @@ -19,44 +19,10 @@ def __init__(
get_item("common", "random_seed", raise_exception=False)
)
self.data = self._load(path)
self.simulator = self._load_simulator(simulator_name, simulator_kwargs)
self.simulator = load_simulator(simulator_name, simulator_kwargs)
self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]

def _load_simulator(self, name, simulator_kwargs):
try:
sim_location = get_item("common", "sim_location", raise_exception=False)
simulator_path = os.environ[f"{sim_location}:{name}"]
except KeyError as e:
raise RuntimeError(
f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?"
)

new_class = os.path.dirname(simulator_path)
sys.path.insert(1, new_class)

# TODO robust error checks
module_name = os.path.basename(simulator_path.rstrip(".py"))
m = importlib.import_module(module_name)

simulator = getattr(m, name)

simulator_kwargs = simulator_kwargs if simulator_kwargs is not None else get_item("data", "simulator_kwargs", raise_exception=False)
simulator_kwargs = {} if simulator_kwargs is None else simulator_kwargs
simulator_instance = simulator(**simulator_kwargs)

if not hasattr(simulator_instance, "generate_context"):
raise RuntimeError(
"Simulator improperly formed - requires a generate_context method."
)

if not hasattr(simulator_instance, "simulate"):
raise RuntimeError(
"Simulator improperly formed - requires a simulate method."
)

return simulator_instance

def _load(self, path: str):
raise NotImplementedError

Expand Down
7 changes: 6 additions & 1 deletion src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from metrics.all_sbc import AllSBC
from metrics.coverage_fraction import CoverageFraction
from metrics.local_two_sample import LocalTwoSampleTest

Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC}
Metrics = {
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LocalTwoSampleTest
}
25 changes: 14 additions & 11 deletions src/metrics/all_sbc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional, Sequence
from torch import tensor
from sbi.analysis import run_sbc, check_sbc

Expand All @@ -12,16 +12,19 @@ def __init__(
model: Any,
data: Any,
out_dir: str | None = None,
samples_per_inference=None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:
super().__init__(model, data, out_dir)

if samples_per_inference is None:
self.samples_per_inference = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
else:
self.samples_per_inference = samples_per_inference

super().__init__(model, data, out_dir,
save,
use_progress_bar,
samples_per_inference,
percentiles,
number_simulations)

def _collect_data_params(self):
self.thetas = tensor(self.data.get_theta_true())
Expand All @@ -41,7 +44,7 @@ def calculate(self):
dap_samples,
num_posterior_samples=self.samples_per_inference,
)
self.output = sbc_stats
self.output = {key: value.numpy().tolist() for key, value in sbc_stats.items()}
return sbc_stats

def __call__(self, **kwds: Any) -> Any:
Expand Down
42 changes: 16 additions & 26 deletions src/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from torch import tensor
from tqdm import tqdm
from typing import Any
from typing import Any, Optional, Sequence

from metrics.metric import Metric
from utils.config import get_item
Expand All @@ -14,32 +14,22 @@ def __init__(
self,
model: Any,
data: Any,
out_dir: str | None = None,
samples_per_inference=None,
percentiles=None,
progress_bar: bool = None,
out_dir: Optional[str] = None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:
super().__init__(model, data, out_dir)

super().__init__(model, data, out_dir,
save,
use_progress_bar,
samples_per_inference,
percentiles,
number_simulations)
self._collect_data_params()

self.samples_per_inference = (
samples_per_inference
if samples_per_inference is not None
else get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
)
self.percentiles = (
percentiles
if percentiles is not None
else get_item("metrics_common", "percentiles", raise_exception=False)
)
self.progress_bar = (
progress_bar
if progress_bar is not None
else get_item("metrics_common", "use_progress_bar", raise_exception=False)
)

def _collect_data_params(self):
self.thetas = self.data.get_theta_true()
self.context = self.data.true_context()
Expand All @@ -54,11 +44,11 @@ def calculate(self):
)
count_array = []
iterator = enumerate(self.context)
if self.progress_bar:
if self.use_progress_bar:
iterator = tqdm(
iterator,
desc="Sampling from the posterior for each observation",
unit="observation",
unit=" observation",
)
for y_sample_index, y_sample in iterator:
samples = self._run_model_inference(self.samples_per_inference, y_sample)
Expand Down
44 changes: 32 additions & 12 deletions src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,58 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union
import numpy as np

from sklearn.model_selection import KFold
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle

from metrics.metric import Metric
from utils.config import get_item

class LocalTwoSampleTest(Metric):
def __init__(self, model: Any, data: Any, out_dir: str | None = None, num_simulations: Optional[int] = None) -> None:
super().__init__(model, data, out_dir)
self.num_simulations = num_simulations if num_simulations is not None else get_item(
"metrics_common", "number_simulations", raise_exception=False
def __init__(
self,
model: Any,
data: Any,
out_dir: Optional[str] = None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:

super().__init__(
model,
data,
out_dir,
save,
use_progress_bar,
samples_per_inference,
percentiles,
number_simulations
)

def _collect_data_params(self):

# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.num_simulations)
self.p = self.data.sample_prior(self.number_simulations)
self.q = np.zeros_like(self.p)

self.outcome_given_p = np.zeros((self.num_simulations, self.data.simulator.generate_context().shape[-1]))
context_size = self.data.true_context().shape[-1]
self.outcome_given_p = np.zeros(
(self.number_simulations, context_size)
)
self.outcome_given_q = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros_like(self.outcome_given_p)

for index, p in enumerate(self.p):
context = self.data.simulator.generate_context()
context = self.data.simulator.generate_context(context_size)
self.outcome_given_p[index] = self.data.simulator.simulate(p, context)
# Q is the approximate posterior amortized in x
q = self.model.sample_posterior(1, context).ravel()
self.q[index] = q
self.outcome_given_q[index] = self.data.simulator.simulate(q, context)

self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(self.num_simulations)])
self.evaluation_context = np.array([self.data.simulator.generate_context(context_size) for _ in range(self.number_simulations)])

def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}):
classifier_map = {
Expand Down Expand Up @@ -162,8 +182,8 @@ def calculate(

null = np.array(null_hypothesis_probabilities)
self.output = {
"lc2st_probabilities": probabilities,
"lc2st_null_hypothesis_probabilities": null
"lc2st_probabilities": probabilities.tolist(),
"lc2st_null_hypothesis_probabilities": null.tolist()
}
return probabilities, null

Expand Down
48 changes: 39 additions & 9 deletions src/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from typing import Any, Optional
from typing import Any, Optional, Sequence
import json
import os

from data import data
from models import model
from utils.config import get_item


class Metric:
def __init__(self, model: model, data: data, out_dir: Optional[str] = None) -> None:
def __init__(
self,
model: model,
data: data,
out_dir: Optional[str] = None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:
self.model = model
self.data = data

self.out_dir = out_dir
if save:
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)

self.output = None

self.use_progress_bar = use_progress_bar if use_progress_bar is not None else get_item("metrics_common", "use_progress_bar", raise_exception=False)
self.samples_per_inference = samples_per_inference if samples_per_inference is not None else get_item("metrics_common", "samples_per_inference", raise_exception=False)
self.percentiles = percentiles if percentiles is not None else get_item("metrics_common", "percentiles", raise_exception=False)
self.number_simulations = number_simulations if number_simulations is not None else get_item("metrics_common", "number_simulations", raise_exception=False)

def _collect_data_params():
raise NotImplementedError

Expand All @@ -29,17 +47,29 @@ def _finish(self):
), "Calculation has not been completed, have you run Metric.calculate?"

if self.out_dir is not None:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)

with open(f"{self.out_dir.rstrip('/')}/diagnostic_metrics.json", "w+") as f:
try:
data = json.load(f)
except json.decoder.JSONDecodeError:
data = {}

with open(self.out_dir) as f:
data = json.load(f)
data.update(self.output)
json.dump(data, f, ensure_ascii=True)
f.close()

def __call__(self, **kwds: Any) -> Any:
self._collect_data_params()
self._run_model_inference()

try:
self._collect_data_params()
except NotImplementedError:
pass
try:
self._run_model_inference()
except NotImplementedError:
pass

self.calculate(kwds)
self._finish()
2 changes: 2 additions & 0 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks
from plots.tarp import TARP
from plots.local_two_sample import LocalTwoSampleTest

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
"LC2ST": LocalTwoSampleTest
}
Loading