Skip to content

Commit

Permalink
Speedup model compilation in slow sampling tests
Browse files Browse the repository at this point in the history
Add a specific FAST_COMPILE mode that skips canonicalization and specialization, while keeping rewrites that are required from aeppl and pymc for proper sampling. This mode is used in tests that take a long time to compile and for which numerical accuracy is not important (e.g., because we care only about the shape of the draws or deterministics of observed values)
  • Loading branch information
ricardoV94 authored and twiecki committed Jan 20, 2022
1 parent 333f7f3 commit da7c5df
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 34 deletions.
18 changes: 17 additions & 1 deletion pymc/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import numpy.random as nr

from aesara.gradient import verify_grad as at_verify_grad
from aesara.graph.opt import in2out
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream

from pymc.aesaraf import at_rng, set_at_rng
from pymc.aesaraf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng


class SeededTest:
Expand Down Expand Up @@ -132,3 +133,18 @@ def assert_random_state_equal(state1, state2):
np.testing.assert_array_equal(field1, field2)
else:
assert field1 == field2


# This mode can be used for tests where model compilations takes the bulk of the runtime
# AND where we don't care about posterior numerical or sampling stability (e.g., when
# all that matters are the shape of the draws or deterministic values of observed data).
# DO NOT USE UNLESS YOU HAVE A GOOD REASON TO!
fast_unstable_sampling_mode = (
aesara.compile.mode.FAST_COMPILE
# Remove slow rewrite phases
.excluding("canonicalize", "specialize")
# Include necessary rewrites for proper logp handling
.including("remove_TransformedVariables").register(
(in2out(local_check_parameter_to_ninf_switch), -1)
)
)
38 changes: 22 additions & 16 deletions pymc/tests/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pymc.ode import DifferentialEquation
from pymc.ode.utils import augment_system
from pymc.tests.helpers import fast_unstable_sampling_mode

IS_FLOAT32 = aesara.config.floatX == "float32"
IS_WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -291,11 +292,13 @@ def system(y, t, p):
sigma = pm.HalfCauchy("sigma", 1)
forward = ode_model(theta=[alpha], y0=[y0])
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)
idata = pm.sample(100, tune=0, chains=1)

assert idata.posterior["alpha"].shape == (1, 100)
assert idata.posterior["y0"].shape == (1, 100)
assert idata.posterior["sigma"].shape == (1, 100)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
idata = pm.sample(50, tune=0, chains=1)

assert idata.posterior["alpha"].shape == (1, 50)
assert idata.posterior["y0"].shape == (1, 50)
assert idata.posterior["sigma"].shape == (1, 50)

def test_scalar_ode_2_param(self):
"""Test running model for a scalar ODE with 2 parameters"""
Expand All @@ -321,12 +324,13 @@ def system(y, t, p):
forward = ode_model(theta=[alpha, beta], y0=[y0])
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)

idata = pm.sample(100, tune=0, chains=1)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
idata = pm.sample(50, tune=0, chains=1)

assert idata.posterior["alpha"].shape == (1, 100)
assert idata.posterior["beta"].shape == (1, 100)
assert idata.posterior["y0"].shape == (1, 100)
assert idata.posterior["sigma"].shape == (1, 100)
assert idata.posterior["alpha"].shape == (1, 50)
assert idata.posterior["beta"].shape == (1, 50)
assert idata.posterior["y0"].shape == (1, 50)
assert idata.posterior["sigma"].shape == (1, 50)

def test_vector_ode_1_param(self):
"""Test running model for a vector ODE with 1 parameter"""
Expand Down Expand Up @@ -362,10 +366,11 @@ def system(y, t, p):
forward = ode_model(theta=[R], y0=[0.99, 0.01])
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)

idata = pm.sample(100, tune=0, chains=1)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
idata = pm.sample(50, tune=0, chains=1)

