Skip to content

Commit

Permalink
Update to more accurate way of calculating ancestors
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll authored and twiecki committed Oct 11, 2018
1 parent fd4c71d commit 4e33b32
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from theano.gof.graph import inputs
import itertools

from theano.gof.graph import ancestors

from .util import get_default_varnames
import pymc3 as pm


def powerset(iterable):
"""All *nonempty* subsets of an iterable.
From itertools docs.
powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
"""
s = list(iterable)
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(1, len(s)+1))


class ModelGraph(object):
def __init__(self, model):
self.model = model
Expand All @@ -21,30 +34,29 @@ def get_deterministics(self, var):
deterministics.append(v)
return deterministics

def _inputs(self, var, func, blockers=None):
"""Get inputs to a function that are also named PyMC3 variables"""
return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var])
def _ancestors(self, var, func, blockers=None):
"""Get ancestors of a function that are also named PyMC3 variables"""
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])

def _get_inputs(self, var, func):
"""Get all inputs to a function, doing some accounting for deterministics
def _get_ancestors(self, var, func):
"""Get all ancestors of a function, doing some accounting for deterministics
Specifically, if a deterministic is an input, theano.gof.graph.inputs will
Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
return only the inputs *to the deterministic*. However, if we pass in the
deterministic as a blocker, it will skip those nodes.
"""
deterministics = self.get_deterministics(var)
upstream = self._inputs(var, func)
parents = self._inputs(var, func, blockers=deterministics)
if parents != upstream:
det_map = {}
for d in deterministics:
d_set = {j for j in inputs([func], blockers=[d])}
if upstream - d_set:
det_map[d] = d_set
for d, d_set in det_map.items():
if all(d_set.issubset(other) for other in det_map.values()):
parents.add(d)
return parents
upstream = self._ancestors(var, func)

# Usual case
if upstream == self._ancestors(var, func, blockers=upstream):
return upstream
else: # deterministic accounting
for d in powerset(upstream):
blocked = self._ancestors(var, func, blockers=d)
if set(d) == blocked:
return d
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')

def _filter_parents(self, var, parents):
"""Get direct parents of a var, as strings"""
Expand All @@ -70,7 +82,7 @@ def get_parents(self, var):
else:
func = var

parents = self._get_inputs(var, func)
parents = self._get_ancestors(var, func)
return self._filter_parents(var, parents)

def make_compute_graph(self):
Expand Down

0 comments on commit 4e33b32

Please sign in to comment.