diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index f181554387..e68a28cb17 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -1,9 +1,22 @@ -from theano.gof.graph import inputs +import itertools + +from theano.gof.graph import ancestors from .util import get_default_varnames import pymc3 as pm +def powerset(iterable): + """All *nonempty* subsets of an iterable. + + From itertools docs. + + powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) + """ + s = list(iterable) + return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(1, len(s)+1)) + + class ModelGraph(object): def __init__(self, model): self.model = model @@ -21,30 +34,29 @@ def get_deterministics(self, var): deterministics.append(v) return deterministics - def _inputs(self, var, func, blockers=None): - """Get inputs to a function that are also named PyMC3 variables""" - return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var]) + def _ancestors(self, var, func, blockers=None): + """Get ancestors of a function that are also named PyMC3 variables""" + return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var]) - def _get_inputs(self, var, func): - """Get all inputs to a function, doing some accounting for deterministics + def _get_ancestors(self, var, func): + """Get all ancestors of a function, doing some accounting for deterministics - Specifically, if a deterministic is an input, theano.gof.graph.inputs will + Specifically, if a deterministic is an input, theano.gof.graph.ancestors will return only the inputs *to the deterministic*. However, if we pass in the deterministic as a blocker, it will skip those nodes. """ deterministics = self.get_deterministics(var) - upstream = self._inputs(var, func) - parents = self._inputs(var, func, blockers=deterministics) - if parents != upstream: - det_map = {} - for d in deterministics: - d_set = {j for j in inputs([func], blockers=[d])} - if upstream - d_set: - det_map[d] = d_set - for d, d_set in det_map.items(): - if all(d_set.issubset(other) for other in det_map.values()): - parents.add(d) - return parents + upstream = self._ancestors(var, func) + + # Usual case + if upstream == self._ancestors(var, func, blockers=upstream): + return upstream + else: # deterministic accounting + for d in powerset(upstream): + blocked = self._ancestors(var, func, blockers=d) + if set(d) == blocked: + return d + raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.') def _filter_parents(self, var, parents): """Get direct parents of a var, as strings""" @@ -70,7 +82,7 @@ def get_parents(self, var): else: func = var - parents = self._get_inputs(var, func) + parents = self._get_ancestors(var, func) return self._filter_parents(var, parents) def make_compute_graph(self):