diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala index 79416ae734c52..33e5760aed997 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala @@ -50,7 +50,6 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} object StreamingKMeans { def main(args: Array[String]) { - if (args.length != 5) { System.err.println( "Usage: StreamingKMeans " + @@ -67,14 +66,12 @@ object StreamingKMeans { val model = new StreamingKMeans() .setK(args(3).toInt) .setDecayFactor(1.0) - .setRandomCenters(args(4).toInt) + .setRandomCenters(args(4).toInt, 0.0) model.trainOn(trainingData) model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3a6451118ca5e..5919c3d30a277 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag -import breeze.linalg.{Vector => BV} - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.{Vectors, Vector} -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * :: DeveloperApi :: @@ -66,55 +65,81 @@ import org.apache.spark.util.Utils @DeveloperApi class StreamingKMeansModel( override val clusterCenters: Array[Vector], - val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging { + val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { /** Perform a k-means update on a batch of data. */ def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { - val centers = clusterCenters - val counts = clusterCounts - // find nearest cluster to each point - val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong))) + val closest = data.map(point => (this.predict(point), (point, 1L))) // get sums and counts for updating each cluster - type WeightedPoint = (BV[Double], Long) - def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = { - (p1._1 += p2._1, p1._2 + p2._2) + val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => { + BLAS.axpy(1.0, p2._1, p1._1) + (p1._1, p1._2 + p2._2) } - val pointStats: Array[(Int, (BV[Double], Long))] = - closest.reduceByKey(mergeContribs).collect() + val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest + .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) + .collect() + + val discount = timeUnit match { + case StreamingKMeans.BATCHES => decayFactor + case StreamingKMeans.POINTS => + val numNewPoints = pointStats.view.map { case (_, (_, n)) => + n + }.sum + math.pow(decayFactor, numNewPoints) + } + + // apply discount to weights + BLAS.scal(discount, Vectors.dense(clusterWeights)) // implement update rule - pointStats.foreach { case (label, (mean, count)) => - // store old count and centroid - val oldCount = counts(label) - val oldCentroid = centers(label).toBreeze - // get new count and centroid - val newCount = count - val newCentroid = mean / newCount.toDouble - // compute the normalized scale factor that controls forgetting - val lambda = timeUnit match { - case "batches" => newCount / (decayFactor * oldCount + newCount) - case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount) - } - // perform the update - val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda - // store the new counts and centers - counts(label) = oldCount + newCount - centers(label) = Vectors.fromBreeze(updatedCentroid) + pointStats.foreach { case (label, (sum, count)) => + val centroid = clusterCenters(label) + + val updatedWeight = clusterWeights(label) + count + val lambda = count / math.max(updatedWeight, 1e-16) + + clusterWeights(label) = updatedWeight + BLAS.scal(1.0 - lambda, centroid) + BLAS.axpy(lambda / count, sum, centroid) // display the updated cluster centers - val display = centers(label).size match { - case x if x > 100 => centers(label).toArray.take(100).mkString("[", ",", "...") - case _ => centers(label).toArray.mkString("[", ",", "]") + val display = clusterCenters(label).size match { + case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...") + case _ => centroid.toArray.mkString("[", ",", "]") + } + + logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display") + } + + // Check whether the smallest cluster is dying. If so, split the largest cluster. + val weightsWithIndex = clusterWeights.view.zipWithIndex + val (maxWeight, largest) = weightsWithIndex.maxBy(_._1) + val (minWeight, smallest) = weightsWithIndex.minBy(_._1) + if (minWeight < 1e-8 * maxWeight) { + logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.") + val weight = (maxWeight + minWeight) / 2.0 + clusterWeights(largest) = weight + clusterWeights(smallest) = weight + val largestClusterCenter = clusterCenters(largest) + val smallestClusterCenter = clusterCenters(smallest) + var j = 0 + while (j < dim) { + val x = largestClusterCenter(j) + val p = 1e-14 * math.max(math.abs(x), 1.0) + largestClusterCenter.toBreeze(j) = x + p + smallestClusterCenter.toBreeze(j) = x - p + j += 1 } - logInfo("Cluster %d updated: %s ".format (label, display)) } - new StreamingKMeansModel(centers, counts) - } + this + } } + /** * :: DeveloperApi :: * StreamingKMeans provides methods for configuring a @@ -128,7 +153,7 @@ class StreamingKMeansModel( * val model = new StreamingKMeans() * .setDecayFactor(0.5) * .setK(3) - * .setRandomCenters(5) + * .setRandomCenters(5, 100.0) * .trainOn(DStream) */ @DeveloperApi @@ -137,9 +162,9 @@ class StreamingKMeans( var decayFactor: Double, var timeUnit: String) extends Logging { - protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) + def this() = this(2, 1.0, StreamingKMeans.BATCHES) - def this() = this(2, 1.0, "batches") + protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) /** Set the number of clusters. */ def setK(k: Int): this.type = { @@ -155,7 +180,7 @@ class StreamingKMeans( /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ def setHalfLife(halfLife: Double, timeUnit: String): this.type = { - if (timeUnit != "batches" && timeUnit != "points") { + if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) } this.decayFactor = math.exp(math.log(0.5) / halfLife) @@ -165,26 +190,23 @@ class StreamingKMeans( } /** Specify initial centers directly. */ - def setInitialCenters(initialCenters: Array[Vector]): this.type = { - val clusterCounts = new Array[Long](this.k) - this.model = new StreamingKMeansModel(initialCenters, clusterCounts) + def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + model = new StreamingKMeansModel(centers, weights) this } - /** Initialize random centers, requiring only the number of dimensions. - * - * @param dim Number of dimensions - * @param seed Random seed - * */ - def setRandomCenters(dim: Int, seed: Long = Utils.random.nextLong): this.type = { - - val random = Utils.random - random.setSeed(seed) - - val initialCenters = (0 until k) - .map(_ => Vectors.dense(Array.fill(dim)(random.nextGaussian()))).toArray - val clusterCounts = new Array[Long](this.k) - this.model = new StreamingKMeansModel(initialCenters, clusterCounts) + /** + * Initialize random centers, requiring only the number of dimensions. + * + * @param dim Number of dimensions + * @param weight Weight for each center + * @param seed Random seed + */ + def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + val random = new XORShiftRandom(seed) + val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) + val weights = Array.fill(k)(weight) + model = new StreamingKMeansModel(centers, weights) this } @@ -202,9 +224,9 @@ class StreamingKMeans( * @param data DStream containing vector data */ def trainOn(data: DStream[Vector]) { - this.assertInitialized() + assertInitialized() data.foreachRDD { (rdd, time) => - model = model.update(rdd, this.decayFactor, this.timeUnit) + model = model.update(rdd, decayFactor, timeUnit) } } @@ -215,7 +237,7 @@ class StreamingKMeans( * @return DStream containing predictions */ def predictOn(data: DStream[Vector]): DStream[Int] = { - this.assertInitialized() + assertInitialized() data.map(model.predict) } @@ -227,16 +249,20 @@ class StreamingKMeans( * @return DStream containing the input keys and the predictions as values */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { - this.assertInitialized() + assertInitialized() data.mapValues(model.predict) } /** Check whether cluster centers have been initialized. */ - def assertInitialized(): Unit = { - if (Option(model.clusterCenters) == None) { + private[this] def assertInitialized(): Unit = { + if (model.clusterCenters == null) { throw new IllegalStateException( "Initial cluster centers must be set before starting predictions") } } +} +private[clustering] object StreamingKMeans { + final val BATCHES = "batches" + final val POINTS = "points" } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index de79c7026a696..850c9fce507cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,21 +17,19 @@ package org.apache.spark.mllib.clustering -import scala.util.Random - import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.random.XORShiftRandom class StreamingKMeansSuite extends FunSuite with TestSuiteBase { override def maxWaitTimeMillis = 30000 test("accuracy for single center and equivalence to grand average") { - // set parameters val numBatches = 10 val numPoints = 50 @@ -43,9 +41,9 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { val model = new StreamingKMeans() .setK(1) .setDecayFactor(1.0) - .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0))) + .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0)) - // generate random data for kmeans + // generate random data for k-means val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training @@ -60,13 +58,12 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 - val grandMean = input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + val grandMean = + input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) - } test("accuracy for two centers") { - val numBatches = 10 val numPoints = 5 val k = 2 @@ -74,27 +71,66 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { val r = 0.1 // create model with two clusters - val model = new StreamingKMeans() + val kMeans = new StreamingKMeans() .setK(2) - .setDecayFactor(1.0) - .setInitialCenters(Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), - Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1))) + .setHalfLife(2, "batches") + .setInitialCenters( + Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), + Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)), + Array(5.0, 5.0)) - // generate random data for kmeans + // generate random data for k-means val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { - model.trainOn(inputDStream) + kMeans.trainOn(inputDStream) inputDStream.count() }) runStreams(ssc, numBatches, numBatches) // check that estimated centers are close to true centers // NOTE exact assignment depends on the initialization! - assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) - assert(centers(1) ~== model.latestModel().clusterCenters(1) absTol 1E-1) + assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + } + + test("detecting dying clusters") { + val numBatches = 10 + val numPoints = 5 + val k = 1 + val d = 1 + val r = 1.0 + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(0.5, "points") + .setInitialCenters( + Array(Vectors.dense(0.0), Vectors.dense(1000.0)), + Array(1.0, 1.0)) + + // new data are all around the first cluster 0.0 + val (input, _) = + StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + val model = kMeans.latestModel() + val c0 = model.clusterCenters(0)(0) + val c1 = model.clusterCenters(1)(0) + + assert(c0 * c1 < 0.0, "should have one positive center and one negative center") + // 0.8 is the mean of half-normal distribution + assert(math.abs(c0) ~== 0.8 absTol 0.6) + assert(math.abs(c1) ~== 0.8 absTol 0.6) } def StreamingKMeansDataGenerator( @@ -105,7 +141,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { r: Double, seed: Int, initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = { - val rand = new Random(seed) + val rand = new XORShiftRandom(seed) val centers = initCenters match { case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) case _ => initCenters @@ -118,6 +154,4 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { } (data, centers) } - - }