From 9e1ea0e991c403e34fce0af258d25df469508850 Mon Sep 17 00:00:00 2001 From: Darren Wilkinson Date: Sun, 3 Nov 2024 16:28:15 +0000 Subject: [PATCH] rename methods to be more pythonic, in response to issue #8 --- README.md | 4 ++-- demos/abc-cal.py | 8 ++++---- demos/abc.py | 6 +++--- demos/abcRun.py | 6 +++--- demos/abcSmc.py | 12 ++++++------ demos/lv.py | 2 +- demos/m-bd.py | 2 +- demos/m-dimer.py | 2 +- demos/m-id.py | 2 +- demos/m-lv-cle.py | 2 +- demos/m-lv-euler.py | 2 +- demos/m-lv-pts.py | 2 +- demos/m-lv.py | 2 +- demos/m-mm.py | 2 +- demos/m-sir.py | 2 +- demos/metropolisHastings.py | 6 +++--- demos/pfMLLik.py | 4 ++-- demos/pmmh.py | 4 ++-- demos/shbuild.py | 8 ++++---- demos/stepCLE1D.py | 6 +++--- demos/stepCLE2D.py | 4 ++-- demos/stepCLE2Df.py | 4 ++-- demos/stepGillespie1D.py | 6 +++--- demos/stepGillespie2D.py | 4 ++-- demos/time-lv-cle.py | 4 ++-- demos/time-lv-gillespie.py | 2 +- src/jsmfsb/data.py | 6 +++--- src/jsmfsb/inference.py | 22 +++++++++++----------- src/jsmfsb/models.py | 12 ++++++------ src/jsmfsb/sim.py | 8 ++++---- src/jsmfsb/smfsbSbml.py | 16 ++++++++-------- src/jsmfsb/spatial.py | 20 ++++++++++---------- src/jsmfsb/spn.py | 34 +++++++++++++++++----------------- tests/test_inference.py | 14 +++++++------- tests/test_sim.py | 8 ++++---- tests/test_spatial.py | 28 ++++++++++++++-------------- 36 files changed, 138 insertions(+), 138 deletions(-) diff --git a/README.md b/README.md index e5d9c79..d73fa3e 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ import jsmfsb lvmod = jsmfsb.models.lv() step = lvmod.step_gillespie() k0 = jax.random.key(42) -out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step) +out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step) assert(out.shape == (300, 2)) ``` @@ -38,7 +38,7 @@ fig.savefig("lv.pdf") The API for this package is very similar to that of the `smfsb` package. The main difference is that non-deterministic (random) functions have an extra argument (typically the first argument) that corresponds to a JAX random number key. See the [relevant section](https://jax.readthedocs.io/en/latest/random-numbers.html) of the JAX documentation for further information regarding random numbers in JAX code. -For further information, see the [demo directory](https://github.com/darrenjw/jax-smfsb/tree/main/demos) and the [API documentation](https://jax-smfsb.readthedocs.io/en/latest/index.html). Within the demos directory, see [shbuild.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/shbuild.py) for an example of how to specify a (SEIR epidemic) model using SBML-shorthand and [step_cle2Df.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/step_cle2Df.py) for a 2-d reaction-diffusion simulation. For parameter inference (from time course data), see [abc-cal.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc-cal.py) for ABC inference, [abcSmc.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abcSmc.py) for ABC-SMC inference and [pmmh.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/pmmh.py) for particle marginal Metropolis-Hastings MCMC-based inference. There are many other demos besides these. +For further information, see the [demo directory](https://github.com/darrenjw/jax-smfsb/tree/main/demos) and the [API documentation](https://jax-smfsb.readthedocs.io/en/latest/index.html). Within the demos directory, see [shbuild.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/shbuild.py) for an example of how to specify a (SEIR epidemic) model using SBML-shorthand and [step_cle_2df.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/step_cle_2df.py) for a 2-d reaction-diffusion simulation. For parameter inference (from time course data), see [abc-cal.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc-cal.py) for ABC inference, [abc_smc.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/abc_smc.py) for ABC-SMC inference and [pmmh.py](https://github.com/darrenjw/jax-smfsb/tree/main/demos/pmmh.py) for particle marginal Metropolis-Hastings MCMC-based inference. There are many other demos besides these. You can view this package on [GitHub](https://github.com/darrenjw/jax-smfsb) or [PyPI](https://pypi.org/project/jsmfsb/). diff --git a/demos/abc-cal.py b/demos/abc-cal.py index d773a6e..9ed6215 100755 --- a/demos/abc-cal.py +++ b/demos/abc-cal.py @@ -10,7 +10,7 @@ print("ABC with calibrated summary stats") -data = jsmfsb.data.LVperfect[:,1:3] +data = jsmfsb.data.lv_perfect[:,1:3] def rpr(k): k1, k2, k3 = jax.random.split(k, 3) @@ -19,7 +19,7 @@ def rpr(k): jax.random.uniform(k3, minval=-4, maxval=2)])) def rmod(k, th): - return jsmfsb.simTs(k, jnp.array([50.0, 100.0]), 0, 30, 2, + return jsmfsb.sim_time_series(k, jnp.array([50.0, 100.0]), 0, 30, 2, jsmfsb.models.lv(th).step_cle(0.1)) def ss1d(vec): @@ -43,7 +43,7 @@ def ssi(ts): k0 = jax.random.key(42) k1, k2 = jax.random.split(k0) -p, d = jsmfsb.abcRun(k1, 100000, rpr, lambda k,th: ssi(rmod(k,th)), batch_size=10000) +p, d = jsmfsb.abc_run(k1, 100000, rpr, lambda k,th: ssi(rmod(k,th)), batch_size=10000) prmat = jnp.vstack(p) dmat = jnp.vstack(d) print(prmat.shape) @@ -67,7 +67,7 @@ def dist(ss): def rdis(k, th): return dist(sumStats(rmod(k, th))) -p, d = jsmfsb.abcRun(k2, 1000000, rpr, rdis, batch_size=100000, verb=False) +p, d = jsmfsb.abc_run(k2, 1000000, rpr, rdis, batch_size=100000, verb=False) q = jnp.nanquantile(d, 0.01) prmat = jnp.vstack(p) diff --git a/demos/abc.py b/demos/abc.py index f07e715..4b0f77b 100755 --- a/demos/abc.py +++ b/demos/abc.py @@ -9,7 +9,7 @@ print("ABC") -data = jsmfsb.data.LVperfect[:,1:3] +data = jsmfsb.data.lv_perfect[:,1:3] def rpr(k): k1, k2, k3 = jax.random.split(k, 3) @@ -18,7 +18,7 @@ def rpr(k): jax.random.uniform(k3, minval=-4, maxval=2)])) def rmod(k, th): - return jsmfsb.simTs(k, jnp.array([50.0, 100.0]), 0, 30, 2, + return jsmfsb.sim_time_series(k, jnp.array([50.0, 100.0]), 0, 30, 2, jsmfsb.models.lv(th).step_cle(0.1)) def sumStats(dat): @@ -34,7 +34,7 @@ def rdis(k, th): return dist(sumStats(rmod(k, th))) k0 = jax.random.key(42) -p, d = jsmfsb.abcRun(k0, 1000000, rpr, rdis, batch_size=100000, verb=False) +p, d = jsmfsb.abc_run(k0, 1000000, rpr, rdis, batch_size=100000, verb=False) q = jnp.nanquantile(d, 0.01) prmat = jnp.vstack(p) diff --git a/demos/abcRun.py b/demos/abcRun.py index 5a59ae4..b1b3fed 100755 --- a/demos/abcRun.py +++ b/demos/abcRun.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# abcRun.py +# abc_run.py import jsmfsb import jax @@ -29,7 +29,7 @@ def dist(ss): def rdis(k, th): return dist(sumStats(rmod(k, th))) -p, d = jsmfsb.abcRun(k2, 1000000, rpr, rdis) +p, d = jsmfsb.abc_run(k2, 1000000, rpr, rdis) q = jnp.quantile(d, 0.01) prmat = jnp.vstack(p) @@ -44,7 +44,7 @@ def rdis(k, th): axes[1, 1].plot(range(its), postmat[:,1], linewidth=0.1) axes[2, 0].hist(postmat[:,0], bins=30) axes[2, 1].hist(postmat[:,1], bins=30) -fig.savefig("abcRun.pdf") +fig.savefig("abc_run.pdf") # eof diff --git a/demos/abcSmc.py b/demos/abcSmc.py index c7efe44..6ad48a6 100755 --- a/demos/abcSmc.py +++ b/demos/abcSmc.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# abcSmc.py +# abc_smc.py import jsmfsb import jax @@ -8,7 +8,7 @@ print("ABC-SMC") -data = jsmfsb.data.LVperfect[:,1:3] +data = jsmfsb.data.lv_perfect[:,1:3] # Very vague prior @@ -44,7 +44,7 @@ def dpr(th): # Model def rmod(k, th): - return jsmfsb.simTs(k, jnp.array([50.0, 100]), 0, 30, 2, + return jsmfsb.sim_time_series(k, jnp.array([50.0, 100]), 0, 30, 2, jsmfsb.models.lv(jnp.exp(th)).step_cle(0.1)) print("Pilot run...") @@ -64,7 +64,7 @@ def ssi(ts): jnp.array([jnp.corrcoef(ts[:,0], ts[:,1])[0,1]]))) key = jax.random.key(42) -p, d = jsmfsb.abcRun(key, 20000, rpr, lambda k,th: ssi(rmod(k,th)), verb=False) +p, d = jsmfsb.abc_run(key, 20000, rpr, lambda k,th: ssi(rmod(k,th)), verb=False) prmat = jnp.vstack(p) dmat = jnp.vstack(d) print(prmat.shape) @@ -93,7 +93,7 @@ def rper(k,th): def dper(ne, ol): return jnp.sum(jsp.stats.norm.logpdf(ne, ol, 0.5)) -postmat = jsmfsb.abcSmc(key, 10000, rpr, dpr, rdis, rper, dper, +postmat = jsmfsb.abc_smc(key, 10000, rpr, dpr, rdis, rper, dper, factor=5, steps=8, verb=True) its, var = postmat.shape @@ -107,7 +107,7 @@ def dper(ne, ol): axes[1, 0].hist(postmat[:,0], bins=30) axes[1, 1].hist(postmat[:,1], bins=30) axes[1, 2].hist(postmat[:,2], bins=30) -fig.savefig("abcSmc.pdf") +fig.savefig("abc_smc.pdf") print("All done.") diff --git a/demos/lv.py b/demos/lv.py index e65311f..c77353e 100644 --- a/demos/lv.py +++ b/demos/lv.py @@ -19,7 +19,7 @@ stepC = lvmod.step_cle(0.01) print(stepC(k0, lvmod.m, 0, 30)) -out = jsmfsb.simSample(k0, 10000, lvmod.m, 0, 30, stepC) +out = jsmfsb.sim_sample(k0, 10000, lvmod.m, 0, 30, stepC) out = jnp.where(out > 1000, 1000, out) import scipy as sp print(sp.stats.describe(out)) diff --git a/demos/m-bd.py b/demos/m-bd.py index 4ac10b1..6540f79 100644 --- a/demos/m-bd.py +++ b/demos/m-bd.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, bdmod.m, 0, 30)) -out = jsmfsb.simTs(k0, bdmod.m, 0, 20, 0.1, step) +out = jsmfsb.sim_time_series(k0, bdmod.m, 0, 20, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-dimer.py b/demos/m-dimer.py index 910c052..fe2818d 100644 --- a/demos/m-dimer.py +++ b/demos/m-dimer.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, dimermod.m, 0, 30)) -out = jsmfsb.simTs(k0, dimermod.m, 0, 30, 0.1, step) +out = jsmfsb.sim_time_series(k0, dimermod.m, 0, 30, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-id.py b/demos/m-id.py index 7e6a547..1f439d5 100644 --- a/demos/m-id.py +++ b/demos/m-id.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, idmod.m, 0, 30)) -out = jsmfsb.simTs(k0, idmod.m, 0, 100, 0.1, step) +out = jsmfsb.sim_time_series(k0, idmod.m, 0, 100, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-lv-cle.py b/demos/m-lv-cle.py index 63be349..e9319b6 100644 --- a/demos/m-lv-cle.py +++ b/demos/m-lv-cle.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, lvmod.m, 0, 30)) -out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step) +out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-lv-euler.py b/demos/m-lv-euler.py index 4db4107..a4b0308 100644 --- a/demos/m-lv-euler.py +++ b/demos/m-lv-euler.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, lvmod.m, 0, 30)) -out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step) +out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-lv-pts.py b/demos/m-lv-pts.py index 51b5bb9..c076726 100644 --- a/demos/m-lv-pts.py +++ b/demos/m-lv-pts.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, lvmod.m, 0, 30)) -out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step) +out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-lv.py b/demos/m-lv.py index 20c9a9f..22d9fe6 100644 --- a/demos/m-lv.py +++ b/demos/m-lv.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, lvmod.m, 0, 30)) -out = jsmfsb.simTs(k0, lvmod.m, 0, 30, 0.1, step) +out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-mm.py b/demos/m-mm.py index 37e7245..83f1020 100644 --- a/demos/m-mm.py +++ b/demos/m-mm.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, mmmod.m, 0, 30)) -out = jsmfsb.simTs(k0, mmmod.m, 0, 100, 0.1, step) +out = jsmfsb.sim_time_series(k0, mmmod.m, 0, 100, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/m-sir.py b/demos/m-sir.py index 7d1c175..f254dea 100644 --- a/demos/m-sir.py +++ b/demos/m-sir.py @@ -15,7 +15,7 @@ k0 = jax.random.key(42) print(step(k0, sirmod.m, 0, 30)) -out = jsmfsb.simTs(k0, sirmod.m, 0, 100, 0.1, step) +out = jsmfsb.sim_time_series(k0, sirmod.m, 0, 100, 0.1, step) import matplotlib.pyplot as plt fig, axis = plt.subplots() diff --git a/demos/metropolisHastings.py b/demos/metropolisHastings.py index 7993f95..1cd97ac 100755 --- a/demos/metropolisHastings.py +++ b/demos/metropolisHastings.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# metropolisHastings.py +# metropolis_hastings.py import jsmfsb import jax @@ -11,7 +11,7 @@ data = jax.random.normal(k1, 250)*2 + 5 llik = lambda k,x: jnp.sum(jsp.stats.norm.logpdf(data, x[0], x[1])) prop = lambda k,x: jax.random.normal(k, 2)*0.1 + x -postmat = jsmfsb.metropolisHastings(k2, jnp.array([1.0,1.0]), llik, prop, verb=False) +postmat = jsmfsb.metropolis_hastings(k2, jnp.array([1.0,1.0]), llik, prop, verb=False) import matplotlib.pyplot as plt @@ -22,7 +22,7 @@ axes[1, 1].plot(range(10000), postmat[:,1], linewidth=0.1) axes[2, 0].hist(postmat[:,0], bins=30) axes[2, 1].hist(postmat[:,1], bins=30) -fig.savefig("metropolisHastings.pdf") +fig.savefig("metropolis_hastings.pdf") # eof diff --git a/demos/pfMLLik.py b/demos/pfMLLik.py index ce32a0f..af7885b 100755 --- a/demos/pfMLLik.py +++ b/demos/pfMLLik.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# pfMLLik.py +# pf_marginal_ll.py import jsmfsb import jax @@ -17,7 +17,7 @@ def step(k, x, t, dt, th): sf = jsmfsb.models.lv(th).step_cle(0.1) #sf = jsmfsb.models.lv(th).step_gillespie() return sf(k, x, t, dt) -mll = jsmfsb.pfMLLik(100, simX, 0, step, obsll, jsmfsb.data.LVnoise10) +mll = jsmfsb.pf_marginal_ll(100, simX, 0, step, obsll, jsmfsb.data.lv_noise_10) k = jax.random.split(jax.random.key(42), 5) print(mll(k[0], jnp.array([1, 0.005, 0.6]))) diff --git a/demos/pmmh.py b/demos/pmmh.py index b73e189..8a7483b 100755 --- a/demos/pmmh.py +++ b/demos/pmmh.py @@ -21,7 +21,7 @@ def step(k, x, t, dt, th): #sf = jsmfsb.models.lv(th).step_gillespie() sf = jsmfsb.models.lv(th).step_cle(0.1) return sf(k, x, t, dt) -mll = jsmfsb.pfMLLik(100, simX, 0, step, obsll, jsmfsb.data.LVnoise10) +mll = jsmfsb.pf_marginal_ll(100, simX, 0, step, obsll, jsmfsb.data.lv_noise_10) print("Test evals") k0 = jax.random.key(42) @@ -34,7 +34,7 @@ def step(k, x, t, dt, th): def prop(k, th, tune=0.01): return jnp.exp(jax.random.normal(k, shape=(3))*tune) * th -thmat = jsmfsb.metropolisHastings(k3, jnp.array([1, 0.005, 0.6]), mll, prop, +thmat = jsmfsb.metropolis_hastings(k3, jnp.array([1, 0.005, 0.6]), mll, prop, iters=5000, thin=1, verb=False) print("MCMC done. Now processing the results...") diff --git a/demos/shbuild.py b/demos/shbuild.py index 9631226..521b514 100644 --- a/demos/shbuild.py +++ b/demos/shbuild.py @@ -31,10 +31,10 @@ gamma*I : gamma=0.5 """ -seir = jsmfsb.sh2Spn(seirSH) +seir = jsmfsb.shorthand_to_spn(seirSH) stepSeir = seir.step_gillespie() k0 = jax.random.key(42) -out = jsmfsb.simTs(k0, seir.m, 0, 40, 0.05, stepSeir) +out = jsmfsb.sim_time_series(k0, seir.m, 0, 40, 0.05, stepSeir) import matplotlib.pyplot as plt fig, axis = plt.subplots() @@ -44,8 +44,8 @@ axis.legend(seir.n) fig.savefig("shbuild.pdf") -# simSample -out = jsmfsb.simSample(k0, 10000, seir.m, 0, 10, stepSeir) +# sim_sample +out = jsmfsb.sim_sample(k0, 10000, seir.m, 0, 10, stepSeir) import scipy as sp print(sp.stats.describe(out)) fig, axes = plt.subplots(4,1) diff --git a/demos/stepCLE1D.py b/demos/stepCLE1D.py index 62efa71..42927b8 100755 --- a/demos/stepCLE1D.py +++ b/demos/stepCLE1D.py @@ -11,18 +11,18 @@ x0 = jnp.zeros((2,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(N/2)].set(lv.m) -stepLv1d = lv.step_cle1D(jnp.array([0.6, 0.6])) +stepLv1d = lv.step_cle_1d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv1d(k0, x0, 0, 1) print(x1) -out = jsmfsb.simTs1D(k0, x0, 0, T, 1, stepLv1d, True) +out = jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d, True) #print(out) fig, axis = plt.subplots() for i in range(2): axis.imshow(out[i,:,:]) axis.set_title(lv.n[i]) - fig.savefig(f"step_cle1D{i}.pdf") + fig.savefig(f"step_cle_1d{i}.pdf") # eof diff --git a/demos/stepCLE2D.py b/demos/stepCLE2D.py index 170a9de..876bee9 100755 --- a/demos/stepCLE2D.py +++ b/demos/stepCLE2D.py @@ -12,7 +12,7 @@ x0 = jnp.zeros((2,M,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) -stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6])) +stepLv2d = lv.step_cle_2d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv2d(k0, x0, 0, T) @@ -20,7 +20,7 @@ for i in range(2): axis.imshow(x1[i,:,:]) axis.set_title(lv.n[i]) - fig.savefig(f"step_cle2D{i}.pdf") + fig.savefig(f"step_cle_2d{i}.pdf") diff --git a/demos/stepCLE2Df.py b/demos/stepCLE2Df.py index de7a54b..2afc9ea 100755 --- a/demos/stepCLE2Df.py +++ b/demos/stepCLE2Df.py @@ -12,7 +12,7 @@ x0 = jnp.zeros((2,M,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) -stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6]), 0.1) +stepLv2d = lv.step_cle_2d(jnp.array([0.6, 0.6]), 0.1) k0 = jax.random.key(42) x1 = stepLv2d(k0, x0, 0, T) @@ -20,7 +20,7 @@ for i in range(2): axis.imshow(x1[i,:,:]) axis.set_title(lv.n[i]) - fig.savefig(f"step_cle2Df{i}.pdf") + fig.savefig(f"step_cle_2df{i}.pdf") diff --git a/demos/stepGillespie1D.py b/demos/stepGillespie1D.py index abe8901..f34d558 100755 --- a/demos/stepGillespie1D.py +++ b/demos/stepGillespie1D.py @@ -12,17 +12,17 @@ lv = jsmfsb.models.lv() x0 = x0.at[:,int(N/2)].set(lv.m) k0 = jax.random.key(42) -stepLv1d = lv.step_gillespie1D(jnp.array([0.6, 0.6])) +stepLv1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6])) x1 = stepLv1d(k0, x0, 0, 1) print(x1) -out = jsmfsb.simTs1D(k0, x0, 0, T, 1, stepLv1d, True) +out = jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d, True) #print(out) fig, axis = plt.subplots() for i in range(2): axis.imshow(out[i,:,:]) axis.set_title(lv.n[i]) - fig.savefig(f"step_gillespie1D{i}.pdf") + fig.savefig(f"step_gillespie_1d{i}.pdf") # eof diff --git a/demos/stepGillespie2D.py b/demos/stepGillespie2D.py index f33ba09..09432bf 100755 --- a/demos/stepGillespie2D.py +++ b/demos/stepGillespie2D.py @@ -12,7 +12,7 @@ x0 = jnp.zeros((2,M,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) -stepLv2d = lv.step_gillespie2D(jnp.array([0.6, 0.6])) +stepLv2d = lv.step_gillespie_2d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv2d(k0, x0, 0, T) @@ -20,7 +20,7 @@ for i in range(2): axis.imshow(x1[i,:,:]) axis.set_title(lv.n[i]) - fig.savefig(f"step_gillespie2D{i}.pdf") + fig.savefig(f"step_gillespie_2d{i}.pdf") # eof diff --git a/demos/time-lv-cle.py b/demos/time-lv-cle.py index 4266df1..b511213 100644 --- a/demos/time-lv-cle.py +++ b/demos/time-lv-cle.py @@ -17,8 +17,8 @@ ## Start timer startTime = time.time() -out = jsmfsb.simSample(k0, 10000, lvmod.m, 0, 20, step) -#out = jsmfsb.simSampleMap(k0, 10000, lvmod.m, 0, 20, step) +out = jsmfsb.sim_sample(k0, 10000, lvmod.m, 0, 20, step) +#out = jsmfsb.sim_sampleMap(k0, 10000, lvmod.m, 0, 20, step) endTime = time.time() ## End timer elapsedTime = endTime - startTime diff --git a/demos/time-lv-gillespie.py b/demos/time-lv-gillespie.py index f1952c3..3ad4cca 100755 --- a/demos/time-lv-gillespie.py +++ b/demos/time-lv-gillespie.py @@ -17,7 +17,7 @@ ## Start timer startTime = time.time() -out = jsmfsb.simSample(k0, 10000, lvmod.m, 0, 20, step, batch_size=100) +out = jsmfsb.sim_sample(k0, 10000, lvmod.m, 0, 20, step, batch_size=100) endTime = time.time() ## End timer elapsedTime = endTime - startTime diff --git a/src/jsmfsb/data.py b/src/jsmfsb/data.py index df3537d..81e0332 100644 --- a/src/jsmfsb/data.py +++ b/src/jsmfsb/data.py @@ -5,7 +5,7 @@ # time, prey, predator -LVperfect = jnp.array([ +lv_perfect = jnp.array([ [ 0, 50, 100], [ 2, 145, 93], [ 4, 265, 248], @@ -26,7 +26,7 @@ # time, prey, predator -LVnoise10 = jnp.array([ +lv_noise_10 = jnp.array([ [ 0, 34.19903, 98.11945], [ 2, 156.54757, 86.52563], [ 4, 267.77267, 260.94433], @@ -47,7 +47,7 @@ # time, prey -LVpreyNoise10 = LVnoise10[:, [0,1]] +lv_prey_noise_10 = lv_noise_10[:, [0,1]] # eof diff --git a/src/jsmfsb/inference.py b/src/jsmfsb/inference.py index 8ad2ed1..196d138 100644 --- a/src/jsmfsb/inference.py +++ b/src/jsmfsb/inference.py @@ -8,7 +8,7 @@ # MCMC functions -def metropolisHastings(key, init, logLik, rprop, +def metropolis_hastings(key, init, logLik, rprop, ldprop=lambda n, o: 1, ldprior=lambda x: 1, iters=10000, thin=10, verb=True): """Run a Metropolis-Hastings MCMC algorithm for the parameters of a @@ -83,7 +83,7 @@ def metropolisHastings(key, init, logLik, rprop, >>> data = jax.random.normal(k1, 250)*2 + 5 >>> llik = lambda k, x: jnp.sum(jsp.stats.norm.logpdf(data, x[0], x[1])) >>> prop = lambda k, x: jax.random.normal(k, 2)*0.1 + x - >>> jsmfsb.metropolisHastings(k2, jnp.array([1.0,1.0]), llik, prop) + >>> jsmfsb.metropolis_hastings(k2, jnp.array([1.0,1.0]), llik, prop) """ def step(s, k): [x, ll] = s @@ -107,7 +107,7 @@ def itera(s, k): return states[0] -def pfMLLik(n, simX0, t0, stepFun, dataLLik, data, debug=False): +def pf_marginal_ll(n, simX0, t0, stepFun, dataLLik, data, debug=False): """Create a function for computing the log of an unbiased estimate of marginal likelihood of a time course data set @@ -168,7 +168,7 @@ def pfMLLik(n, simX0, t0, stepFun, dataLLik, data, debug=False): >>> sf = jsmfsb.models.lv(th).step_gillespie() >>> return sf(key, x, t, dt) >>> - >>> mll = jsmfsb.pfMLLik(80, simX, 0, step, obsll, jsmfsb.data.LVnoise10) + >>> mll = jsmfsb.pf_marginal_ll(80, simX, 0, step, obsll, jsmfsb.data.lv_noise_10) >>> k0 = jax.random.key(42) >>> mll(k0, jnp.array([1, 0.005, 0.6])) >>> mll(k0, jnp.array([2, 0.005, 0.6])) @@ -216,7 +216,7 @@ def prop(k, x): # ABC functions -def abcRun(key, n, rprior, rdist, batch_size=None, verb=False): +def abc_run(key, n, rprior, rdist, batch_size=None, verb=False): """Run a set of simulations initialised with parameters sampled from a given prior distribution, and compute statistics required for an ABC analaysis @@ -278,7 +278,7 @@ def abcRun(key, n, rprior, rdist, batch_size=None, verb=False): >>> def rdis(k, th): >>> return dist(sumStats(rmod(k, th))) >>> - >>> smfsb.abcRun(k2, 100, rpr, rdis) + >>> smfsb.abc_run(k2, 100, rpr, rdis) """ @jit def pair(k): @@ -295,11 +295,11 @@ def pair(k): # ABC-SMC functions -def abcSmcStep(key, dprior, priorSample, priorLW, rdist, rperturb, +def abc_smc_step(key, dprior, priorSample, priorLW, rdist, rperturb, dperturb, factor): """Carry out one step of an ABC-SMC algorithm - Not meant to be directly called by users. See abcSmc. + Not meant to be directly called by users. See abc_smc. """ k1, k2, k3 = jax.random.split(key, 3) n = priorSample.shape[0] @@ -326,7 +326,7 @@ def logWeight(th): return new, nlw -def abcSmc(key, N, rprior, dprior, rdist, rperturb, dperturb, +def abc_smc(key, N, rprior, dprior, rdist, rperturb, dperturb, factor=10, steps=15, verb=False, debug=False): """Run an ABC-SMC algorithm for infering the parameters of a forward model @@ -409,7 +409,7 @@ def abcSmc(key, N, rprior, dprior, rdist, rperturb, dperturb, >>> def rdis(k, th): >>> return dist(sumStats(rmod(k, th))) >>> - >>> jsmfsb.abcSmc(k2, 100, rpr, + >>> jsmfsb.abc_smc(k2, 100, rpr, >>> lambda x: jnp.sum(jnp.log(((x<3)&(x>-3))/6)), >>> rdis, >>> lambda k,x: jax.random.normal(k)*0.1 + x, @@ -424,7 +424,7 @@ def abcSmc(key, N, rprior, dprior, rdist, rperturb, dperturb, key, k1 = jax.random.split(key) if (verb): print(steps-i, end=' ', flush=True) - priorSample, priorLW = abcSmcStep(k1, dprior, priorSample, priorLW, + priorSample, priorLW = abc_smc_step(k1, dprior, priorSample, priorLW, rdist, rperturb, dperturb, factor) if (debug): print(priorSample.shape) diff --git a/src/jsmfsb/models.py b/src/jsmfsb/models.py index a8e8e01..de2dd03 100644 --- a/src/jsmfsb/models.py +++ b/src/jsmfsb/models.py @@ -27,7 +27,7 @@ def bd(th=[1, 1.1]): >>> bd = jsmfsb.models.bd() >>> step = bd.step_gillespie() >>> k = jax.random.key(42) - >>> jsmfsb.simTs(k, bd.m, 0, 50, 0.1, step) + >>> jsmfsb.sim_time_series(k, bd.m, 0, 50, 0.1, step) """ return Spn(["X"], ["Birth","Death"], [[1],[1]], [[2],[0]], lambda x, t: jnp.array([th[0]*x[0], th[1]*x[0]]), @@ -55,7 +55,7 @@ def dimer(th=[0.00166, 0.2]): >>> dimer = jsmfsb.models.dimer() >>> step = dimer.step_gillespie() >>> k = jax.random.key(42) - >>> jsmfsb.simTs(k, dimer.m, 0, 50, 0.1, step) + >>> jsmfsb.sim_time_series(k, dimer.m, 0, 50, 0.1, step) """ return Spn(["P", "P2"], ["Dim", "Diss"], [[2,0],[0,1]], [[0,1],[2,0]], lambda x, t: jnp.array([th[0]*x[0]*(x[0]-1)/2, th[1]*x[1]]), @@ -84,7 +84,7 @@ def id(th=[1, 0.1]): >>> id = jsmfsb.models.id() >>> step = id.step_gillespie() >>> k = jax.random.key(42) - >>> jsmfsb.simTs(k, id.m, 0, 50, 0.1, step) + >>> jsmfsb.sim_time_series(k, id.m, 0, 50, 0.1, step) """ return Spn(["X"], ["Immigration", "Death"], [[0],[1]], [[1],[0]], lambda x, t: jnp.array([th[0], th[1]*x[0]]), @@ -114,7 +114,7 @@ def lv(th=[1, 0.005, 0.6]): >>> lv = jsmfsb.models.lv() >>> step = lv.step_gillespie() >>> k = jax.random.key(42) - >>> jsmfsb.simTs(k, lv.m, 0, 50, 0.1, step) + >>> jsmfsb.sim_time_series(k, lv.m, 0, 50, 0.1, step) """ return Spn(["Prey", "Predator"], ["Prey rep", "Inter", "Pred death"], [[1,0],[1,1],[0,1]], [[2,0],[0,2],[0,0]], @@ -144,7 +144,7 @@ def mm(th=[0.00166, 1e-4, 0.1]): >>> mm = jsmfsb.models.mm() >>> step = mm.step_gillespie() >>> k = jax.random.key(42) - >>> jsmfsb.simTs(k, mm.m, 0, 50, 0.1, step) + >>> jsmfsb.sim_time_series(k, mm.m, 0, 50, 0.1, step) """ return Spn(["S", "E", "SE", "P"], ["Bind", "Unbind", "Produce"], [[1,1,0,0],[0,0,1,0],[0,0,1,0]], @@ -175,7 +175,7 @@ def sir(th=[0.0015, 0.1]): >>> sir = jsmfsb.models.sir() >>> step = sir.step_gillespie() >>> k = jax.random.key(42) - >>> jsmfsb.simTs(k, sir.m, 0, 50, 0.1, step) + >>> jsmfsb.sim_time_series(k, sir.m, 0, 50, 0.1, step) """ return Spn(["S", "I", "R"], ["S->I", "I->R"], [[1,1,0],[0,1,0]], [[0,2,0],[0,0,1]], lambda x, t: jnp.array([th[0]*x[0]*x[1], th[1]*x[1]]), diff --git a/src/jsmfsb/sim.py b/src/jsmfsb/sim.py index c9869cb..5a1e903 100644 --- a/src/jsmfsb/sim.py +++ b/src/jsmfsb/sim.py @@ -8,7 +8,7 @@ from jax import jit import jax.lax as jl -def simTs(key, x0, t0, tt, dt, stepFun): +def sim_time_series(key, x0, t0, tt, dt, stepFun): """Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model @@ -45,7 +45,7 @@ def simTs(key, x0, t0, tt, dt, stepFun): >>> import jsmfsb.models >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_gillespie() - >>> jsmfsb.simTs(jax.random.key(42), lv.m, 0, 100, 0.1, stepLv) + >>> jsmfsb.sim_time_series(jax.random.key(42), lv.m, 0, 100, 0.1, stepLv) """ n = int((tt-t0) // dt) + 1 keys = jax.random.split(key, n) @@ -59,7 +59,7 @@ def advance(state, key): return mat -def simSample(key, n, x0, t0, deltat, stepFun, batch_size=None): +def sim_sample(key, n, x0, t0, deltat, stepFun, batch_size=None): """Simulate a many realisations of a model at a given fixed time in the future given an initial time and state, using a function (closure) for advancing the state of the model @@ -98,7 +98,7 @@ def simSample(key, n, x0, t0, deltat, stepFun, batch_size=None): >>> import jsmfsb.models >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_gillespie() - >>> jsmfsb.simSample(jax.random.key(42), 10, lv.m, 0, 30, stepLv) + >>> jsmfsb.sim_sample(jax.random.key(42), 10, lv.m, 0, 30, stepLv) """ u = len(x0) keys = jax.random.split(key, n) diff --git a/src/jsmfsb/smfsbSbml.py b/src/jsmfsb/smfsbSbml.py index 09cac40..25e3034 100644 --- a/src/jsmfsb/smfsbSbml.py +++ b/src/jsmfsb/smfsbSbml.py @@ -9,7 +9,7 @@ from sbmlsh import mod2sbml -def mod2Spn(filename, verb=False): +def mod_to_spn(filename, verb=False): """Convert an SBML-shorthand model into a Spn object Read a file containing a model in SBML-shorthand and convert into @@ -43,11 +43,11 @@ def mod2Spn(filename, verb=False): if (m == None): sys.stderr.write("Error: can't extract SBML model\n") sys.exit(1) - return(model2Spn(m, verb)) + return(model_to_spn(m, verb)) -def sh2Spn(shString, verb=False): +def shorthand_to_spn(shString, verb=False): """Convert an SBML-shorthand model string into a Spn object Parse a string containing a model in SBML-shorthand and convert into @@ -79,11 +79,11 @@ def sh2Spn(shString, verb=False): if (m == None): sys.stderr.write("Error: couldn't parse the shorthand string\n") sys.exit(1) - return(model2Spn(m, verb)) + return(model_to_spn(m, verb)) -def file2Spn(filename, verb=False): +def file_to_spn(filename, verb=False): """Convert an SBML model into a Spn object Read a file containing a model in SBML and convert into @@ -111,11 +111,11 @@ def file2Spn(filename, verb=False): if (m == None): sys.stderr.write("Can't parse SBML file: "+filename+"\n") sys.exit(1) - return(model2Spn(m, verb)) + return(model_to_spn(m, verb)) -def model2Spn(m, verb=False): +def model_to_spn(m, verb=False): """Convert a libSBML model into a Spn object Convert a libSBML model into a Spn object for simulation and analysis. @@ -137,7 +137,7 @@ def model2Spn(m, verb=False): >>> import libsbml >>> d = libsbml.readSBML("myModel.xml") >>> m = d.getModel() - >>> myMod = smfsb.model2Spn(m) + >>> myMod = smfsb.model_to_spn(m) >>> step = myMod.step_gillespie() """ # Species and initial amounts diff --git a/src/jsmfsb/spatial.py b/src/jsmfsb/spatial.py index 19e6f03..e15379d 100644 --- a/src/jsmfsb/spatial.py +++ b/src/jsmfsb/spatial.py @@ -6,14 +6,14 @@ import jax.numpy as jnp import jax.lax as jl -def simTs1D(key, x0, t0, tt, dt, stepFun, verb=False): +def sim_time_series_1d(key, x0, t0, tt, dt, stepFun, verb=False): """Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model This function simulates single realisation of a model on a 1D regular spatial grid and regular grid of times using a function (closure) for advancing the state of the model, such as created by - `step_gillespie1D`. + `step_gillespie_1d`. Parameters ---------- @@ -33,7 +33,7 @@ def simTs1D(key, x0, t0, tt, dt, stepFun, verb=False): accuracy of the simulation process. stepFun : function A function (closure) for advancing the state of the process, - such as produced by `step_gillespie1D`. + such as produced by `step_gillespie_1d`. verb : boolean Output progress to the console (this function can be very slow). @@ -48,13 +48,13 @@ def simTs1D(key, x0, t0, tt, dt, stepFun, verb=False): >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() - >>> stepLv1d = lv.step_gillespie1D(jnp.array([0.6,0.6])) + >>> stepLv1d = lv.step_gillespie_1d(jnp.array([0.6,0.6])) >>> N = 10 >>> T = 5 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) - >>> jsmfsb.simTs1D(k0, x0, 0, T, 1, stepLv1d, True) + >>> jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d, True) """ N = int((tt - t0)//dt + 1) u, n = x0.shape @@ -71,14 +71,14 @@ def advance(state, key): return jnp.moveaxis(arr, 0, 2) -def simTs2D(key, x0, t0, tt, dt, stepFun, verb=False): +def sim_time_series_2d(key, x0, t0, tt, dt, stepFun, verb=False): """Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model This function simulates single realisation of a model on a 2D regular spatial grid and regular grid of times using a function (closure) for advancing the state of the model, such as created by - `step_gillespie2D`. + `step_gillespie_2d`. Parameters ---------- @@ -98,7 +98,7 @@ def simTs2D(key, x0, t0, tt, dt, stepFun, verb=False): accuracy of the simulation process. stepFun : function A function (closure) for advancing the state of the process, - such as produced by `step_gillespie2D`. + such as produced by `step_gillespie_2d`. verb : boolean Output progress to the console (this function can be very slow). @@ -113,14 +113,14 @@ def simTs2D(key, x0, t0, tt, dt, stepFun, verb=False): >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() - >>> stepLv2d = lv.step_gillespie2D(jnp.array([0.6,0.6])) + >>> stepLv2d = lv.step_gillespie_2d(jnp.array([0.6,0.6])) >>> M = 10 >>> N = 15 >>> T = 5 >>> x0 = jnp.zeros((2,M,N)) >>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) - >>> jsmfsb.simTs2D(k0, x0, 0, T, 1, stepLv2d, True) + >>> jsmfsb.sim_time_series_2d(k0, x0, 0, T, 1, stepLv2d, True) """ N = int((tt - t0)//dt + 1) u, m, n = x0.shape diff --git a/src/jsmfsb/spn.py b/src/jsmfsb/spn.py index e6d0662..47ee9fa 100644 --- a/src/jsmfsb/spn.py +++ b/src/jsmfsb/spn.py @@ -51,7 +51,7 @@ def __init__(self, n, t, pre, post, h, m): lambda x, t: jnp.array([0.3*x[0]*x[1]/200, 0.1*x[1]]), [197, 3, 0]) >>> stepSir = sir.step_gillespie() - >>> jsmfsb.simSample(jax.random.key(42), 10, sir.m, 0, 20, stepSir) + >>> jsmfsb.sim_sample(jax.random.key(42), 10, sir.m, 0, 20, stepSir) """ self.n = n # species names self.t = t # reaction names @@ -77,7 +77,7 @@ def step_gillespie(self, minHaz=1e-10, maxHaz=1e07): This method returns a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as - `simTs`) for simulating realisations of SPN models. + `sim_time_series`) for simulating realisations of SPN models. Parameters ---------- @@ -137,7 +137,7 @@ def step_poisson(self, dt = 0.01): This method returns a function for advancing the state of an SPN model using a simple approximate Poisson time stepping method. The resulting function (closure) can be used in conjunction with other - functions (such as ‘simTs’) for simulating realisations of SPN + functions (such as ‘sim_time_series’) for simulating realisations of SPN models. Parameters @@ -196,7 +196,7 @@ def step_euler(self, dt = 0.01): This method returns a function for advancing the state of an SPN model using a simple continuous deterministic Euler integration method. The resulting function (closure) can be used in - conjunction with other functions (such as ‘simTs’) for simulating + conjunction with other functions (such as ‘sim_time_series’) for simulating realisations of SPN models. Parameters @@ -252,7 +252,7 @@ def step_cle(self, dt = 0.01): model using a simple Euler-Maruyama integration method method for the chemical Langevin equation form of the model.The resulting function (closure) can be used in - conjunction with other functions (such as `simTs`) for simulating + conjunction with other functions (such as `sim_time_series`) for simulating realisations of SPN models. Parameters @@ -304,14 +304,14 @@ def step(key, x0, t0, deltat): # spatial simulation functions, from chapter 9 - def step_gillespie1D(self, d, minHaz=1e-10, maxHaz=1e07): + def step_gillespie_1d(self, d, minHaz=1e-10, maxHaz=1e07): """Create a function for advancing the state of an SPN by using the Gillespie algorithm on a 1D regular grid This method creates a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as - `simTs1D`) for simulating realisations of SPN models in space and + `sim_time_series_1d`) for simulating realisations of SPN models in space and time. Parameters @@ -347,7 +347,7 @@ def step_gillespie1D(self, d, minHaz=1e-10, maxHaz=1e07): >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() - >>> stepLv1d = lv.step_gillespie1D(jnp.array([0.6, 0.6])) + >>> stepLv1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6])) >>> N = 20 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) @@ -402,14 +402,14 @@ def step(key, x0, t0, deltat): return step - def step_gillespie2D(self, d, minHaz=1e-10, maxHaz=1e07): + def step_gillespie_2d(self, d, minHaz=1e-10, maxHaz=1e07): """Create a function for advancing the state of an SPN by using the Gillespie algorithm on a 2D regular grid This method creates a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as - `simTs2D`) for simulating realisations of SPN models in space and + `sim_time_series_2d`) for simulating realisations of SPN models in space and time. Parameters @@ -445,7 +445,7 @@ def step_gillespie2D(self, d, minHaz=1e-10, maxHaz=1e07): >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() - >>> stepLv2d = lv.step_gillespie2D(jnp.array([0.6, 0.6])) + >>> stepLv2d = lv.step_gillespie_2d(jnp.array([0.6, 0.6])) >>> N = 20 >>> x0 = jnp.zeros((2, N, N)) >>> x0 = x0.at[:, int(N/2), int(N/2)].set(lv.m) @@ -509,14 +509,14 @@ def step(key, x0, t0, deltat): return step - def step_cle1D(self, d, dt = 0.01): + def step_cle_1d(self, d, dt = 0.01): """Create a function for advancing the state of an SPN by using a simple Euler-Maruyama discretisation of the CLE on a 1D regular grid This method creates a function for advancing the state of an SPN model using a simple Euler-Maruyama discretisation of the CLE on a 1D regular grid. The resulting function (closure) can be used in - conjunction with other functions (such as `simTs1D`) for + conjunction with other functions (such as `sim_time_series_1d`) for simulating realisations of SPN models in space and time. Parameters @@ -550,7 +550,7 @@ def step_cle1D(self, d, dt = 0.01): >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() - >>> stepLv1d = lv.step_cle1D(jnp.array([0.6,0.6])) + >>> stepLv1d = lv.step_cle_1d(jnp.array([0.6,0.6])) >>> N = 20 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) @@ -596,14 +596,14 @@ def advance(state, key): return step - def step_cle2D(self, d, dt = 0.01): + def step_cle_2d(self, d, dt = 0.01): """Create a function for advancing the state of an SPN by using a simple Euler-Maruyama discretisation of the CLE on a 2D regular grid This method creates a function for advancing the state of an SPN model using a simple Euler-Maruyama discretisation of the CLE on a 2D regular grid. The resulting function (closure) can be used in - conjunction with other functions (such as `simTs2D`) for + conjunction with other functions (such as `sim_time_series_2d`) for simulating realisations of SPN models in space and time. Parameters @@ -636,7 +636,7 @@ def step_cle2D(self, d, dt = 0.01): >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() - >>> stepLv2d = lv.step_cle2D(jnp.array([0.6,0.6])) + >>> stepLv2d = lv.step_cle_2d(jnp.array([0.6,0.6])) >>> M = 15 >>> N = 20 >>> x0 = jnp.zeros((2,M,N)) diff --git a/tests/test_inference.py b/tests/test_inference.py index 1b52d6e..4014b9b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -8,12 +8,12 @@ -def test_metropolisHastings(): +def test_metropolis_hastings(): key = jax.random.key(42) data = jax.random.normal(key, 250)*2 + 5 llik = lambda k,x: jnp.sum(jsp.stats.norm.logpdf(data, x[0], x[1])) prop = lambda k,x: jax.random.normal(k, 2)*0.1 + x - out = jsmfsb.metropolisHastings(key, jnp.array([1.0,1.0]), llik, prop, + out = jsmfsb.metropolis_hastings(key, jnp.array([1.0,1.0]), llik, prop, iters=1000, thin=2, verb=False) assert(out.shape == (1000, 2)) @@ -28,12 +28,12 @@ def simX(k, t0, th): def step(k, x, t, dt, th): sf = jsmfsb.models.lv(th).step_cle() return sf(k, x, t, dt) - mll = jsmfsb.pfMLLik(50, simX, 0, step, obsll, jsmfsb.data.LVnoise10) + mll = jsmfsb.pf_marginal_ll(50, simX, 0, step, obsll, jsmfsb.data.lv_noise_10) k = jax.random.key(42) assert (mll(k, jnp.array([1, 0.005, 0.6])) > mll(k, jnp.array([2, 0.005, 0.6]))) -def test_abcRun(): +def test_abc_run(): k0 = jax.random.key(42) k1, k2 = jax.random.split(k0) data = jax.random.normal(k1, 250)*2 + 5 @@ -49,7 +49,7 @@ def dist(ss): return jnp.sqrt(jnp.sum(diff*diff)) def rdis(k, th): return dist(sumStats(rmod(k, th))) - p, d = jsmfsb.abcRun(k2, 100, rpr, rdis) + p, d = jsmfsb.abc_run(k2, 100, rpr, rdis) assert(len(p) == 100) assert(len(d) == 100) @@ -73,7 +73,7 @@ def rdis(k, th): N = 100 keys = jax.random.split(k2, N) samples = jax.lax.map(rpr, keys) - th, lw = jsmfsb.abcSmcStep(k0, + th, lw = jsmfsb.abc_smc_step(k0, lambda x: jnp.log(jnp.sum(((x<3)&(x>-3))/6)), samples, jnp.zeros(N) + jnp.log(1/N), @@ -102,7 +102,7 @@ def dist(ss): def rdis(k,th): return dist(sumStats(rmod(k,th))) N = 100 - post = jsmfsb.abcSmc(k2, N, rpr, lambda x: jnp.sum(jnp.log(((x<3)&(x>-3))/6)), + post = jsmfsb.abc_smc(k2, N, rpr, lambda x: jnp.sum(jnp.log(((x<3)&(x>-3))/6)), rdis, lambda k,x: jax.random.normal(k)*0.1 + x, lambda x,y: jnp.sum(jsp.stats.norm.logpdf(y, x, 0.1))) assert(post.shape == (N, 2)) diff --git a/tests/test_sim.py b/tests/test_sim.py index d0e04ae..d21a8aa 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -8,22 +8,22 @@ def test_simts(): lv = jsmfsb.models.lv() step = lv.step_gillespie() k0 = jax.random.key(42) - out = jsmfsb.simTs(k0, lv.m, 0, 10, 0.1, step) + out = jsmfsb.sim_time_series(k0, lv.m, 0, 10, 0.1, step) assert(out.shape == (100, 2)) def test_simsample(): lv = jsmfsb.models.lv() step = lv.step_gillespie() k0 = jax.random.key(42) - out = jsmfsb.simSample(k0, 20, lv.m, 0, 10, step) + out = jsmfsb.sim_sample(k0, 20, lv.m, 0, 10, step) assert(out.shape == (20, 2)) def test_simsamples(): lv = jsmfsb.models.lv() step = lv.step_gillespie() k0 = jax.random.key(42) - out = jsmfsb.simSample(k0, 20, lv.m, 0, 10, step) - outB = jsmfsb.simSample(k0, 20, lv.m, 0, 10, step, batch_size=5) + out = jsmfsb.sim_sample(k0, 20, lv.m, 0, 10, step) + outB = jsmfsb.sim_sample(k0, 20, lv.m, 0, 10, step, batch_size=5) assert(jnp.all(out == outB)) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index 1049950..cef9cf6 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -7,66 +7,66 @@ import matplotlib.pyplot as plt import jsmfsb.models -def test_step_gillespie1D(): +def test_step_gillespie_1d(): N=20 x0 = jnp.zeros((2,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(N/2)].set(lv.m) - stepLv1d = lv.step_gillespie1D(jnp.array([0.6, 0.6])) + stepLv1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv1d(k0, x0, 0, 1) assert(x1.shape == (2,N)) -def test_simTs1D(): +def test_sim_time_series_1d(): N=8 T=6 x0 = jnp.zeros((2,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(N/2)].set(lv.m) - stepLv1d = lv.step_gillespie1D(jnp.array([0.6, 0.6])) + stepLv1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) - out = jsmfsb.simTs1D(k0, x0, 0, T, 1, stepLv1d) + out = jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d) assert(out.shape == (2, N, T+1)) -def test_step_gillespie2D(): +def test_step_gillespie_2d(): M=16 N=20 x0 = jnp.zeros((2,M,N)) lv = jsmfsb.models.lv() x0 = x0.at[:, int(M/2), int(N/2)].set(lv.m) - stepLv2d = lv.step_gillespie2D(jnp.array([0.6, 0.6])) + stepLv2d = lv.step_gillespie_2d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv2d(k0, x0, 0, 1) assert(x1.shape == (2, M, N)) -def test_simTs2D(): +def test_sim_time_series_2d(): M=16 N=20 x0 = jnp.zeros((2,M,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) - stepLv2d = lv.step_gillespie2D(jnp.array([0.6, 0.6])) + stepLv2d = lv.step_gillespie_2d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) - out = jsmfsb.simTs2D(k0, x0, 0, 5, 1, stepLv2d) + out = jsmfsb.sim_time_series_2d(k0, x0, 0, 5, 1, stepLv2d) assert(out.shape == (2, M, N, 6)) -def test_step_cle1D(): +def test_step_cle_1d(): N=20 x0 = jnp.zeros((2,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(N/2)].set(lv.m) - stepLv1d = lv.step_cle1D(jnp.array([0.6, 0.6])) + stepLv1d = lv.step_cle_1d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv1d(k0, x0, 0, 1) assert(x1.shape == (2, N)) -def test_step_cle2D(): +def test_step_cle_2d(): M=16 N=20 x0 = jnp.zeros((2,M,N)) lv = jsmfsb.models.lv() x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) - stepLv2d = lv.step_cle2D(jnp.array([0.6, 0.6])) + stepLv2d = lv.step_cle_2d(jnp.array([0.6, 0.6])) k0 = jax.random.key(42) x1 = stepLv2d(k0, x0, 0, 1) assert(x1.shape == (2, M, N))