Skip to content

Commit

Permalink
take discount on previous weights; use BLAS; detect dying clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Oct 31, 2014
1 parent 0411bf5 commit 2e682c0
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
object StreamingKMeans {

def main(args: Array[String]) {

if (args.length != 5) {
System.err.println(
"Usage: StreamingKMeans " +
Expand All @@ -67,14 +66,12 @@ object StreamingKMeans {
val model = new StreamingKMeans()
.setK(args(3).toInt)
.setDecayFactor(1.0)
.setRandomCenters(args(4).toInt)
.setRandomCenters(args(4).toInt, 0.0)

model.trainOn(trainingData)
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()

ssc.start()
ssc.awaitTermination()

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering

import scala.reflect.ClassTag

import breeze.linalg.{Vector => BV}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -66,55 +65,81 @@ import org.apache.spark.util.Utils
@DeveloperApi
class StreamingKMeansModel(
override val clusterCenters: Array[Vector],
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging {
val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {

/** Perform a k-means update on a batch of data. */
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {

val centers = clusterCenters
val counts = clusterCounts

// find nearest cluster to each point
val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong)))
val closest = data.map(point => (this.predict(point), (point, 1L)))

// get sums and counts for updating each cluster
type WeightedPoint = (BV[Double], Long)
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
(p1._1 += p2._1, p1._2 + p2._2)
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
BLAS.axpy(1.0, p2._1, p1._1)
(p1._1, p1._2 + p2._2)
}
val pointStats: Array[(Int, (BV[Double], Long))] =
closest.reduceByKey(mergeContribs).collect()
val dim = clusterCenters(0).size
val pointStats: Array[(Int, (Vector, Long))] = closest
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
.collect()

val discount = timeUnit match {
case StreamingKMeans.BATCHES => decayFactor
case StreamingKMeans.POINTS =>
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
n
}.sum
math.pow(decayFactor, numNewPoints)
}

// apply discount to weights
BLAS.scal(discount, Vectors.dense(clusterWeights))

// implement update rule
pointStats.foreach { case (label, (mean, count)) =>
// store old count and centroid
val oldCount = counts(label)
val oldCentroid = centers(label).toBreeze
// get new count and centroid
val newCount = count
val newCentroid = mean / newCount.toDouble
// compute the normalized scale factor that controls forgetting
val lambda = timeUnit match {
case "batches" => newCount / (decayFactor * oldCount + newCount)
case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
}
// perform the update
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
// store the new counts and centers
counts(label) = oldCount + newCount
centers(label) = Vectors.fromBreeze(updatedCentroid)
pointStats.foreach { case (label, (sum, count)) =>
val centroid = clusterCenters(label)

val updatedWeight = clusterWeights(label) + count
val lambda = count / math.max(updatedWeight, 1e-16)

clusterWeights(label) = updatedWeight
BLAS.scal(1.0 - lambda, centroid)
BLAS.axpy(lambda / count, sum, centroid)

// display the updated cluster centers
val display = centers(label).size match {
case x if x > 100 => centers(label).toArray.take(100).mkString("[", ",", "...")
case _ => centers(label).toArray.mkString("[", ",", "]")
val display = clusterCenters(label).size match {
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
case _ => centroid.toArray.mkString("[", ",", "]")
}

logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
}

// Check whether the smallest cluster is dying. If so, split the largest cluster.
val weightsWithIndex = clusterWeights.view.zipWithIndex
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
if (minWeight < 1e-8 * maxWeight) {
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
val weight = (maxWeight + minWeight) / 2.0
clusterWeights(largest) = weight
clusterWeights(smallest) = weight
val largestClusterCenter = clusterCenters(largest)
val smallestClusterCenter = clusterCenters(smallest)
var j = 0
while (j < dim) {
val x = largestClusterCenter(j)
val p = 1e-14 * math.max(math.abs(x), 1.0)
largestClusterCenter.toBreeze(j) = x + p
smallestClusterCenter.toBreeze(j) = x - p
j += 1
}
logInfo("Cluster %d updated: %s ".format (label, display))
}
new StreamingKMeansModel(centers, counts)
}

this
}
}

/**
* :: DeveloperApi ::
* StreamingKMeans provides methods for configuring a
Expand All @@ -128,7 +153,7 @@ class StreamingKMeansModel(
* val model = new StreamingKMeans()
* .setDecayFactor(0.5)
* .setK(3)
* .setRandomCenters(5)
* .setRandomCenters(5, 100.0)
* .trainOn(DStream)
*/
@DeveloperApi
Expand All @@ -137,9 +162,9 @@ class StreamingKMeans(
var decayFactor: Double,
var timeUnit: String) extends Logging {

protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
def this() = this(2, 1.0, StreamingKMeans.BATCHES)

def this() = this(2, 1.0, "batches")
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)

/** Set the number of clusters. */
def setK(k: Int): this.type = {
Expand All @@ -155,7 +180,7 @@ class StreamingKMeans(

/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
if (timeUnit != "batches" && timeUnit != "points") {
if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
}
this.decayFactor = math.exp(math.log(0.5) / halfLife)
Expand All @@ -165,26 +190,23 @@ class StreamingKMeans(
}

/** Specify initial centers directly. */
def setInitialCenters(initialCenters: Array[Vector]): this.type = {
val clusterCounts = new Array[Long](this.k)
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
model = new StreamingKMeansModel(centers, weights)
this
}

/** Initialize random centers, requiring only the number of dimensions.
*
* @param dim Number of dimensions
* @param seed Random seed
* */
def setRandomCenters(dim: Int, seed: Long = Utils.random.nextLong): this.type = {

val random = Utils.random
random.setSeed(seed)

val initialCenters = (0 until k)
.map(_ => Vectors.dense(Array.fill(dim)(random.nextGaussian()))).toArray
val clusterCounts = new Array[Long](this.k)
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
/**
* Initialize random centers, requiring only the number of dimensions.
*
* @param dim Number of dimensions
* @param weight Weight for each center
* @param seed Random seed
*/
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight)
model = new StreamingKMeansModel(centers, weights)
this
}

Expand All @@ -202,9 +224,9 @@ class StreamingKMeans(
* @param data DStream containing vector data
*/
def trainOn(data: DStream[Vector]) {
this.assertInitialized()
assertInitialized()
data.foreachRDD { (rdd, time) =>
model = model.update(rdd, this.decayFactor, this.timeUnit)
model = model.update(rdd, decayFactor, timeUnit)
}
}

Expand All @@ -215,7 +237,7 @@ class StreamingKMeans(
* @return DStream containing predictions
*/
def predictOn(data: DStream[Vector]): DStream[Int] = {
this.assertInitialized()
assertInitialized()
data.map(model.predict)
}

Expand All @@ -227,16 +249,20 @@ class StreamingKMeans(
* @return DStream containing the input keys and the predictions as values
*/
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
this.assertInitialized()
assertInitialized()
data.mapValues(model.predict)
}

/** Check whether cluster centers have been initialized. */
def assertInitialized(): Unit = {
if (Option(model.clusterCenters) == None) {
private[this] def assertInitialized(): Unit = {
if (model.clusterCenters == null) {
throw new IllegalStateException(
"Initial cluster centers must be set before starting predictions")
}
}
}

private[clustering] object StreamingKMeans {
final val BATCHES = "batches"
final val POINTS = "points"
}
Loading

0 comments on commit 2e682c0

Please sign in to comment.