Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bones #43

Merged
merged 9 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@

## General pieces to ignore during git pushing, so you can use git add . without adding everything.

*__pycache__
*sbi-logs

# Distribution / packaging
.Python
build/
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
[submodule ".showyourwork"]
path = .showyourwork
url = https://github.com/showyourwork/showyourwork
1 change: 0 additions & 1 deletion .showyourwork
Submodule .showyourwork deleted from bee25e
10 changes: 0 additions & 10 deletions environment.yml

This file was deleted.

Binary file added plots/generative/mackelab_pairplot.pdf
Binary file not shown.
Binary file added plots/static/getdist_cornerplot.pdf
Binary file not shown.
Binary file added plots/static/mackelab_pairplot.pdf
Binary file not shown.
1,787 changes: 103 additions & 1,684 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,25 @@ description = "a package for diagnosing posterior quality from inference methods
authors = ["Becky Nevin <[email protected]>"]
license = "MIT"

[tool.poetry.scripts]
diagnose = "client.client:main"

[tool.poetry.dependencies]
python = "^3.9"
jupyter = "^1.0.0"
sbi = "^0.22.0"
flake8 = "^7.0.0"
black = "^24.1.1"
pyarrow = "^15.0.0"
getdist = "^1.4.7"
h5py = "^3.10.0"
numpy = "^1.26.4"
matplotlib = "^3.8.3"

[tool.poetry.dev-dependencies]
pytest = "^7.3.2"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.2"
pytest-cov = "^4.1.0"
flake8 = "^7.0.0"
pytest = "^7.3.2"
black = "^24.3.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Binary file added resources/saveddata/data_train.h5
Binary file not shown.
Binary file added resources/saveddata/data_validation.h5
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added results/plots/generative/mackelab_pairplot.pdf
Binary file not shown.
Binary file added results/plots/static/getdist_cornerplot.pdf
Binary file not shown.
Binary file added results/plots/static/mackelab_pairplot.pdf
Binary file not shown.
Binary file added saveddata/data_train.h5
Binary file not shown.
Binary file added saveddata/data_validation.h5
Binary file not shown.
Binary file added savedmodels/sbi/sbi_linear_from_data.pkl
Binary file not shown.
Binary file added savedmodels/sbi/sbi_linear_generative.pkl
Binary file not shown.
7 changes: 0 additions & 7 deletions showyourwork.yml

This file was deleted.

File renamed without changes.
79 changes: 79 additions & 0 deletions src/client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import yaml
from argparse import ArgumentParser

from utils.config import Config
from utils.defaults import Defaults
from data import DataModules
from models import ModelModules
from metrics import Metrics
from plots import Plots


def parser():
parser = ArgumentParser()
parser.add_argument("--config", '-c', default=None)

# Model
parser.add_argument("--model_path", '-m', default=None)
parser.add_argument("--model_engine", '-e', default=Defaults['model']['model_engine'], choices=ModelModules.keys())

# Data
parser.add_argument("--data_path", '-d', default=None)
parser.add_argument("--data_engine", '-g', default=Defaults['data']['data_engine'], choices=DataModules.keys())
parser.add_argument("--simulator", '-s', default=None)
# Common
parser.add_argument("--out_dir", default=Defaults['common']['out_dir'])

# List of metrics (cannot supply specific kwargs)
parser.add_argument("--metrics", nargs='+', default=list(Defaults['metrics'].keys()), choices=Metrics.keys())

# List of plots
parser.add_argument("--plots", nargs='+', default=list(Defaults['plots'].keys()), choices=Plots.keys())


args = parser.parse_args()
if args.config is not None:
config = Config(args.config)

else:
temp_config = Defaults['common']['temp_config']
os.makedirs(os.path.dirname(temp_config), exist_ok=True)

input_yaml = {
"common": {"out_dir":args.out_dir},
"model": {"model_path":args.model_path, "model_engine":args.model_engine},
"data": {"data_path":args.data_path, "data_engine":args.data_engine, "simulator": args.simulator},
"plots": {key: {} for key in args.plots},
"metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
config = Config(temp_config)

return config

def main():
config = parser()

model_path = config.get_item("model", "model_path")
model_engine = config.get_item("model", "model_engine", raise_exception=False)
model = ModelModules[model_engine](model_path)

data_path = config.get_item("data", "data_path")
data_engine = config.get_item("data", "data_engine", raise_exception=False)
simulator_name = config.get_item("data", "simulator")
data = DataModules[data_engine](data_path, simulator_name)

out_dir = config.get_item("common", "out_dir", raise_exception=False)
if not os.path.exists(os.path.dirname(out_dir)):
os.makedirs(os.path.dirname(out_dir))

metrics = config.get_section("metrics", raise_exception=False)
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(**plot_args)
9 changes: 9 additions & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

from data.h5_data import H5Data
from data.pickle_data import PickleData
from data.simulator import Simulator

DataModules = {
"H5Data": H5Data,
"PickleData": PickleData
}
50 changes: 50 additions & 0 deletions src/data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import importlib.util
import sys
import os

from utils.config import get_item
from utils.defaults import Defaults

class Data:
def __init__(self, path:str, simulator_name: str):
self.data = self._load(path)
self.simulator = self._load_simulator(simulator_name)

def _load_simulator(self, name):
try:
simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"]
except KeyError as e:
raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registed your simulation with utils.register_simulator?")

new_class = os.path.dirname(simulator_path)
sys.path.insert(1, new_class)

# TODO robust error checks
module_name = os.path.basename(simulator_path.rstrip('.py'))
m = importlib.import_module(module_name)

simulator = getattr(m, name)
return simulator()

def _load(self, path:str):
raise NotImplementedError

def x_true(self):
# From Data
raise NotImplementedError

def y_true(self):
return self.simulator(self.theta_true(), self.x_true())

def prior(self):
# From Data
raise NotImplementedError

def theta_true(self):
return get_item("data", "theta_true")

def sigma_true(self):
return get_item("data", "sigma_true")

def save(self, data, path:str):
raise NotImplementedError
30 changes: 30 additions & 0 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Callable
import h5py
import numpy as np
import torch
import os

from data.data import Data

class H5Data(Data):
def __init__(self, path:str, simulator:Callable):
super().__init__(path, simulator)

def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
with h5py.File(path, "r") as file:
for key in file.keys():
loaded_data[key] = torch.Tensor(file[key][...])
return loaded_data

def save(self, data:dict[str, Any], path: str): # Todo typing for data dict
assert path.split(".")[-1] == "h5", "File extension must be h5"
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))

