Skip to content

Commit

Permalink
Implemented get_first_level_conditionals to try to get rid of the a…
Browse files Browse the repository at this point in the history
…dded `conditional_on` attribute of every distribution. This function does a breadth first search on the node's `logpt` or `transformed.logpt` graph, looking for named nodes which are different from the root node, or the node's transformed, and is also not a `TensorConstant` or `SharedVariable`. Each branch was searched until the first named node was found. This way, the parent conditionals of the `root` searched node, which were only one step away from it in the bayesian network were returned. However, this ran into a problem with `Mixture` classes. These add to the `logpt` graph, another `logpt` graph from the `comp_dists`. This leads to the problem that the `logpt`'s first level conditionals will also be seen as if they were first level conditional of the `root`. Furthermore, many copies of nodes done by the added `logpt` ended up being inserted into the computed `conditional_on`. This lead to a very strange error, in which loops appeared in the DAG, and depths started to be wrong. In particular, there were no depth 0 nodes. My view is that the explicit `conditional_on` attribute prevents problems like this one from happening, and so I left it as is, to discuss. Other changes done in this commit are that `test_exact_step` for the SMC uses `draw_values` on a hierarchy, and given that `draw_values`'s behavior changed in the hierarchy situations, the exact trace values must also be adjusted. Finally `test_bad_init` was changed to run on one core, this way the parallel exception chaining does not change the exception type.
  • Loading branch information
lucianopaz committed Sep 27, 2018
1 parent d43d149 commit 339828d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 52 deletions.
47 changes: 47 additions & 0 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,52 @@ def not_shared_or_constant_variable(x):
) or (isinstance(x, (FreeRV, MultiObservedRV, TransformedRV)))


def get_first_level_conditionals(root):
"""Performs a breadth first search on the supplied root node's logpt or
transformed logpt graph searching for named input nodes, which are
different from the supplied root. Each explored branch will stop when
either when it ends or when it finds its first named node.
Parameters
----------
root: theano.Variable (mandatory)
The node from which to get the transformed.logpt or logpt and perform
the search. If root does not have either of these attributes, the
function returns None.
Returns
-------
conditional_on : set, with named nodes that are not theano.Constant nor
SharedVariable. The input `root` is conditionally dependent on these nodes
and is one step away from them in the bayesian network that specifies the
relationships, hence the name `get_first_level_conditionals`.
"""
transformed = getattr(root, 'transformed', None)
try:
cond = transformed.logpt
except AttributeError:
cond = getattr(root, 'logpt', None)
if cond is None:
return None
conditional_on = set()
queue = copy(getattr(cond.owner, 'inputs', []))
while queue:
parent = queue.pop(0)
if (parent is not None and getattr(parent, 'name', None) is not None
and not_shared_or_constant_variable(parent)):
# We don't include as a conditional relation either logpt depending
# on root or on transformed because they are both deterministic
# relations
if parent == root and parent == transformed:
conditional_on.add(parent)
else:
parent_owner = getattr(parent, 'owner', None)
queue.extend(getattr(parent_owner, 'inputs', []))
if not conditional_on:
return None
return conditional_on


