Skip to content

Commit

Permalink
For private vars needed for testing, I made them private and added ac…
Browse files Browse the repository at this point in the history
…cessors. Java doesn’t understand package-private tags, so this minimizes the issues Java users might encounter.

Change miniBatchFraction default to 0.05 to match maxIterations.

Added a little doc.

Changed end of main online LDA update code to avoid the kron() call.  Please confirm if you agree that should be more efficient (not explicitly instantiating a big matrix).

Changed Gamma() to use random seed.

Scala style updates
  • Loading branch information
jkbradley committed May 2, 2015
1 parent 6149ca6 commit cf376ff
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Random

import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
import breeze.numerics.{digamma, exp, abs}
import breeze.stats.distributions.Gamma
import breeze.stats.distributions.{Gamma, RandBasis}

import org.apache.spark.annotation.Experimental
import org.apache.spark.graphx._
Expand Down Expand Up @@ -227,20 +227,37 @@ class OnlineLDAOptimizer extends LDAOptimizer {
private var k: Int = 0
private var corpusSize: Long = 0
private var vocabSize: Int = 0
private[clustering] var alpha: Double = 0
private[clustering] var eta: Double = 0

/** alias for docConcentration */
private var alpha: Double = 0

/** (private[clustering] for debugging) Get docConcentration */
private[clustering] def getAlpha: Double = alpha

/** alias for topicConcentration */
private var eta: Double = 0

/** (private[clustering] for debugging) Get topicConcentration */
private[clustering] def getEta: Double = eta

private var randomGenerator: java.util.Random = null

// Online LDA specific parameters
// Learning rate is: (tau_0 + t)^{-kappa}
private var tau_0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.01
private var miniBatchFraction: Double = 0.05

// internal data structure
private var docs: RDD[(Long, Vector)] = null
private[clustering] var lambda: BDM[Double] = null

// count of invocation to next, which helps deciding the weight for each iteration
/** Dirichlet parameter for the posterior over topics */
private var lambda: BDM[Double] = null

/** (private[clustering] for debugging) Get parameter for topics */
private[clustering] def getLambda: BDM[Double] = lambda

/** Current iteration (count of invocations of [[next()]]) */
private var iteration: Int = 0
private var gammaShape: Double = 100

Expand Down Expand Up @@ -285,7 +302,12 @@ class OnlineLDAOptimizer extends LDAOptimizer {
/**
* Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in
* each iteration.
* Default: 0.01, i.e., 1% of total documents
*
* Note that this should be adjusted in synch with [[LDA.setMaxIterations()]]
* so the entire corpus is used. Specifically, set both so that
* maxIterations * miniBatchFraction >= 1.
*
* Default: 0.05, i.e., 5% of total documents.
*/
def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
Expand All @@ -295,15 +317,20 @@ class OnlineLDAOptimizer extends LDAOptimizer {
}

/**
* The function is for test only now. In the future, it can help support training stop/resume
* (private[clustering])
* Set the Dirichlet parameter for the posterior over topics.
* This is only used for testing now. In the future, it can help support training stop/resume.
*/
private[clustering] def setLambda(lambda: BDM[Double]): this.type = {
this.lambda = lambda
this
}

/**
* Used to control the gamma distribution. Larger value produces values closer to 1.0.
* (private[clustering])
* Used for random initialization of the variational parameters.
* Larger value produces values closer to 1.0.
* This is only used for testing currently.
*/
private[clustering] def setGammaShape(shape: Double): this.type = {
this.gammaShape = shape
Expand Down Expand Up @@ -380,12 +407,11 @@ class OnlineLDAOptimizer extends LDAOptimizer {
meanchange = sum(abs(gammad - lastgamma)) / k
}

val m1 = expElogthetad.t.toDenseMatrix.t
val m2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(m1, m2) // K * ids
val m1 = expElogthetad.t
val m2 = (ctsVector / phinorm).t.toDenseVector
var i = 0
while (i < ids.size) {
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
i += 1
}
}
Expand Down Expand Up @@ -423,7 +449,9 @@ class OnlineLDAOptimizer extends LDAOptimizer {
* Get a random matrix to initialize lambda
*/
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
randomGenerator.nextLong()))
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis)
val temp = gammaRandomGenerator.sample(row * col).toArray
new BDM[Double](col, row, temp).t
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.Serializable;
import java.util.ArrayList;

