Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11057] [SQL] Add correlation and covariance matrices #9366

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,39 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
StatFunctions.calculateCov(df, Seq(col1, col2))
}

/**
* Calculate the sample covariance between columns of a DataFrame.
*
* @return a covariance matrix as a DataFrame
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
* .withColumn("rand2", rand(seed=27))
* val covmatrix = df.stat.cov()
* covmatrix.show()
* +---------+------------------+-------------------+-------------------+
* |FieldName| id| rand1| rand2|
* +---------+------------------+-------------------+-------------------+
* | id| 9.166666666666666| 0.4131594565676311| 0.7012982830955725|
* | rand1|0.4131594565676311|0.11982701890603772|0.06500805072758595|
* | rand2|0.7012982830955725|0.06500805072758595|0.09383550706974164|
* +---------+------------------+-------------------+-------------------+
* }}}
*
* @since 1.6.0
*/
def cov(): DataFrame = {
StatFunctions.calculateCov(df)
}

/**
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param col1 the name of the column
* @param col2 the name of the column to calculate the correlation against
* @param method the name of the correlation method
* @return The Pearson Correlation Coefficient as a Double.
*
* {{{
Expand Down Expand Up @@ -96,6 +122,63 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
corr(col1, col2, "pearson")
}

/**
* Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param method the name of the correlation method
* @return The Pearson Correlation matrix as a DataFrame.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
* .withColumn("rand2", rand(seed=27))
* val corrmatrix = df.stat.corr()
* corrmatrix.show()
* +---------+------------------+------------------+------------------+
* |FieldName| id| rand1| rand2|
* +---------+------------------+------------------+------------------+
* | id| 1.0| 0.3942163209095|0.7561595709319909|
* | rand1| 0.3942163209095| 1.0|0.6130644931298477|
* | rand2|0.7561595709319909|0.6130644931298477| 1.0|
* +---------+------------------+------------------+------------------+
* }}}
*
* @since 1.6.0
*/
def corr(method: String): DataFrame = {
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
StatFunctions.pearsonCorrelation(df)
}

/**
* Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @return The Pearson Correlation matrix as a DataFrame.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
* .withColumn("rand2", rand(seed=27))
* val corrmatrix = df.stat.corr()
* corrmatrix.show()
* +---------+------------------+------------------+------------------+
* |FieldName| id| rand1| rand2|
* +---------+------------------+------------------+------------------+
* | id| 1.0| 0.3942163209095|0.7561595709319909|
* | rand1| 0.3942163209095| 1.0|0.6130644931298477|
* | rand2|0.7561595709319909|0.6130644931298477| 1.0|
* +---------+------------------+------------------+------------------+
* }}}
*
* @since 1.6.0
*/
def corr(): DataFrame = {
corr("pearson")
}

/**
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
* The number of distinct values for each column should be less than 1e4. At most 1e6 non-zero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.stat

import org.apache.spark.Logging
import org.apache.spark.sql.{Row, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast, AttributeReference}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand All @@ -33,6 +34,31 @@ private[sql] object StatFunctions extends Logging {
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
}

/** Calculate the Pearson Correlation matrix for given DataFrame */
private[sql] def pearsonCorrelation(df: DataFrame): DataFrame = {
val fieldNames = df.schema.fieldNames
val dfStructAttrs = ArrayBuffer[AttributeReference](
AttributeReference("FieldName", StringType, true)())
val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1)
countsRow.update(0, UTF8String.fromString(fname))
countsRow
}.toSeq
// generates field types of the output DataFrame
for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)()

// fills the correlation matrix by computing column-by-column correlations
for (i <- 0 to fieldNames.length - 1){
for (j <- 0 to i){
val corr = pearsonCorrelation(df, Seq(fieldNames(i), fieldNames(j)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't assume all columns are of numeric type. Catch exception here and use null as value if exception happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if showing null is valid. If not numeric then not showing anything, I think so ....

rows(i).setDouble(j + 1, corr)
rows(j).setDouble(i + 1, corr)
}
rows(i).setDouble(i + 1, 1.0)
}

new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows))
}

/** Helper class to simplify tracking and merging counts. */
private class CovarianceCounter extends Serializable {
var xAvg = 0.0 // the mean of all examples seen so far in col1
Expand Down Expand Up @@ -102,6 +128,34 @@ private[sql] object StatFunctions extends Logging {
counts.cov
}

/**
* Calculate the covariance of two numerical columns of a DataFrame.
* @param df The DataFrame
* @return the covariance matrix.
*/
private[sql] def calculateCov(df: DataFrame): DataFrame = {
val fieldNames = df.schema.fieldNames
val dfStructAttrs = ArrayBuffer[AttributeReference](
AttributeReference("FieldName", StringType, true)())
val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1)
countsRow.update(0, UTF8String.fromString(fname))
countsRow
}.toSeq
// generates field types of the output DataFrame
for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)()

// fills the covariance matrix by computing column-by-column covariances
for (i <- 0 to fieldNames.length-1){
for (j <- 0 to i){
val cov = calculateCov(df, Seq(fieldNames(i), fieldNames(j)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't assume all columns are of numeric type. Catch exception here and use null as value if exception happens?

rows(i).setDouble(j + 1, cov)
rows(j).setDouble(i + 1, cov)
}
}

new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows))
}

/** Generate a table of frequencies for the elements of two columns. */
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}

test("pearson correlation matrix") {
val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

intercept[IllegalArgumentException] {
df1.stat.corr() // doesn't accept non-numerical dataTypes
}

val df2 = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val results = df2.stat.corr()

val row1 = results.where($"FieldName" === "a").collect()(0)
assert(row1.getString(0) == "a")
assert(row1.getDouble(1) == 1.0)
assert(row1.getDouble(2) == 1.0)
assert(row1.getDouble(3) == -1.0)

val row2 = results.where($"FieldName" === "b").collect()(0)
assert(row2.getString(0) == "b")
assert(row2.getDouble(1) == 1.0)
assert(row2.getDouble(2) == 1.0)
assert(row2.getDouble(3) == -1.0)

val row3 = results.where($"FieldName" === "c").collect()(0)
assert(row3.getString(0) == "c")
assert(row3.getDouble(1) == -1.0)
assert(row3.getDouble(2) == -1.0)
assert(row3.getDouble(3) == 1.0)
}

test("covariance") {
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

Expand All @@ -98,6 +127,27 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(math.abs(decimalRes) < 1e-12)
}

test("covariance matrix") {
val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

intercept[IllegalArgumentException] {
df1.stat.cov() // doesn't accept non-numerical dataTypes
}

val df2 = Seq.tabulate(10)(i => (i, 2.0 * i)).toDF("singles", "doubles")
val results = df2.stat.cov()

val row1 = results.where($"FieldName" === "singles").collect()(0)
assert(row1.getString(0) == "singles")
assert(row1.getDouble(1) == 9.166666666666666)
assert(row1.getDouble(2) == 18.333333333333332)

val row2 = results.where($"FieldName" === "doubles").collect()(0)
assert(row2.getString(0) == "doubles")
assert(row2.getDouble(1) == row1.getDouble(2))
assert(row2.getDouble(2) == 36.666666666666664)
}

test("crosstab") {
val rng = new Random()
val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10)))
Expand Down