-
Notifications
You must be signed in to change notification settings - Fork 759
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add DP example with base distribution
- Loading branch information
1 parent
fae636e
commit 45daeec
Showing
2 changed files
with
88 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#!/usr/bin/env python | ||
"""Dirichlet process. | ||
We sample from a Dirichlet process (with inputted base distribution) | ||
via its stick breaking construction. | ||
References | ||
---------- | ||
https://probmods.org/chapters/12-non-parametric-models.html#infinite-discrete-distributions-the-dirichlet-processes | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
from edward.models import Bernoulli, Beta, Normal | ||
|
||
|
||
|
||
def dirichlet_process(alpha, base_cls, sample_n=50, *args, **kwargs): | ||
"""A Dirichlet process DP(``alpha``, ``base_cls(*args, **kwargs)``) | ||
with ``sample_n`` draws. | ||
Only works for scalar alpha and scalar base distribution. | ||
Parameters | ||
---------- | ||
alpha : float | ||
Concentration parameter, with same batch shape as the base distribution. | ||
base_cls : RandomVariable | ||
Class of Base distribution, whose instantiation has the same batch | ||
shape as ``alpha``. | ||
sample_n : int, optional | ||
Number of samples for each DP in the batch shape. | ||
*args, **kwargs : optional | ||
Arguments passed into ``base_cls``. | ||
Returns | ||
------- | ||
tf.Tensor | ||
A ``tf.Tensor`` of shape ``[sample_n] + batch_shape + event_shape``, | ||
where ``sample_n`` is the number of samples, ``batch_shape`` is the | ||
number of independent DPs, and ``event_shape`` is the ``event_shape`` | ||
of ``base_cls``. | ||
""" | ||
def cond(k, beta_k, draws, bools): | ||
# Proceed if at least one bool is True. | ||
return tf.reduce_any(bools) | ||
|
||
def body(k, beta_k, draws, bools): | ||
k = k + 1 | ||
beta_k = beta_k * Beta(a=1.0, b=alpha) | ||
theta_k = base_cls(*args, **kwargs) | ||
|
||
# Assign ongoing samples to the new theta_k. | ||
indicator = tf.cast(bools, draws.dtype) | ||
new = indicator * theta_k | ||
draws = draws * (1.0 - indicator) + new | ||
|
||
flips = tf.cast(Bernoulli(p=beta_k), tf.bool) | ||
bools = tf.logical_and(flips, tf.equal(draws, theta_k)) | ||
return k, beta_k, draws, bools | ||
|
||
k = 0 | ||
beta_k = Beta(a=tf.ones(sample_n), b=alpha * tf.ones(sample_n)) | ||
theta_k = base_cls(*args, **kwargs) | ||
|
||
# Initialize all samples as theta_k. | ||
draws = tf.ones(sample_n) * theta_k | ||
# Flip ``sample_n`` coins, one for each sample. | ||
flips = tf.cast(Bernoulli(p=beta_k), tf.bool) | ||
# Get boolean tensor for samples that return heads | ||
# and are currently equal to theta_k. | ||
bools = tf.logical_and(flips, tf.equal(draws, theta_k)) | ||
|
||
total_sticks, _, samples, _ = tf.while_loop( | ||
cond, body, loop_vars=[k, beta_k, draws, bools]) | ||
return total_sticks, samples | ||
|
||
|
||
dp = dirichlet_process(0.1, Normal, mu=0.0, sigma=1.0) | ||
sess = tf.Session() | ||
print(sess.run(dp)) | ||
print(sess.run(dp)) | ||
print(sess.run(dp)) |