Skip to content

Commit

Permalink
allow external step method
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Mar 9, 2022
1 parent 1f37454 commit 5f44242
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 45 deletions.
50 changes: 6 additions & 44 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,7 @@
)
from pymc.model import Model, modelcontext
from pymc.parallel_sampling import Draw, _cpu_count
from pymc.step_methods import (
NUTS,
BinaryGibbsMetropolis,
BinaryMetropolis,
CategoricalGibbsMetropolis,
CompoundStep,
DEMetropolis,
HamiltonianMC,
Metropolis,
Slice,
)
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
Expand All @@ -98,15 +88,6 @@
"draw",
]

STEP_METHODS = (
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
)
Step: TypeAlias = Union[BlockedStep, CompoundStep]

ArrayLike: TypeAlias = Union[np.ndarray, List[float]]
Expand Down Expand Up @@ -164,7 +145,7 @@ def instantiate_steppers(
return steps


def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None):
def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
"""Assign model variables to appropriate step methods.
Passing a specified model will auto-assign its constituent stochastic
Expand Down Expand Up @@ -197,6 +178,9 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
steps = []
assigned_vars = set()

if methods is None:
methods = pm.STEP_METHODS

if step is not None:
try:
steps += list(step)
Expand Down Expand Up @@ -481,29 +465,7 @@ def sample(
draws += tune

initial_points = None
if step is None and init is not None and all_continuous(model.value_vars):
try:
# By default, try to use NUTS
_log.info("Auto-assigning NUTS sampler...")
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
seeds=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
# gradient computation failed
_log.info("Initializing NUTS failed. Falling back to elementwise auto-assignment.")
_log.debug("Exception in init nuts", exc_info=True)
step = assign_step_methods(model, step, step_kwargs=kwargs)
else:
step = assign_step_methods(model, step, step_kwargs=kwargs)
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)

if isinstance(step, list):
step = CompoundStep(step)
Expand Down
10 changes: 10 additions & 0 deletions pymc/step_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@
RecursiveDAProposal,
)
from pymc.step_methods.slicer import Slice

STEP_METHODS = (
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
)
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def competence(var, has_grad):

dist = getattr(var.owner, "op", None)
if var.dtype in continuous_types and has_grad:
return Competence.IDEAL
return Competence.PREFERRED
return Competence.INCOMPATIBLE

def warnings(self):
Expand Down
22 changes: 22 additions & 0 deletions pymc/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from aesara.graph.op import Op
from numpy.testing import assert_array_almost_equal

import pymc as pm

from pymc.aesaraf import floatX
from pymc.data import Data
from pymc.distributions import (
Expand Down Expand Up @@ -741,6 +743,26 @@ def kill_grad(x):
steps = assign_step_methods(model, [])
assert isinstance(steps, Slice)

def test_modify_step_methods(self):
"""Test step methods can be changed"""
# remove nuts from step_methods
step_methods = list(pm.STEP_METHODS)
step_methods.remove(NUTS)
pm.STEP_METHODS = step_methods

with Model() as model:
Normal("x", 0, 1)
steps = assign_step_methods(model, [])
assert not isinstance(steps, NUTS)

# add back nuts
pm.STEP_METHODS = step_methods + [NUTS]

with Model() as model:
Normal("x", 0, 1)
steps = assign_step_methods(model, [])
assert isinstance(steps, NUTS)


class TestPopulationSamplers:

Expand Down

0 comments on commit 5f44242

Please sign in to comment.