Skip to content

Commit

Permalink
add DP example with base distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 6, 2017
1 parent fae636e commit 45daeec
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/pp_dirichlet_process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python
"""Dirichlet process.
We sample from a Dirichlet process (with no base distribution) by
using its stick breaking construction.
We sample from a Dirichlet process (with no base distribution) via its
stick breaking construction.
References
----------
Expand Down
86 changes: 86 additions & 0 deletions examples/pp_dirichlet_process_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python
"""Dirichlet process.
We sample from a Dirichlet process (with inputted base distribution)
via its stick breaking construction.
References
----------
https://probmods.org/chapters/12-non-parametric-models.html#infinite-discrete-distributions-the-dirichlet-processes
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from edward.models import Bernoulli, Beta, Normal



def dirichlet_process(alpha, base_cls, sample_n=50, *args, **kwargs):
"""A Dirichlet process DP(``alpha``, ``base_cls(*args, **kwargs)``)
with ``sample_n`` draws.
Only works for scalar alpha and scalar base distribution.
Parameters
----------
alpha : float
Concentration parameter, with same batch shape as the base distribution.
base_cls : RandomVariable
Class of Base distribution, whose instantiation has the same batch
shape as ``alpha``.
sample_n : int, optional
Number of samples for each DP in the batch shape.
*args, **kwargs : optional
Arguments passed into ``base_cls``.
Returns
-------
tf.Tensor
A ``tf.Tensor`` of shape ``[sample_n] + batch_shape + event_shape``,
where ``sample_n`` is the number of samples, ``batch_shape`` is the
number of independent DPs, and ``event_shape`` is the ``event_shape``
of ``base_cls``.
"""
def cond(k, beta_k, draws, bools):
# Proceed if at least one bool is True.
return tf.reduce_any(bools)

def body(k, beta_k, draws, bools):
k = k + 1
beta_k = beta_k * Beta(a=1.0, b=alpha)
theta_k = base_cls(*args, **kwargs)

# Assign ongoing samples to the new theta_k.
indicator = tf.cast(bools, draws.dtype)
new = indicator * theta_k
draws = draws * (1.0 - indicator) + new

flips = tf.cast(Bernoulli(p=beta_k), tf.bool)
bools = tf.logical_and(flips, tf.equal(draws, theta_k))
return k, beta_k, draws, bools

k = 0
beta_k = Beta(a=tf.ones(sample_n), b=alpha * tf.ones(sample_n))
theta_k = base_cls(*args, **kwargs)

# Initialize all samples as theta_k.
draws = tf.ones(sample_n) * theta_k
# Flip ``sample_n`` coins, one for each sample.
flips = tf.cast(Bernoulli(p=beta_k), tf.bool)
# Get boolean tensor for samples that return heads
# and are currently equal to theta_k.
bools = tf.logical_and(flips, tf.equal(draws, theta_k))

total_sticks, _, samples, _ = tf.while_loop(
cond, body, loop_vars=[k, beta_k, draws, bools])
return total_sticks, samples


dp = dirichlet_process(0.1, Normal, mu=0.0, sigma=1.0)
sess = tf.Session()
print(sess.run(dp))
print(sess.run(dp))
print(sess.run(dp))

0 comments on commit 45daeec

Please sign in to comment.