Skip to content

Commit

Permalink
rename variables and adjust code
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 4a5c38d commit 1376ff4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,20 @@ package org.apache.spark.mllib.rdd

import breeze.linalg.{axpy, Vector => BV}

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
* elements count.
*/
case class VectorRDDStatisticalSummary(
mean: Vector,
variance: Vector,
count: Long,
max: Vector,
min: Vector,
nonZeroCnt: Vector) extends Serializable

/**
* Case class of the aggregate value for collecting summary statistics from RDD[Vector]. These
* values are relatively with
* [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary]], the
* latter is computed from the former.
*/
private case class VectorRDDStatisticalRing(
fakeMean: BV[Double],
fakeM2n: BV[Double],
totalCnt: Double,
nnz: BV[Double],
fakeMax: BV[Double],
fakeMin: BV[Double])
case class VectorRDDStatisticalAggregator(
mean: BV[Double],
statCounter: BV[Double],
totalCount: Double,
numNonZeros: BV[Double],
max: BV[Double],
min: BV[Double])

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
Expand All @@ -58,11 +44,12 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
* Aggregate function used for aggregating elements in a worker together.
*/
private def seqOp(
aggregator: VectorRDDStatisticalRing,
currData: BV[Double]): VectorRDDStatisticalRing = {
aggregator: VectorRDDStatisticalAggregator,
currData: BV[Double]): VectorRDDStatisticalAggregator = {
aggregator match {
case VectorRDDStatisticalRing(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
case VectorRDDStatisticalAggregator(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
currData.activeIterator.foreach {
case (id, 0.0) =>
case (id, value) =>
if (maxVec(id) < value) maxVec(id) = value
if (minVec(id) > value) minVec(id) = value
Expand All @@ -74,7 +61,7 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
nnzVec(id) += 1.0
}

VectorRDDStatisticalRing(
VectorRDDStatisticalAggregator(
prevMean,
prevM2n,
cnt + 1.0,
Expand All @@ -88,11 +75,11 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
* Combine function used for combining intermediate results together from every worker.
*/
private def combOp(
statistics1: VectorRDDStatisticalRing,
statistics2: VectorRDDStatisticalRing): VectorRDDStatisticalRing = {
statistics1: VectorRDDStatisticalAggregator,
statistics2: VectorRDDStatisticalAggregator): VectorRDDStatisticalAggregator = {
(statistics1, statistics2) match {
case (VectorRDDStatisticalRing(mean1, m2n1, cnt1, nnz1, max1, min1),
VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
case (VectorRDDStatisticalAggregator(mean1, m2n1, cnt1, nnz1, max1, min1),
VectorRDDStatisticalAggregator(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
val totalCnt = cnt1 + cnt2
val deltaMean = mean2 - mean1

Expand Down Expand Up @@ -120,51 +107,50 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
}

axpy(1.0, nnz2, nnz1)
VectorRDDStatisticalRing(mean1, m2n1, totalCnt, nnz1, max1, min1)
VectorRDDStatisticalAggregator(mean1, m2n1, totalCnt, nnz1, max1, min1)
}
}

/**
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
*/
def summarizeStatistics(size: Int): VectorRDDStatisticalSummary = {
val zeroValue = VectorRDDStatisticalRing(
def summarizeStatistics(): VectorRDDStatisticalAggregator = {
val size = self.take(1).head.size
val zeroValue = VectorRDDStatisticalAggregator(
BV.zeros[Double](size),
BV.zeros[Double](size),
0.0,
BV.zeros[Double](size),
BV.fill(size)(Double.MinValue),
BV.fill(size)(Double.MaxValue))

val VectorRDDStatisticalRing(fakeMean, fakeM2n, totalCnt, nnz, fakeMax, fakeMin) =
val VectorRDDStatisticalAggregator(currMean, currM2n, totalCnt, nnz, currMax, currMin) =
self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)

// solve real mean
val realMean = fakeMean :* nnz :/ totalCnt
val realMean = currMean :* nnz :/ totalCnt

// solve real m2n
val deltaMean = fakeMean
val realM2n = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
val deltaMean = currMean
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)

// remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
val max = Vectors.sparse(size, fakeMax.activeIterator.map { case (id, value) =>
if ((value == Double.MinValue) && (realMean(id) != Double.MinValue)) (id, 0.0)
else (id, value)
}.toSeq)
val min = Vectors.sparse(size, fakeMin.activeIterator.map { case (id, value) =>
if ((value == Double.MaxValue) && (realMean(id) != Double.MaxValue)) (id, 0.0)
else (id, value)
}.toSeq)
nnz.activeIterator.foreach {
case (id, 0.0) =>
currMax(id) = 0.0
currMin(id) = 0.0
case _ =>
}

// get variance
realM2n :/= totalCnt

VectorRDDStatisticalSummary(
Vectors.fromBreeze(realMean),
Vectors.fromBreeze(realM2n),
totalCnt.toLong,
Vectors.fromBreeze(nnz),
max,
min)
VectorRDDStatisticalAggregator(
realMean,
realM2n,
totalCnt,
nnz,
currMax,
currMin)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.rdd

import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -45,18 +44,23 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

test("full-statistics") {
val data = sc.parallelize(localData, 2)
val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) =
time(data.summarizeStatistics(3))
val (VectorRDDStatisticalAggregator(mean, variance, cnt, nnz, max, min), denseTime) =
time(data.summarizeStatistics())

assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
assert(cnt === 3, "Column cnt do not match.")
assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
assert(equivVector(Vectors.fromBreeze(mean), Vectors.dense(4.0, 5.0, 6.0)),
"Column mean do not match.")
assert(equivVector(Vectors.fromBreeze(variance), Vectors.dense(6.0, 6.0, 6.0)),
"Column variance do not match.")
assert(cnt === 3.0, "Column cnt do not match.")
assert(equivVector(Vectors.fromBreeze(nnz), Vectors.dense(3.0, 3.0, 3.0)),
"Column nnz do not match.")
assert(equivVector(Vectors.fromBreeze(max), Vectors.dense(7.0, 8.0, 9.0)),
"Column max do not match.")
assert(equivVector(Vectors.fromBreeze(min), Vectors.dense(1.0, 2.0, 3.0)),
"Column min do not match.")

val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
val (_, sparseTime) = time(dataForSparse.summarizeStatistics(20))
val (_, sparseTime) = time(dataForSparse.summarizeStatistics())

println(s"dense time is $denseTime, sparse time is $sparseTime.")
assert(relativeTime(denseTime, sparseTime),
Expand All @@ -80,5 +84,4 @@ object VectorRDDFunctionsSuite {
val denominator = math.max(lhs, rhs)
math.abs(lhs - rhs) / denominator < 0.3
}
}

}

0 comments on commit 1376ff4

Please sign in to comment.