diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 5806a42ee5..5988beb564 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -48,11 +48,12 @@ def instantiate_steppers(model, steps, selected_steps, step_kwargs=None): ---------- model : Model object A fully-specified model object - step : step function or vector of step functions + step : step function or iterable of step functions One or more step functions that have been assigned to some subset of the model's parameters. Defaults to None (no assigned variables). selected_steps: dictionary of step methods and variables - The step methods and the variables that have were assigned to them. + Variables with selected step methods. Keys are the step methods, and + values are the variables that have been assigned to them. step_kwargs : dict Parameters for the samplers. Keys are the lower case names of the step method, values a dict of arguments. @@ -65,18 +66,18 @@ def instantiate_steppers(model, steps, selected_steps, step_kwargs=None): if step_kwargs is None: step_kwargs = {} - used_keys = set() - for step_class, vars in selected_steps.items(): - if len(vars) == 0: + used_args = set() + for step, var_list in selected_steps.items(): + if len(var_list) == 0: continue - args = step_kwargs.get(step_class.name, {}) - used_keys.add(step_class.name) - step = step_class(vars=vars, **args) + args = step_kwargs.get(step.name, {}) + used_args.add(step.name) steps.append(step) - unused_args = set(step_kwargs).difference(used_keys) + unused_args = set(step_kwargs).difference(used_args) if unused_args: - raise ValueError('Unused step method arguments: %s' % unused_args) + raise ValueError('Unused arguments for step method(s): %s' + % [s for s in unused_args]) if len(steps) == 1: steps = steps[0] @@ -100,7 +101,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, ---------- model : Model object A fully-specified model object - step : step function or vector of step functions + step : step function or iterable of step functions One or more step functions that have been assigned to some subset of the model's parameters. Defaults to None (no assigned variables). methods : vector of step method classes @@ -116,26 +117,29 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, List of step methods associated with the model's variables. """ steps = [] - assigned_vars = set() + selected_steps = defaultdict(list) if step is not None: - try: + try: # If `step` is a list, concatenate steps += list(step) - except TypeError: + except TypeError: # If `step` is not iterable, append steps.append(step) + for step in steps: try: - assigned_vars = assigned_vars.union(set(step.vars)) + selected_steps[step] += step.vars except AttributeError: for method in step.methods: - assigned_vars = assigned_vars.union(set(method.vars)) + selected_steps[step] += method.vars # Use competence classmethods to select step methods for remaining # variables - selected_steps = defaultdict(list) for var in model.free_RVs: + # Flatten assigned variables into a set + assigned_vars = set(var for lst in selected_steps.values() + for var in lst) if var not in assigned_vars: - # determine if a gradient can be computed + # Determine if a gradient can be computed has_gradient = var.dtype not in discrete_types if has_gradient: try: @@ -144,7 +148,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, NotImplementedError, tg.NullTypeGradError): has_gradient = False - # select the best method + # Select the method with maximum competence selected = max(methods, key=lambda method, var=var, has_gradient=has_gradient: method._competence(var, has_gradient)) diff --git a/pymc3/step_methods/compound.py b/pymc3/step_methods/compound.py index 8deb0555fd..ac2ca9a777 100644 --- a/pymc3/step_methods/compound.py +++ b/pymc3/step_methods/compound.py @@ -9,6 +9,7 @@ class CompoundStep(object): """Step method composed of a list of several other step methods applied in sequence.""" + name = 'compound' def __init__(self, methods): self.methods = list(methods) diff --git a/pymc3/step_methods/elliptical_slice.py b/pymc3/step_methods/elliptical_slice.py index 5936c554c0..ed58035df8 100644 --- a/pymc3/step_methods/elliptical_slice.py +++ b/pymc3/step_methods/elliptical_slice.py @@ -66,7 +66,7 @@ class EllipticalSlice(ArrayStep): Artificial Intelligence and Statistics (AISTATS), JMLR W&CP 9:541-548, 2010. """ - + name = 'elliptical_slice' default_blocked = True def __init__(self, vars=None, prior_cov=None, prior_chol=None, model=None,