Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [SPARK-1328] Add vector statistics #268

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8c6c0e1
add basic statistics
yinxusen Mar 28, 2014
54b19ab
add new API to shrink RDD[Vector]
yinxusen Mar 28, 2014
28cf060
fix error of column means
yinxusen Mar 29, 2014
8ef3377
pass all tests
yinxusen Mar 29, 2014
e09d5d2
add scala docs and refine shrink method
yinxusen Mar 29, 2014
ad6c82d
add shrink test
yinxusen Mar 29, 2014
9af2e95
refine the code style
yinxusen Mar 29, 2014
cc65810
add parallel mean and variance
yinxusen Mar 30, 2014
1338ea1
all-in-one version test passed
yinxusen Mar 30, 2014
c4651bb
remove row-wise APIs and refine code
yinxusen Mar 30, 2014
d816ac7
remove useless APIs
yinxusen Mar 30, 2014
9a75ebd
add case class to wrap return values
yinxusen Apr 1, 2014
62a2c3e
use axpy and in-place if possible
yinxusen Apr 1, 2014
3980287
rename variables
yinxusen Apr 1, 2014
a6d5a2e
rewrite for only computing non-zero elements
yinxusen Apr 1, 2014
4e4fbd1
separate seqop and combop out as independent functions
yinxusen Apr 1, 2014
4cfbadf
fix bug of min max
yinxusen Apr 1, 2014
f6e8e9a
add sparse vectors test
yinxusen Apr 1, 2014
036b7a5
fix the bug of Nan occur
yinxusen Apr 1, 2014
4a5c38d
add scala doc, refine code and comments
yinxusen Apr 1, 2014
1376ff4
rename variables and adjust code
yinxusen Apr 2, 2014
138300c
add new Aggregator class
yinxusen Apr 2, 2014
967d041
full revision with Aggregator class
yinxusen Apr 2, 2014
f7a3ca2
fix the corner case of maxmin
yinxusen Apr 2, 2014
18cf072
change def to lazy val to make sure that the computations in function…
yinxusen Apr 2, 2014
dc77e38
test sparse vector RDD
yinxusen Apr 2, 2014
86522c4
add comments on functions
yinxusen Apr 2, 2014
548e9de
minor revision
yinxusen Apr 3, 2014
69e1f37
remove lazy eval, and minor memory footprint
yinxusen Apr 3, 2014
1fba230
merge while loop together
yinxusen Apr 3, 2014
e624f93
fix scala style error
yinxusen Apr 3, 2014
48ee053
fix minor error
yinxusen Apr 4, 2014
4eaf28a
merge VectorRDDStatistics into RowMatrix
yinxusen Apr 9, 2014
cbbefdb
update multivariate statistical summary interface and clean tests
mengxr Apr 9, 2014
b064714
remove computeStat in MLUtils
yinxusen Apr 10, 2014
10cf5d3
refine some return type
yinxusen Apr 10, 2014
16ae684
fix minor error and remove useless method
yinxusen Apr 10, 2014
d61363f
rebase to latest master
yinxusen Apr 10, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,146 @@ package org.apache.spark.mllib.linalg.distributed

import java.util

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary

/**
* Column statistics aggregator implementing
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
* together with add() and merge() function.
* A numerically stable algorithm is implemented to compute sample mean and variance:
*[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
* to have time complexity O(nnz) instead of O(n) for each column.
*/
private class ColumnStatisticsAggregator(private val n: Int)
extends MultivariateStatisticalSummary with Serializable {

private val currMean: BDV[Double] = BDV.zeros[Double](n)
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
private var totalCnt = 0.0
private val nnz: BDV[Double] = BDV.zeros[Double](n)
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)

override def mean: Vector = {
val realMean = BDV.zeros[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * nnz(i) / totalCnt
i += 1
}
Vectors.fromBreeze(realMean)
}

override def variance: Vector = {
val realVariance = BDV.zeros[Double](n)

val denominator = totalCnt - 1.0

// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
while (i < currM2n.size) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= denominator
i += 1
}
}

Vectors.fromBreeze(realVariance)
}

override def count: Long = totalCnt.toLong

override def numNonzeros: Vector = Vectors.fromBreeze(nnz)

override def max: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
}

override def min: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
}

