Skip to content

Commit

Permalink
Merge pull request #68 from voetberg/demo_nb
Browse files Browse the repository at this point in the history
Demo notebook
  • Loading branch information
bnord authored Jun 6, 2024
2 parents 2a9817f + 7c7b158 commit 3f67b67
Show file tree
Hide file tree
Showing 23 changed files with 1,190 additions and 345 deletions.
782 changes: 782 additions & 0 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def parser():
# List of metrics (cannot supply specific kwargs)
parser.add_argument(
"--metrics",
nargs="+",
nargs="?",
default=list(Defaults["metrics"].keys()),
choices=Metrics.keys(),
)

# List of plots
parser.add_argument(
"--plots",
nargs="+",
nargs="?",
default=list(Defaults["plots"].keys()),
choices=Plots.keys(),
)
Expand Down
38 changes: 2 additions & 36 deletions src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from utils.config import get_item

from utils.register import load_simulator

class Data:
def __init__(
Expand All @@ -19,44 +19,10 @@ def __init__(
get_item("common", "random_seed", raise_exception=False)
)
self.data = self._load(path)
self.simulator = self._load_simulator(simulator_name, simulator_kwargs)
self.simulator = load_simulator(simulator_name, simulator_kwargs)
self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]

def _load_simulator(self, name, simulator_kwargs):
try:
sim_location = get_item("common", "sim_location", raise_exception=False)
simulator_path = os.environ[f"{sim_location}:{name}"]
except KeyError as e:
raise RuntimeError(
f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?"
)

new_class = os.path.dirname(simulator_path)
sys.path.insert(1, new_class)

# TODO robust error checks
module_name = os.path.basename(simulator_path.rstrip(".py"))
m = importlib.import_module(module_name)

simulator = getattr(m, name)

simulator_kwargs = simulator_kwargs if simulator_kwargs is not None else get_item("data", "simulator_kwargs", raise_exception=False)
simulator_kwargs = {} if simulator_kwargs is None else simulator_kwargs
simulator_instance = simulator(**simulator_kwargs)

if not hasattr(simulator_instance, "generate_context"):
raise RuntimeError(
"Simulator improperly formed - requires a generate_context method."
)

if not hasattr(simulator_instance, "simulate"):
raise RuntimeError(
"Simulator improperly formed - requires a simulate method."
)

return simulator_instance

def _load(self, path: str):
raise NotImplementedError

Expand Down
7 changes: 6 additions & 1 deletion src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from metrics.all_sbc import AllSBC
from metrics.coverage_fraction import CoverageFraction
from metrics.local_two_sample import LocalTwoSampleTest

Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC}
Metrics = {
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LocalTwoSampleTest
}
25 changes: 14 additions & 11 deletions src/metrics/all_sbc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional, Sequence
from torch import tensor
from sbi.analysis import run_sbc, check_sbc

Expand All @@ -12,16 +12,19 @@ def __init__(
model: Any,
data: Any,
out_dir: str | None = None,
samples_per_inference=None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:
super().__init__(model, data, out_dir)

if samples_per_inference is None:
self.samples_per_inference = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
else:
self.samples_per_inference = samples_per_inference

super().__init__(model, data, out_dir,
save,
use_progress_bar,
samples_per_inference,
percentiles,
number_simulations)

def _collect_data_params(self):
self.thetas = tensor(self.data.get_theta_true())
Expand All @@ -41,7 +44,7 @@ def calculate(self):
dap_samples,
num_posterior_samples=self.samples_per_inference,
)
self.output = sbc_stats
self.output = {key: value.numpy().tolist() for key, value in sbc_stats.items()}
return sbc_stats

def __call__(self, **kwds: Any) -> Any:
Expand Down
42 changes: 16 additions & 26 deletions src/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from torch import tensor
from tqdm import tqdm
from typing import Any
from typing import Any, Optional, Sequence

from metrics.metric import Metric
from utils.config import get_item
Expand All @@ -14,32 +14,22 @@ def __init__(
self,
model: Any,
data: Any,
out_dir: str | None = None,
samples_per_inference=None,
percentiles=None,
progress_bar: bool = None,
out_dir: Optional[str] = None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:
super().__init__(model, data, out_dir)

super().__init__(model, data, out_dir,
save,
use_progress_bar,
samples_per_inference,
percentiles,
number_simulations)
self._collect_data_params()

self.samples_per_inference = (
samples_per_inference
if samples_per_inference is not None
else get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
)
self.percentiles = (
percentiles
if percentiles is not None
else get_item("metrics_common", "percentiles", raise_exception=False)
)
self.progress_bar = (
progress_bar
if progress_bar is not None
else get_item("metrics_common", "use_progress_bar", raise_exception=False)
)

def _collect_data_params(self):
self.thetas = self.data.get_theta_true()
self.context = self.data.true_context()
Expand All @@ -54,11 +44,11 @@ def calculate(self):
)
count_array = []
iterator = enumerate(self.context)
if self.progress_bar:
if self.use_progress_bar:
iterator = tqdm(
iterator,
desc="Sampling from the posterior for each observation",
unit="observation",
unit=" observation",
)
for y_sample_index, y_sample in iterator:
samples = self._run_model_inference(self.samples_per_inference, y_sample)
Expand Down
44 changes: 32 additions & 12 deletions src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,58 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union
import numpy as np

