Skip to content

Commit

Permalink
dbtsai-summarizer
Browse files Browse the repository at this point in the history
  • Loading branch information
dbtsai authored and DB Tsai committed Jul 11, 2014
1 parent f4f46de commit b13ac90
Show file tree
Hide file tree
Showing 5 changed files with 458 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,138 +28,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary

/**
* Column statistics aggregator implementing
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
* together with add() and merge() function.
* A numerically stable algorithm is implemented to compute sample mean and variance:
* [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
* to have time complexity O(nnz) instead of O(n) for each column.
*/
private class ColumnStatisticsAggregator(private val n: Int)
extends MultivariateStatisticalSummary with Serializable {

private val currMean: BDV[Double] = BDV.zeros[Double](n)
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
private var totalCnt = 0.0
private val nnz: BDV[Double] = BDV.zeros[Double](n)
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)

override def mean: Vector = {
val realMean = BDV.zeros[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * nnz(i) / totalCnt
i += 1
}
Vectors.fromBreeze(realMean)
}

override def variance: Vector = {
val realVariance = BDV.zeros[Double](n)

val denominator = totalCnt - 1.0

// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
while (i < currM2n.size) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= denominator
i += 1
}
}

Vectors.fromBreeze(realVariance)
}

override def count: Long = totalCnt.toLong

override def numNonzeros: Vector = Vectors.fromBreeze(nnz)

override def max: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
}

override def min: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
}

/**
* Aggregates a row.
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.
case (i, value) =>
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)

nnz(i) += 1.0
}

totalCnt += 1.0
this
}

/**
* Merges another aggregator.
*/
def merge(other: ColumnStatisticsAggregator): this.type = {
require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")

totalCnt += other.totalCnt
val deltaMean = currMean - other.currMean

var i = 0
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
i += 1
}

nnz += other.nnz
this
}
}
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}

/**
* :: Experimental ::
Expand Down Expand Up @@ -478,8 +347,7 @@ class RowMatrix(
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vectors, Vector}

/**
* :: DeveloperApi ::
* MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
* variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
* format in a online fashion.
*
* Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
* the corresponding joint dataset.
*
* A numerically stable algorithm is implemented to compute sample mean and variance:
* Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
* Zero elements (including explicit zero values) are skipped when calling add(),
* to have time complexity O(nnz) instead of O(n) for each column.
*/
@DeveloperApi
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {

private var n = 0
private var currMean: BDV[Double] = _
private var currM2n: BDV[Double] = _
private var totalCnt: Long = 0
private var nnz: BDV[Double] = _
private var currMax: BDV[Double] = _
private var currMin: BDV[Double] = _

/**
* Add a new sample to this summarizer, and update the statistical summary.
*
* @param sample The sample in dense/sparse vector format to be added into this summarizer.
* @return This MultivariateOnlineSummarizer object.
*/
def add(sample: Vector): this.type = {
if (n == 0) {
require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
n = sample.toBreeze.length

currMean = BDV.zeros[Double](n)
currM2n = BDV.zeros[Double](n)
nnz = BDV.zeros[Double](n)
currMax = BDV.fill(n)(Double.MinValue)
currMin = BDV.fill(n)(Double.MaxValue)
}

require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.toBreeze.length}.")

sample.toBreeze.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.
case (i, value) =>
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)

nnz(i) += 1.0
}

totalCnt += 1
this
}

/**
* Merge another MultivariateOnlineSummarizer, and update the statistical summary.
* (Note that it's in place merging; as a result, `this` object will be modified.)
*
* @param other The other MultivariateOnlineSummarizer to be merged.
* @return This MultivariateOnlineSummarizer object.
*/
def merge(other: MultivariateOnlineSummarizer): this.type = {
if (this.totalCnt != 0 && other.totalCnt != 0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
i += 1
}
nnz += other.nnz
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy
this.currM2n = other.currM2n.copy
this.totalCnt = other.totalCnt
this.nnz = other.nnz.copy
this.currMax = other.currMax.copy
this.currMin = other.currMin.copy
}
this
}

override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realMean = BDV.zeros[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1
}
Vectors.fromBreeze(realMean)
}

override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realVariance = BDV.zeros[Double](n)

val denominator = totalCnt - 1.0

// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
while (i < currM2n.size) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= denominator
i += 1
}
}

Vectors.fromBreeze(realVariance)
}

override def count: Long = totalCnt

override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

Vectors.fromBreeze(nnz)
}

override def max: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
}

override def min: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
}
}
Loading

0 comments on commit b13ac90

Please sign in to comment.