Skip to content

Commit

Permalink
starting on syntax refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
darrenjw committed Nov 3, 2024
1 parent 71a139c commit 3835819
Show file tree
Hide file tree
Showing 37 changed files with 100 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Makefile

VERSION=1.0.1
VERSION=1.1.0

FORCE:
make install
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import jax
import jsmfsb

lvmod = jsmfsb.models.lv()
step = lvmod.stepGillespie()
step = lvmod.step_gillespie()
k0 = jax.random.key(42)
out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step)
assert(out.shape == (300, 2))
Expand All @@ -38,7 +38,7 @@ fig.savefig("lv.pdf")

The API for this package is very similar to that of the `smfsb` package. The main difference is that non-deterministic (random) functions have an extra argument (typically the first argument) that corresponds to a JAX random number key. See the [relevant section](https://jax.readthedocs.io/en/latest/random-numbers.html) of the JAX documentation for further information regarding random numbers in JAX code.

For further information, see the [demo directory](https://github.com/darrenjw/jax-smfsb/tree/main/demos) and the [API documentation](https://jax-smfsb.readthedocs.io/en/latest/index.html). Within the demos directory, see [shbuild.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/shbuild.py) for an example of how to specify a (SEIR epidemic) model using SBML-shorthand and [stepCLE2Df.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/stepCLE2Df.py) for a 2-d reaction-diffusion simulation. For parameter inference (from time course data), see [abc-cal.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc-cal.py) for ABC inference, [abcSmc.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abcSmc.py) for ABC-SMC inference and [pmmh.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/pmmh.py) for particle marginal Metropolis-Hastings MCMC-based inference. There are many other demos besides these.
For further information, see the [demo directory](https://github.com/darrenjw/jax-smfsb/tree/main/demos) and the [API documentation](https://jax-smfsb.readthedocs.io/en/latest/index.html). Within the demos directory, see [shbuild.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/shbuild.py) for an example of how to specify a (SEIR epidemic) model using SBML-shorthand and [step_cle2Df.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/step_cle2Df.py) for a 2-d reaction-diffusion simulation. For parameter inference (from time course data), see [abc-cal.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc-cal.py) for ABC inference, [abcSmc.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abcSmc.py) for ABC-SMC inference and [pmmh.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/pmmh.py) for particle marginal Metropolis-Hastings MCMC-based inference. There are many other demos besides these.

You can view this package on [GitHub](https://github.com/darrenjw/jax-smfsb) or [PyPI](https://pypi.org/project/jsmfsb/).

Expand Down
2 changes: 1 addition & 1 deletion demos/abc-cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def rpr(k):

def rmod(k, th):
return jsmfsb.simTs(k, jnp.array([50.0, 100.0]), 0, 30, 2,
jsmfsb.models.lv(th).stepCLE(0.1))
jsmfsb.models.lv(th).step_cle(0.1))

def ss1d(vec):
n = len(vec)
Expand Down
2 changes: 1 addition & 1 deletion demos/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def rpr(k):

def rmod(k, th):
return jsmfsb.simTs(k, jnp.array([50.0, 100.0]), 0, 30, 2,
jsmfsb.models.lv(th).stepCLE(0.1))
jsmfsb.models.lv(th).step_cle(0.1))

def sumStats(dat):
return dat
Expand Down
2 changes: 1 addition & 1 deletion demos/abcSmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def dpr(th):

def rmod(k, th):
return jsmfsb.simTs(k, jnp.array([50.0, 100]), 0, 30, 2,
jsmfsb.models.lv(jnp.exp(th)).stepCLE(0.1))
jsmfsb.models.lv(jnp.exp(th)).step_cle(0.1))

print("Pilot run...")

Expand Down
4 changes: 2 additions & 2 deletions demos/lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import jsmfsb

lvmod = jsmfsb.models.lv()
step = lvmod.stepGillespie()
step = lvmod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

stepC = lvmod.stepCLE(0.01)
stepC = lvmod.step_cle(0.01)
print(stepC(k0, lvmod.m, 0, 30))

out = jsmfsb.simSample(k0, 10000, lvmod.m, 0, 30, stepC)
Expand Down
2 changes: 1 addition & 1 deletion demos/m-bd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

bdmod = jsmfsb.models.bd()
step = bdmod.stepGillespie()
step = bdmod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, bdmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-dimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

dimermod = jsmfsb.models.dimer()
step = dimermod.stepGillespie()
step = dimermod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, dimermod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-id.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

idmod = jsmfsb.models.id()
step = idmod.stepGillespie()
step = idmod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, idmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv-cle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

lvmod = jsmfsb.models.lv()
step = lvmod.stepCLE()
step = lvmod.step_cle()
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv-euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

lvmod = jsmfsb.models.lv()
step = lvmod.stepEuler(0.001)
step = lvmod.step_euler(0.001)
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv-pts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

lvmod = jsmfsb.models.lv()
step = lvmod.stepPTS()
step = lvmod.step_poisson()
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

lvmod = jsmfsb.models.lv()
step = lvmod.stepGillespie()
step = lvmod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, lvmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

mmmod = jsmfsb.models.mm()
step = mmmod.stepGillespie()
step = mmmod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, mmmod.m, 0, 30))