from sklearn.model_selection import KFold
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle

from metrics.metric import Metric
from utils.config import get_item

class LocalTwoSampleTest(Metric):
def __init__(self, model: Any, data: Any, out_dir: str | None = None, num_simulations: Optional[int] = None) -> None:
super().__init__(model, data, out_dir)
self.num_simulations = num_simulations if num_simulations is not None else get_item(
"metrics_common", "number_simulations", raise_exception=False
def __init__(
self,
model: Any,
data: Any,
out_dir: Optional[str] = None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:

super().__init__(
model,
data,
out_dir,
save,
use_progress_bar,
samples_per_inference,
percentiles,
number_simulations
)

def _collect_data_params(self):

# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.num_simulations)
self.p = self.data.sample_prior(self.number_simulations)
self.q = np.zeros_like(self.p)

self.outcome_given_p = np.zeros((self.num_simulations, self.data.simulator.generate_context().shape[-1]))
context_size = self.data.true_context().shape[-1]
self.outcome_given_p = np.zeros(
(self.number_simulations, context_size)
)
self.outcome_given_q = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros_like(self.outcome_given_p)

for index, p in enumerate(self.p):
context = self.data.simulator.generate_context()
context = self.data.simulator.generate_context(context_size)
self.outcome_given_p[index] = self.data.simulator.simulate(p, context)
# Q is the approximate posterior amortized in x
q = self.model.sample_posterior(1, context).ravel()
self.q[index] = q
self.outcome_given_q[index] = self.data.simulator.simulate(q, context)

self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(self.num_simulations)])
self.evaluation_context = np.array([self.data.simulator.generate_context(context_size) for _ in range(self.number_simulations)])

def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}):
classifier_map = {
Expand Down Expand Up @@ -162,8 +182,8 @@ def calculate(

null = np.array(null_hypothesis_probabilities)
self.output = {
"lc2st_probabilities": probabilities,
"lc2st_null_hypothesis_probabilities": null
"lc2st_probabilities": probabilities.tolist(),
"lc2st_null_hypothesis_probabilities": null.tolist()
}
return probabilities, null

Expand Down
48 changes: 39 additions & 9 deletions src/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from typing import Any, Optional
from typing import Any, Optional, Sequence
import json
import os

from data import data
from models import model
from utils.config import get_item


class Metric:
def __init__(self, model: model, data: data, out_dir: Optional[str] = None) -> None:
def __init__(
self,
model: model,
data: data,
out_dir: Optional[str] = None,
save: bool=True,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
percentiles: Optional[Sequence[int]] = None,
number_simulations: Optional[int] = None,
) -> None:
self.model = model
self.data = data

self.out_dir = out_dir
if save:
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)

self.output = None

self.use_progress_bar = use_progress_bar if use_progress_bar is not None else get_item("metrics_common", "use_progress_bar", raise_exception=False)
self.samples_per_inference = samples_per_inference if samples_per_inference is not None else get_item("metrics_common", "samples_per_inference", raise_exception=False)
self.percentiles = percentiles if percentiles is not None else get_item("metrics_common", "percentiles", raise_exception=False)
self.number_simulations = number_simulations if number_simulations is not None else get_item("metrics_common", "number_simulations", raise_exception=False)

def _collect_data_params():
raise NotImplementedError

Expand All @@ -29,17 +47,29 @@ def _finish(self):
), "Calculation has not been completed, have you run Metric.calculate?"

if self.out_dir is not None:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)

with open(f"{self.out_dir.rstrip('/')}/diagnostic_metrics.json", "w+") as f:
try:
data = json.load(f)
except json.decoder.JSONDecodeError:
data = {}

with open(self.out_dir) as f:
data = json.load(f)
data.update(self.output)
json.dump(data, f, ensure_ascii=True)
f.close()

def __call__(self, **kwds: Any) -> Any:
self._collect_data_params()
self._run_model_inference()

try:
self._collect_data_params()
except NotImplementedError:
pass
try:
self._run_model_inference()
except NotImplementedError:
pass

self.calculate(kwds)
self._finish()
2 changes: 2 additions & 0 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks
from plots.tarp import TARP
from plots.local_two_sample import LocalTwoSampleTest

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
"LC2ST": LocalTwoSampleTest
}
Loading

0 comments on commit 3f67b67

Please sign in to comment.