Skip to content

Commit

Permalink
rename methods to be more pythonic, in response to issue #8
Browse files Browse the repository at this point in the history
  • Loading branch information
darrenjw committed Nov 3, 2024
1 parent 3835819 commit 9e1ea0e
Show file tree
Hide file tree
Showing 36 changed files with 138 additions and 138 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import jsmfsb
lvmod = jsmfsb.models.lv()
step = lvmod.step_gillespie()
k0 = jax.random.key(42)
out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step)
out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step)
assert(out.shape == (300, 2))
```

Expand All @@ -38,7 +38,7 @@ fig.savefig("lv.pdf")

The API for this package is very similar to that of the `smfsb` package. The main difference is that non-deterministic (random) functions have an extra argument (typically the first argument) that corresponds to a JAX random number key. See the [relevant section](https://jax.readthedocs.io/en/latest/random-numbers.html) of the JAX documentation for further information regarding random numbers in JAX code.

For further information, see the [demo directory](https://github.com/darrenjw/jax-smfsb/tree/main/demos) and the [API documentation](https://jax-smfsb.readthedocs.io/en/latest/index.html). Within the demos directory, see [shbuild.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/shbuild.py) for an example of how to specify a (SEIR epidemic) model using SBML-shorthand and [step_cle2Df.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/step_cle2Df.py) for a 2-d reaction-diffusion simulation. For parameter inference (from time course data), see [abc-cal.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc-cal.py) for ABC inference, [abcSmc.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abcSmc.py) for ABC-SMC inference and [pmmh.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/pmmh.py) for particle marginal Metropolis-Hastings MCMC-based inference. There are many other demos besides these.
For further information, see the [demo directory](https://github.com/darrenjw/jax-smfsb/tree/main/demos) and the [API documentation](https://jax-smfsb.readthedocs.io/en/latest/index.html). Within the demos directory, see [shbuild.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/shbuild.py) for an example of how to specify a (SEIR epidemic) model using SBML-shorthand and [step_cle_2df.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/step_cle_2df.py) for a 2-d reaction-diffusion simulation. For parameter inference (from time course data), see [abc-cal.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc-cal.py) for ABC inference, [abc_smc.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc_smc.py) for ABC-SMC inference and [pmmh.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/pmmh.py) for particle marginal Metropolis-Hastings MCMC-based inference. There are many other demos besides these.

You can view this package on [GitHub](https://github.com/darrenjw/jax-smfsb) or [PyPI](https://pypi.org/project/jsmfsb/).

Expand Down
8 changes: 4 additions & 4 deletions demos/abc-cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

print("ABC with calibrated summary stats")

data = jsmfsb.data.LVperfect[:,1:3]
data = jsmfsb.data.lv_perfect[:,1:3]

def rpr(k):
k1, k2, k3 = jax.random.split(k, 3)
Expand All @@ -19,7 +19,7 @@ def rpr(k):
jax.random.uniform(k3, minval=-4, maxval=2)]))

def rmod(k, th):
return jsmfsb.simTs(k, jnp.array([50.0, 100.0]), 0, 30, 2,
return jsmfsb.sim_time_series(k, jnp.array([50.0, 100.0]), 0, 30, 2,
jsmfsb.models.lv(th).step_cle(0.1))

def ss1d(vec):
Expand All @@ -43,7 +43,7 @@ def ssi(ts):
k0 = jax.random.key(42)
k1, k2 = jax.random.split(k0)

p, d = jsmfsb.abcRun(k1, 100000, rpr, lambda k,th: ssi(rmod(k,th)), batch_size=10000)
p, d = jsmfsb.abc_run(k1, 100000, rpr, lambda k,th: ssi(rmod(k,th)), batch_size=10000)
prmat = jnp.vstack(p)
dmat = jnp.vstack(d)
print(prmat.shape)
Expand All @@ -67,7 +67,7 @@ def dist(ss):
def rdis(k, th):
return dist(sumStats(rmod(k, th)))

p, d = jsmfsb.abcRun(k2, 1000000, rpr, rdis, batch_size=100000, verb=False)
p, d = jsmfsb.abc_run(k2, 1000000, rpr, rdis, batch_size=100000, verb=False)

q = jnp.nanquantile(d, 0.01)
prmat = jnp.vstack(p)
Expand Down
6 changes: 3 additions & 3 deletions demos/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

print("ABC")

data = jsmfsb.data.LVperfect[:,1:3]
data = jsmfsb.data.lv_perfect[:,1:3]

def rpr(k):
k1, k2, k3 = jax.random.split(k, 3)
Expand All @@ -18,7 +18,7 @@ def rpr(k):
jax.random.uniform(k3, minval=-4, maxval=2)]))

def rmod(k, th):
return jsmfsb.simTs(k, jnp.array([50.0, 100.0]), 0, 30, 2,
return jsmfsb.sim_time_series(k, jnp.array([50.0, 100.0]), 0, 30, 2,
jsmfsb.models.lv(th).step_cle(0.1))

def sumStats(dat):
Expand All @@ -34,7 +34,7 @@ def rdis(k, th):
return dist(sumStats(rmod(k, th)))

k0 = jax.random.key(42)
p, d = jsmfsb.abcRun(k0, 1000000, rpr, rdis, batch_size=100000, verb=False)
p, d = jsmfsb.abc_run(k0, 1000000, rpr, rdis, batch_size=100000, verb=False)

q = jnp.nanquantile(d, 0.01)
prmat = jnp.vstack(p)
Expand Down
6 changes: 3 additions & 3 deletions demos/abcRun.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# abcRun.py
# abc_run.py

import jsmfsb
import jax
Expand Down Expand Up @@ -29,7 +29,7 @@ def dist(ss):
def rdis(k, th):
return dist(sumStats(rmod(k, th)))

p, d = jsmfsb.abcRun(k2, 1000000, rpr, rdis)
p, d = jsmfsb.abc_run(k2, 1000000, rpr, rdis)

q = jnp.quantile(d, 0.01)
prmat = jnp.vstack(p)
Expand All @@ -44,7 +44,7 @@ def rdis(k, th):
axes[1, 1].plot(range(its), postmat[:,1], linewidth=0.1)
axes[2, 0].hist(postmat[:,0], bins=30)
axes[2, 1].hist(postmat[:,1], bins=30)
fig.savefig("abcRun.pdf")
fig.savefig("abc_run.pdf")


# eof
Expand Down
12 changes: 6 additions & 6 deletions demos/abcSmc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# abcSmc.py
# abc_smc.py

import jsmfsb
import jax
Expand All @@ -8,7 +8,7 @@

print("ABC-SMC")

data = jsmfsb.data.LVperfect[:,1:3]
data = jsmfsb.data.lv_perfect[:,1:3]

# Very vague prior

Expand Down Expand Up @@ -44,7 +44,7 @@ def dpr(th):
# Model

def rmod(k, th):
return jsmfsb.simTs(k, jnp.array([50.0, 100]), 0, 30, 2,
return jsmfsb.sim_time_series(k, jnp.array([50.0, 100]), 0, 30, 2,
jsmfsb.models.lv(jnp.exp(th)).step_cle(0.1))

print("Pilot run...")
Expand All @@ -64,7 +64,7 @@ def ssi(ts):
jnp.array([jnp.corrcoef(ts[:,0], ts[:,1])[0,1]])))

key = jax.random.key(42)
p, d = jsmfsb.abcRun(key, 20000, rpr, lambda k,th: ssi(rmod(k,th)), verb=False)
p, d = jsmfsb.abc_run(key, 20000, rpr, lambda k,th: ssi(rmod(k,th)), verb=False)
prmat = jnp.vstack(p)
dmat = jnp.vstack(d)
print(prmat.shape)
Expand Down Expand Up @@ -93,7 +93,7 @@ def rper(k,th):
def dper(ne, ol):
return jnp.sum(jsp.stats.norm.logpdf(ne, ol, 0.5))

postmat = jsmfsb.abcSmc(key, 10000, rpr, dpr, rdis, rper, dper,
postmat = jsmfsb.abc_smc(key, 10000, rpr, dpr, rdis, rper, dper,
factor=5, steps=8, verb=True)

its, var = postmat.shape
Expand All @@ -107,7 +107,7 @@ def dper(ne, ol):
axes[1, 0].hist(postmat[:,0], bins=30)
axes[1, 1].hist(postmat[:,1], bins=30)
axes[1, 2].hist(postmat[:,2], bins=30)
fig.savefig("abcSmc.pdf")
fig.savefig("abc_smc.pdf")

print("All done.")

Expand Down
2 changes: 1 addition & 1 deletion demos/lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
stepC = lvmod.step_cle(0.01)
print(stepC(k0, lvmod.m, 0, 30))

out = jsmfsb.simSample(k0, 10000, lvmod.m, 0, 30, stepC)
out = jsmfsb.sim_sample(k0, 10000, lvmod.m, 0, 30, stepC)
out = jnp.where(out > 1000, 1000, out)
import scipy as sp
print(sp.stats.describe(out))
Expand Down
2 changes: 1 addition & 1 deletion demos/m-bd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, bdmod.m, 0, 30))

out = jsmfsb.simTs(k0, bdmod.m, 0, 20, 0.1, step)
out = jsmfsb.sim_time_series(k0, bdmod.m, 0, 20, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-dimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, dimermod.m, 0, 30))

out = jsmfsb.simTs(k0, dimermod.m, 0, 30, 0.1, step)
out = jsmfsb.sim_time_series(k0, dimermod.m, 0, 30, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-id.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, idmod.m, 0, 30))

out = jsmfsb.simTs(k0, idmod.m, 0, 100, 0.1, step)
out = jsmfsb.sim_time_series(k0, idmod.m, 0, 100, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv-cle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step)
out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv-euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step)
out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv-pts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step)
out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step)
out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, mmmod.m, 0, 30))

out = jsmfsb.simTs(k0, mmmod.m, 0, 100, 0.1, step)
out = jsmfsb.sim_time_series(k0, mmmod.m, 0, 100, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
2 changes: 1 addition & 1 deletion demos/m-sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
k0 = jax.random.key(42)
print(step(k0, sirmod.m, 0, 30))

out = jsmfsb.simTs(k0, sirmod.m, 0, 100, 0.1, step)
out = jsmfsb.sim_time_series(k0, sirmod.m, 0, 100, 0.1, step)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand Down
6 changes: 3 additions & 3 deletions demos/metropolisHastings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# metropolisHastings.py
# metropolis_hastings.py

import jsmfsb
import jax
Expand All @@ -11,7 +11,7 @@
data = jax.random.normal(k1, 250)*2 + 5
llik = lambda k,x: jnp.sum(jsp.stats.norm.logpdf(data, x[0], x[1]))
prop = lambda k,x: jax.random.normal(k, 2)*0.1 + x
postmat = jsmfsb.metropolisHastings(k2, jnp.array([1.0,1.0]), llik, prop, verb=False)
postmat = jsmfsb.metropolis_hastings(k2, jnp.array([1.0,1.0]), llik, prop, verb=False)


import matplotlib.pyplot as plt
Expand All @@ -22,7 +22,7 @@
axes[1, 1].plot(range(10000), postmat[:,1], linewidth=0.1)
axes[2, 0].hist(postmat[:,0], bins=30)
axes[2, 1].hist(postmat[:,1], bins=30)
fig.savefig("metropolisHastings.pdf")
fig.savefig("metropolis_hastings.pdf")


# eof
Expand Down
4 changes: 2 additions & 2 deletions demos/pfMLLik.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# pfMLLik.py
# pf_marginal_ll.py

import jsmfsb
import jax
Expand All @@ -17,7 +17,7 @@ def step(k, x, t, dt, th):
sf = jsmfsb.models.lv(th).step_cle(0.1)
#sf = jsmfsb.models.lv(th).step_gillespie()
return sf(k, x, t, dt)
mll = jsmfsb.pfMLLik(100, simX, 0, step, obsll, jsmfsb.data.LVnoise10)
mll = jsmfsb.pf_marginal_ll(100, simX, 0, step, obsll, jsmfsb.data.lv_noise_10)

k = jax.random.split(jax.random.key(42), 5)
print(mll(k[0], jnp.array([1, 0.005, 0.6])))
Expand Down
4 changes: 2 additions & 2 deletions demos/pmmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def step(k, x, t, dt, th):
#sf = jsmfsb.models.lv(th).step_gillespie()
sf = jsmfsb.models.lv(th).step_cle(0.1)
return sf(k, x, t, dt)
mll = jsmfsb.pfMLLik(100, simX, 0, step, obsll, jsmfsb.data.LVnoise10)
mll = jsmfsb.pf_marginal_ll(100, simX, 0, step, obsll, jsmfsb.data.lv_noise_10)

print("Test evals")
k0 = jax.random.key(42)
Expand All @@ -34,7 +34,7 @@ def step(k, x, t, dt, th):
def prop(k, th, tune=0.01):
return jnp.exp(jax.random.normal(k, shape=(3))*tune) * th

thmat = jsmfsb.metropolisHastings(k3, jnp.array([1, 0.005, 0.6]), mll, prop,
thmat = jsmfsb.metropolis_hastings(k3, jnp.array([1, 0.005, 0.6]), mll, prop,
iters=5000, thin=1, verb=False)

print("MCMC done. Now processing the results...")
Expand Down
8 changes: 4 additions & 4 deletions demos/shbuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
gamma*I : gamma=0.5
"""

seir = jsmfsb.sh2Spn(seirSH)
seir = jsmfsb.shorthand_to_spn(seirSH)
stepSeir = seir.step_gillespie()
k0 = jax.random.key(42)
out = jsmfsb.simTs(k0, seir.m, 0, 40, 0.05, stepSeir)
out = jsmfsb.sim_time_series(k0, seir.m, 0, 40, 0.05, stepSeir)

import matplotlib.pyplot as plt
fig, axis = plt.subplots()
Expand All @@ -44,8 +44,8 @@
axis.legend(seir.n)
fig.savefig("shbuild.pdf")

# simSample
out = jsmfsb.simSample(k0, 10000, seir.m, 0, 10, stepSeir)
# sim_sample
out = jsmfsb.sim_sample(k0, 10000, seir.m, 0, 10, stepSeir)
import scipy as sp
print(sp.stats.describe(out))
fig, axes = plt.subplots(4,1)
Expand Down
6 changes: 3 additions & 3 deletions demos/stepCLE1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
x0 = jnp.zeros((2,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(N/2)].set(lv.m)
stepLv1d = lv.step_cle1D(jnp.array([0.6, 0.6]))
stepLv1d = lv.step_cle_1d(jnp.array([0.6, 0.6]))
k0 = jax.random.key(42)
x1 = stepLv1d(k0, x0, 0, 1)
print(x1)
out = jsmfsb.simTs1D(k0, x0, 0, T, 1, stepLv1d, True)
out = jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d, True)
#print(out)

fig, axis = plt.subplots()
for i in range(2):
axis.imshow(out[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"step_cle1D{i}.pdf")
fig.savefig(f"step_cle_1d{i}.pdf")


# eof
4 changes: 2 additions & 2 deletions demos/stepCLE2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
x0 = jnp.zeros((2,M,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6]))
stepLv2d = lv.step_cle_2d(jnp.array([0.6, 0.6]))
k0 = jax.random.key(42)
x1 = stepLv2d(k0, x0, 0, T)

fig, axis = plt.subplots()
for i in range(2):
axis.imshow(x1[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"step_cle2D{i}.pdf")
fig.savefig(f"step_cle_2d{i}.pdf")



Expand Down
4 changes: 2 additions & 2 deletions demos/stepCLE2Df.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
x0 = jnp.zeros((2,M,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6]), 0.1)
stepLv2d = lv.step_cle_2d(jnp.array([0.6, 0.6]), 0.1)
k0 = jax.random.key(42)
x1 = stepLv2d(k0, x0, 0, T)

fig, axis = plt.subplots()
for i in range(2):
axis.imshow(x1[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"step_cle2Df{i}.pdf")
fig.savefig(f"step_cle_2df{i}.pdf")



Expand Down
Loading

0 comments on commit 9e1ea0e

Please sign in to comment.