/**
* Aggregates a row.
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.
case (i, value) =>
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)

nnz(i) += 1.0
}

totalCnt += 1.0
this
}

/**
* Merges another aggregator.
*/
def merge(other: ColumnStatisticsAggregator): this.type = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has a few too many blank lines, e.g. there's no need to have one at the beginning. Probably fine if we merge this as is but if you make another pass fix these.

require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")

totalCnt += other.totalCnt
val deltaMean = currMean - other.currMean

var i = 0
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
i += 1
}

nnz += other.nnz
this
}
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -182,13 +314,7 @@ class RowMatrix(
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
)

// Update _m if it is not set, or verify its value.
if (nRows <= 0L) {
nRows = m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
updateNumRows(m)

mean :/= m.toDouble

Expand Down Expand Up @@ -240,6 +366,19 @@ class RowMatrix(
}
}

/**
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
updateNumRows(summary.count)
summary
}

/**
* Multiply this matrix by a local matrix on the right.
*
Expand Down Expand Up @@ -276,6 +415,16 @@ class RowMatrix(
}
mat
}

/** Updates or verfires the number of rows. */
private def updateNumRows(m: Long) {
if (nRows <= 0) {
nRows == m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
}
}

object RowMatrix {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat

import org.apache.spark.mllib.linalg.Vector

/**
* Trait for multivariate statistical summary of a data matrix.
*/
trait MultivariateStatisticalSummary {

/**
* Sample mean vector.
*/
def mean: Vector

/**
* Sample variance vector. Should return a zero vector if the sample size is 1.
*/
def variance: Vector

/**
* Sample size.
*/
def count: Long

/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
def numNonzeros: Vector

/**
* Maximum value of each column.
*/
def max: Vector

/**
* Minimum value of each column.
*/
def min: Vector
}
57 changes: 2 additions & 55 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.mllib.util

import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
squaredDistance => breezeSquaredDistance}
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}

import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.Vectors

/**
* Helper methods to load, save and pre-process data used in ML Lib.
Expand Down Expand Up @@ -158,58 +157,6 @@ object MLUtils {
dataStr.saveAsTextFile(dir)
}

/**
* Utility function to compute mean and standard deviation on a given dataset.
*
* @param data - input data set whose statistics are computed
* @param numFeatures - number of features
* @param numExamples - number of examples in input dataset
*
* @return (yMean, xColMean, xColSd) - Tuple consisting of
* yMean - mean of the labels
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
private[mllib] def computeStats(
data: RDD[LabeledPoint],
numFeatures: Int,
numExamples: Long): (Double, Vector, Vector) = {
val brzData = data.map { case LabeledPoint(label, features) =>
(label, features.toBreeze)
}
val aggStats = brzData.aggregate(
(0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
)(
seqOp = (c, v) => (c, v) match {
case ((n, sumLabel, sum, sumSq), (label, features)) =>
features.activeIterator.foreach { case (i, x) =>
sumSq(i) += x * x
}
(n + 1L, sumLabel + label, sum += features, sumSq)
},
combOp = (c1, c2) => (c1, c2) match {
case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
(n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
}
)
val (nl, sumLabel, sum, sumSq) = aggStats

require(nl > 0, "Input data is empty.")
require(nl == numExamples)

val n = nl.toDouble
val yMean = sumLabel / n
val mean = sum / n
val std = new Array[Double](sum.length)
var i = 0
while (i < numFeatures) {
std(i) = sumSq(i) / n - mean(i) * mean(i)
i += 1
}

(yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
}

/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,19 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
))
}
}

test("compute column summary statistics") {
for (mat <- Seq(denseMat, sparseMat)) {
val summary = mat.computeColumnSummaryStatistics()
// Run twice to make sure no internal states are changed.
for (k <- 0 to 1) {
assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
assert(summary.count === m, "count mismatch.")
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._

class MLUtilsSuite extends FunSuite with LocalSparkContext {
Expand Down Expand Up @@ -56,18 +55,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}

test("compute stats") {
val data = Seq.fill(3)(Seq(
LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
)).flatten
val rdd = sc.parallelize(data, 2)
val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
assert(meanLabel === 0.5)
assert(mean === Vectors.dense(2.0, 3.0, 4.0))
assert(std === Vectors.dense(1.0, 1.0, 1.0))
}

test("loadLibSVMData") {
val lines =
"""
Expand Down