Skip to content

Commit

Permalink
Merge pull request #62 from voetberg/variable_naming_verbosity
Browse files Browse the repository at this point in the history
Modification of simulator for extra control, extra verbose variable names
  • Loading branch information
bnord authored May 20, 2024
2 parents 399fb3d + 685c2ef commit c6076c7
Show file tree
Hide file tree
Showing 25 changed files with 689 additions and 440 deletions.
92 changes: 59 additions & 33 deletions src/client/client.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,82 @@
import os
import yaml
import yaml
from argparse import ArgumentParser

from utils.config import Config
from utils.defaults import Defaults
from data import DataModules
from models import ModelModules
from metrics import Metrics
from data import DataModules
from models import ModelModules
from metrics import Metrics
from plots import Plots


def parser():
def parser():
parser = ArgumentParser()
parser.add_argument("--config", '-c', default=None)

# Model
parser.add_argument("--model_path", '-m', default=None)
parser.add_argument("--model_engine", '-e', default=Defaults['model']['model_engine'], choices=ModelModules.keys())

# Data
parser.add_argument("--data_path", '-d', default=None)
parser.add_argument("--data_engine", '-g', default=Defaults['data']['data_engine'], choices=DataModules.keys())
parser.add_argument("--simulator", '-s', default=None)
parser.add_argument("--config", "-c", default=None)

# Model
parser.add_argument("--model_path", "-m", default=None)
parser.add_argument(
"--model_engine",
"-e",
default=Defaults["model"]["model_engine"],
choices=ModelModules.keys(),
)

# Data
parser.add_argument("--data_path", "-d", default=None)
parser.add_argument(
"--data_engine",
"-g",
default=Defaults["data"]["data_engine"],
choices=DataModules.keys(),
)
parser.add_argument("--simulator", "-s", default=None)
# Common
parser.add_argument("--out_dir", default=Defaults['common']['out_dir'])
parser.add_argument("--out_dir", default=Defaults["common"]["out_dir"])

# List of metrics (cannot supply specific kwargs)
parser.add_argument("--metrics", nargs='+', default=list(Defaults['metrics'].keys()), choices=Metrics.keys())

# List of plots
parser.add_argument("--plots", nargs='+', default=list(Defaults['plots'].keys()), choices=Plots.keys())

parser.add_argument(
"--metrics",
nargs="+",
default=list(Defaults["metrics"].keys()),
choices=Metrics.keys(),
)

# List of plots
parser.add_argument(
"--plots",
nargs="+",
default=list(Defaults["plots"].keys()),
choices=Plots.keys(),
)

args = parser.parse_args()
if args.config is not None:
if args.config is not None:
config = Config(args.config)

else:
temp_config = Defaults['common']['temp_config']
else:
temp_config = Defaults["common"]["temp_config"]
os.makedirs(os.path.dirname(temp_config), exist_ok=True)

input_yaml = {
"common": {"out_dir":args.out_dir},
"model": {"model_path":args.model_path, "model_engine":args.model_engine},
"data": {"data_path":args.data_path, "data_engine":args.data_engine, "simulator": args.simulator},
"plots": {key: {} for key in args.plots},
"metrics": {key: {} for key in args.metrics},
"common": {"out_dir": args.out_dir},
"model": {"model_path": args.model_path, "model_engine": args.model_engine},
"data": {
"data_path": args.data_path,
"data_engine": args.data_engine,
"simulator": args.simulator,
},
"plots": {key: {} for key in args.plots},
"metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
config = Config(temp_config)

return config


def main():
config = parser()

Expand All @@ -66,14 +90,16 @@ def main():
data = DataModules[data_engine](data_path, simulator_name)

out_dir = config.get_item("common", "out_dir", raise_exception=False)
if not os.path.exists(os.path.dirname(out_dir)):
if not os.path.exists(os.path.dirname(out_dir)):
os.makedirs(os.path.dirname(out_dir))

metrics = config.get_section("metrics", raise_exception=False)
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(**plot_args)
for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
6 changes: 1 addition & 5 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@

from data.h5_data import H5Data
from data.pickle_data import PickleData

DataModules = {
"H5Data": H5Data,
"PickleData": PickleData
}
DataModules = {"H5Data": H5Data, "PickleData": PickleData}
140 changes: 109 additions & 31 deletions src/data/data.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,128 @@
import importlib.util
import sys
import os
import numpy as np

from utils.config import get_item
from utils.defaults import Defaults

class Data:
def __init__(self, path:str, simulator_name: str):

class Data:
def __init__(
self,
path: str,
simulator_name: str,
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
):
self.rng = np.random.default_rng(
get_item("common", "random_seed", raise_exception=False)
)
self.data = self._load(path)
self.simulator = self._load_simulator(simulator_name)
self.simulator = self._load_simulator(simulator_name, simulator_kwargs)
self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]

