Skip to content

Commit

Permalink
Cleaned up model.py, made it comply with pep8, and fixed lint error o…
Browse files Browse the repository at this point in the history
…n distribution.py.
  • Loading branch information
lucianopaz committed Sep 27, 2018
1 parent 339828d commit 890ae74
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 103 deletions.
35 changes: 24 additions & 11 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import theano
from ..memoize import memoize
from ..model import (
Model, modelcontext, FreeRV, ObservedRV, MultiObservedRV,
Model, modelcontext, FreeRV, ObservedRV,
not_shared_or_constant_variable, DependenceDAG
)
from ..vartypes import string_types
Expand Down Expand Up @@ -35,12 +35,14 @@ def __new__(cls, name, *args, **kwargs):
if isinstance(name, string_types):
data = kwargs.pop('observed', None)
if isinstance(data, ObservedRV) or isinstance(data, FreeRV):
raise TypeError("observed needs to be data but got: {}".format(type(data)))
raise TypeError("observed needs to be data but got: {}".
format(type(data)))
total_size = kwargs.pop('total_size', None)
dist = cls.dist(*args, **kwargs)
return model.Var(name, dist, data, total_size)
else:
raise TypeError("Name needs to be a string but got: {}".format(name))
raise TypeError("Name needs to be a string but got: {}".
format(name))

def __getnewargs__(self):
return _Unpickling,
Expand All @@ -64,12 +66,14 @@ def __init__(self, shape, dtype, testval=None, defaults=(),
self.conditional_on = None

def default(self):
return np.asarray(self.get_test_val(self.testval, self.defaults), self.dtype)
return np.asarray(self.get_test_val(self.testval, self.defaults),
self.dtype)

def get_test_val(self, val, defaults):
if val is None:
for v in defaults:
if hasattr(self, v) and np.all(np.isfinite(self.getattr_value(v))):
if (hasattr(self, v) and
np.all(np.isfinite(self.getattr_value(v)))):
return self.getattr_value(v)
else:
return self.getattr_value(val)
Expand Down Expand Up @@ -132,7 +136,8 @@ class NoDistribution(Distribution):
def __init__(self, shape, dtype, testval=None, defaults=(),
transform=None, parent_dist=None, *args, **kwargs):
super(NoDistribution, self).__init__(shape=shape, dtype=dtype,
testval=testval, defaults=defaults,
testval=testval,
defaults=defaults,
*args, **kwargs)
self.parent_dist = parent_dist

Expand Down Expand Up @@ -161,7 +166,8 @@ def __init__(self, shape=(), dtype=None, defaults=('mode',),
else:
dtype = 'int64'
if dtype != 'int16' and dtype != 'int64':
raise TypeError('Discrete classes expect dtype to be int16 or int64.')
raise TypeError('Discrete classes expect dtype to be int16 or '
'int64.')

if kwargs.get('transform', None) is not None:
raise ValueError("Transformations for discrete distributions "
Expand All @@ -174,7 +180,8 @@ def __init__(self, shape=(), dtype=None, defaults=('mode',),
class Continuous(Distribution):
"""Base class for continuous distributions"""

def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'),
def __init__(self, shape=(), dtype=None,
defaults=('median', 'mean', 'mode'),
*args, **kwargs):
if dtype is None:
dtype = theano.config.floatX
Expand All @@ -195,12 +202,15 @@ class DensityDist(Distribution):
with pm.Model():
mu = pm.Normal('mu',0,1)
normal_dist = pm.Normal.dist(mu, 1)
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
pm.DensityDist('density_dist', normal_dist.logp,
observed=np.random.randn(100),
random=normal_dist.random)
trace = pm.sample(100)
"""

def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **kwargs):
def __init__(self, logp, shape=(), dtype=None, testval=0, random=None,
*args, **kwargs):
if dtype is None:
dtype = theano.config.floatX
super(DensityDist, self).__init__(
Expand All @@ -213,7 +223,8 @@ def random(self, *args, **kwargs):
return self.rand(*args, **kwargs)
else:
raise ValueError("Distribution was not passed any random method "
"Define a custom random method and pass it as kwarg random")
"Define a custom random method and pass it as "
"kwarg random")


def draw_values(params, point=None, size=None, model=None):
Expand Down Expand Up @@ -462,6 +473,7 @@ def to_tuple(shape):
shape = tuple(shape)
return shape


def _is_one_d(dist_shape):
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
return True
Expand All @@ -471,6 +483,7 @@ def _is_one_d(dist_shape):
return True
return False


def generate_samples(generator, *args, **kwargs):
"""Generate samples from the distribution of a random variable.
Expand Down
Loading

0 comments on commit 890ae74

Please sign in to comment.