class DependenceDAG(object):
"""
`DependenceDAG` instances represent the directed acyclic graph (DAG) that
Expand Down Expand Up @@ -1866,6 +1912,7 @@ def add(self, node, force=False, return_added_node=False,
self.depth[node] = 0

# Try to get the conditional parents of node and add them
# cond = get_first_level_conditionals(node)
try:
cond = node.distribution.conditional_on
except AttributeError:
Expand Down
93 changes: 41 additions & 52 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,56 +105,46 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
1.58740483, 1.67905741, 0.77744868, 0.15050587, 0.15050587,
0.73979127, 0.15445515, 0.13134717, 0.85068974, 0.85068974,
0.6974799 , 0.16170472, 0.86405959, 0.86405959, -0.22032854]),
SMC: np.array([ 5.10950205e-02, 1.09811720e+00, 1.78330202e-01, 6.85938766e-01,
1.42354476e-01, -1.59630758e+00, 1.57176810e+00, -4.01398917e-01,
1.14567871e+00, 1.14954938e+00, 4.94399840e-01, 1.16253017e+00,
1.17432244e+00, 7.79195162e-01, 1.29017945e+00, 2.53722905e-01,
5.38589898e-01, 3.52121216e-01, 1.35795966e+00, 1.02086933e-01,
1.58845251e+00, 6.76852927e-01, -1.04716592e-02, -1.01613324e-01,
1.37680965e+00, 7.40036542e-01, 2.89069320e-01, 1.48153741e+00,
9.58156958e-01, 5.73623782e-02, 7.68850721e-01, 3.68643390e-01,
1.47645964e+00, 2.32596780e-01, -1.85008158e-01, 3.71335958e-01,
2.68600102e+00, -4.89504443e-01, 6.54265561e-02, 3.80455349e-01,
1.17875338e+00, 2.30233324e-01, 6.90960231e-01, 8.81668685e-01,
-2.19754340e-01, 1.27686862e-01, 3.28444250e-01, 1.34820635e-01,
5.29725257e-01, 1.43783915e+00, -1.64754264e-01, 7.41446719e-01,
-1.17733186e+00, 6.01215658e-02, 1.82638158e-01, -2.23232214e-02,
-1.79877583e-02, 8.37949150e-01, 4.41964955e-01, -8.66524743e-01,
4.90738093e-01, 2.42056488e-01, 4.67699626e-01, 2.91075351e-01,
1.49541153e+00, 8.30730845e-01, 1.03956404e+00, -5.16162910e-01,
2.84338859e-01, 1.72305888e+00, 9.52445566e-01, 1.48831718e+00,
8.03455325e-01, 1.48840970e+00, 6.98122664e-01, 3.30187139e-01,
7.88029712e-01, 9.31510828e-01, 1.01326878e+00, 2.26637755e-01,
1.70703646e-01, -8.54429841e-01, 2.97254590e-01, -2.77843274e-01,
-2.25544207e-01, 1.98862826e-02, 5.05953885e-01, 4.98203941e-01,
1.20897382e+00, -6.32958669e-05, -7.22425896e-01, 1.60930869e+00,
-5.02773645e-01, 2.46405678e+00, 9.16039706e-01, 1.14146060e+00,
-1.95781984e-01, -2.44653942e-01, 2.67851290e-01, 2.37462012e-01,
6.71471950e-01, 1.18319765e+00, 1.29146530e+00, -3.14177753e-01,
-1.31041215e-02, 1.05029405e+00, 1.31202399e+00, 7.40532839e-02,
9.15510041e-01, 7.71054604e-01, 9.83483263e-01, 9.03032142e-01,
9.14191160e-01, 9.32285366e-01, 1.13937607e+00, -4.29155928e-01,
3.44609229e-02, -5.46423555e-02, 1.34625982e+00, -1.28287047e-01,
-1.55214879e-02, 3.25294234e-01, 1.06120585e+00, -5.09891282e-01,
1.25789335e+00, 1.01808348e+00, -9.92590713e-01, 1.72832932e+00,
1.12232980e+00, 8.54801892e-01, 1.41534752e+00, 3.50798405e-01,
3.69381623e-01, 1.48608411e+00, -1.15506310e-02, 1.57066360e+00,
2.00747378e-01, 4.47219763e-01, 5.57720524e-01, -7.74295353e-02,
1.79192501e+00, 7.66510475e-01, 1.38852488e+00, -4.06055122e-01,
2.73203156e-01, 3.61014687e-01, 1.23574043e+00, 1.64565746e-01,
-9.89896480e-02, 9.26130265e-02, 1.06440134e+00, -1.55890408e-01,
4.47131846e-01, -7.59186008e-01, -1.50881256e+00, -2.13928005e-01,
-4.19160151e-01, 1.75815544e+00, 7.45423008e-01, 6.94781506e-01,
1.58596346e+00, 1.75508724e+00, 4.56070434e-01, 2.94128709e-02,
1.17703970e+00, -9.90230827e-02, 8.42796845e-01, 1.79154944e+00,
5.92779197e-01, 2.73562285e-01, 1.61597907e+00, 1.23514403e+00,
4.86261080e-01, -3.10434934e-01, 5.57873722e-01, 6.50365217e-01,
-3.41009850e-01, 9.26851109e-01, 8.28936486e-01, 9.16180689e-02,
1.30226405e+00, 3.73945789e-01, 6.04560122e-02, 6.00698708e-01,
9.68764731e-02, 1.41904148e+00, 6.94182961e-03, 3.17504138e-01,
5.90956041e-01, -5.78113887e-01, 5.26615565e-01, -4.19715252e-01,
8.92891364e-01, 1.30207363e-01, 4.19899637e-01, 7.10275704e-01,
9.27418179e-02, 1.85758044e+00, 4.76988907e-01, -1.36341398e-01]),
SMC: np.array([ 0.40152748, -0.1440789 , 1.87105436, 1.65027354, 0.78140894,
-0.33437271, 0.55987446, 1.05976848, 0.52126327, 0.5295624 ,
-0.7120724 , 0.39250673, 0.92590897, 0.776836 , 0.30528805,
1.32178809, 1.30972392, 0.77107019, 1.11885364, 0.59633151,
0.63584096, -0.29117982, 0.97372731, 1.06270256, 0.87424729,
0.49249202, -0.55942483, -0.17608982, 0.47118016, 1.0026767 ,
1.42476886, 1.16505966, 0.71572226, 1.14267914, -0.27628211,
0.66712824, 0.58322462, 0.28193361, 0.30175522, -0.11615552,
-0.02127047, 0.01085484, 1.21229396, 0.50109798, 0.2046552 ,
0.95648093, 0.26673391, -0.703456 , 1.23223409, -0.87686456,
1.45480993, 1.04172093, 1.73512969, 1.00835375, 0.56551883,
0.43457948, 1.85267864, 0.51961398, 0.20641743, 0.70484816,
1.04491792, -0.70236338, 1.47248532, 0.57438209, -0.15590465,
0.51528505, 1.49158593, 0.02418851, -0.04563402, 1.50712686,
1.01211014, -0.1058956 , 1.91153929, 1.09281243, 0.78028316,
0.08148316, 0.3989925 , 0.30230531, 1.59469562, -0.53948736,
-0.35653048, 0.44440402, 1.02983002, 0.05184227, 0.78152799,
0.99204159, 0.44148902, -0.12657838, 0.97114256, 0.67963455,
1.33757129, 0.71977859, 0.09706076, -0.13609892, -0.39969385,
0.04687582, 0.053386 , 0.33382962, -0.36082645, 0.86597207,
0.09824643, -0.85212079, 0.54518473, -0.26622955, 0.71836765,
0.81359943, 1.39550066, 0.25118273, 1.03965837, -0.65995684,
-0.25522586, 2.12497766, 0.69534904, 0.74613619, -0.10312994,
1.3244944 , -0.036056 , 0.90976629, 0.49647046, 0.80779428,
0.18921903, -0.18365952, 0.56968353, -0.8232526 , -0.88612154,
-0.47326386, 0.18939692, 0.2298177 , 0.65693251, 1.08908496,
1.04748985, 0.53615771, -0.4611776 , 1.12076823, -0.79971572,
1.78908277, 1.32673932, 1.43691077, 0.2564599 , 0.08480867,
0.26340606, -0.86864626, 1.05716355, 0.18611255, 0.44701292,
-0.06966819, 0.3325726 , 0.94594745, -0.0904025 , 0.14349182,
0.83638941, 0.57657934, 0.9549692 , -0.18496471, 0.87838048,
0.66938294, 0.54401984, 0.47804147, 0.32545637, -0.82626784,
0.93390148, 0.39170683, -0.22244643, 0.36576256, 0.62426937,
-0.16594267, 1.55050592, 0.60508809, -1.28925325, 1.1470063 ,
0.71030941, 1.20896922, 1.23267962, 0.67278808, 0.5846423 ,
-0.09343583, -0.28323718, 0.87891542, 0.54779014, 0.17131075,
1.02287448, 0.61819842, 1.28724788, 0.641085 , 1.48324063,
-1.68770188, 0.03750369, 0.47352403, 0.22929128, 0.637757 ,
0.61735636, 0.17260147, 1.10929764, -0.33766643, 0.27064342,
-0.54594464, -1.23229206, -0.18328842, -0.78636148, 1.38189874]),
}

def setup_class(self):
Expand Down Expand Up @@ -202,7 +192,6 @@ def check_trace(self, step_method):
trace = sample(0, tune=n_steps,
discard_tuned_samples=False,
step=step_method(), random_seed=1, chains=1)

assert_array_almost_equal(
trace['x'],
self.master_samples[step_method],
Expand Down Expand Up @@ -428,7 +417,7 @@ def test_bad_init(self):
with Model():
HalfNormal('a', sd=1, testval=-1, transform=None)
with pytest.raises(ValueError) as error:
sample(init=None)
sample(init=None, cores=1)
error.match('Bad initial')

def test_linalg(self, caplog):
Expand Down

0 comments on commit 339828d

Please sign in to comment.