From f3554119dd7a555fd1ec4d9d61edd0cbf5787e77 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 10 Mar 2014 22:51:20 -0700 Subject: [PATCH] add BreezeVectorWithSquaredNorm case class --- .../spark/mllib/clustering/KMeans.scala | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index f0aeec0c882d8..52e2b20a4883b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +private[clustering] +case class BreezeVectorWithSquaredNorm(vector: BV[Double], squaredNorm: Double) + /** * K-means clustering with support for multiple parallel runs and a k-means++ like initialization * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -362,6 +365,28 @@ object KMeans { (bestIndex, bestDistance) } + /** + * Returns the index of the closest center to the given point, as well as the squared distance. + */ + private[mllib] def findClosest( + centers: TraversableOnce[BreezeVectorWithSquaredNorm], + point: BreezeVectorWithSquaredNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val distance: Double = MLUtils.fastSquaredDistance( + center.vector, center.squaredNorm, point.vector, point.squaredNorm + ) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + /** * Return the K-means cost of a given point against the given cluster centers. */