Skip to content

Commit

Permalink
full revision with Aggregator class
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 138300c commit 967d041
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 127 deletions.
155 changes: 35 additions & 120 deletions mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,12 @@ private class Aggregator(
var totalCnt: Double,
val nnz: BV[Double],
val currMax: BV[Double],
val currMin: BV[Double]) extends VectorRDDStatisticalSummary {
nnz.activeIterator.foreach {
case (id, 0.0) =>
currMax(id) = 0.0
currMin(id) = 0.0
case _ =>
val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable {

override def mean(): Vector = {
Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
}
override def mean(): Vector = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)

override def variance(): Vector = {
val deltaMean = currMean
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
Expand All @@ -58,8 +56,23 @@ private class Aggregator(
override def totalCount(): Long = totalCnt.toLong

override def numNonZeros(): Vector = Vectors.fromBreeze(nnz)
override def max(): Vector = Vectors.fromBreeze(currMax)
override def min(): Vector = Vectors.fromBreeze(currMin)

override def max(): Vector = {
nnz.activeIterator.foreach {
case (id, 0.0) => currMax(id) = 0.0
case _ =>
}
Vectors.fromBreeze(currMax)
}

override def min(): Vector = {
nnz.activeIterator.foreach {
case (id, 0.0) => currMin(id) = 0.0
case _ =>
}
Vectors.fromBreeze(currMin)
}

/**
* Aggregate function used for aggregating elements in a worker together.
*/
Expand All @@ -75,15 +88,19 @@ private class Aggregator(
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)

nnz(id) += 1.0
totalCnt += 1.0
}

totalCnt += 1.0
this
}

/**
* Combine function used for combining intermediate results together from every worker.
*/
def merge(other: this.type): this.type = {
def merge(other: Aggregator): this.type = {

totalCnt += other.totalCnt

val deltaMean = currMean - other.currMean

other.currMean.activeIterator.foreach {
Expand Down Expand Up @@ -114,132 +131,30 @@ private class Aggregator(
}
}

case class VectorRDDStatisticalAggregator(
mean: BV[Double],
statCnt: BV[Double],
totalCnt: Double,
nnz: BV[Double],
currMax: BV[Double],
currMin: BV[Double])

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
* implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
* these functions.
*/
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {

/**
* Aggregate function used for aggregating elements in a worker together.
*/
private def seqOp(
aggregator: VectorRDDStatisticalAggregator,
currData: BV[Double]): VectorRDDStatisticalAggregator = {
aggregator match {
case VectorRDDStatisticalAggregator(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
currData.activeIterator.foreach {
case (id, 0.0) =>
case (id, value) =>
if (maxVec(id) < value) maxVec(id) = value
if (minVec(id) > value) minVec(id) = value

val tmpPrevMean = prevMean(id)
prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0)
prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean)

nnzVec(id) += 1.0
}

VectorRDDStatisticalAggregator(
prevMean,
prevM2n,
cnt + 1.0,
nnzVec,
maxVec,
minVec)
}
}

/**
* Combine function used for combining intermediate results together from every worker.
*/
private def combOp(
statistics1: VectorRDDStatisticalAggregator,
statistics2: VectorRDDStatisticalAggregator): VectorRDDStatisticalAggregator = {
(statistics1, statistics2) match {
case (VectorRDDStatisticalAggregator(mean1, m2n1, cnt1, nnz1, max1, min1),
VectorRDDStatisticalAggregator(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
val totalCnt = cnt1 + cnt2
val deltaMean = mean2 - mean1

mean2.activeIterator.foreach {
case (id, 0.0) =>
case (id, value) =>
mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
}

m2n2.activeIterator.foreach {
case (id, 0.0) =>
case (id, value) =>
m2n1(id) +=
value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id))
}

max2.activeIterator.foreach {
case (id, value) =>
if (max1(id) < value) max1(id) = value
}

min2.activeIterator.foreach {
case (id, value) =>
if (min1(id) > value) min1(id) = value
}

axpy(1.0, nnz2, nnz1)
VectorRDDStatisticalAggregator(mean1, m2n1, totalCnt, nnz1, max1, min1)
}
}

/**
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
*/
def summarizeStatistics(): VectorRDDStatisticalAggregator = {
def summarizeStatistics(): VectorRDDStatisticalSummary = {
val size = self.take(1).head.size
val zeroValue = VectorRDDStatisticalAggregator(

val zeroValue = new Aggregator(
BV.zeros[Double](size),
BV.zeros[Double](size),
0.0,
BV.zeros[Double](size),
BV.fill(size)(Double.MinValue),
BV.fill(size)(Double.MaxValue))

val VectorRDDStatisticalAggregator(currMean, currM2n, totalCnt, nnz, currMax, currMin) =
self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)

// solve real mean
val realMean = currMean :* nnz :/ totalCnt

// solve real m2n
val deltaMean = currMean
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)

// remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
nnz.activeIterator.foreach {
case (id, 0.0) =>
currMax(id) = 0.0
currMin(id) = 0.0
case _ =>
}

// get variance
realM2n :/= totalCnt

VectorRDDStatisticalAggregator(
realMean,
realM2n,
totalCnt,
nnz,
currMax,
currMin)
self.map(_.toBreeze).aggregate[Aggregator](zeroValue)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

test("full-statistics") {
val data = sc.parallelize(localData, 2)
val (VectorRDDStatisticalAggregator(mean, variance, cnt, nnz, max, min), denseTime) =
val (summary, denseTime) =
time(data.summarizeStatistics())

assert(equivVector(Vectors.fromBreeze(mean), Vectors.dense(4.0, 5.0, 6.0)),
assert(equivVector(summary.mean(), Vectors.dense(4.0, 5.0, 6.0)),
"Column mean do not match.")
assert(equivVector(Vectors.fromBreeze(variance), Vectors.dense(6.0, 6.0, 6.0)),
assert(equivVector(summary.variance(), Vectors.dense(6.0, 6.0, 6.0)),
"Column variance do not match.")
assert(cnt === 3.0, "Column cnt do not match.")
assert(equivVector(Vectors.fromBreeze(nnz), Vectors.dense(3.0, 3.0, 3.0)),
assert(summary.totalCount() === 3, "Column cnt do not match.")
assert(equivVector(summary.numNonZeros(), Vectors.dense(3.0, 3.0, 3.0)),
"Column nnz do not match.")
assert(equivVector(Vectors.fromBreeze(max), Vectors.dense(7.0, 8.0, 9.0)),
assert(equivVector(summary.max(), Vectors.dense(7.0, 8.0, 9.0)),
"Column max do not match.")
assert(equivVector(Vectors.fromBreeze(min), Vectors.dense(1.0, 2.0, 3.0)),
assert(equivVector(summary.min(), Vectors.dense(1.0, 2.0, 3.0)),
"Column min do not match.")

val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
Expand Down

0 comments on commit 967d041

Please sign in to comment.