diff --git a/docs/tex/iclr2017.tex b/docs/tex/iclr2017.tex index 0fac2bb06..61380058c 100644 --- a/docs/tex/iclr2017.tex +++ b/docs/tex/iclr2017.tex @@ -333,8 +333,7 @@ \subsubsection{Appendix A. Model Examples} N = 1000 # number of data points D = 5 # data dimensionality -dp = DirichletProcess( - alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D)) +dp = DirichletProcess(alpha=1.0, base=Normal(mu=tf.zeros(D), sigma=tf.ones(D))) mu = dp.sample(N) x = Normal(mu=mu, sigma=tf.ones([N, D])) \end{lstlisting} diff --git a/edward/models/dirichlet_process.py b/edward/models/dirichlet_process.py index a73779821..b2a95fb75 100644 --- a/edward/models/dirichlet_process.py +++ b/edward/models/dirichlet_process.py @@ -14,47 +14,48 @@ class DirichletProcess(RandomVariable, Distribution): - def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True, - name="DirichletProcess", value=None, *args, **kwargs): - """Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`. + """Dirichlet process :math:`\mathcal{DP}(\\alpha, H)`. - It has two parameters: a positive real value :math:`\\alpha`, - known as the concentration parameter (``alpha``), and a base - distribution :math:`H` (``base_cls(*args, **kwargs)``). + It has two parameters: a positive real value :math:`\\alpha`, + known as the concentration parameter (``alpha``), and a base + distribution :math:`H` (``base``). + """ + def __init__(self, alpha, base, validate_args=False, allow_nan_stats=True, + name="DirichletProcess", *args, **kwargs): + """Initialize a batch of Dirichlet processes. Parameters ---------- alpha : tf.Tensor Concentration parameter. Must be positive real-valued. Its shape determines the number of independent DPs (batch shape). - base_cls : RandomVariable - Class of base distribution. Its shape (when instantiated) - determines the shape of an individual DP (event shape). - *args, **kwargs : optional - Arguments passed into ``base_cls``. + base : RandomVariable + Base distribution. Its shape determines the shape of an + individual DP (event shape). Examples -------- >>> # scalar concentration parameter, scalar base distribution - >>> dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0) + >>> dp = DirichletProcess(0.1, Normal(mu=0.0, sigma=1.0)) >>> assert dp.shape == () >>> >>> # vector of concentration parameters, matrix of Exponentials >>> dp = DirichletProcess(tf.constant([0.1, 0.4]), - ... Exponential, lam=tf.ones([5, 3])) + ... Exponential(lam=tf.ones([5, 3]))) >>> assert dp.shape == (2, 5, 3) """ parameters = locals() parameters.pop("self") - with tf.name_scope(name, values=[alpha]) as ns: - with tf.control_dependencies([]): + with tf.name_scope(name, values=[alpha]): + with tf.control_dependencies([ + tf.assert_positive(alpha), + ] if validate_args else []): + if validate_args and isinstance(base, RandomVariable): + raise TypeError("base must be a ed.RandomVariable object.") + self._alpha = tf.identity(alpha, name="alpha") - self._base_cls = base_cls - self._base_args = args - self._base_kwargs = kwargs + self._base = base - # Instantiate base distribution. - self._base = self._base_cls(*self._base_args, **self._base_kwargs) # Create empty tensor to store future atoms. self._theta = tf.zeros( [0] + @@ -63,28 +64,33 @@ def __init__(self, alpha, base_cls, validate_args=False, allow_nan_stats=True, dtype=self._base.dtype) # Instantiate beta distribution for stick breaking proportions. - self._betadist = Beta(a=tf.ones_like(self.alpha), b=self.alpha) + self._betadist = Beta(a=tf.ones_like(self._alpha), b=self._alpha) # Create empty tensor to store stick breaking proportions. self._beta = tf.zeros( [0] + self.get_batch_shape().as_list(), dtype=self._betadist.dtype) - super(DirichletProcess, self).__init__( - dtype=tf.int32, - is_continuous=False, - is_reparameterized=False, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - parameters=parameters, - graph_parents=[self._alpha, self._beta, self._theta], - name=ns, - value=value) + super(DirichletProcess, self).__init__( + dtype=tf.int32, + is_continuous=False, + is_reparameterized=False, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._alpha, self._beta, self._theta], + name=name, + *args, **kwargs) @property def alpha(self): """Concentration parameter.""" return self._alpha + @property + def base(self): + """Base distribution used for drawing the atoms.""" + return self._base + @property def beta(self): """Stick breaking proportions. It has shape [None] + batch_shape, where @@ -106,10 +112,10 @@ def _get_batch_shape(self): return self.alpha.shape def _event_shape(self): - return tf.shape(self._base) + return tf.shape(self.base) def _get_event_shape(self): - return self._base.shape + return self.base.shape def _sample_n(self, n, seed=None): """Sample ``n`` draws from the DP. Draws from the base @@ -154,7 +160,7 @@ def _sample_n(self, n, seed=None): bools = tf.ones([n] + batch_shape, dtype=tf.bool) # Initialize all samples as zero, they will be overwritten in any case - draws = tf.zeros([n] + batch_shape + event_shape, dtype=self._base.dtype) + draws = tf.zeros([n] + batch_shape + event_shape, dtype=self.base.dtype) # Calculate shape invariance conditions for theta and beta as these # can change shape between loop iterations. @@ -187,7 +193,7 @@ def _sample_n_body(self, k, bools, theta, beta, draws): lambda: (theta, beta), lambda: ( tf.concat( - [theta, tf.expand_dims(self._base.sample(batch_shape), 0)], 0), + [theta, tf.expand_dims(self.base.sample(batch_shape), 0)], 0), tf.concat( [beta, tf.expand_dims(self._betadist.sample(), 0)], 0))) theta_k = tf.gather(theta, k) diff --git a/edward/models/empirical.py b/edward/models/empirical.py index 5a741cc12..1e504580f 100644 --- a/edward/models/empirical.py +++ b/edward/models/empirical.py @@ -12,9 +12,27 @@ class Empirical(RandomVariable, Distribution): """Empirical random variable.""" def __init__(self, params, validate_args=False, allow_nan_stats=True, name="Empirical", *args, **kwargs): + """Initialize an ``Empirical`` random variable. + + Parameters + ---------- + params : tf.Tensor + Collection of samples. Its outer (left-most) dimension + determines the number of samples. + + Examples + -------- + >>> # 100 samples of a scalar + >>> x = Empirical(params=tf.zeros(100)) + >>> assert x.shape == () + >>> + >>> # 5 samples of a 2 x 3 matrix + >>> dp = Empirical(params=tf.zeros([5, 2, 3])) + >>> assert x.shape == (2, 3) + """ parameters = locals() parameters.pop("self") - with tf.name_scope(name, values=[params]) as ns: + with tf.name_scope(name, values=[params]): with tf.control_dependencies([]): self._params = tf.identity(params, name="params") try: @@ -22,16 +40,16 @@ def __init__(self, params, validate_args=False, allow_nan_stats=True, except ValueError: # scalar params self._n = tf.constant(1) - super(Empirical, self).__init__( - dtype=self._params.dtype, - is_continuous=False, - is_reparameterized=True, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - parameters=parameters, - graph_parents=[self._params, self._n], - name=ns, - *args, **kwargs) + super(Empirical, self).__init__( + dtype=self._params.dtype, + is_continuous=False, + is_reparameterized=True, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._params, self._n], + name=name, + *args, **kwargs) @staticmethod def _param_shapes(sample_shape): diff --git a/edward/models/point_mass.py b/edward/models/point_mass.py index f3197b2ab..61400369e 100644 --- a/edward/models/point_mass.py +++ b/edward/models/point_mass.py @@ -16,21 +16,39 @@ class PointMass(RandomVariable, Distribution): """ def __init__(self, params, validate_args=False, allow_nan_stats=True, name="PointMass", *args, **kwargs): + """Initialize a ``PointMass`` random variable. + + Parameters + ---------- + params : tf.Tensor + The location with all probability mass. + + Examples + -------- + >>> # scalar + >>> x = PointMass(params=28.3) + >>> assert x.shape == () + >>> + >>> # 5 x 2 x 3 tensor + >>> dp = PointMass(params=tf.zeros([5, 2, 3])) + >>> assert x.shape == (5, 2, 3) + """ parameters = locals() parameters.pop("self") - with tf.name_scope(name, values=[params]) as ns: + with tf.name_scope(name, values=[params]): with tf.control_dependencies([]): self._params = tf.identity(params, name="params") - super(PointMass, self).__init__( - dtype=self._params.dtype, - is_continuous=False, - is_reparameterized=True, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - parameters=parameters, - graph_parents=[self._params], - name=ns, - *args, **kwargs) + + super(PointMass, self).__init__( + dtype=self._params.dtype, + is_continuous=False, + is_reparameterized=True, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._params], + name=name, + *args, **kwargs) @staticmethod def _param_shapes(sample_shape): diff --git a/examples/pp_dirichlet_process.py b/examples/pp_dirichlet_process.py index f3df5b798..4de141e3a 100644 --- a/examples/pp_dirichlet_process.py +++ b/examples/pp_dirichlet_process.py @@ -46,12 +46,11 @@ def body(k, beta_k): print(sess.run(dp)) # Demo of the DirichletProcess random variable in Edward. -base_cls = Normal -kwargs = {'mu': 0.0, 'sigma': 1.0} +base = Normal(mu=0.0, sigma=1.0) # Highly concentrated DP. alpha = 1.0 -dp = DirichletProcess(alpha, base_cls, **kwargs) +dp = DirichletProcess(alpha, base) x = dp.sample(1000) samples = sess.run(x) plt.hist(samples, bins=100, range=(-3.0, 3.0)) @@ -60,7 +59,7 @@ def body(k, beta_k): # More spread out DP. alpha = 50.0 -dp = DirichletProcess(alpha, base_cls, **kwargs) +dp = DirichletProcess(alpha, base) x = dp.sample(1000) samples = sess.run(x) plt.hist(samples, bins=100, range=(-3.0, 3.0)) @@ -69,7 +68,7 @@ def body(k, beta_k): # States persist across calls to sample() in a DP. alpha = 1.0 -dp = DirichletProcess(alpha, base_cls, **kwargs) +dp = DirichletProcess(alpha, base) x = dp.sample(50) y = dp.sample(75) samples_x, samples_y = sess.run([x, y]) @@ -82,13 +81,13 @@ def body(k, beta_k): # ``theta`` is the distribution indirectly returned by the DP. # Fetching theta is the same as fetching the Dirichlet process. -dp = DirichletProcess(alpha, base_cls, **kwargs) -theta = base_cls(value=tf.cast(dp, tf.float32), **kwargs) +dp = DirichletProcess(alpha, base) +theta = Normal(0.0, 1.0, value=tf.cast(dp, tf.float32)) print(sess.run([dp, theta])) print(sess.run([dp, theta])) # DirichletProcess can also take in non-scalar concentrations and bases. -base_cls = Exponential -kwargs = {'lam': tf.ones([5, 2])} -dp = DirichletProcess(tf.constant([0.1, 0.6, 0.4]), base_cls, **kwargs) +alpha = tf.constant([0.1, 0.6, 0.4]) +base = Exponential(lam=tf.ones([5, 2])) +dp = DirichletProcess(alpha, base) print(dp) diff --git a/notebooks/iclr2017.ipynb b/notebooks/iclr2017.ipynb index 46f2964d0..3e735f5fd 100644 --- a/notebooks/iclr2017.ipynb +++ b/notebooks/iclr2017.ipynb @@ -518,8 +518,7 @@ "N = 1000 # number of data points\n", "D = 5 # data dimensionality\n", "\n", - "dp = DirichletProcess(\n", - " alpha=1.0, base_cls=Normal, mu=tf.zeros(D), sigma=tf.ones(D))\n", + "dp = DirichletProcess(alpha=1.0, base=Normal(mu=tf.zeros(D), sigma=tf.ones(D)))\n", "mu = dp.sample(N)\n", "x = Normal(mu=mu, sigma=tf.ones([N, D]))" ] diff --git a/tests/test-models/test_dirichlet_process_sample.py b/tests/test-models/test_dirichlet_process_sample.py index 8f7912971..870bd44a0 100644 --- a/tests/test-models/test_dirichlet_process_sample.py +++ b/tests/test-models/test_dirichlet_process_sample.py @@ -10,9 +10,8 @@ class test_dirichletprocess_sample_class(tf.test.TestCase): - def _test(self, n, alpha, base_cls, *args, **kwargs): - x = DirichletProcess(alpha=alpha, base_cls=base_cls, *args, **kwargs) - base = base_cls(*args, **kwargs) + def _test(self, n, alpha, base): + x = DirichletProcess(alpha=alpha, base=base) val_est = x.sample(n).shape.as_list() val_true = n + tf.convert_to_tensor(alpha).shape.as_list() + \ tf.convert_to_tensor(base).shape.as_list() @@ -20,36 +19,36 @@ def _test(self, n, alpha, base_cls, *args, **kwargs): def test_alpha_0d_base_0d(self): with self.test_session(): - self._test([1], 0.5, Normal, mu=0.0, sigma=0.5) - self._test([5], tf.constant(0.5), Normal, mu=0.0, sigma=0.5) + self._test([1], 0.5, Normal(mu=0.0, sigma=0.5)) + self._test([5], tf.constant(0.5), Normal(mu=0.0, sigma=0.5)) def test_alpha_1d_base0d(self): with self.test_session(): - self._test([1], np.array([0.5]), Normal, mu=0.0, sigma=0.5) - self._test([5], tf.constant([0.5]), Normal, mu=0.0, sigma=0.5) - self._test([1], tf.constant([0.2, 1.5]), Normal, mu=0.0, sigma=0.5) - self._test([5], tf.constant([0.2, 1.5]), Normal, mu=0.0, sigma=0.5) + self._test([1], np.array([0.5]), Normal(mu=0.0, sigma=0.5)) + self._test([5], tf.constant([0.5]), Normal(mu=0.0, sigma=0.5)) + self._test([1], tf.constant([0.2, 1.5]), Normal(mu=0.0, sigma=0.5)) + self._test([5], tf.constant([0.2, 1.5]), Normal(mu=0.0, sigma=0.5)) def test_alpha_0d_base1d(self): with self.test_session(): - self._test([1], 0.5, Normal, mu=tf.zeros(3), sigma=tf.ones(3)) - self._test([5], tf.constant(0.5), Normal, - mu=tf.zeros(3), sigma=tf.ones(3)) + self._test([1], 0.5, Normal(mu=tf.zeros(3), sigma=tf.ones(3))) + self._test([5], tf.constant(0.5), + Normal(mu=tf.zeros(3), sigma=tf.ones(3))) def test_alpha_1d_base2d(self): with self.test_session(): - self._test([1], np.array([0.5]), Normal, - mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])) - self._test([5], tf.constant([0.5]), Normal, - mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])) - self._test([1], tf.constant([0.2, 1.5]), Normal, - mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])) - self._test([5], tf.constant([0.2, 1.5]), Normal, - mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4])) + self._test([1], np.array([0.5]), + Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))) + self._test([5], tf.constant([0.5]), + Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))) + self._test([1], tf.constant([0.2, 1.5]), + Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))) + self._test([5], tf.constant([0.2, 1.5]), + Normal(mu=tf.zeros([3, 4]), sigma=tf.ones([3, 4]))) def test_persistent_state(self): with self.test_session() as sess: - dp = DirichletProcess(0.1, Normal, mu=0.0, sigma=1.0) + dp = DirichletProcess(0.1, Normal(mu=0.0, sigma=1.0)) x = dp.sample(5) y = dp.sample(5) x_data, y_data, theta = sess.run([x, y, dp.theta])