Expand Down
2 changes: 1 addition & 1 deletion demos/m-sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jsmfsb

sirmod = jsmfsb.models.sir()
step = sirmod.stepGillespie()
step = sirmod.step_gillespie()
k0 = jax.random.key(42)
print(step(k0, sirmod.m, 0, 30))

Expand Down
4 changes: 2 additions & 2 deletions demos/pfMLLik.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def simX(k, t0, th):
return jnp.array([jax.random.poisson(k1, 50),
jax.random.poisson(k2, 100)]).astype(jnp.float32)
def step(k, x, t, dt, th):
sf = jsmfsb.models.lv(th).stepCLE(0.1)
#sf = jsmfsb.models.lv(th).stepGillespie()
sf = jsmfsb.models.lv(th).step_cle(0.1)
#sf = jsmfsb.models.lv(th).step_gillespie()
return sf(k, x, t, dt)
mll = jsmfsb.pfMLLik(100, simX, 0, step, obsll, jsmfsb.data.LVnoise10)

Expand Down
4 changes: 2 additions & 2 deletions demos/pmmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def simX(k, t0, th):
return jnp.array([jax.random.poisson(k1, 50),
jax.random.poisson(k2, 100)]).astype(jnp.float32)
def step(k, x, t, dt, th):
#sf = jsmfsb.models.lv(th).stepGillespie()
sf = jsmfsb.models.lv(th).stepCLE(0.1)
#sf = jsmfsb.models.lv(th).step_gillespie()
sf = jsmfsb.models.lv(th).step_cle(0.1)
return sf(k, x, t, dt)
mll = jsmfsb.pfMLLik(100, simX, 0, step, obsll, jsmfsb.data.LVnoise10)

