Skip to content

Commit

Permalink
Fix for pymc-devs#3210 which uses a completely different approach tha…
Browse files Browse the repository at this point in the history
…n PR pymc-devs#3214. It uses a context manager inside `draw_values` that makes all the values drawn from `TensorVariables` or `MultiObservedRV`s available to nested calls of the original call to `draw_values`. It is partly inspired by how Edward2 approaches the problem of forward sampling. Ed2 tensors fix a `_values` attribute after they first call `sample` and then only return that. They can do it because of their functional scheme, where the entire graph is recreated each time the generative function is called. Our object oriented paradigm cannot set a fixed _values, it has to know it is in the context of a single `draw_values` call. That is why I opted for context managers to store the drawn values.
  • Loading branch information
lucianopaz committed Nov 27, 2018
1 parent cf62eec commit 6cefd17
Show file tree
Hide file tree
Showing 4 changed files with 482 additions and 297 deletions.
311 changes: 214 additions & 97 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import collections
import six
import numbers

import numpy as np
Expand All @@ -8,7 +8,7 @@
from ..memoize import memoize
from ..model import (
Model, get_named_nodes_and_relations, FreeRV,
ObservedRV, MultiObservedRV
ObservedRV, MultiObservedRV, Context, InitContextMeta
)
from ..vartypes import string_types

Expand Down Expand Up @@ -214,6 +214,48 @@ def random(self, *args, **kwargs):
"Define a custom random method and pass it as kwarg random")


class _DrawValuesContext(six.with_metaclass(InitContextMeta, Context)):
""" A context manager class used while drawing values with draw_values
"""

def __new__(cls, *args, **kwargs):
# resolves the parent instance
instance = super(_DrawValuesContext, cls).__new__(cls)
if cls.get_contexts():
potencial_parent = cls.get_contexts()[-1]
# We have to make sure that the context is a _DrawValuesContext
# and not a Model
if isinstance(potencial_parent, cls):
instance._parent = potencial_parent
else:
instance._parent = None
else:
instance._parent = None
return instance

def __init__(self):
if self.parent is not None:
# All _DrawValuesContext instances that are in the context of
# another _DrawValuesContext will share the reference to the
# drawn_vars dictionary. This means that separate branches
# in the nested _DrawValuesContext context tree will see the
# same drawn values
self.drawn_vars = self.parent.drawn_vars
else:
self.drawn_vars = dict()

@property
def parent(self):
return self._parent


def is_fast_drawable(var):
return isinstance(var, (numbers.Number,
np.ndarray,
tt.TensorConstant,
tt.sharedvar.SharedVariable))


def draw_values(params, point=None, size=None):
"""
Draw (fix) parameter values. Handles a number of cases:
Expand All @@ -232,97 +274,134 @@ def draw_values(params, point=None, size=None):
b) are *RVs with a random method
"""
# Distribution parameters may be nodes which have named node-inputs
# specified in the point. Need to find the node-inputs, their
# parents and children to replace them.
leaf_nodes = {}
named_nodes_parents = {}
named_nodes_children = {}
for param in params:
if hasattr(param, 'name'):
# Get the named nodes under the `param` node
nn, nnp, nnc = get_named_nodes_and_relations(param)
leaf_nodes.update(nn)
# Update the discovered parental relationships
for k in nnp.keys():
if k not in named_nodes_parents.keys():
named_nodes_parents[k] = nnp[k]
else:
named_nodes_parents[k].update(nnp[k])
# Update the discovered child relationships
for k in nnc.keys():
if k not in named_nodes_children.keys():
named_nodes_children[k] = nnc[k]
else:
named_nodes_children[k].update(nnc[k])

# Init givens and the stack of nodes to try to `_draw_value` from
givens = {}
stored = set() # Some nodes
stack = list(leaf_nodes.values()) # A queue would be more appropriate
while stack:
next_ = stack.pop(0)
if next_ in stored:
# If the node already has a givens value, skip it
continue
elif isinstance(next_, (tt.TensorConstant,
tt.sharedvar.SharedVariable)):
# If the node is a theano.tensor.TensorConstant or a
# theano.tensor.sharedvar.SharedVariable, its value will be
# available automatically in _compile_theano_function so
# we can skip it. Furthermore, if this node was treated as a
# TensorVariable that should be compiled by theano in
# _compile_theano_function, it would raise a `TypeError:
# ('Constants not allowed in param list', ...)` for
# TensorConstant, and a `TypeError: Cannot use a shared
# variable (...) as explicit input` for SharedVariable.
stored.add(next_.name)
continue
else:
# If the node does not have a givens value, try to draw it.
# The named node's children givens values must also be taken
# into account.
children = named_nodes_children[next_]
temp_givens = [givens[k] for k in givens if k in children]
try:
# This may fail for autotransformed RVs, which don't
# have the random method
givens[next_.name] = (next_, _draw_value(next_,
point=point,
givens=temp_givens,
size=size))
stored.add(next_.name)
except theano.gof.fg.MissingInputError:
# The node failed, so we must add the node's parents to
# the stack of nodes to try to draw from. We exclude the
# nodes in the `params` list.
stack.extend([node for node in named_nodes_parents[next_]
if node is not None and
node.name not in stored and
node not in params])