def _load_simulator(self, name):
try:
simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"]
except KeyError as e:
raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?")
def _load_simulator(self, name, simulator_kwargs):
try:
sim_location = get_item("common", "sim_location", raise_exception=False)
simulator_path = os.environ[f"{sim_location}:{name}"]
except KeyError as e:
raise RuntimeError(
f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?"
)

new_class = os.path.dirname(simulator_path)
sys.path.insert(1, new_class)

# TODO robust error checks
module_name = os.path.basename(simulator_path.rstrip('.py'))
# TODO robust error checks
module_name = os.path.basename(simulator_path.rstrip(".py"))
m = importlib.import_module(module_name)

simulator = getattr(m, name)
return simulator()

def _load(self, path:str):
raise NotImplementedError
simulator_kwargs = simulator_kwargs if simulator_kwargs is not None else get_item("data", "simulator_kwargs", raise_exception=False)
simulator_kwargs = {} if simulator_kwargs is None else simulator_kwargs
simulator_instance = simulator(**simulator_kwargs)

if not hasattr(simulator_instance, "generate_context"):
raise RuntimeError(
"Simulator improperly formed - requires a generate_context method."
)

if not hasattr(simulator_instance, "simulate"):
raise RuntimeError(
"Simulator improperly formed - requires a simulate method."
)

def x_true(self):
# From Data
return simulator_instance

def _load(self, path: str):
raise NotImplementedError

def y_true(self):
return self.simulator(self.theta_true(), self.x_true())

def prior(self):

def true_context(self):
# From Data
raise NotImplementedError

def theta_true(self):
return get_item("data", "theta_true")

def sigma_true(self):
return get_item("data", "sigma_true")

def save(self, data, path:str):
raise NotImplementedError

def true_simulator_outcome(self):
return self.simulator(self.get_theta_true(), self.true_context())

def sample_prior(self, n_samples: int):
return self.prior_dist(size=(n_samples, self.n_dims))

def simulator_outcome(self, theta, condition_context=None, n_samples=None):
if condition_context is None:
if n_samples is None:
raise ValueError(
"Samples required if condition context is not specified"
)
return self.simulator(theta, n_samples)
else:
return self.simulator.simulate(theta, condition_context)

def simulated_context(self, n_samples):
return self.simulator.generate_context(n_samples)

def get_theta_true(self):
if hasattr(self, "theta_true"):
return self.theta_true
else:
return get_item("data", "theta_true", raise_exception=True)

def get_sigma_true(self):
if hasattr(self, "sigma_true"):
return self.sigma_true
else:
return get_item("data", "sigma_true", raise_exception=True)

def save(self, data, path: str):
raise NotImplementedError

def read_prior(self):
raise NotImplementedError

def load_prior(self, prior, prior_kwargs):
if prior is None:
prior = get_item("data", "prior", raise_exception=False)
try:
prior = self.read_prior()
except NotImplementedError:
choices = {
"normal": self.rng.normal,
"poisson": self.rng.poisson,
"uniform": self.rng.uniform,
"gamma": self.rng.gamma,
"beta": self.rng.beta,
"binominal": self.rng.binomial,
}

if prior not in choices.keys():
raise NotImplementedError(
f"{prior} is not an option for a prior, choose from {list(choices.keys())}"
)
if prior_kwargs is None:
prior_kwargs = {}
return lambda size: choices[prior](**prior_kwargs, size=size)

except KeyError as e:
raise RuntimeError(f"Data missing a prior specification - {e}")
40 changes: 19 additions & 21 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@
import h5py
import numpy as np
import torch
import os
import os

from data.data import Data

class H5Data(Data):
def __init__(self, path:str, simulator:Callable):

class H5Data(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)

def _load(self, path):
def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
with h5py.File(path, "r") as file:
for key in file.keys():
loaded_data[key] = torch.Tensor(file[key][...])
return loaded_data
def save(self, data:dict[str, Any], path: str): # Todo typing for data dict

def save(self, data: dict[str, Any], path: str): # Todo typing for data dict
assert path.split(".")[-1] == "h5", "File extension must be h5"
if not os.path.exists(os.path.dirname(path)):
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))

data_arrays = {key: np.asarray(value) for key, value in data.items()}
Expand All @@ -29,22 +30,19 @@ def save(self, data:dict[str, Any], path: str): # Todo typing for data dict
for key, value in data_arrays.items():
file.create_dataset(key, data=value)

def x_true(self):
# From Data
return self.data['xs']

def y_true(self):
return self.simulator(self.theta_true(), self.x_true())

def prior(self):
def true_context(self):
# From Data
return self.data["xs"] # TODO change name

def prior(self):
# From Data
raise NotImplementedError

def theta_true(self):
return self.data['thetas']

def sigma_true(self):
try:
def get_theta_true(self):
return self.data["thetas"]

def get_sigma_true(self):
try:
return super().sigma_true()
except (AssertionError, KeyError):
except (AssertionError, KeyError):
return 1
Loading

0 comments on commit c6076c7

Please sign in to comment.