Skip to content

Commit

Permalink
make methods argument optional
Browse files Browse the repository at this point in the history
remove unused imports

add test modify step methods

allow external step method

resolve conflict

allow external step method
  • Loading branch information
aloctavodia committed Mar 9, 2022
1 parent 9af48dd commit 4635bf3
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 45 deletions.
50 changes: 6 additions & 44 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,7 @@
)
from pymc.model import Model, modelcontext
from pymc.parallel_sampling import Draw, _cpu_count
from pymc.step_methods import (
NUTS,
BinaryGibbsMetropolis,
BinaryMetropolis,
CategoricalGibbsMetropolis,
CompoundStep,
DEMetropolis,
HamiltonianMC,
Metropolis,
Slice,
)
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
Expand All @@ -98,15 +88,6 @@
"draw",
]

STEP_METHODS = (
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
)
Step: TypeAlias = Union[BlockedStep, CompoundStep]

ArrayLike: TypeAlias = Union[np.ndarray, List[float]]
Expand Down Expand Up @@ -164,7 +145,7 @@ def instantiate_steppers(
return steps


def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None):
def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
"""Assign model variables to appropriate step methods.
Passing a specified model will auto-assign its constituent stochastic
Expand Down Expand Up @@ -197,6 +178,9 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
steps = []
assigned_vars = set()

if methods is None:
methods = pm.STEP_METHODS

if step is not None:
try:
steps += list(step)
Expand Down Expand Up @@ -481,29 +465,7 @@ def sample(
draws += tune

initial_points = None
if step is None and init is not None and all_continuous(model.value_vars):
try:
# By default, try to use NUTS
_log.info("Auto-assigning NUTS sampler...")
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
seeds=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
# gradient computation failed
_log.info("Initializing NUTS failed. Falling back to elementwise auto-assignment.")
_log.debug("Exception in init nuts", exc_info=True)
step = assign_step_methods(model, step, step_kwargs=kwargs)
else:
step = assign_step_methods(model, step, step_kwargs=kwargs)
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)

if isinstance(step, list):
step = CompoundStep(step)
Expand Down
10 changes: 10 additions & 0 deletions pymc/step_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@
RecursiveDAProposal,
)
from pymc.step_methods.slicer import Slice

STEP_METHODS = (
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
)
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def competence(var, has_grad):

dist = getattr(var.owner, "op", None)
if var.dtype in continuous_types and has_grad:
return Competence.IDEAL
return Competence.PREFERRED
return Competence.INCOMPATIBLE

def warnings(self):
Expand Down
22 changes: 22 additions & 0 deletions pymc/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from aesara.graph.op import Op
from numpy.testing import assert_array_almost_equal

import pymc as pm

from pymc.aesaraf import floatX
from pymc.data import Data
from pymc.distributions import (
Expand Down Expand Up @@ -741,6 +743,26 @@ def kill_grad(x):
steps = assign_step_methods(model, [])
assert isinstance(steps, Slice)

def test_modify_step_methods(self):
"""Test step methods can be changed"""
# remove nuts from step_methods
step_methods = list(pm.STEP_METHODS)
step_methods.remove(NUTS)
pm.STEP_METHODS = step_methods

with Model() as model:
Normal("x", 0, 1)
steps = assign_step_methods(model, [])
assert not isinstance(steps, NUTS)

# add back nuts
pm.STEP_METHODS = step_methods + [NUTS]

with Model() as model:
Normal("x", 0, 1)
steps = assign_step_methods(model, [])
assert isinstance(steps, NUTS)


class TestPopulationSamplers:

Expand Down

0 comments on commit 4635bf3

Please sign in to comment.