Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bones #43

Merged
merged 9 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,787 changes: 103 additions & 1,684 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ license = "MIT"

[tool.poetry.dependencies]
python = "^3.9"
jupyter = "^1.0.0"
sbi = "^0.22.0"
flake8 = "^7.0.0"
black = "^24.1.1"
pyarrow = "^15.0.0"
getdist = "^1.4.7"
h5py = "^3.10.0"
numpy = "^1.26.4"
matplotlib = "^3.8.3"

[tool.poetry.dev-dependencies]
pytest = "^7.3.2"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.2"
pytest-cov = "^4.1.0"
flake8 = "^7.0.0"
pytest = "^7.3.2"
black = "^24.3.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
68 changes: 68 additions & 0 deletions src/client/DeepDiagnostics
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3

import yaml
from argparse import ArgumentParser

from src.utils.config import Config
from src.utils.defaults import Defaults
from src.data import DataModules
from src.models import ModelModules
from src.metrics import Metrics
from src.plots import Plots

def parser():
parser = ArgumentParser()
parser.add_argument("--config", '-c', default=None)

# Model
parser.add_argument("--model_path", '-m', default=None)
parser.add_argument("--model_engine", '-e', default=Defaults['model']['model_engine'], choices=ModelModules.keys())

# Data
parser.add_argument("--data_path", '-d', default=None)
parser.add_argument("--data_engine", '-g', default=Defaults['data']['data_engine'], choices=DataModules.keys())

# Common
parser.add_argument("--out_dir", default=Defaults['common']['out_dir'])

# List of metrics (cannot supply specific kwargs)
parser.add_argument("--metrics", nargs='+', default=Defaults['metrics'].keys(), choices=Metrics.keys())

# List of plots
parser.add_argument("--plots", nargs='+', default=Defaults['plots'].keys(), choices=Plots.keys())


args = parser.parse_args()
if args.config is not None:
config = Config(args.config)

else:
tmp_config = './DeepDiagonistics/temp/temp_config.yml' # TODO Un-hardcode
input_yaml = vars(args) # TODO Sections
yaml.dump(input_yaml, open(tmp_config, "w"))
config = Config(tmp_config)

return config


if __name__ == "__main__":
config = parser()

model_path = config.get_item("model", "model_path")
model_engine = config.get_item("model", "model_engine")
model = ModelModules[model_engine](model_path)

data_path = config.get_item("data", "data_path")
data_engine = config.get_item("data", "data_engine")
data = DataModules[data_engine](data_path)

out_dir = config.get_item("common", "out_dir")

metrics = config.get_section("metrics")
plots = config.get_section("plots")

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(**plot_args)
Empty file added src/client/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

from typing import TypeVar

from src.data.data import Data
from src.data.h5_data import H5Data
from src.data.pickle_data import PickleData

data = TypeVar("data", Data)

DataModules = {
"H5Data": H5Data,
"PickleData": PickleData
}
8 changes: 8 additions & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import TypeVar

from src.metrics.metric import Metric

metric = TypeVar("metric", Metric)
Metrics = {

}
26 changes: 21 additions & 5 deletions src/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Any
from src.data.data import Data
from src.models.model import Model
from typing import Any, Optional
import json

from src.data import data
from src.models import model

class Metric:
def __init__(self, model:Model, data:Data) -> None:
def __init__(self, model:model, data:data, out_dir:Optional[str]=None) -> None:
self.model = model
self.data = data

self.out_dir = out_dir
self.output = None

def _collect_data_params():
raise NotImplementedError

Expand All @@ -16,9 +21,20 @@ def _run_model_inference():
def calculate(self):
raise NotImplementedError

def _finish(self):
assert self.output is not None, "Calculation has not been completed, have you run Metric.calculate?"

if self.out_dir is not None:
with open(self.out_dir) as f:
data = json.load(f)
data.update(self.output)
json.dump(data, f, ensure_ascii=True)
f.close()

def __call__(self, *args: Any, **kwds: Any) -> Any:
self._collect_data_params()
self._run_model_inference()
return self.calculate()
self.calculate()
self._finish()


9 changes: 9 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import TypeVar

from src.models.sbi_model import SBIModel
from src.models.model import Model

model = TypeVar("model", Model)
ModelModules = {
"SBIModel": SBIModel
}
4 changes: 2 additions & 2 deletions src/models/sbi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle

from src.models.model import Model
from src.data.data import Data
from src.data import data

