Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify inputs to DirichletProcess; improve docstrings #583

Merged
merged 2 commits into from
Mar 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/tex/iclr2017.tex
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,7 @@ \subsubsection{Appendix A. Model Examples}
N = 1000 # number of data points
D = 5 # data dimensionality

dp = DirichletProcess(
alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D))
dp = DirichletProcess(alpha=1.0, base=Normal(mu=tf.zeros(D), sigma=tf.ones(D)))
mu = dp.sample(N)
x = Normal(mu=mu, sigma=tf.ones([N, D]))
\end{lstlisting}
Expand Down
76 changes: 41 additions & 35 deletions edward/models/dirichlet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,48 @@


class DirichletProcess(RandomVariable, Distribution):
def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True,
name="DirichletProcess", value=None, *args, **kwargs):
"""Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`.
"""Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`.

It has two parameters: a positive real value :math:`\\alpha`,
known as the concentration parameter (``alpha``), and a base
distribution :math:`H` (``base_cls(*args, **kwargs)``).
It has two parameters: a positive real value :math:`\\alpha`,
known as the concentration parameter (``alpha``), and a base
distribution :math:`H` (``base``).
"""
def __init__(self, alpha, base, validate_args=False, allow_nan_stats=True,
name="DirichletProcess", *args, **kwargs):
"""Initialize a batch of Dirichlet processes.

Parameters
----------
alpha : tf.Tensor
Concentration parameter. Must be positive real-valued. Its shape
determines the number of independent DPs (batch shape).
base_cls : RandomVariable
Class of base distribution. Its shape (when instantiated)
determines the shape of an individual DP (event shape).
*args, **kwargs : optional
Arguments passed into ``base_cls``.
base : RandomVariable
Base distribution. Its shape determines the shape of an
individual DP (event shape).

Examples
--------
>>> # scalar concentration parameter, scalar base distribution
>>> dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0)
>>> dp = DirichletProcess(0.1, Normal(mu=0.0, sigma=1.0))
>>> assert dp.shape == ()
>>>
>>> # vector of concentration parameters, matrix of Exponentials
>>> dp = DirichletProcess(tf.constant([0.1, 0.4]),
... Exponential, lam=tf.ones([5, 3]))
... Exponential(lam=tf.ones([5, 3])))
>>> assert dp.shape == (2, 5, 3)
"""
parameters = locals()
parameters.pop("self")
with tf.name_scope(name, values=[alpha]) as ns:
with tf.control_dependencies([]):
with tf.name_scope(name, values=[alpha]):
with tf.control_dependencies([
tf.assert_positive(alpha),
] if validate_args else []):
if validate_args and isinstance(base, RandomVariable):
raise TypeError("base must be a ed.RandomVariable object.")

self._alpha = tf.identity(alpha, name="alpha")
self._base_cls = base_cls
self._base_args = args
self._base_kwargs = kwargs
self._base = base

# Instantiate base distribution.
self._base = self._base_cls(*self._base_args, **self._base_kwargs)
# Create empty tensor to store future atoms.
self._theta = tf.zeros(
[0] +
Expand All @@ -63,28 +64,33 @@ def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True,
dtype=self._base.dtype)

# Instantiate beta distribution for stick breaking proportions.
self._betadist = Beta(a=tf.ones_like(self.alpha), b=self.alpha)
self._betadist = Beta(a=tf.ones_like(self._alpha), b=self._alpha)
# Create empty tensor to store stick breaking proportions.
self._beta = tf.zeros(
[0] + self.get_batch_shape().as_list(),
dtype=self._betadist.dtype)

super(DirichletProcess, self).__init__(
dtype=tf.int32,
is_continuous=False,
is_reparameterized=False,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._alpha, self._beta, self._theta],
name=ns,
value=value)
super(DirichletProcess, self).__init__(
dtype=tf.int32,
is_continuous=False,
is_reparameterized=False,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._alpha, self._beta, self._theta],
name=name,
*args, **kwargs)

@property
def alpha(self):
"""Concentration parameter."""
return self._alpha

@property
def base(self):
"""Base distribution used for drawing the atoms."""
return self._base

@property
def beta(self):
"""Stick breaking proportions. It has shape [None] + batch_shape, where
Expand All @@ -106,10 +112,10 @@ def _get_batch_shape(self):
return self.alpha.shape

def _event_shape(self):
return tf.shape(self._base)
return tf.shape(self.base)

def _get_event_shape(self):
return self._base.shape
return self.base.shape

def _sample_n(self, n, seed=None):
"""Sample ``n`` draws from the DP. Draws from the base
Expand Down Expand Up @@ -154,7 +160,7 @@ def _sample_n(self, n, seed=None):
bools = tf.ones([n] + batch_shape, dtype=tf.bool)

# Initialize all samples as zero, they will be overwritten in any case
draws = tf.zeros([n] + batch_shape + event_shape, dtype=self._base.dtype)
draws = tf.zeros([n] + batch_shape + event_shape, dtype=self.base.dtype)

# Calculate shape invariance conditions for theta and beta as these
# can change shape between loop iterations.
Expand Down Expand Up @@ -187,7 +193,7 @@ def _sample_n_body(self, k, bools, theta, beta, draws):
lambda: (theta, beta),
lambda: (
tf.concat(
[theta, tf.expand_dims(self._base.sample(batch_shape), 0)], 0),
[theta, tf.expand_dims(self.base.sample(batch_shape), 0)], 0),
tf.concat(
[beta, tf.expand_dims(self._betadist.sample(), 0)], 0)))
theta_k = tf.gather(theta, k)
Expand Down
40 changes: 29 additions & 11 deletions edward/models/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,44 @@ class Empirical(RandomVariable, Distribution):
"""Empirical random variable."""
def __init__(self, params, validate_args=False, allow_nan_stats=True,
name="Empirical", *args, **kwargs):
"""Initialize an ``Empirical`` random variable.

Parameters
----------
params : tf.Tensor
Collection of samples. Its outer (left-most) dimension
determines the number of samples.

Examples
--------
>>> # 100 samples of a scalar
>>> x = Empirical(params=tf.zeros(100))
>>> assert x.shape == ()
>>>
>>> # 5 samples of a 2 x 3 matrix
>>> dp = Empirical(params=tf.zeros([5, 2, 3]))
>>> assert x.shape == (2, 3)
"""
parameters = locals()
parameters.pop("self")
with tf.name_scope(name, values=[params]) as ns:
with tf.name_scope(name, values=[params]):
with tf.control_dependencies([]):
self._params = tf.identity(params, name="params")
try:
self._n = tf.shape(self._params)[0]
except ValueError: # scalar params
self._n = tf.constant(1)

super(Empirical, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params, self._n],
name=ns,
*args, **kwargs)
super(Empirical, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params, self._n],
name=name,
*args, **kwargs)

@staticmethod
def _param_shapes(sample_shape):
Expand Down
40 changes: 29 additions & 11 deletions edward/models/point_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,39 @@ class PointMass(RandomVariable, Distribution):
"""
def __init__(self, params, validate_args=False, allow_nan_stats=True,
name="PointMass", *args, **kwargs):
"""Initialize a ``PointMass`` random variable.

Parameters
----------
params : tf.Tensor
The location with all probability mass.

Examples
--------
>>> # scalar
>>> x = PointMass(params=28.3)
>>> assert x.shape == ()
>>>
>>> # 5 x 2 x 3 tensor
>>> dp = PointMass(params=tf.zeros([5, 2, 3]))
>>> assert x.shape == (5, 2, 3)
"""
parameters = locals()
parameters.pop("self")
with tf.name_scope(name, values=[params]) as ns:
with tf.name_scope(name, values=[params]):
with tf.control_dependencies([]):
self._params = tf.identity(params, name="params")
super(PointMass, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params],
name=ns,
*args, **kwargs)

super(PointMass, self).__init__(
dtype=self._params.dtype,
is_continuous=False,
is_reparameterized=True,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._params],
name=name,
*args, **kwargs)

@staticmethod
def _param_shapes(sample_shape):
Expand Down
19 changes: 9 additions & 10 deletions examples/pp_dirichlet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ def body(k, beta_k):
print(sess.run(dp))

# Demo of the DirichletProcess random variable in Edward.
base_cls = Normal
kwargs = {'mu': 0.0, 'sigma': 1.0}
base = Normal(mu=0.0, sigma=1.0)

# Highly concentrated DP.
alpha = 1.0
dp = DirichletProcess(alpha, base_cls, **kwargs)
dp = DirichletProcess(alpha, base)
x = dp.sample(1000)
samples = sess.run(x)
plt.hist(samples, bins=100, range=(-3.0, 3.0))
Expand All @@ -60,7 +59,7 @@ def body(k, beta_k):

# More spread out DP.
alpha = 50.0
dp = DirichletProcess(alpha, base_cls, **kwargs)
dp = DirichletProcess(alpha, base)
x = dp.sample(1000)
samples = sess.run(x)
plt.hist(samples, bins=100, range=(-3.0, 3.0))
Expand All @@ -69,7 +68,7 @@ def body(k, beta_k):

# States persist across calls to sample() in a DP.
alpha = 1.0
dp = DirichletProcess(alpha, base_cls, **kwargs)
dp = DirichletProcess(alpha, base)
x = dp.sample(50)
y = dp.sample(75)
samples_x, samples_y = sess.run([x, y])
Expand All @@ -82,13 +81,13 @@ def body(k, beta_k):

# ``theta`` is the distribution indirectly returned by the DP.
# Fetching theta is the same as fetching the Dirichlet process.
dp = DirichletProcess(alpha, base_cls, **kwargs)
theta = base_cls(value=tf.cast(dp, tf.float32), **kwargs)
dp = DirichletProcess(alpha, base)
theta = Normal(0.0, 1.0, value=tf.cast(dp, tf.float32))
print(sess.run([dp, theta]))
print(sess.run([dp, theta]))

# DirichletProcess can also take in non-scalar concentrations and bases.
base_cls = Exponential
kwargs = {'lam': tf.ones([5, 2])}
dp = DirichletProcess(tf.constant([0.1, 0.6, 0.4]), base_cls, **kwargs)
alpha = tf.constant([0.1, 0.6, 0.4])
base = Exponential(lam=tf.ones([5, 2]))
dp = DirichletProcess(alpha, base)
print(dp)
3 changes: 1 addition & 2 deletions notebooks/iclr2017.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,7 @@
"N = 1000 # number of data points\n",
"D = 5 # data dimensionality\n",
"\n",
"dp = DirichletProcess(\n",
" alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D))\n",
"dp = DirichletProcess(alpha=1.0, base=Normal(mu=tf.zeros(D), sigma=tf.ones(D)))\n",
"mu = dp.sample(N)\n",
"x = Normal(mu=mu, sigma=tf.ones([N, D]))"
]
Expand Down
41 changes: 20 additions & 21 deletions tests/test-models/test_dirichlet_process_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,45 @@

class test_dirichletprocess_sample_class(tf.test.TestCase):

def _test(self, n, alpha, base_cls, *args, **kwargs):
x = DirichletProcess(alpha=alpha, base_cls=base_cls, *args, **kwargs)
base = base_cls(*args, **kwargs)
def _test(self, n, alpha, base):
x = DirichletProcess(alpha=alpha, base=base)
val_est = x.sample(n).shape.as_list()
val_true = n + tf.convert_to_tensor(alpha).shape.as_list() + \
tf.convert_to_tensor(base).shape.as_list()
self.assertEqual(val_est, val_true)

def test_alpha_0d_base_0d(self):
with self.test_session():
self._test([1], 0.5, Normal, mu=0.0, sigma=0.5)
self._test([5], tf.constant(0.5), Normal, mu=0.0, sigma=0.5)
self._test([1], 0.5, Normal(mu=0.0, sigma=0.5))
self._test([5], tf.constant(0.5), Normal(mu=0.0, sigma=0.5))

def test_alpha_1d_base0d(self):
with self.test_session():
self._test([1], np.array([0.5]), Normal, mu=0.0, sigma=0.5)
self._test([5], tf.constant([0.5]), Normal, mu=0.0, sigma=0.5)
self._test([1], tf.constant([0.2, 1.5]), Normal, mu=0.0, sigma=0.5)
self._test([5], tf.constant([0.2, 1.5]), Normal, mu=0.0, sigma=0.5)
self._test([1], np.array([0.5]), Normal(mu=0.0, sigma=0.5))
self._test([5], tf.constant([0.5]), Normal(mu=0.0, sigma=0.5))
self._test([1], tf.constant([0.2, 1.5]), Normal(mu=0.0, sigma=0.5))
self._test([5], tf.constant([0.2, 1.5]), Normal(mu=0.0, sigma=0.5))

def test_alpha_0d_base1d(self):
with self.test_session():
self._test([1], 0.5, Normal, mu=tf.zeros(3), sigma=tf.ones(3))
self._test([5], tf.constant(0.5), Normal,
mu=tf.zeros(3), sigma=tf.ones(3))
self._test([1], 0.5, Normal(mu=tf.zeros(3), sigma=tf.ones(3)))
self._test([5], tf.constant(0.5),
Normal(mu=tf.zeros(3), sigma=tf.ones(3)))

def test_alpha_1d_base2d(self):
with self.test_session():
self._test([1], np.array([0.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([5], tf.constant([0.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([1], tf.constant([0.2, 1.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([5], tf.constant([0.2, 1.5]), Normal,
mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))
self._test([1], np.array([0.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))
self._test([5], tf.constant([0.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))
self._test([1], tf.constant([0.2, 1.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))
self._test([5], tf.constant([0.2, 1.5]),
Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])))

def test_persistent_state(self):
with self.test_session() as sess:
dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0)
dp = DirichletProcess(0.1, Normal(mu=0.0, sigma=1.0))
x = dp.sample(5)
y = dp.sample(5)
x_data, y_data, theta = sess.run([x, y, dp.theta])
Expand Down