# the below makes sure the graph is evaluated in order
# test_distributions_random::TestDrawValues::test_draw_order fails without it
params = dict(enumerate(params)) # some nodes are not hashable
evaluated = {}
to_eval = set()
missing_inputs = set(params)
while to_eval or missing_inputs:
if to_eval == missing_inputs:
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
to_eval = set(missing_inputs)
missing_inputs = set()
for param_idx in to_eval:
param = params[param_idx]
if hasattr(param, 'name') and param.name in givens:
evaluated[param_idx] = givens[param.name][1]
# Get fast drawable values (i.e. things in point or numbers, arrays,
# constants or shares, or things that were already drawn in related
# contexts)
if point is None:
point = {}
with _DrawValuesContext() as context:
params = dict(enumerate(params))
drawn = context.drawn_vars
evaluated = {}
symbolic_params = []
for i, p in params.items():
# If the param is fast drawable, then draw the value immediately
if is_fast_drawable(p):
v = _draw_value(p, point=point, size=size)
evaluated[i] = v
continue

name = getattr(p, 'name', None)
if p in drawn:
# param was drawn in related contexts
v = drawn[p]
evaluated[i] = v
elif name is not None and name in point:
# param.name is in point
v = point[name]
evaluated[i] = drawn[p] = v
else:
try: # might evaluate in a bad order,
evaluated[param_idx] = _draw_value(param, point=point, givens=givens.values(), size=size)
if isinstance(param, collections.Hashable) and named_nodes_parents.get(param):
givens[param.name] = (param, evaluated[param_idx])
# param still needs to be drawn
symbolic_params.append((i, p))

if not symbolic_params:
# We only need to enforce the correct order if there are symbolic
# params that could be drawn in variable order
return [evaluated[i] for i in params]

# Distribution parameters may be nodes which have named node-inputs
# specified in the point. Need to find the node-inputs, their
# parents and children to replace them.
leaf_nodes = {}
named_nodes_parents = {}
named_nodes_children = {}
for _, param in symbolic_params:
if hasattr(param, 'name'):
# Get the named nodes under the `param` node
nn, nnp, nnc = get_named_nodes_and_relations(param)
leaf_nodes.update(nn)
# Update the discovered parental relationships
for k in nnp.keys():
if k not in named_nodes_parents.keys():
named_nodes_parents[k] = nnp[k]
else:
named_nodes_parents[k].update(nnp[k])
# Update the discovered child relationships
for k in nnc.keys():
if k not in named_nodes_children.keys():
named_nodes_children[k] = nnc[k]
else:
named_nodes_children[k].update(nnc[k])

# Init givens and the stack of nodes to try to `_draw_value` from
givens = {p.name: (p, v) for p, v in drawn.items()
if getattr(p, 'name', None) is not None}
stack = list(leaf_nodes.values()) # A queue would be more appropriate
while stack:
next_ = stack.pop(0)
if next_ in drawn:
# If the node already has a givens value, skip it
continue
elif isinstance(next_, (tt.TensorConstant,
tt.sharedvar.SharedVariable)):
# If the node is a theano.tensor.TensorConstant or a
# theano.tensor.sharedvar.SharedVariable, its value will be
# available automatically in _compile_theano_function so
# we can skip it. Furthermore, if this node was treated as a
# TensorVariable that should be compiled by theano in
# _compile_theano_function, it would raise a `TypeError:
# ('Constants not allowed in param list', ...)` for
# TensorConstant, and a `TypeError: Cannot use a shared
# variable (...) as explicit input` for SharedVariable.
continue
else:
# If the node does not have a givens value, try to draw it.
# The named node's children givens values must also be taken
# into account.
children = named_nodes_children[next_]
temp_givens = [givens[k] for k in givens if k in children]
try:
# This may fail for autotransformed RVs, which don't
# have the random method
value = _draw_value(next_,
point=point,
givens=temp_givens,
size=size)
givens[next_.name] = (next_, value)
drawn[next_] = value
except theano.gof.fg.MissingInputError:
missing_inputs.add(param_idx)
# The node failed, so we must add the node's parents to
# the stack of nodes to try to draw from. We exclude the
# nodes in the `params` list.
stack.extend([node for node in named_nodes_parents[next_]
if node is not None and
node.name not in drawn and
node not in params])

