Skip to content

Commit

Permalink
add BreezeVectorWithSquaredNorm case class
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 11, 2014
1 parent ab74f67 commit f355411
Showing 1 changed file with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom

private[clustering]
case class BreezeVectorWithSquaredNorm(vector: BV[Double], squaredNorm: Double)

/**
* K-means clustering with support for multiple parallel runs and a k-means++ like initialization
* mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
Expand Down Expand Up @@ -362,6 +365,28 @@ object KMeans {
(bestIndex, bestDistance)
}

/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
private[mllib] def findClosest(
centers: TraversableOnce[BreezeVectorWithSquaredNorm],
point: BreezeVectorWithSquaredNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
centers.foreach { center =>
val distance: Double = MLUtils.fastSquaredDistance(
center.vector, center.squaredNorm, point.vector, point.squaredNorm
)
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
}
i += 1
}
(bestIndex, bestDistance)
}

/**
* Return the K-means cost of a given point against the given cluster centers.
*/
Expand Down

0 comments on commit f355411

Please sign in to comment.