Skip to content

Commit

Permalink
Merge pull request #43 from voetberg/bones
Browse files Browse the repository at this point in the history
Bones
  • Loading branch information
voetberg authored Apr 1, 2024
2 parents 07cb472 + 25fc729 commit 87a2e42
Show file tree
Hide file tree
Showing 52 changed files with 720 additions and 2,106 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@

## General pieces to ignore during git pushing, so you can use git add . without adding everything.

*__pycache__
*sbi-logs

# Distribution / packaging
.Python
build/
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
[submodule ".showyourwork"]
path = .showyourwork
url = https://github.com/showyourwork/showyourwork
1 change: 0 additions & 1 deletion .showyourwork
Submodule .showyourwork deleted from bee25e
10 changes: 0 additions & 10 deletions environment.yml

This file was deleted.

Binary file added plots/generative/mackelab_pairplot.pdf
Binary file not shown.
Binary file added plots/static/getdist_cornerplot.pdf
Binary file not shown.
Binary file added plots/static/mackelab_pairplot.pdf
Binary file not shown.
1,787 changes: 103 additions & 1,684 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,25 @@ description = "a package for diagnosing posterior quality from inference methods
authors = ["Becky Nevin <[email protected]>"]
license = "MIT"

[tool.poetry.scripts]
diagnose = "client.client:main"

[tool.poetry.dependencies]
python = "^3.9"
jupyter = "^1.0.0"
sbi = "^0.22.0"
flake8 = "^7.0.0"
black = "^24.1.1"
pyarrow = "^15.0.0"
getdist = "^1.4.7"
h5py = "^3.10.0"
numpy = "^1.26.4"
matplotlib = "^3.8.3"

[tool.poetry.dev-dependencies]
pytest = "^7.3.2"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.2"
pytest-cov = "^4.1.0"
flake8 = "^7.0.0"
pytest = "^7.3.2"
black = "^24.3.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Binary file added resources/saveddata/data_train.h5
Binary file not shown.
Binary file added resources/saveddata/data_validation.h5
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added results/plots/generative/mackelab_pairplot.pdf
Binary file not shown.
Binary file added results/plots/static/getdist_cornerplot.pdf
Binary file not shown.
Binary file added results/plots/static/mackelab_pairplot.pdf
Binary file not shown.
Binary file added saveddata/data_train.h5
Binary file not shown.
Binary file added saveddata/data_validation.h5
Binary file not shown.
Binary file added savedmodels/sbi/sbi_linear_from_data.pkl
Binary file not shown.
Binary file added savedmodels/sbi/sbi_linear_generative.pkl
Binary file not shown.
7 changes: 0 additions & 7 deletions showyourwork.yml

This file was deleted.

File renamed without changes.
79 changes: 79 additions & 0 deletions src/client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import yaml
from argparse import ArgumentParser

from utils.config import Config
from utils.defaults import Defaults
from data import DataModules
from models import ModelModules
from metrics import Metrics
from plots import Plots


def parser():
parser = ArgumentParser()
parser.add_argument("--config", '-c', default=None)

# Model
parser.add_argument("--model_path", '-m', default=None)
parser.add_argument("--model_engine", '-e', default=Defaults['model']['model_engine'], choices=ModelModules.keys())

# Data
parser.add_argument("--data_path", '-d', default=None)
parser.add_argument("--data_engine", '-g', default=Defaults['data']['data_engine'], choices=DataModules.keys())
parser.add_argument("--simulator", '-s', default=None)
# Common
parser.add_argument("--out_dir", default=Defaults['common']['out_dir'])

# List of metrics (cannot supply specific kwargs)
parser.add_argument("--metrics", nargs='+', default=list(Defaults['metrics'].keys()), choices=Metrics.keys())

# List of plots
parser.add_argument("--plots", nargs='+', default=list(Defaults['plots'].keys()), choices=Plots.keys())


args = parser.parse_args()
if args.config is not None:
config = Config(args.config)

else:
temp_config = Defaults['common']['temp_config']
os.makedirs(os.path.dirname(temp_config), exist_ok=True)