# the below makes sure the graph is evaluated in order
# test_distributions_random::TestDrawValues::test_draw_order fails without it
# The remaining params that must be drawn are all hashable
to_eval = set()
missing_inputs = set([j for j, p in symbolic_params])
while to_eval or missing_inputs:
if to_eval == missing_inputs:
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
to_eval = set(missing_inputs)
missing_inputs = set()
for param_idx in to_eval:
param = params[param_idx]
if param in drawn:
evaluated[param_idx] = drawn[param]
else:
try: # might evaluate in a bad order,
value = _draw_value(param,
point=point,
givens=givens.values(),
size=size)
evaluated[param_idx] = drawn[param] = value
givens[param.name] = (param, value)
except theano.gof.fg.MissingInputError:
missing_inputs.add(param_idx)

return [evaluated[j] for j in params] # set the order back

Expand Down Expand Up @@ -400,8 +479,16 @@ def _draw_value(param, point=None, givens=None, size=None):
# reset shape to account for shape changes
# with theano.shared inputs
dist_tmp.shape = np.array([])
val = dist_tmp.random(point=point, size=None)
dist_tmp.shape = val.shape
val = np.atleast_1d(dist_tmp.random(point=point,
size=None))
# Sometimes point may change the size of val but not the
# distribution's shape
if point and size is not None:
temp_size = np.atleast_1d(size)
if all(val.shape[:len(temp_size)] == temp_size):
dist_tmp.shape = val.shape[len(temp_size):]
else:
dist_tmp.shape = val.shape
return dist_tmp.random(point=point, size=size)
else:
return param.distribution.random(point=point, size=size)
Expand All @@ -411,10 +498,24 @@ def _draw_value(param, point=None, givens=None, size=None):
else:
variables = values = []
func = _compile_theano_function(param, variables)
if size and values and not all(var.dshape == val.shape for var, val in zip(variables, values)):
return np.array([func(*v) for v in zip(*values)])
if size is not None:
size = np.atleast_1d(size)
dshaped_variables = all((hasattr(var, 'dshape')
for var in variables))
if (values and dshaped_variables and
not all(var.dshape == getattr(val, 'shape', tuple())
for var, val in zip(variables, values))):
output = np.array([func(*v) for v in zip(*values)])
elif (size is not None and any((val.ndim > var.ndim)
for var, val in zip(variables, values))):
output = np.array([func(*v) for v in zip(*values)])
else:
return func(*values)
output = func(*values)
return output
print(param,
type(param),
isinstance(param, tt.TensorVariable),
isinstance(param, (tt.TensorVariable, MultiObservedRV)))
raise ValueError('Unexpected type in draw_value: %s' % type(param))


Expand Down Expand Up @@ -499,6 +600,20 @@ def generate_samples(generator, *args, **kwargs):
samples = generator(size=broadcast_shape, *args, **kwargs)
elif dist_shape == broadcast_shape:
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
elif len(dist_shape) == 0 and size_tup and broadcast_shape[:len(size_tup)] == size_tup:
# Input's dist_shape is scalar, but it has size repetitions.
# So now the size matches but we have to manually broadcast to
# the right dist_shape
samples = [generator(*args, **kwargs)]
if samples[0].shape == broadcast_shape:
samples = samples[0]
else:
suffix = broadcast_shape[len(size_tup):] + dist_shape
samples.extend([generator(*args, **kwargs).
reshape(broadcast_shape)[..., np.newaxis]
for _ in range(np.prod(suffix,
dtype=int) - 1)])
samples = np.hstack(samples).reshape(size_tup + suffix)
else:
samples = None
# Args have been broadcast correctly, can just ask for the right shape out
Expand All @@ -515,9 +630,11 @@ def generate_samples(generator, *args, **kwargs):
if samples is None:
raise TypeError('''Attempted to generate values with incompatible shapes:
size: {size}
size_tup: {size_tup}
broadcast_shape[:len(size_tup)] == size_tup: {test}
dist_shape: {dist_shape}
broadcast_shape: {broadcast_shape}
'''.format(size=size, dist_shape=dist_shape, broadcast_shape=broadcast_shape))
'''.format(size=size, size_tup=size_tup, dist_shape=dist_shape, broadcast_shape=broadcast_shape, test=broadcast_shape[:len(size_tup)] == size_tup))

# reshape samples here
if samples.shape[0] == 1 and size == 1:
Expand Down
10 changes: 10 additions & 0 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,16 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
self.distribution = distribution
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)

# Make hashable by id for draw_values
def __hash__(self):
return id(self)

def __eq__(self, other):
return self.id == other.id

def __ne__(self, other):
return not self == other


def _walk_up_rv(rv):
"""Walk up theano graph to get inputs for deterministic RV."""
Expand Down
Loading

0 comments on commit 6cefd17

Please sign in to comment.