assert idata.posterior["R"].shape == (1, 100)
assert idata.posterior["sigma"].shape == (1, 100, 2)
assert idata.posterior["R"].shape == (1, 50)
assert idata.posterior["sigma"].shape == (1, 50, 2)

def test_vector_ode_2_param(self):
"""Test running model for a vector ODE with 2 parameters"""
Expand Down Expand Up @@ -402,8 +407,9 @@ def system(y, t, p):
forward = ode_model(theta=[beta, gamma], y0=[0.99, 0.01])
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)

idata = pm.sample(100, tune=0, chains=1)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
idata = pm.sample(50, tune=0, chains=1)

assert idata.posterior["beta"].shape == (1, 100)
assert idata.posterior["gamma"].shape == (1, 100)
assert idata.posterior["sigma"].shape == (1, 100, 2)
assert idata.posterior["beta"].shape == (1, 50)
assert idata.posterior["gamma"].shape == (1, 50)
assert idata.posterior["sigma"].shape == (1, 50, 2)
49 changes: 32 additions & 17 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pymc.backends.base import MultiTrace
from pymc.backends.ndarray import NDArray
from pymc.exceptions import IncorrectArgumentsError, SamplingError
from pymc.tests.helpers import SeededTest
from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode
from pymc.tests.models import simple_init


Expand Down Expand Up @@ -665,7 +665,8 @@ def test_model_not_drawable_prior(self):
with model:
mu = pm.HalfFlat("sigma")
pm.Poisson("foo", mu=mu, observed=data)
idata = pm.sample(tune=1000)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
idata = pm.sample(tune=10, draws=40, chains=1)

with model:
with pytest.raises(NotImplementedError) as excinfo:
Expand Down Expand Up @@ -718,12 +719,15 @@ def test_deterministic_of_observed(self):
out_diff = in_1 + in_2
pm.Deterministic("out", out_diff)

trace = pm.sample(
100,
chains=nchains,
return_inferencedata=False,
compute_convergence_checks=False,
)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
trace = pm.sample(
tune=100,
draws=100,
chains=nchains,
step=pm.Metropolis(),
return_inferencedata=False,
compute_convergence_checks=False,
)

rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4

Expand Down Expand Up @@ -754,11 +758,14 @@ def test_deterministic_of_observed_modified_interface(self):
out_diff = in_1 + in_2
pm.Deterministic("out", out_diff)

trace = pm.sample(
100,
return_inferencedata=False,
compute_convergence_checks=False,
)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
trace = pm.sample(
tune=100,
draws=100,
step=pm.Metropolis(),
return_inferencedata=False,
compute_convergence_checks=False,
)
varnames = [v for v in trace.varnames if v != "out"]
ppc_trace = [
dict(zip(varnames, row)) for row in zip(*(trace.get_values(v) for v in varnames))
Expand All @@ -779,7 +786,10 @@ def test_variable_type(self):
mu = pm.HalfNormal("mu", 1)
a = pm.Normal("a", mu=mu, sigma=2, observed=np.array([1, 2]))
b = pm.Poisson("b", mu, observed=np.array([1, 2]))
trace = pm.sample(compute_convergence_checks=False, return_inferencedata=False)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
trace = pm.sample(
tune=10, draws=10, compute_convergence_checks=False, return_inferencedata=False
)

with model:
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False, samples=1)
Expand Down Expand Up @@ -998,9 +1008,14 @@ def test_multivariate2(self):
with pm.Model() as dm_model:
probs = pm.Dirichlet("probs", a=np.ones(6))
obs = pm.Multinomial("obs", n=100, p=probs, observed=mn_data)
burned_trace = pm.sample(
20, tune=10, cores=1, return_inferencedata=False, compute_convergence_checks=False
)
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
burned_trace = pm.sample(
tune=10,
draws=20,
chains=1,
return_inferencedata=False,
compute_convergence_checks=False,
)
sim_priors = pm.sample_prior_predictive(
return_inferencedata=False, samples=20, model=dm_model
)
Expand Down

0 comments on commit da7c5df

Please sign in to comment.