diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 8178ae0d22..057f7b7904 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -92,10 +92,40 @@ def __new__(cls, name, *args, **kwargs): if not isinstance(name, string_types): raise TypeError(f"Name needs to be a string but got: {name}") - data = kwargs.pop("observed", None) - cls.data = data - if isinstance(data, ObservedRV) or isinstance(data, FreeRV): - raise TypeError("observed needs to be data but got: {}".format(type(data))) + observed_data = kwargs.pop("observed", None) + if isinstance(observed_data, ObservedRV) or isinstance(observed_data, FreeRV): + raise TypeError("observed needs to be data but got: {}".format(type(observed_data))) + given_data = kwargs.pop("givens", None) + if given_data is None: + cls.data = observed_data + elif not isinstance(given_data, dict): + raise TypeError(f"givens needs to be of type dict but got: {type(givens)}") + elif observed_data is None: + cls.data = given_data + elif isinstance(observed_data, dict): + non_data_obs = { + key: type(value) + for key, value in observed_data.items() + if isinstance(value, ObservedRV) or isinstance(value, FreeRV) + } + if non_data_obs: + raise TypeError( + f"All values in observed dict need to be data but got: {non_data_obs}. " + "You may want to use the givens argument in DensityDist." + ) + intersection = given_data.keys() & observed_data.keys() + if intersection: + raise ValueError( + f"{intersection} keys found in both givens and observed dicts but " + "they can not have repeated keys" + ) + cls.data = {**observed_data, **given_data} + else: + raise ValueError( + "If both observed and givens argument are present, observed needs to " + f"be a dict but got: {type(observed_data)}" + ) + data = cls.data total_size = kwargs.pop("total_size", None) dims = kwargs.pop("dims", None) @@ -119,7 +149,7 @@ def __new__(cls, name, *args, **kwargs): dist = cls.dist(*args, **kwargs, shape=shape) else: dist = cls.dist(*args, **kwargs) - return model.Var(name, dist, data, total_size, dims=dims) + return model.Var(name, dist, data, total_size, dims=dims, givens=given_data) def __getnewargs__(self): return (_Unpickling,) @@ -403,6 +433,8 @@ def __init__( If ``True``, the shape of the random samples generate in the ``random`` method is checked with the expected return shape. This test is only performed if ``wrap_random_with_dist_shape is False``. + givens : dict, optional + Model variables on which the DensityDist is conditioned. args, kwargs: (Optional) These are passed to the parent class' ``__init__``. @@ -506,24 +538,6 @@ def __init__( the returned array of samples. It is the user's responsibility to wrap the callable to make it comply with PyMC3's interpretation of ``size``. - - - .. code-block:: python - - with pm.Model(): - mu = pm.Normal('mu', 0 , 1) - normal_dist = pm.Normal.dist(mu, 1, shape=3) - dens = pm.DensityDist( - 'density_dist', - normal_dist.logp, - observed=np.random.randn(100, 3), - shape=3, - random=stats.norm.rvs, - pymc3_size_interpretation=False, # Is True by default - ) - prior = pm.sample_prior_predictive(10)['density_dist'] - assert prior.shape == (10, 100, 3) - """ if dtype is None: dtype = theano.config.floatX diff --git a/pymc3/model.py b/pymc3/model.py index 393c4d2f6a..b343074bde 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1109,7 +1109,7 @@ def add_coords(self, coords): else: self.coords[name] = coords[name] - def Var(self, name, dist, data=None, total_size=None, dims=None): + def Var(self, name, dist, data=None, total_size=None, dims=None, givens=None): """Create and add (un)observed random variable to the model with an appropriate prior distribution. @@ -1161,6 +1161,7 @@ def Var(self, name, dist, data=None, total_size=None, dims=None): var = MultiObservedRV( name=name, data=data, + givens=givens, distribution=dist, total_size=total_size, model=self, @@ -1834,7 +1835,7 @@ class MultiObservedRV(Factor): Potentially partially observed. """ - def __init__(self, name, data, distribution, total_size=None, model=None): + def __init__(self, name, data, distribution, total_size=None, model=None, givens=None): """ Parameters ---------- @@ -1850,6 +1851,7 @@ def __init__(self, name, data, distribution, total_size=None, model=None): self.data = { name: as_tensor(data, name, model, distribution) for name, data in data.items() } + self.givens = givens self.missing_values = [ datum.missing_values for datum in self.data.values() if datum.missing_values is not None