Skip to content

Commit

Permalink
Initial commit for correelation and covariance matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
NarineK committed Oct 29, 2015
1 parent cf2e0ae commit 74bdf54
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,39 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
StatFunctions.calculateCov(df, Seq(col1, col2))
}

/**
* Calculate the sample covariance between columns of a DataFrame.
*
* @return a covariance matrix as a DataFrame
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
* .withColumn("rand2", rand(seed=27))
* val covmatrix = df.stat.cov()
* covmatrix.show()
* +---------+------------------+-------------------+-------------------+
* |FieldName| id| rand1| rand2|
* +---------+------------------+-------------------+-------------------+
* | id| 9.166666666666666| 0.4131594565676311| 0.7012982830955725|
* | rand1|0.4131594565676311|0.11982701890603772|0.06500805072758595|
* | rand2|0.7012982830955725|0.06500805072758595|0.09383550706974164|
* +---------+------------------+-------------------+-------------------+
* }}}
*
* @since 1.6.0
*/
def cov(): DataFrame = {
StatFunctions.calculateCov(df)
}

/**
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param col1 the name of the column
* @param col2 the name of the column to calculate the correlation against
* @param method the name of the correlation method
* @return The Pearson Correlation Coefficient as a Double.
*
* {{{
Expand Down Expand Up @@ -96,6 +122,63 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
corr(col1, col2, "pearson")
}

/**
* Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param method the name of the correlation method
* @return The Pearson Correlation matrix as a DataFrame.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
* .withColumn("rand2", rand(seed=27))
* val corrmatrix = df.stat.corr()
* corrmatrix.show()
* +---------+------------------+------------------+------------------+
* |FieldName| id| rand1| rand2|
* +---------+------------------+------------------+------------------+
* | id| 1.0| 0.3942163209095|0.7561595709319909|
* | rand1| 0.3942163209095| 1.0|0.6130644931298477|
* | rand2|0.7561595709319909|0.6130644931298477| 1.0|
* +---------+------------------+------------------+------------------+
* }}}
*
* @since 1.6.0
*/
def corr(method: String): DataFrame = {
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
StatFunctions.pearsonCorrelation(df)
}

/**
* Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @return The Pearson Correlation matrix as a DataFrame.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
* .withColumn("rand2", rand(seed=27))
* val corrmatrix = df.stat.corr()
* corrmatrix.show()
* +---------+------------------+------------------+------------------+
* |FieldName| id| rand1| rand2|
* +---------+------------------+------------------+------------------+
* | id| 1.0| 0.3942163209095|0.7561595709319909|
* | rand1| 0.3942163209095| 1.0|0.6130644931298477|
* | rand2|0.7561595709319909|0.6130644931298477| 1.0|
* +---------+------------------+------------------+------------------+
* }}}
*
* @since 1.6.0
*/
def corr(): DataFrame = {
corr("pearson")
}

/**
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
* The number of distinct values for each column should be less than 1e4. At most 1e6 non-zero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.stat

import org.apache.spark.Logging
import org.apache.spark.sql.{Row, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast, AttributeReference}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand All @@ -33,6 +34,31 @@ private[sql] object StatFunctions extends Logging {
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
}

/** Calculate the Pearson Correlation matrix for given DataFrame */
private[sql] def pearsonCorrelation(df: DataFrame): DataFrame = {
val fieldNames = df.schema.fieldNames
val dfStructAttrs = ArrayBuffer[AttributeReference](
AttributeReference("FieldName", StringType, true)())
val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1)
countsRow.update(0, UTF8String.fromString(fname))
countsRow
}.toSeq
// generates field types of the output DataFrame
for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)()

// fills the correlation matrix by computing column-by-column correlations
for (i <- 0 to fieldNames.length - 1){
for (j <- 0 to i){
val corr = pearsonCorrelation(df, Seq(fieldNames(i), fieldNames(j)))
rows(i).setDouble(j + 1, corr)
rows(j).setDouble(i + 1, corr)
}
rows(i).setDouble(i + 1, 1.0)
}

new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows))
}

/** Helper class to simplify tracking and merging counts. */
private class CovarianceCounter extends Serializable {
var xAvg = 0.0 // the mean of all examples seen so far in col1
Expand Down Expand Up @@ -102,6 +128,34 @@ private[sql] object StatFunctions extends Logging {
counts.cov
}

/**
* Calculate the covariance of two numerical columns of a DataFrame.
* @param df The DataFrame
* @return the covariance matrix.
*/
private[sql] def calculateCov(df: DataFrame): DataFrame = {
val fieldNames = df.schema.fieldNames
val dfStructAttrs = ArrayBuffer[AttributeReference](
AttributeReference("FieldName", StringType, true)())
val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1)
countsRow.update(0, UTF8String.fromString(fname))
countsRow
}.toSeq
// generates field types of the output DataFrame
for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)()

// fills the covariance matrix by computing column-by-column covariances
for (i <- 0 to fieldNames.length-1){
for (j <- 0 to i){
val cov = calculateCov(df, Seq(fieldNames(i), fieldNames(j)))
rows(i).setDouble(j + 1, cov)
rows(j).setDouble(i + 1, cov)
}
}

new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows))
}

/** Generate a table of frequencies for the elements of two columns. */
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}

test("pearson correlation matrix") {
val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

intercept[IllegalArgumentException] {
df1.stat.corr() // doesn't accept non-numerical dataTypes
}

val df2 = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val results = df2.stat.corr()

val row1 = results.where($"FieldName" === "a").collect()(0)
assert(row1.getString(0) == "a")
assert(row1.getDouble(1) == 1.0)
assert(row1.getDouble(2) == 1.0)
assert(row1.getDouble(3) == -1.0)

val row2 = results.where($"FieldName" === "b").collect()(0)
assert(row2.getString(0) == "b")
assert(row2.getDouble(1) == 1.0)
assert(row2.getDouble(2) == 1.0)
assert(row2.getDouble(3) == -1.0)

val row3 = results.where($"FieldName" === "c").collect()(0)
assert(row3.getString(0) == "c")
assert(row3.getDouble(1) == -1.0)
assert(row3.getDouble(2) == -1.0)
assert(row3.getDouble(3) == 1.0)
}

test("covariance") {
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

Expand All @@ -98,6 +127,27 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(math.abs(decimalRes) < 1e-12)
}

test("covariance matrix") {
val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

intercept[IllegalArgumentException] {
df1.stat.cov() // doesn't accept non-numerical dataTypes
}

val df2 = Seq.tabulate(10)(i => (i, 2.0 * i)).toDF("singles", "doubles")
val results = df2.stat.cov()

val row1 = results.where($"FieldName" === "singles").collect()(0)
assert(row1.getString(0) == "singles")
assert(row1.getDouble(1) == 9.166666666666666)
assert(row1.getDouble(2) == 18.333333333333332)

val row2 = results.where($"FieldName" === "doubles").collect()(0)
assert(row2.getString(0) == "doubles")
assert(row2.getDouble(1) == row1.getDouble(2))
assert(row2.getDouble(2) == 36.666666666666664)
}

test("crosstab") {
val rng = new Random()
val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10)))
Expand Down

0 comments on commit 74bdf54

Please sign in to comment.