Skip to content

Commit

Permalink
add predict(RDD[Vector]) to KMeansModel
Browse files Browse the repository at this point in the history
add a test for two clusteres
  • Loading branch information
mengxr committed Mar 12, 2014
1 parent 42b4e50 commit d6e6c07
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ import org.apache.spark.mllib.linalg.Vector
*/
class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable {

private val breezeClusterCenters = clusterCenters.map { v =>
new BreezeDenseVector[Double](v)
}

/** Total number of clusters. */
def k: Int = clusterCenters.length

Expand All @@ -40,10 +36,18 @@ class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable
KMeans.findClosest(clusterCenters, point)._1
}

/** Returns the cluster index that a given point belongs to. */
def predict(point: Vector): Int = {
val breezeClusterCenters = clusterCenters.view.map(new BreezeDenseVector[Double](_))
KMeans.findClosest(breezeClusterCenters, point.toBreeze)._1
}

/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val breezeClusterCenters = clusterCenters.map(new BreezeDenseVector[Double](_))
points.map(p => KMeans.findClosest(breezeClusterCenters, p.toBreeze)._1)
}

/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
Expand All @@ -57,6 +61,7 @@ class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable
* model on the given data.
*/
def computeCost(data: RDD[Vector])(implicit d: DummyImplicit): Double = {
val breezeClusterCenters = clusterCenters.map(new BreezeDenseVector[Double](_))
data.map(p => KMeans.pointCost(breezeClusterCenters, p.toBreeze)).sum()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.linalg.Vectors

class KMeansSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -204,4 +204,29 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
assertSetsEqual(model.clusterCenters, points)
}

test("two clusters") {
val points = Array(
Array(0.0, 0.0),
Array(0.0, 0.1),
Array(0.1, 0.0),
Array(9.0, 0.0),
Array(9.0, 0.2),
Array(9.2, 0.0)
).map(Vectors.dense)
val rdd = sc.parallelize(points, 3)

for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
// Two iterations are sufficient no matter where the initial centers are.
val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode)

val predicts = model.predict(rdd).collect()

assert(predicts(0) === predicts(1))
assert(predicts(0) === predicts(2))
assert(predicts(3) === predicts(4))
assert(predicts(3) === predicts(5))
assert(predicts(0) != predicts(3))
}
}
}

0 comments on commit d6e6c07

Please sign in to comment.