Skip to content

Commit

Permalink
SPARK-4156 [MLLIB] EM algorithm for GMMs
Browse files Browse the repository at this point in the history
Implementation of Expectation-Maximization for Gaussian Mixture Models.

This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license.

Author: Travis Galoppo <[email protected]>
Author: Travis Galoppo <[email protected]>
Author: tgaloppo <[email protected]>
Author: FlytxtRnD <[email protected]>

Closes #3022 from tgaloppo/master and squashes the following commits:

aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib]
709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter
acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM
9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class
b97fe00 [Travis Galoppo] Minor fixes and tweaks.
1de73f3 [Travis Galoppo] Removed redundant array from array creation
578c2d1 [Travis Galoppo] Removed unused import
227ad66 [Travis Galoppo] Moved prediction methods into model class.
308c8ad [Travis Galoppo] Numerous changes to improve code
cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate
20ebca1 [Travis Galoppo] Removed unusued code
42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added two cluster test to testing suite.
8b633f3 [Travis Galoppo] Style issue
9be2534 [Travis Galoppo] Style issue
d695034 [Travis Galoppo] Fixed style issues
c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark   Adds predict() method
2df336b [Travis Galoppo] Fixed style issue
b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch
f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values
97044cf [Travis Galoppo] Fixed style issues
dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class
e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved comments
9770261 [Travis Galoppo] Corrected a variety of style and naming issues.
8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count and tolerance parameters.
676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line
e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context changes. Improved cluster initialization strategy.
86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master'
719d8cc [Travis Galoppo] Added scala test suite with basic test
c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function for better performance
5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master'
c15405c [Travis Galoppo] SPARK-4156
  • Loading branch information
tgaloppo authored and mengxr committed Dec 29, 2014
1 parent 9bc0df6 commit 6cf6fdf
Show file tree
Hide file tree
Showing 6 changed files with 517 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.GaussianMixtureEM
import org.apache.spark.mllib.linalg.Vectors

/**
* An example Gaussian Mixture Model EM app. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM <input> <k> <covergenceTol>
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DenseGmmEM {
def main(args: Array[String]): Unit = {
if (args.length < 3) {
println("usage: DenseGmmEM <input file> <k> <convergenceTol> [maxIterations]")
} else {
val maxIterations = if (args.length > 3) args(3).toInt else 100
run(args(0), args(1).toInt, args(2).toDouble, maxIterations)
}
}

private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
val ctx = new SparkContext(conf)

val data = ctx.textFile(inputFile).map { line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))
}.cache()

val clusters = new GaussianMixtureEM()
.setK(k)
.setConvergenceTol(convergenceTol)
.setMaxIterations(maxIterations)
.run(data)

for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
}

println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>
print(" " + x)
}
println()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering

import scala.collection.mutable.IndexedSeq

import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils

/**
* This class performs expectation maximization for multivariate Gaussian
* Mixture Models (GMMs). A GMM represents a composite distribution of
* independent Gaussian distributions with associated "mixing" weights
* specifying each's contribution to the composite.
*
* Given a set of sample points, this class will maximize the log-likelihood
* for a mixture of k Gaussians, iterating until the log-likelihood changes by
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
* to find a global optimum.
*
* @param k The number of independent Gaussians in the mixture model
* @param convergenceTol The maximum change in log-likelihood at which convergence
* is considered to have occurred.
* @param maxIterations The maximum number of iterations to perform
*/
class GaussianMixtureEM private (
private var k: Int,
private var convergenceTol: Double,
private var maxIterations: Int) extends Serializable {

/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
def this() = this(2, 0.01, 100)

// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5

// an initializing GMM can be provided rather than using the
// default random starting point
private var initialModel: Option[GaussianMixtureModel] = None

/** Set the initial GMM starting point, bypassing the random initialization.
* You must call setK() prior to calling this method, and the condition
* (model.k == this.k) must be met; failure will result in an IllegalArgumentException
*/
def setInitialModel(model: GaussianMixtureModel): this.type = {
if (model.k == k) {
initialModel = Some(model)
} else {
throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
}
this
}

/** Return the user supplied initial GMM, if supplied */
def getInitialModel: Option[GaussianMixtureModel] = initialModel

/** Set the number of Gaussians in the mixture model. Default: 2 */
def setK(k: Int): this.type = {
this.k = k
this
}

/** Return the number of Gaussians in the mixture model */
def getK: Int = k

/** Set the maximum number of iterations to run. Default: 100 */
def setMaxIterations(maxIterations: Int): this.type = {
this.maxIterations = maxIterations
this
}

/** Return the maximum number of iterations to run */
def getMaxIterations: Int = maxIterations

/**
* Set the largest change in log-likelihood at which convergence is
* considered to have occurred.
*/
def setConvergenceTol(convergenceTol: Double): this.type = {
this.convergenceTol = convergenceTol
this
}

/** Return the largest change in log-likelihood at which convergence is
* considered to have occurred.
*/
def getConvergenceTol: Double = convergenceTol

/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext

// we will operate on the data as breeze data
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()

// Get length of the input vectors
val d = breezeData.first.length

// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
// we start with uniform weights, a random mean from the data, and
// diagonal covariance matrices using component variances
// derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
})

case None => {
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
})
}
}

var llh = Double.MinValue // current log-likelihood
var llhp = 0.0 // previous log-likelihood

var iter = 0
while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) {
// create and broadcast curried cluster contribution function
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)

// aggregate the cluster contribution for all sample points
val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)

// Create new distributions based on the partial assignments
// (often referred to as the "M" step in literature)
val sumWeights = sums.weights.sum
var i = 0
while (i < k) {
val mu = sums.means(i) / sums.weights(i)
val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
weights(i) = sums.weights(i) / sumWeights
gaussians(i) = new MultivariateGaussian(mu, sigma)
i = i + 1
}

llhp = llh // current becomes previous
llh = sums.logLikelihood // this is the freshly computed log-likelihood
iter += 1
}

// Need to convert the breeze matrices to MLlib matrices
val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
new GaussianMixtureModel(weights, means, sigmas)
}

/** Average of dense breeze vectors */
private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
val v = BreezeVector.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
v / x.length.toDouble
}

/**
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
*/
private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
val mu = vectorMean(x)
val ss = BreezeVector.zeros[Double](x(0).length)
x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
diag(ss / x.length.toDouble)
}
}

// companion class to provide zero constructor for ExpectationSum
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
new ExpectationSum(0.0, Array.fill(k)(0.0),
Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
}

// compute cluster contributions for each input point
// (U, T) => U for aggregation
def add(
weights: Array[Double],
dists: Array[MultivariateGaussian])
(sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
}
val pSum = p.sum
sums.logLikelihood += math.log(pSum)
val xxt = x * new Transpose(x)
var i = 0
while (i < sums.k) {
p(i) /= pSum
sums.weights(i) += p(i)
sums.means(i) += x * p(i)
sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
i = i + 1
}
sums
}
}

// Aggregation class for partial expectation results
private class ExpectationSum(
var logLikelihood: Double,
val weights: Array[Double],
val means: Array[BreezeVector[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {

val k = weights.length

def +=(x: ExpectationSum): ExpectationSum = {
var i = 0
while (i < k) {
weights(i) += x.weights(i)
means(i) += x.means(i)
sigmas(i) += x.sigmas(i)
i = i + 1
}
logLikelihood += x.logLikelihood
this
}
}
Loading

0 comments on commit 6cf6fdf

Please sign in to comment.