data_arrays = {key: np.asarray(value) for key, value in data.items()}
with h5py.File(path, "w") as file:
# Save each array as a dataset in the HDF5 file
for key, value in data_arrays.items():
file.create_dataset(key, data=value)
18 changes: 18 additions & 0 deletions src/data/pickle_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pickle
from typing import Any, Callable
from data.data import Data

class PickleData(Data):
def __init__(self, path:str, simulator:Callable):
super().__init__(path, simulator)

def _load(self, path:str):
assert path.split('.')[-1] == 'pkl', "File extension must be 'pkl'"
with open(path, "rb") as file:
data = pickle.load(file)
return data

def save(self, data:Any, path:str):
assert path.split('.')[-1] == 'pkl', "File extension must be 'pkl'"
with open(path, "wb") as file:
pickle.dump(data, file)
10 changes: 10 additions & 0 deletions src/data/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

from typing import Any

class Simulator:
def __init__(self) -> None:
pass

def __call__(self, *args: Any, **kwds: Any) -> Any:
pass

3 changes: 3 additions & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Metrics = {

}
44 changes: 44 additions & 0 deletions src/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Optional
import json
import os

from data import data
from models import model

class Metric:
def __init__(self, model:model, data:data, out_dir:Optional[str]=None) -> None:
self.model = model
self.data = data

self.out_dir = out_dir
self.output = None

def _collect_data_params():
raise NotImplementedError

def _run_model_inference():
raise NotImplementedError

def calculate(self):
raise NotImplementedError

def _finish(self):
assert self.output is not None, "Calculation has not been completed, have you run Metric.calculate?"

if self.out_dir is not None:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))

with open(self.out_dir) as f:
data = json.load(f)
data.update(self.output)
json.dump(data, f, ensure_ascii=True)
f.close()

def __call__(self, *args: Any, **kwds: Any) -> Any:
self._collect_data_params()
self._run_model_inference()
self.calculate()
self._finish()


5 changes: 5 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from models.sbi_model import SBIModel

ModelModules = {
"SBIModel": SBIModel
}
15 changes: 15 additions & 0 deletions src/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class Model:
def __init__(self, model_path:str) -> None:
self.model = self._load(model_path)

def _load(self, path:str) -> None:
return NotImplementedError

def sample_posterior(self):
return NotImplementedError

def sample_simulation(self, data):
raise NotImplementedError

def predict_posterior(self, data):
raise NotImplementedError
24 changes: 24 additions & 0 deletions src/models/sbi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import pickle

from models.model import Model

class SBIModel(Model):
def __init__(self, model_path:str):
super().__init__(model_path)

def _load(self, path:str) -> None:
assert os.path.exists(path), f"Cannot find model file at location {path}"
assert path.split(".")[-1] == 'pkl', "File extension must be 'pkl'"

with open(path, "rb") as file:
posterior = pickle.load(file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@beckynevin is the posterior the only object we'll do any diagnostics on? If we wanted to expand to include the likelihood (e.g., SNLE) in a beta version, do you know if that would require significant structural changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would using the likelihood require? Would it require running the MCMC chain to get to the posterior? Or are you talking about evaluating the likelihood itself? Both would require significant structural changes; the former would require the package to handle MCMCing while the later I think would require redesigning the metrics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if people wanted to stop at the Likelihood and not go to the Posterior.

I don't think we've talked about metrics/diagnostics for that, but I wasn't sure if you had any schemes going in that direction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not advocating for adding more tools to get a posterior from the likelihood.

self.posterior = posterior

def sample_posterior(self, n_samples:int, data): # TODO typing
return self.posterior.sample((n_samples,), x=data.y_true)

def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
posterior_predictive_samples = data.simulator(data.theta_true(), posterior_samples)
return posterior_predictive_samples
7 changes: 7 additions & 0 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import TypeVar
from src.plots.plot import Display

display = TypeVar("display", Display)
Plots = {

}
Loading