Skip to content

Commit

Permalink
Merge pull request #59 from deepskies/issue/remove_inference
Browse files Browse the repository at this point in the history
Issue/remove inference
  • Loading branch information
beckynevin authored Feb 5, 2024
2 parents 9d2f71d + 2717219 commit 832c5f4
Show file tree
Hide file tree
Showing 16 changed files with 4 additions and 33,663 deletions.
317 changes: 0 additions & 317 deletions notebooks/SBI.ipynb

This file was deleted.

697 changes: 0 additions & 697 deletions notebooks/SBI_hierarchical_csv.ipynb

This file was deleted.

16,178 changes: 0 additions & 16,178 deletions notebooks/SBI_linefit.ipynb

This file was deleted.

Empty file removed notebooks/example.ipynb
Empty file.
1,125 changes: 0 additions & 1,125 deletions notebooks/numpyro_iterative_dataset_varying_noise.ipynb

This file was deleted.

663 changes: 0 additions & 663 deletions notebooks/numpyro_linefit.ipynb

This file was deleted.

1,916 changes: 0 additions & 1,916 deletions notebooks/pendulum_error_one_moment_in_time_DeepEnsemble.ipynb

This file was deleted.

3,995 changes: 0 additions & 3,995 deletions notebooks/pendulum_numpyro_many_times_hierarchical_ex.ipynb

This file was deleted.

2,105 changes: 0 additions & 2,105 deletions notebooks/pendulum_one_time_hierarchical.ipynb

This file was deleted.

3,248 changes: 0 additions & 3,248 deletions notebooks/pendulum_simple_numpyro_inference.ipynb

This file was deleted.

2,192 changes: 0 additions & 2,192 deletions notebooks/sampling_numpyro.ipynb

This file was deleted.

579 changes: 1 addition & 578 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[tool.poetry]
name = "DeepUQ"
packages = [{include = "*", from="src"}]
version = "0.1.0"
description = ""
description = "a package for investigating and comparing ML model's predictive uncertainties"
authors = ["beckynevin <[email protected]>"]
readme = "README.md"
license = "MIT"

[tool.poetry.dependencies]
python = ">=3.9,<3.11"
numpyro = "^0.13.2"
jupyter = "^1.0.0"
matplotlib = "^3.7.1"
arviz = "^0.15.1"
Expand All @@ -17,7 +17,6 @@ scikit-learn = "^1.3.0"
graphviz = "^0.20.1"
seaborn = "^0.12.2"
torch = "^2.0.1"
sbi = "^0.21.0"
pytest-cov = "^4.1.0"
deepbench = "^0.2.2"

Expand Down
96 changes: 0 additions & 96 deletions src/scripts/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import numpyro
import numpyro.distributions as dist
import numpy as np
import jax
import jax.numpy as jnp # yes i know this is confusing
import torch.nn as nn
import torch
import math
Expand Down Expand Up @@ -63,101 +59,9 @@ def forward(self, x):
# HMC after SBI to look at degeneracies between params
# different guides (some are slower but better at showing degeneracies)

## define the platform and number of cores (one chain per core)
numpyro.set_platform('cpu')
core_num = 4
numpyro.set_host_device_count(core_num)

def hierarchical_model(planet_code,
pendulum_code,
times,
exponential,
pos_obs=None):
"""
"""
## inputs to a numpyro model are rows from a dataframe:
## planet code - array of embedded numbers representing which planet {0...1}
## pendulum code - array of embedded numbers representing which pendulum {0...7}
## times - moments in time (s)
## pos_obs - this is optional, set to None but used to compare the model with data
## (when data, xpos, is defined)

## numpyro models function by drawing parameters from samples
## first, we define the global parameters, mean and sigma of a normal from
## which the individual a_g values of each planet will be drawn


#μ_a_g = numpyro.sample("μ_a_g", dist.LogUniform(5.0,15.0))
μ_a_g = numpyro.sample("μ_a_g", dist.TruncatedNormal(12.5, 5, low=0.01))
# scale parameters should be log uniform so that they don't go negative
# and so that they're not uniform
# 1 / x in linear space
σ_a_g = numpyro.sample("σ_a_g", dist.TruncatedNormal(0.1, 0.01, low=0.01))
n_planets = len(np.unique(planet_code))
n_pendulums = len(np.unique(pendulum_code))

