diff --git a/docs/tex/api/inference-classes.tex b/docs/tex/api/inference-classes.tex index d2547f70b..e57957b13 100644 --- a/docs/tex/api/inference-classes.tex +++ b/docs/tex/api/inference-classes.tex @@ -161,6 +161,9 @@ \subsubsection{Exact Inference} .. autoclass:: edward.inferences.WGANInference :members: +.. autoclass:: edward.inferences.ImplicitKLqp + :members: + .. autoclass:: edward.inferences.KLpq :members: diff --git a/edward/__init__.py b/edward/__init__.py index bd33e8870..d71eab3ee 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -16,7 +16,7 @@ HMC, MetropolisHastings, SGLD, SGHMC, \ KLpq, KLqp, MFVI, ReparameterizationKLqp, ReparameterizationKLKLqp, \ ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \ - GANInference, WGANInference, MAP, Laplace + GANInference, WGANInference, ImplicitKLqp, MAP, Laplace from edward.models import PyMC3Model, PythonModel, StanModel, \ RandomVariable from edward.util import copy, dot, get_ancestors, get_children, \ diff --git a/edward/inferences/__init__.py b/edward/inferences/__init__.py index 979c72af0..7577d15f2 100644 --- a/edward/inferences/__init__.py +++ b/edward/inferences/__init__.py @@ -4,6 +4,7 @@ from edward.inferences.gan_inference import * from edward.inferences.hmc import * +from edward.inferences.implicit_klqp import * from edward.inferences.inference import * from edward.inferences.klpq import * from edward.inferences.klqp import * diff --git a/edward/inferences/implicit_klqp.py b/edward/inferences/implicit_klqp.py new file mode 100644 index 000000000..491e03656 --- /dev/null +++ b/edward/inferences/implicit_klqp.py @@ -0,0 +1,237 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import tensorflow as tf + +from edward.inferences.gan_inference import GANInference +from edward.models import RandomVariable +from edward.util import copy, get_session + + +class ImplicitKLqp(GANInference): + """Variational inference with implicit probabilistic models. + + It minimizes the KL divergence + + .. math:: + + \\text{KL}( q(z, \beta; \lambda) \| p(z, \beta \mid x) ), + + where :math:`z` are local variables associated to a data point and + :math:`\beta` are global variables shared across data points. + + Global latent variables require ``log_prob()`` and need to return a + random sample when fetched from the graph. Local latent variables + and observed variables require only a random sample when fetched + from the graph. (This is true for both :math:`p` and :math:`q`.) + + All variational factors must be reparameterizable: each of the + random variables (``rv``) satisfies ``rv.is_reparameterized`` and + ``rv.is_continuous``. + """ + def __init__(self, latent_vars, data=None, discriminator=None, + global_vars=None): + """ + Parameters + ---------- + discriminator : function + Function (with parameters). Unlike ``GANInference``, it is + interpreted as a ratio estimator rather than a discriminator. + It takes three arguments: a data dict, local latent variable + dict, and global latent variable dict. As with GAN + discriminators, it can take a batch of data points and local + variables, of size :math:`M`, and output a vector of length + :math:`M`. + global_vars : dict of RandomVariable to RandomVariable, optional + Identifying which variables in ``latent_vars`` are global + variables, shared across data points. These will not be + encompassed in the ratio estimation problem, and will be + estimated with tractable variational approximations. + + Notes + ----- + Unlike ``GANInference``, ``discriminator`` takes dict's as input, + and must subset to the appropriate values through lexical scoping + from the previously defined model and latent variables. This is + necessary as the discriminator can take an arbitrary set of data, + latent, and global variables. + + Note the type for ``discriminator``'s output changes when one + passes in the ``scale`` argument to ``initialize()``. + + + If ``scale`` has at most one item, then ``discriminator`` + outputs a tensor whose multiplication with that element is + broadcastable. (For example, the output is a tensor and the single + scale factor is a scalar.) + + If ``scale`` has more than one item, then in order to scale + its corresponding output, ``discriminator`` must output a + dictionary of same size and keys as ``scale``. + """ + if discriminator is None: + raise NotImplementedError() + + if global_vars is None: + global_vars = {} + elif not isinstance(latent_vars, dict): + raise TypeError() + + self.discriminator = discriminator + self.global_vars = global_vars + # call grandparent's method; avoid parent (GANInference) + super(GANInference, self).__init__(latent_vars, data, model_wrapper=None) + + def initialize(self, ratio_loss='log', *args, **kwargs): + """Initialization. + + Parameters + ---------- + ratio_loss : str or fn, optional + Loss function minimized to get the ratio estimator. 'log' or 'hinge'. + Alternatively, one can pass in a function of two inputs, + ``psamples`` and ``qsamples``, and output a point-wise value + with shape matching the shapes of the two inputs. + """ + if callable(ratio_loss): + self.ratio_loss = ratio_loss + elif ratio_loss == 'log': + self.ratio_loss = log_loss + elif ratio_loss == 'hinge': + self.ratio_loss = hinge_loss + else: + raise ValueError('Ratio loss not found:', ratio_loss) + + return super(ImplicitKLqp, self).initialize(*args, **kwargs) + + def build_loss_and_gradients(self, var_list): + """Build loss function + + .. math:: + + -[\mathbb{E}_{q(\beta)} [ log p(\beta) - log q(\beta) ] + + \sum_{n=1}^N \mathbb{E}_{q(\beta)q(z_n|\beta)} [ r*(x_n, z_n, \beta) ] ] + + We minimize it with respect to parameterized variational + families :math:`q(z, beta; \lambda)`. + + :math:`r*(x_n, z_n, beta)` is a function of a single data point + :math:`x_n`, single local variable :math:`z_n`, and all global + variables :math:`\beta`. It is equal to the log-ratio + + .. math:: + + \log p(x_n, z_n | \beta) - \log q(z_n | \beta). + + Rather than explicit calculation, :math:`r*(x, z, \beta)` is the + solution to a ratio estimation problem, minimizing the specified + ``ratio_loss``. + + Gradients are taken using the reparameterization trick (Kingma and + Welling, 2014). + + Notes + ----- + This also includes model parameters :math:`p(x, z, beta; theta)` + and variational distributions with inference networks :math:`q(z | + x)`. + + There are a bunch of extensions we could easily do in this + implementation: + + + further factorizations can be used to better leverage the + graph structure for more complicated models; + + score function gradients for global variables; + + use more samples; this would require the ``copy()`` utility + function for q's as well, and an additional loop. we opt not to + because it complicates the code; + + analytic KL/swapping out the penalty term for the globals. + """ + # Collect tensors used in calculation of losses. + scope = 'inference_' + str(id(self)) + qbeta_sample = {} + pbeta_log_prob = 0.0 + qbeta_log_prob = 0.0 + for beta, qbeta in six.iteritems(self.global_vars): + # Draw a sample beta' ~ q(beta) and calculate + # log p(beta') and log q(beta'). + qbeta_sample[beta] = qbeta.value() + pbeta_log_prob += tf.reduce_sum(beta.log_prob(qbeta_sample[beta])) + qbeta_log_prob += tf.reduce_sum(qbeta.log_prob(qbeta_sample[beta])) + + pz_sample = {} + qz_sample = {} + for z, qz in six.iteritems(self.latent_vars): + if z not in self.global_vars: + # Copy local variables p(z), q(z) to draw samples + # z' ~ p(z | beta'), z' ~ q(z | beta'). + pz_copy = copy(z, dict_swap=qbeta_sample, scope=scope) + pz_sample[z] = pz_copy.value() + qz_sample[z] = qz.value() + + # Collect x' ~ p(x | z', beta') and x' ~ q(x). + dict_swap = qbeta_sample.copy() + dict_swap.update(qz_sample) + x_psample = {} + x_qsample = {} + for x, x_data in six.iteritems(self.data): + if isinstance(x, tf.Tensor): + if "Placeholder" not in x.op.type: + # Copy p(x | z, beta) to get draw p(x | z', beta'). + x_copy = copy(x, dict_swap=dict_swap, scope=scope) + x_psample[x] = x_copy + x_qsample[x] = x_data + elif isinstance(x, RandomVariable): + # Copy p(x | z, beta) to get draw p(x | z', beta'). + x_copy = copy(x, dict_swap=dict_swap, scope=scope) + x_psample[x] = x_copy.value() + x_qsample[x] = x_data + + with tf.variable_scope("Disc"): + r_psample = self.discriminator(x_psample, pz_sample, qbeta_sample) + + with tf.variable_scope("Disc", reuse=True): + r_qsample = self.discriminator(x_qsample, qz_sample, qbeta_sample) + + # Form ratio loss and ratio estimator. + if len(self.scale) <= 1: + loss_d = tf.reduce_mean(self.ratio_loss(r_psample, r_qsample)) + scale = list(six.itervalues(self.scale)) + scale = scale[0] if scale else 1.0 + scaled_ratio = tf.reduce_sum(scale * r_qsample) + else: + loss_d = [tf.reduce_mean(self.ratio_loss(r_psample[key], r_qsample[key])) + for key in six.iterkeys(self.scale)] + loss_d = tf.reduce_sum(loss_d) + scaled_ratio = [tf.reduce_sum(self.scale[key] * r_qsample[key]) + for key in six.iterkeys(self.scale)] + scaled_ratio = tf.reduce_sum(scaled_ratio) + + # Form variational objective. + loss = -(pbeta_log_prob - qbeta_log_prob + scaled_ratio) + + var_list_d = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") + if var_list is None: + var_list = [v for v in tf.trainable_variables() if v not in var_list_d] + + grads = tf.gradients(loss, var_list) + grads_d = tf.gradients(loss_d, var_list_d) + grads_and_vars = list(zip(grads, var_list)) + grads_and_vars_d = list(zip(grads_d, var_list_d)) + return loss, grads_and_vars, loss_d, grads_and_vars_d + + +def log_loss(psample, qsample): + """Point-wise log loss.""" + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.ones_like(psample), logits=psample) + \ + tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.zeros_like(qsample), logits=qsample) + return loss + + +def hinge_loss(psample, qsample): + """Point-wise hinge loss.""" + loss = tf.nn.relu(1.0 - psample) + tf.nn.relu(1.0 + qsample) + return loss diff --git a/examples/bayesian_linear_regression_implicitklqp.py b/examples/bayesian_linear_regression_implicitklqp.py new file mode 100644 index 000000000..d49425211 --- /dev/null +++ b/examples/bayesian_linear_regression_implicitklqp.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +"""Bayesian linear regression. Inference uses data subsampling and +scales the log-likelihood. + +One local optima is an inferred posterior mean of about [-5.0 5.0]. +This implies there is some weird symmetry happening; this result can +be obtained by initializing the first coordinate to be negative. +Similar occurs for the second coordinate. + +Note as with all GAN-style training, the algorithm is not stable. It +is recommended to monitor training and halt manually according to some +criterion (e.g., prediction accuracy on validation test, quality of +samples). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import edward as ed +import numpy as np +import tensorflow as tf + +from edward.models import Normal +from tensorflow.contrib import slim + + +def build_toy_dataset(N, w, noise_std=0.1): + D = len(w) + x = np.random.randn(N, D).astype(np.float32) + y = np.dot(x, w) + np.random.normal(0, noise_std, size=N) + return x, y + + +def ratio_estimator(data, local_vars, global_vars): + """Takes as input a dict of data x, local variable samples z, and + global variable samples beta; outputs real values of shape + (x.shape[0] + z.shape[0],). In this example, there are no local + variables. + """ + # data[y] has shape (M,); global_vars[w] has shape (D,) + # we concatenate w to each data point y, so input has shape (M, 1 + D) + input = tf.concat([ + tf.reshape(data[y], [M, 1]), + tf.tile(tf.reshape(global_vars[w], [1, D]), [M, 1])], 1) + hidden = slim.fully_connected(input, 64, activation_fn=tf.nn.relu) + output = slim.fully_connected(hidden, 1, activation_fn=None) + return output + + +def next_batch(size, i): + diff = (i + 1) * size - X_train.shape[0] + if diff <= 0: + X_batch = X_train[(i * size):((i + 1) * size), :] + y_batch = y_train[(i * size):((i + 1) * size)] + i = i + 1 + else: + X_batch = np.concatenate((X_train[(i * size):, :], X_train[:diff, :])) + y_batch = np.concatenate((y_train[(i * size):], y_train[:diff])) + i = 0 + + return X_batch, y_batch, i + + +ed.set_seed(42) + +N = 500 # number of data points +M = 50 # batch size during training +D = 2 # number of features + +# DATA +w_true = np.ones(D) * 5.0 +X_train, y_train = build_toy_dataset(N, w_true) +X_test, y_test = build_toy_dataset(N, w_true) + +# MODEL +X = tf.placeholder(tf.float32, [M, D]) +y_ph = tf.placeholder(tf.float32, [M]) +w = Normal(mu=tf.zeros(D), sigma=tf.ones(D)) +y = Normal(mu=ed.dot(X, w), sigma=tf.ones(M)) + +# INFERENCE +qw = Normal(mu=tf.Variable(tf.random_normal([D]) + 1.0), + sigma=tf.nn.softplus(tf.Variable(tf.random_normal([D])))) + +inference = ed.ImplicitKLqp( + {w: qw}, data={y: y_ph}, + discriminator=ratio_estimator, global_vars={w: qw}) +inference.initialize(n_iter=5000, n_print=100, scale={y: float(N) / M}) + +sess = ed.get_session() +tf.global_variables_initializer().run() + +i = 0 +for _ in range(inference.n_iter): + X_batch, y_batch, i = next_batch(M, i) + for _ in range(5): + info_dict_d = inference.update( + variables="Disc", feed_dict={X: X_batch, y_ph: y_batch}) + + info_dict = inference.update( + variables="Gen", feed_dict={X: X_batch, y_ph: y_batch}) + info_dict['loss_d'] = info_dict_d['loss_d'] + info_dict['t'] = info_dict['t'] // 6 # say set of 6 updates is 1 iteration + + t = info_dict['t'] + inference.print_progress(info_dict) + if t == 1 or t % inference.n_print == 0: + # Check inferred posterior parameters. + mean, std = sess.run([qw.mean(), qw.std()]) + print("Inferred mean & std") + print(mean) + print(std) diff --git a/tests/test-inferences/test_implicitklqp.py b/tests/test-inferences/test_implicitklqp.py new file mode 100644 index 000000000..38d5997d1 --- /dev/null +++ b/tests/test-inferences/test_implicitklqp.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import edward as ed +import tensorflow as tf + +from edward.models import Normal + + +class test_implicit_klqp_class(tf.test.TestCase): + + def test_normal_run(self): + def ratio_estimator(data, local_vars, global_vars): + """Use the optimal ratio estimator, r(z) = log p(z). We add a + TensorFlow variable as the algorithm assumes that the function + has parameters to optimize.""" + w = tf.get_variable("w", []) + return z.log_prob(local_vars[z]) + w + + with self.test_session() as sess: + z = Normal(mu=5.0, sigma=1.0) + + qz = Normal(mu=tf.Variable(tf.random_normal([])), + sigma=tf.nn.softplus(tf.Variable(tf.random_normal([])))) + + inference = ed.ImplicitKLqp({z: qz}, discriminator=ratio_estimator) + inference.run(n_iter=200) + + self.assertAllClose(qz.mean().eval(), 5.0, atol=1.0) + +if __name__ == '__main__': + ed.set_seed(47324) + tf.test.main()