From 74bdf5451dd92de10acfea3e1db9cd3325bf6dd7 Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Thu, 29 Oct 2015 07:49:55 -0700 Subject: [PATCH] Initial commit for correelation and covariance matrices --- .../spark/sql/DataFrameStatFunctions.scala | 83 +++++++++++++++++++ .../sql/execution/stat/StatFunctions.scala | 56 ++++++++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 50 +++++++++++ 3 files changed, 188 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 69c984717526d..64259bb4de067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -52,6 +52,31 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { StatFunctions.calculateCov(df, Seq(col1, col2)) } + /** + * Calculate the sample covariance between columns of a DataFrame. + * + * @return a covariance matrix as a DataFrame + * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * val covmatrix = df.stat.cov() + * covmatrix.show() + * +---------+------------------+-------------------+-------------------+ + * |FieldName| id| rand1| rand2| + * +---------+------------------+-------------------+-------------------+ + * | id| 9.166666666666666| 0.4131594565676311| 0.7012982830955725| + * | rand1|0.4131594565676311|0.11982701890603772|0.06500805072758595| + * | rand2|0.7012982830955725|0.06500805072758595|0.09383550706974164| + * +---------+------------------+-------------------+-------------------+ + * }}} + * + * @since 1.6.0 + */ + def cov(): DataFrame = { + StatFunctions.calculateCov(df) + } + /** * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in @@ -59,6 +84,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @param col1 the name of the column * @param col2 the name of the column to calculate the correlation against + * @param method the name of the correlation method * @return The Pearson Correlation Coefficient as a Double. * * {{{ @@ -96,6 +122,63 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { corr(col1, col2, "pearson") } + /** + * Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * MLlib's Statistics. + * + * @param method the name of the correlation method + * @return The Pearson Correlation matrix as a DataFrame. + * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * val corrmatrix = df.stat.corr() + * corrmatrix.show() + * +---------+------------------+------------------+------------------+ + * |FieldName| id| rand1| rand2| + * +---------+------------------+------------------+------------------+ + * | id| 1.0| 0.3942163209095|0.7561595709319909| + * | rand1| 0.3942163209095| 1.0|0.6130644931298477| + * | rand2|0.7561595709319909|0.6130644931298477| 1.0| + * +---------+------------------+------------------+------------------+ + * }}} + * + * @since 1.6.0 + */ + def corr(method: String): DataFrame = { + require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") + StatFunctions.pearsonCorrelation(df) + } + + /** + * Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * MLlib's Statistics. + * + * @return The Pearson Correlation matrix as a DataFrame. + * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * val corrmatrix = df.stat.corr() + * corrmatrix.show() + * +---------+------------------+------------------+------------------+ + * |FieldName| id| rand1| rand2| + * +---------+------------------+------------------+------------------+ + * | id| 1.0| 0.3942163209095|0.7561595709319909| + * | rand1| 0.3942163209095| 1.0|0.6130644931298477| + * | rand2|0.7561595709319909|0.6130644931298477| 1.0| + * +---------+------------------+------------------+------------------+ + * }}} + * + * @since 1.6.0 + */ + def corr(): DataFrame = { + corr("pearson") + } + /** * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. * The number of distinct values for each column should be less than 1e4. At most 1e6 non-zero diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 00231d65a7d54..04d265423ea4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging import org.apache.spark.sql.{Row, Column, DataFrame} -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast, AttributeReference} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -33,6 +34,31 @@ private[sql] object StatFunctions extends Logging { counts.Ck / math.sqrt(counts.MkX * counts.MkY) } + /** Calculate the Pearson Correlation matrix for given DataFrame */ + private[sql] def pearsonCorrelation(df: DataFrame): DataFrame = { + val fieldNames = df.schema.fieldNames + val dfStructAttrs = ArrayBuffer[AttributeReference]( + AttributeReference("FieldName", StringType, true)()) + val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1) + countsRow.update(0, UTF8String.fromString(fname)) + countsRow + }.toSeq + // generates field types of the output DataFrame + for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)() + + // fills the correlation matrix by computing column-by-column correlations + for (i <- 0 to fieldNames.length - 1){ + for (j <- 0 to i){ + val corr = pearsonCorrelation(df, Seq(fieldNames(i), fieldNames(j))) + rows(i).setDouble(j + 1, corr) + rows(j).setDouble(i + 1, corr) + } + rows(i).setDouble(i + 1, 1.0) + } + + new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows)) + } + /** Helper class to simplify tracking and merging counts. */ private class CovarianceCounter extends Serializable { var xAvg = 0.0 // the mean of all examples seen so far in col1 @@ -102,6 +128,34 @@ private[sql] object StatFunctions extends Logging { counts.cov } + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df The DataFrame + * @return the covariance matrix. + */ + private[sql] def calculateCov(df: DataFrame): DataFrame = { + val fieldNames = df.schema.fieldNames + val dfStructAttrs = ArrayBuffer[AttributeReference]( + AttributeReference("FieldName", StringType, true)()) + val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1) + countsRow.update(0, UTF8String.fromString(fname)) + countsRow + }.toSeq + // generates field types of the output DataFrame + for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)() + + // fills the covariance matrix by computing column-by-column covariances + for (i <- 0 to fieldNames.length-1){ + for (j <- 0 to i){ + val cov = calculateCov(df, Seq(fieldNames(i), fieldNames(j))) + rows(i).setDouble(j + 1, cov) + rows(j).setDouble(i + 1, cov) + } + } + + new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows)) + } + /** Generate a table of frequencies for the elements of two columns. */ private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 6524abcf5e97f..744769c2457f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -85,6 +85,35 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) } + test("pearson correlation matrix") { + val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") + + intercept[IllegalArgumentException] { + df1.stat.corr() // doesn't accept non-numerical dataTypes + } + + val df2 = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") + val results = df2.stat.corr() + + val row1 = results.where($"FieldName" === "a").collect()(0) + assert(row1.getString(0) == "a") + assert(row1.getDouble(1) == 1.0) + assert(row1.getDouble(2) == 1.0) + assert(row1.getDouble(3) == -1.0) + + val row2 = results.where($"FieldName" === "b").collect()(0) + assert(row2.getString(0) == "b") + assert(row2.getDouble(1) == 1.0) + assert(row2.getDouble(2) == 1.0) + assert(row2.getDouble(3) == -1.0) + + val row3 = results.where($"FieldName" === "c").collect()(0) + assert(row3.getString(0) == "c") + assert(row3.getDouble(1) == -1.0) + assert(row3.getDouble(2) == -1.0) + assert(row3.getDouble(3) == 1.0) + } + test("covariance") { val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") @@ -98,6 +127,27 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(decimalRes) < 1e-12) } + test("covariance matrix") { + val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") + + intercept[IllegalArgumentException] { + df1.stat.cov() // doesn't accept non-numerical dataTypes + } + + val df2 = Seq.tabulate(10)(i => (i, 2.0 * i)).toDF("singles", "doubles") + val results = df2.stat.cov() + + val row1 = results.where($"FieldName" === "singles").collect()(0) + assert(row1.getString(0) == "singles") + assert(row1.getDouble(1) == 9.166666666666666) + assert(row1.getDouble(2) == 18.333333333333332) + + val row2 = results.where($"FieldName" === "doubles").collect()(0) + assert(row2.getString(0) == "doubles") + assert(row2.getDouble(1) == row1.getDouble(2)) + assert(row2.getDouble(2) == 36.666666666666664) + } + test("crosstab") { val rng = new Random() val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10)))