Skip to content

Commit

Permalink
add basic statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 3bd3129 commit 8c6c0e1
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.rdd

import breeze.linalg.{Vector => BV, *}

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.rdd.RDD

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an implicit conversion.
* Import `org.apache.spark.MLContext._` at the top of your program to use these functions.
*/
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {

def rowMeans(): RDD[Double] = {
self.map(x => x.toArray.sum / x.size)
}

def rowNorm2(): RDD[Double] = {
self.map(x => math.sqrt(x.toArray.map(x => x*x).sum))
}

def rowSDs(): RDD[Double] = {
val means = self.rowMeans()
self.zip(means)
.map{ case(x, m) => x.toBreeze - m }
.map{ x => math.sqrt(x.toArray.map(x => x*x).sum / x.size) }
}

def colMeansOption(): Vector = {
???
}

def colNorm2Option(): Vector = {
???
}

def colSDsOption(): Vector = {
???
}

def colMeans(): Vector = {
Vectors.fromBreeze(self.map(_.toBreeze).zipWithIndex().fold((BV.zeros(1), 0L)) {
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
val totalNow: BV[Double] = lhsVec :* lhsCnt.asInstanceOf[Double]
val totalNew: BV[Double] = (totalNow + rhsVec) :/ rhsCnt.asInstanceOf[Double]
(totalNew, rhsCnt)
}._1)
}

def colNorm2(): Vector = Vectors.fromBreeze(
breezeVector = self.map(_.toBreeze).fold(BV.zeros(1)) {
case (lhs, rhs) => lhs + rhs :* rhs
}.map(math.sqrt))

def colSDs(): Vector = {
val means = this.colMeans()
Vectors.fromBreeze(
breezeVector = self.map(x => x.toBreeze - means.toBreeze)
.zipWithIndex()
.fold((BV.zeros(1), 0L)) {
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
val totalNow: BV[Double] = lhsVec :* lhsCnt.asInstanceOf[Double]
val totalNew: BV[Double] = (totalNow + rhsVec :* rhsVec) :/ rhsCnt.asInstanceOf[Double]
(totalNew, rhsCnt)
}._1.map(math.sqrt))
}

private def maxMinOption(cmp: (Vector, Vector) => Boolean): Option[Vector] = {
def cmpMaxMin(x1: Vector, x2: Vector) = if (cmp(x1, x2)) x1 else x2
self.mapPartitions { iterator =>
Seq(iterator.reduceOption(cmpMaxMin)).iterator
}.collect { case Some(x) => x }.collect().reduceOption(cmpMaxMin)
}

def maxOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(cmp)

def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))

def rowShrink(): RDD[Vector] = {
???
}

def colShrink(): RDD[Vector] = {
???
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,5 @@ object MLUtils {
}
sqDist
}
implicit def rddToVectorRDDFunctions(rdd: RDD[Vector]) = new VectorRDDFunctions(rdd)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.rdd

import org.apache.spark.mllib.linalg.Vector
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils._
import VectorRDDFunctionsSuite._
import org.apache.spark.mllib.util.LocalSparkContext

class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {

val localData = Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(4.0, 5.0, 6.0),
Vectors.dense(7.0, 8.0, 9.0)
)

val rowMeans = Array(2.0, 5.0, 8.0)
val rowNorm2 = Array(math.sqrt(14.0), math.sqrt(77.0), math.sqrt(194.0))
val rowSDs = Array(math.sqrt(2.0 / 3.0), math.sqrt(2.0 / 3.0), math.sqrt(2.0 / 3.0))

val colMeans = Array(4.0, 5.0, 6.0)
val colNorm2 = Array(math.sqrt(66.0), math.sqrt(93.0), math.sqrt(126.0))
val colSDs = Array(math.sqrt(6.0), math.sqrt(6.0), math.sqrt(6.0))

val maxVec = Array(7.0, 8.0, 9.0)
val minVec = Array(1.0, 2.0, 3.0)

test("rowMeans") {
val data = sc.parallelize(localData)
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)), "Row means do not match.")
}

test("rowNorm2") {
val data = sc.parallelize(localData)
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)), "Row norm2s do not match.")
}

test("rowSDs") {
val data = sc.parallelize(localData)
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)), "Row SDs do not match.")
}

test("colMeans") {
val data = sc.parallelize(localData)
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)), "Column means do not match.")
}

test("colNorm2") {
val data = sc.parallelize(localData)
assert(equivVector(data.colNorm2(), Vectors.dense(colNorm2)), "Column norm2s do not match.")
}

test("colSDs") {
val data = sc.parallelize(localData)
assert(equivVector(data.colSDs(), Vectors.dense(colSDs)), "Column SDs do not match.")
}

test("maxOption") {
val data = sc.parallelize(localData)
assert(equivVectorOption(
data.maxOption((lhs: Vector, rhs: Vector) => lhs.toBreeze.norm(2) >= rhs.toBreeze.norm(2)),
Some(Vectors.dense(maxVec))),
"Optional maximum does not match."
)
}

test("minOption") {
val data = sc.parallelize(localData)
assert(equivVectorOption(
data.minOption((lhs: Vector, rhs: Vector) => lhs.toBreeze.norm(2) >= rhs.toBreeze.norm(2)),
Some(Vectors.dense(minVec))),
"Optional minimum does not match."
)
}
}

object VectorRDDFunctionsSuite {
def equivVector(lhs: Vector, rhs: Vector): Boolean = {
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
}

def equivVectorOption(lhs: Option[Vector], rhs: Option[Vector]): Boolean = {
(lhs, rhs) match {
case (Some(a), Some(b)) => (a.toBreeze - a.toBreeze).norm(2) < 1e-9
case (None, None) => true
case _ => false
}
}
}

0 comments on commit 8c6c0e1

Please sign in to comment.