input_yaml = {
"common": {"out_dir":args.out_dir},
"model": {"model_path":args.model_path, "model_engine":args.model_engine},
"data": {"data_path":args.data_path, "data_engine":args.data_engine, "simulator": args.simulator},
"plots": {key: {} for key in args.plots},
"metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
config = Config(temp_config)

return config

def main():
config = parser()

model_path = config.get_item("model", "model_path")
model_engine = config.get_item("model", "model_engine", raise_exception=False)
model = ModelModules[model_engine](model_path)

data_path = config.get_item("data", "data_path")
data_engine = config.get_item("data", "data_engine", raise_exception=False)
simulator_name = config.get_item("data", "simulator")
data = DataModules[data_engine](data_path, simulator_name)

out_dir = config.get_item("common", "out_dir", raise_exception=False)
if not os.path.exists(os.path.dirname(out_dir)):
os.makedirs(os.path.dirname(out_dir))

metrics = config.get_section("metrics", raise_exception=False)
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(**plot_args)
9 changes: 9 additions & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

from data.h5_data import H5Data
from data.pickle_data import PickleData
from data.simulator import Simulator

DataModules = {
"H5Data": H5Data,
"PickleData": PickleData
}
50 changes: 50 additions & 0 deletions src/data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import importlib.util
import sys
import os

from utils.config import get_item
from utils.defaults import Defaults

class Data:
def __init__(self, path:str, simulator_name: str):
self.data = self._load(path)
self.simulator = self._load_simulator(simulator_name)

def _load_simulator(self, name):
try:
simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"]
except KeyError as e:
raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registed your simulation with utils.register_simulator?")

new_class = os.path.dirname(simulator_path)
sys.path.insert(1, new_class)

# TODO robust error checks
module_name = os.path.basename(simulator_path.rstrip('.py'))
m = importlib.import_module(module_name)

simulator = getattr(m, name)
return simulator()

def _load(self, path:str):
raise NotImplementedError

def x_true(self):
# From Data
raise NotImplementedError

def y_true(self):
return self.simulator(self.theta_true(), self.x_true())

def prior(self):
# From Data
raise NotImplementedError

def theta_true(self):
return get_item("data", "theta_true")

def sigma_true(self):
return get_item("data", "sigma_true")

def save(self, data, path:str):
raise NotImplementedError
30 changes: 30 additions & 0 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Callable
import h5py
import numpy as np
import torch
import os

from data.data import Data

class H5Data(Data):
def __init__(self, path:str, simulator:Callable):
super().__init__(path, simulator)

def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
with h5py.File(path, "r") as file:
for key in file.keys():
loaded_data[key] = torch.Tensor(file[key][...])
return loaded_data

def save(self, data:dict[str, Any], path: str): # Todo typing for data dict
assert path.split(".")[-1] == "h5", "File extension must be h5"
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))

data_arrays = {key: np.asarray(value) for key, value in data.items()}
with h5py.File(path, "w") as file:
# Save each array as a dataset in the HDF5 file
for key, value in data_arrays.items():
file.create_dataset(key, data=value)
18 changes: 18 additions & 0 deletions src/data/pickle_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pickle
from typing import Any, Callable
from data.data import Data

class PickleData(Data):
def __init__(self, path:str, simulator:Callable):
super().__init__(path, simulator)

def _load(self, path:str):
assert path.split('.')[-1] == 'pkl', "File extension must be 'pkl'"
with open(path, "rb") as file:
data = pickle.load(file)
return data

def save(self, data:Any, path:str):
assert path.split('.')[-1] == 'pkl', "File extension must be 'pkl'"
with open(path, "wb") as file:
pickle.dump(data, file)
10 changes: 10 additions & 0 deletions src/data/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

from typing import Any

class Simulator:
def __init__(self) -> None:
pass

def __call__(self, *args: Any, **kwds: Any) -> Any:
pass

3 changes: 3 additions & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Metrics = {

}
44 changes: 44 additions & 0 deletions src/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Optional
import json
import os

from data import data
from models import model

class Metric:
def __init__(self, model:model, data:data, out_dir:Optional[str]=None) -> None:
self.model = model
self.data = data

self.out_dir = out_dir
self.output = None

def _collect_data_params():
raise NotImplementedError

def _run_model_inference():
raise NotImplementedError

def calculate(self):
raise NotImplementedError

def _finish(self):
assert self.output is not None, "Calculation has not been completed, have you run Metric.calculate?"

if self.out_dir is not None:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))

with open(self.out_dir) as f:
data = json.load(f)
data.update(self.output)
json.dump(data, f, ensure_ascii=True)
f.close()

def __call__(self, *args: Any, **kwds: Any) -> Any:
self._collect_data_params()
self._run_model_inference()
self.calculate()
self._finish()


5 changes: 5 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from models.sbi_model import SBIModel

ModelModules = {
"SBIModel": SBIModel
}
15 changes: 15 additions & 0 deletions src/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class Model:
def __init__(self, model_path:str) -> None:
self.model = self._load(model_path)

def _load(self, path:str) -> None:
return NotImplementedError

def sample_posterior(self):
return NotImplementedError

def sample_simulation(self, data):
raise NotImplementedError

def predict_posterior(self, data):
raise NotImplementedError
24 changes: 24 additions & 0 deletions src/models/sbi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import pickle

from models.model import Model

class SBIModel(Model):
def __init__(self, model_path:str):
super().__init__(model_path)

def _load(self, path:str) -> None:
assert os.path.exists(path), f"Cannot find model file at location {path}"
assert path.split(".")[-1] == 'pkl', "File extension must be 'pkl'"

with open(path, "rb") as file:
posterior = pickle.load(file)
self.posterior = posterior

def sample_posterior(self, n_samples:int, data): # TODO typing
return self.posterior.sample((n_samples,), x=data.y_true)

def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
posterior_predictive_samples = data.simulator(data.theta_true(), posterior_samples)
return posterior_predictive_samples
7 changes: 7 additions & 0 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import TypeVar
from src.plots.plot import Display

display = TypeVar("display", Display)
Plots = {

}
Loading

0 comments on commit 87a2e42

Please sign in to comment.