## plates are a numpyro primitive or context manager for handing conditionally independence
## for instance, we wish to model a_g for each planet independently
with numpyro.plate("planet_i", n_planets):
a_g = numpyro.sample("a_g", dist.TruncatedNormal(μ_a_g, σ_a_g,
low=0.01))
# helps because a_gs are being pulled from same normal dist
# removes dependency of a_g on sigma_a_g on a prior level
# removing one covariance from model, model is easier
# to sample from

## we also wish to model L and theta for each pendulum independently
## here we draw from an uniform distribution
with numpyro.plate("pend_i", n_pendulums):
L = numpyro.sample("L", dist.TruncatedNormal(5, 2, low=0.01))
theta = numpyro.sample("theta", dist.TruncatedNormal(jnp.pi/100,
jnp.pi/500,
low=0.00001))

## σ is the error on the position measurement for each moment in time
## we also model this
## eventually, we should also model the error on each parameter independently?
## draw from an exponential distribution parameterized by a rate parameter
## the mean of an exponential distribution is 1/r where r is the rate parameter
## exponential distributions are never negative. This is good for error.
σ = numpyro.sample("σ", dist.Exponential(exponential))

## the moments in time are not independent, so we do not place the following in a plate
## instead, the brackets segment the model by pendulum and by planet,
## telling us how to conduct the inference
modelx = L[pendulum_code] * jnp.sin(theta[pendulum_code] * jnp.cos(jnp.sqrt(a_g[planet_code] / L[pendulum_code]) * times))
## don't forget to use jnp instead of np so jax knows what to do
## A BIG QUESTION I STILL HAVE IS WHAT IS THE LIKELIHOOD? IS IT JUST SAMPLED FROM?
## again, for each pendulum we compare the observed to the modeled position:
with numpyro.plate("data", len(pendulum_code)):
pos = numpyro.sample("obs", dist.Normal(modelx, σ), obs=pos_obs)


def unpooled_model(planet_code,
pendulum_code,
times,
exponential,
pos_obs=None):
n_planets = len(np.unique(planet_code))
n_pendulums = len(np.unique(pendulum_code))
with numpyro.plate("planet_i", n_planets):
a_g = numpyro.sample("a_g", dist.TruncatedNormal(12.5, 5,
low=0, high=25))
with numpyro.plate("pend_i", n_pendulums):
L = numpyro.sample("L", dist.TruncatedNormal(5, 2, low = 0.01))
theta = numpyro.sample("theta", dist.TruncatedNormal(jnp.pi/100,
jnp.pi/500,
low=0.00001))
σ = numpyro.sample("σ", dist.Exponential(exponential))
modelx = L[pendulum_code] * jnp.sin(theta[pendulum_code] *
jnp.cos(jnp.sqrt(a_g[planet_code] / L[pendulum_code]) * times))
with numpyro.plate("data", len(pendulum_code)):
pos = numpyro.sample("obs", dist.Normal(modelx, σ), obs=pos_obs)

# This is from PasteurLabs -
# https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/models.py


class Model(nn.Module):
def __init__(self, n_output, n_hidden=64):
super().__init__()
Expand Down
24 changes: 1 addition & 23 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
"""
import argparse
import torch
import sbi
import time
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
from src.scripts import models
from scripts import models
import functools


Expand Down Expand Up @@ -428,27 +427,6 @@ def train_DE(trainDataLoader,



def train_SBI_hierarchical(thetas, xs, prior):
# Now let's put them in a tensor form that SBI can read.
theta = torch.tensor(thetas, dtype=torch.float32)
x = torch.tensor(xs, dtype=torch.float32)

# instantiate the neural density estimator
neural_posterior = sbi.utils.posterior_nn(model='maf')#,
#embedding_net=embedding_net,
#hidden_features=hidden_features,
#num_transforms=num_transforms)
# setup the inference procedure with the SNPE-C procedure
inference = sbi.inference.SNPE(prior=prior,
density_estimator=neural_posterior,
device="cpu")

# now that we have both the simulated images and
# parameters defined properly, we can train the SBI.
density_estimator = inference.append_simulations(theta, x).train()
return inference.build_posterior(density_estimator)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_source", type=str, help="Data used to train the model")
Expand Down
Loading

0 comments on commit 832c5f4

Please sign in to comment.