Expand Down
2 changes: 1 addition & 1 deletion demos/shbuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"""

seir = jsmfsb.sh2Spn(seirSH)
stepSeir = seir.stepGillespie()
stepSeir = seir.step_gillespie()
k0 = jax.random.key(42)
out = jsmfsb.simTs(k0, seir.m, 0, 40, 0.05, stepSeir)

Expand Down
4 changes: 2 additions & 2 deletions demos/stepCLE1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
x0 = jnp.zeros((2,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(N/2)].set(lv.m)
stepLv1d = lv.stepCLE1D(jnp.array([0.6, 0.6]))
stepLv1d = lv.step_cle1D(jnp.array([0.6, 0.6]))
k0 = jax.random.key(42)
x1 = stepLv1d(k0, x0, 0, 1)
print(x1)
Expand All @@ -22,7 +22,7 @@
for i in range(2):
axis.imshow(out[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"stepCLE1D{i}.pdf")
fig.savefig(f"step_cle1D{i}.pdf")


# eof
4 changes: 2 additions & 2 deletions demos/stepCLE2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
x0 = jnp.zeros((2,M,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
stepLv2d = lv.stepCLE2D(jnp.array([0.6, 0.6]))
stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6]))
k0 = jax.random.key(42)
x1 = stepLv2d(k0, x0, 0, T)

fig, axis = plt.subplots()
for i in range(2):
axis.imshow(x1[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"stepCLE2D{i}.pdf")
fig.savefig(f"step_cle2D{i}.pdf")



Expand Down
4 changes: 2 additions & 2 deletions demos/stepCLE2Df.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
x0 = jnp.zeros((2,M,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
stepLv2d = lv.stepCLE2D(jnp.array([0.6, 0.6]), 0.1)
stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6]), 0.1)
k0 = jax.random.key(42)
x1 = stepLv2d(k0, x0, 0, T)

fig, axis = plt.subplots()
for i in range(2):
axis.imshow(x1[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"stepCLE2Df{i}.pdf")
fig.savefig(f"step_cle2Df{i}.pdf")



Expand Down
4 changes: 2 additions & 2 deletions demos/stepGillespie1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(N/2)].set(lv.m)
k0 = jax.random.key(42)
stepLv1d = lv.stepGillespie1D(jnp.array([0.6, 0.6]))
stepLv1d = lv.step_gillespie1D(jnp.array([0.6, 0.6]))
x1 = stepLv1d(k0, x0, 0, 1)
print(x1)
out = jsmfsb.simTs1D(k0, x0, 0, T, 1, stepLv1d, True)
Expand All @@ -22,7 +22,7 @@
for i in range(2):
axis.imshow(out[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"stepGillespie1D{i}.pdf")
fig.savefig(f"step_gillespie1D{i}.pdf")


# eof
4 changes: 2 additions & 2 deletions demos/stepGillespie2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
x0 = jnp.zeros((2,M,N))
lv = jsmfsb.models.lv()
x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
stepLv2d = lv.stepGillespie2D(jnp.array([0.6, 0.6]))
stepLv2d = lv.step_gillespie2D(jnp.array([0.6, 0.6]))
k0 = jax.random.key(42)
x1 = stepLv2d(k0, x0, 0, T)

fig, axis = plt.subplots()
for i in range(2):
axis.imshow(x1[i,:,:])
axis.set_title(lv.n[i])
fig.savefig(f"stepGillespie2D{i}.pdf")
fig.savefig(f"step_gillespie2D{i}.pdf")


# eof
2 changes: 1 addition & 1 deletion demos/time-lv-cle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time

lvmod = jsmfsb.models.lv()
step = lvmod.stepCLE(0.01)
step = lvmod.step_cle(0.01)
k0 = jax.random.key(42)

## Start timer
Expand Down
2 changes: 1 addition & 1 deletion demos/time-lv-gillespie.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time

lvmod = jsmfsb.models.lv()
step = lvmod.stepGillespie()
step = lvmod.step_gillespie()
k0 = jax.random.key(42)

## Start timer
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "jsmfsb"
version = "1.0.1"
version = "1.1.0"
authors = [
{ name="Darren Wilkinson", email="[email protected]" },
]
Expand Down
2 changes: 1 addition & 1 deletion src/jsmfsb/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def pfMLLik(n, simX0, t0, stepFun, dataLLik, data, debug=False):
>>> jax.random.poisson(k2, 100)]).astype(jnp.float32)
>>>
>>> def step(key, x, t, dt, th):
>>> sf = jsmfsb.models.lv(th).stepGillespie()
>>> sf = jsmfsb.models.lv(th).step_gillespie()
>>> return sf(key, x, t, dt)
>>>
>>> mll = jsmfsb.pfMLLik(80, simX, 0, step, obsll, jsmfsb.data.LVnoise10)
Expand Down
12 changes: 6 additions & 6 deletions src/jsmfsb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def bd(th=[1, 1.1]):
>>> import jsmfsb
>>> import jax
>>> bd = jsmfsb.models.bd()
>>> step = bd.stepGillespie()
>>> step = bd.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.simTs(k, bd.m, 0, 50, 0.1, step)
"""
Expand Down Expand Up @@ -53,7 +53,7 @@ def dimer(th=[0.00166, 0.2]):
>>> import jsmfsb
>>> import jax
>>> dimer = jsmfsb.models.dimer()
>>> step = dimer.stepGillespie()
>>> step = dimer.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.simTs(k, dimer.m, 0, 50, 0.1, step)
"""
Expand Down Expand Up @@ -82,7 +82,7 @@ def id(th=[1, 0.1]):
>>> import smfsb
>>> import jax
>>> id = jsmfsb.models.id()
>>> step = id.stepGillespie()
>>> step = id.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.simTs(k, id.m, 0, 50, 0.1, step)
"""
Expand Down Expand Up @@ -112,7 +112,7 @@ def lv(th=[1, 0.005, 0.6]):
>>> import jsmfsb
>>> import jax
>>> lv = jsmfsb.models.lv()
>>> step = lv.stepGillespie()
>>> step = lv.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.simTs(k, lv.m, 0, 50, 0.1, step)
"""
Expand Down Expand Up @@ -142,7 +142,7 @@ def mm(th=[0.00166, 1e-4, 0.1]):
>>> import jsmfsb
>>> import jax
>>> mm = jsmfsb.models.mm()
>>> step = mm.stepGillespie()
>>> step = mm.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.simTs(k, mm.m, 0, 50, 0.1, step)
"""
Expand Down Expand Up @@ -173,7 +173,7 @@ def sir(th=[0.0015, 0.1]):
>>> import jsmfsb
>>> import jax
>>> sir = jsmfsb.models.sir()
>>> step = sir.stepGillespie()
>>> step = sir.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.simTs(k, sir.m, 0, 50, 0.1, step)
"""
Expand Down
Loading

0 comments on commit 3835819

Please sign in to comment.