class SBIModel(Model):
def __init__(self, model_path:str):
Expand All @@ -16,7 +16,7 @@ def _load(self, path:str) -> None:
posterior = pickle.load(file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@beckynevin is the posterior the only object we'll do any diagnostics on? If we wanted to expand to include the likelihood (e.g., SNLE) in a beta version, do you know if that would require significant structural changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would using the likelihood require? Would it require running the MCMC chain to get to the posterior? Or are you talking about evaluating the likelihood itself? Both would require significant structural changes; the former would require the package to handle MCMCing while the later I think would require redesigning the metrics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if people wanted to stop at the Likelihood and not go to the Posterior.

I don't think we've talked about metrics/diagnostics for that, but I wasn't sure if you had any schemes going in that direction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not advocating for adding more tools to get a posterior from the likelihood.

self.posterior = posterior

def sample_posterior(self, n_samples:int, data:Data) -> "": # TODO typing
def sample_posterior(self, n_samples:int, data:data) -> "": # TODO typing
return self.posterior.sample((n_samples,), x=data.y_true)

def predict_posterior(self, data):
Expand Down
7 changes: 7 additions & 0 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import TypeVar
from src.plots.plot import Display

display = TypeVar("display", Display)
Plots = {

}
26 changes: 15 additions & 11 deletions src/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import matplotlib.pyplot as plt
from matplotlib import rcParams

from src.data.data import Data
from src.data import data
from src.models import model

from src.utils.config import get_item, get_section

class Display:
def __init__(self, data:Data, save:bool, show:bool, out_path:Optional[str]):
def __init__(self, model:model, data:data, save:bool, show:bool, out_path:Optional[str]):
self.save = save
self.show = show
self.out_path = out_path.rstrip("/")
Expand All @@ -17,6 +19,7 @@ def __init__(self, data:Data, save:bool, show:bool, out_path:Optional[str]):
if not os.path.exists(os.path.dirname(out_path)):
os.makedirs(os.path.dirname(out_path))

self.model = model
self._data_setup(data)
self._common_settings()
self._plot_settings()
Expand All @@ -33,21 +36,22 @@ def _plot_settings(self):
# TODO Pull fom a config for specific plots
raise NotImplementedError

def _plot(self, *args, **kwrgs):
def _plot(self, **kwrgs):
# Make the plot object with plt.
raise NotImplementedError

def _common_settings(self):
# TODO Pull from a common config
rcParams["axes.spines.right"] = False
rcParams["axes.spines.top"] = False
plot_common = get_section("plot_common", raise_exception=False)
rcParams["axes.spines.right"] = bool(plot_common['axis_spines'])
rcParams["axes.spines.top"] = bool(plot_common['axis_spines'])

# Style
self.colorway = ""
tight_layout = ""

self.colorway = plot_common["colorway"]
tight_layout = bool(plot_common['tight_layout'])
if tight_layout:
plt.tight_layout()
plot_style = plot_common['plot_style']
plt.style.use(plot_style)

def _finish(self):
assert os.path.splitext(self.plot_name)[-1] != '', f"plot name, {self.plot_name}, is malformed. Please supply a name with an extension."
Expand All @@ -56,6 +60,6 @@ def _finish(self):
if self.plot:
plt.show()

def __call__(self) -> None:
self._plot()
def __call__(self, **kwargs) -> None:
self._plot(**kwargs)
self._finish()
35 changes: 24 additions & 11 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,56 @@
import os
import yaml

from src.utils.defaults import Defaults

def get_item(section, item, raise_exception=True, default=None):
return Config().get_item(section, item, raise_exception, default)
def get_item(section, item, raise_exception=True):
return Config().get_item(section, item, raise_exception)

def get_section(section):
return Config().get_section(section)
def get_section(section, raise_exception=True):
return Config().get_section(section, raise_exception)


class Config:
ENV_VAR = "DeepDiagnostics_Config"
ENV_VAR_PATH = "DeepDiagnostics_Config"
def __init__(self, config_path:Optional[str]) -> None:
if config_path is not None:
# Add it to the env vars in case we need to get it later.
os.environ[self.ENV_VAR] = config_path
os.environ[self.ENV_VAR_PATH] = config_path
else:
# Get it from the env vars
try:
config_path = os.environ[self.ENV_VAR]
config_path = os.environ[self.ENV_VAR_PATH]
except KeyError:
assert False, "Cannot load config from enviroment. Hint: Have you set the config path by pasing a str path to Config?"

self.config = self._read_config(config_path)
self._validate_config()

def _validate_config(self):
# Validate common
# TODO
pass

def _read_config(self, path):
assert os.path.exist(path), f"Config path at {path} does not exist."
with open(path, 'r') as f:
config = yaml.safe_load(f)
return config

def get_item(self, section, item, raise_exception=True, default=None):
def get_item(self, section, item, raise_exception=True):
try:
return self.config[section][item]
except KeyError as e:
if raise_exception:
raise KeyError(e)
else:
return default
return Defaults[section][item]

def get_section(self, section):
return self.config[section]
def get_section(self, section, raise_exception=True):
try:
return self.config[section]
except KeyError as e:
if raise_exception:
raise KeyError(e)
else:
return Defaults[section]
26 changes: 26 additions & 0 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Defaults = {
"common":{
"out_dir":"./DeepDiagonistics/results/",
},
"model": {
"model_engine": "SBIModel"
},
"data":{
"data_engine": "H5Data"
},
"plot_common": {
"axis_spines": False,
"tight_layout": True,
"colorway": "virdids",
"plot_style": "fast"
},
"plots":{
"type_of_plot":{"specific_kwargs"}
},
"metric_common": {

},
"metrics":{
"type_of_metrics":{"specific_kwargs"}
}
}
12 changes: 12 additions & 0 deletions src/utils/variable_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class Varaibles:
def __init__(self, session="") -> None:
self.session = {}

def add_variable():
pass

def read_variable():
pass