Skip to content

Commit

Permalink
add simple unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 1, 2017
1 parent b93b193 commit 34277ab
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion edward/inferences/implicit_klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, latent_vars, data=None, discriminator=None,
necessary as the discriminator can take an arbitrary set of data,
latent, and global variables.
Note the type for ``discriminator``'s output change when one
Note the type for ``discriminator``'s output changes when one
passes in the ``scale`` argument to ``initialize()``.
+ If ``scale`` has at most one item, then ``discriminator``
Expand Down
47 changes: 47 additions & 0 deletions tests/test-inferences/test_implicitklqp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import tensorflow as tf

from edward.models import Normal
from tensorflow.contrib import slim


class test_implicit_klqp_class(tf.test.TestCase):

def test_normal_run(self):
def ratio_estimator(data, local_vars, global_vars):
input = tf.reshape(local_vars[z], [1, 1]) # reshape scalar as matrix
h1 = slim.fully_connected(input, 10, activation_fn=tf.nn.relu)
h2 = slim.fully_connected(h1, 1, activation_fn=None)
return h2

with self.test_session() as sess:
z = Normal(mu=5.0, sigma=1.0)

qz = Normal(mu=tf.Variable(tf.random_normal([])),
sigma=tf.nn.softplus(tf.Variable(tf.random_normal([]))))

inference = ed.ImplicitKLqp({z: qz}, discriminator=ratio_estimator)
# inference.run(n_iter=1000)
inference.initialize(n_iter=1000, n_print=100)

sess = ed.get_session()
tf.global_variables_initializer().run()

for _ in range(inference.n_iter):
info_dict = inference.update()
t = info_dict['t']
inference.print_progress(info_dict)
if t == 1 or t % inference.n_print == 0:
# Check inferred posterior parameters.
mean, std = sess.run([qz.mean(), qz.std()])
print("Inferred mean & std: {} {}".format(mean, std))

self.assertAllClose(qz.mean().eval(), 5.0, atol=1.0)

if __name__ == '__main__':
ed.set_seed(47324)
tf.test.main()

0 comments on commit 34277ab

Please sign in to comment.