import org.apache.spark.api.java.JavaRDD;
import scala.Tuple2;

import org.junit.After;
Expand All @@ -30,6 +29,7 @@
import org.junit.Test;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
Expand Down Expand Up @@ -148,6 +148,6 @@ public void OnlineOptimizerCompatibility() {
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite$.MODULE$.tinyTopicDescription();
JavaPairRDD<Long, Vector> corpus;
private JavaPairRDD<Long, Vector> corpus;

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {

// Check: describeTopics() with all terms
val fullTopicSummary = model.describeTopics()
assert(fullTopicSummary.size === tinyK)
assert(fullTopicSummary.length === tinyK)
fullTopicSummary.zip(tinyTopicDescription).foreach {
case ((algTerms, algTermWeights), (terms, termWeights)) =>
assert(algTerms === terms)
Expand Down Expand Up @@ -101,7 +101,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
// Check: per-doc topic distributions
val topicDistributions = model.topicDistributions.collect()
// Ensure all documents are covered.
assert(topicDistributions.size === tinyCorpus.size)
assert(topicDistributions.length === tinyCorpus.length)
assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
// Ensure we have proper distributions
topicDistributions.foreach { case (docId, topicDistribution) =>
Expand Down Expand Up @@ -139,8 +139,8 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
val corpus = sc.parallelize(tinyCorpus, 2)
val op = new OnlineLDAOptimizer().initialize(corpus, lda)
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567)
assert(op.alpha == 0.5) // default 1.0 / k
assert(op.eta == 0.5) // default 1.0 / k
assert(op.getAlpha == 0.5) // default 1.0 / k
assert(op.getEta == 0.5) // default 1.0 / k
assert(op.getKappa == 0.9876)
assert(op.getMiniBatchFraction == 0.123)
assert(op.getTau_0 == 567)
Expand All @@ -154,14 +154,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {

def docs: Array[(Long, Vector)] = Array(
Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1))) // tiger, cat, dog
.zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
val corpus = sc.parallelize(docs, 2)

// setGammaShape large so to avoid the stochastic impact.
// Set GammaShape large to avoid the stochastic impact.
val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40)
.setMiniBatchFraction(1)
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op)
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345)

val state = op.initialize(corpus, lda)
// override lambda to simulate an intermediate state
Expand All @@ -175,8 +175,8 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {

// verify the result, Note this generate the identical result as
// [[https://github.com/Blei-Lab/onlineldavb]]
val topic1 = op.lambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic2 = op.lambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
}
Expand All @@ -186,7 +186,6 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
Vectors.sparse(6, Array(1, 2), Array(1, 1)),
Vectors.sparse(6, Array(0, 2), Array(1, 1)),

Vectors.sparse(6, Array(3, 4), Array(1, 1)),
Vectors.sparse(6, Array(3, 5), Array(1, 1)),
Vectors.sparse(6, Array(4, 5), Array(1, 1))
Expand All @@ -200,6 +199,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setTopicConcentration(0.01)
.setMaxIterations(100)
.setOptimizer(op)
.setSeed(12345)

val ldaModel = lda.run(docs)
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
Expand All @@ -208,10 +208,10 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
}

// check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02)
topics.foreach(topic =>{
val smalls = topic.filter(t => (t._2 < 0.1)).map(_._2)
assert(smalls.size == 3 && smalls.sum < 0.2)
})
topics.foreach { topic =>
val smalls = topic.filter(t => t._2 < 0.1).map(_._2)
assert(smalls.length == 3 && smalls.sum < 0.2)
}
}

}
Expand Down

0 comments on commit cf376ff

Please sign in to comment.