From ab74f67c0168dbe2c010f2d3dc262bd4ca987640 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 9 Mar 2014 20:26:03 -0700 Subject: [PATCH] add fastSquaredDistance for KMeans --- .../org/apache/spark/mllib/util/MLUtils.scala | 51 ++++++++++++++++++- .../spark/mllib/util/MLUtilsSuite.scala | 47 +++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 64c6136a8b89d..4b461f4000208 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -21,9 +21,14 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.commons.math3.util.Precision.EPSILON + import org.jblas.DoubleMatrix + import org.apache.spark.mllib.regression.LabeledPoint +import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} + /** * Helper methods to load, save and pre-process data used in ML Lib. */ @@ -120,4 +125,48 @@ object MLUtils { } sum } -} + + /** + * Returns the squared Euclidean distance between two vectors. The following formula will be used + * if it does not introduce too much numerical error: + *
+   *   \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.
+   * 
+ * When both vector norms are given, this is faster than computing the squared distance directly, + * especially when one of the vectors is a sparse vector. + * + * @param v1 the first vector + * @param squaredNorm1 the squared norm of the first vector, non-negative + * @param v2 the second vector + * @param squaredNorm2 the squared norm of the second vector, non-negative + * @param precision desired relative precision for the squared distance + * @return squared distance between v1 and v2 within the specified precision + */ + private[mllib] def fastSquaredDistance( + v1: BV[Double], + squaredNorm1: Double, + v2: BV[Double], + squaredNorm2: Double, + precision: Double = 1e-6): Double = { + val n = v1.size + require(v2.size == n) + require(squaredNorm1 >= 0.0 && squaredNorm2 >= 0.0) + val sumSquaredNorm = squaredNorm1 + squaredNorm2 + val normDiff = math.sqrt(squaredNorm1) - math.sqrt(squaredNorm2) + var sqDist = 0.0 + val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) + if (precisionBound1 < precision) { + sqDist = sumSquaredNorm - 2.0 * v1.dot(v2) + } else if (v1.isInstanceOf[BSV[Double]] || v2.isInstanceOf[BSV[Double]]) { + val dot = v1.dot(v2) + sqDist = math.max(sumSquaredNorm - 2.0 * dot, 0.0) + val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dot)) / (sqDist + EPSILON) + if (precisionBound2 > precision) { + sqDist = breezeSquaredDistance(v1, v2) + } + } else { + sqDist = breezeSquaredDistance(v1, v2) + } + sqDist + } +} \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala new file mode 100644 index 0000000000000..e1c5d93220579 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.scalatest.FunSuite + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, + squaredDistance => breezeSquaredDistance} + +import org.apache.spark.mllib.util.MLUtils._ + +class MLUtilsSuite extends FunSuite { + + test("fast squared distance") { + val a = (30 to 0 by -1).map(math.pow(2.0, _)).toArray + val n = a.length + val v1 = new BDV[Double](a) + val norm1: Double = breezeNorm(v1) + val squaredNorm1 = norm1 * norm1 + val precision = 1e-6 + for (m <- 0 until n) { + val indices = (0 to m).toArray + val values = indices.map(i => a(i)) + val v2 = new BSV[Double](indices, values, n) + val norm2: Double = breezeNorm(v2) + val squaredNorm2 = norm2 * norm2 + val squaredDist: Double = breezeSquaredDistance(v1, v2) + val fastSquaredDist = fastSquaredDistance(v1, squaredNorm1, v2, squaredNorm2, precision) + assert((fastSquaredDist - squaredDist) <= precision * squaredDist, s"failed with m = $m") + } + } +}