From 9646018bb4466433521b4e602b808f16e8d0ffdb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 3 May 2015 21:44:39 -0700 Subject: [PATCH] [SPARK-7241] Pearson correlation for DataFrames submitting this PR from a phone, excuse the brevity. adds Pearson correlation to Dataframes, reusing the covariance calculation code cc mengxr rxin Author: Burak Yavuz Closes #5858 from brkyvz/df-corr and squashes the following commits: 285b838 [Burak Yavuz] addressed comments v2.0 d10babb [Burak Yavuz] addressed comments v0.2 4b74b24 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-corr 4fe693b [Burak Yavuz] addressed comments v0.1 a682d06 [Burak Yavuz] ready for PR --- python/pyspark/sql/dataframe.py | 26 +++++++++ python/pyspark/sql/tests.py | 6 ++ .../spark/sql/DataFrameStatFunctions.scala | 26 +++++++++ .../sql/execution/stat/StatFunctions.scala | 58 ++++++++++++------- .../apache/spark/sql/JavaDataFrameSuite.java | 7 +++ .../apache/spark/sql/DataFrameStatSuite.scala | 33 +++++++++-- 6 files changed, 130 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8ddcff8fcdf98..aac5b8c4c5770 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -875,6 +875,27 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def corr(self, col1, col2, method=None): + """ + Calculates the correlation of two columns of a DataFrame as a double value. Currently only + supports the Pearson Correlation Coefficient. + :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases. + + :param col1: The name of the first column + :param col2: The name of the second column + :param method: The correlation method. Currently only supports "pearson" + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + if not method: + method = "pearson" + if not method == "pearson": + raise ValueError("Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") + return self._jdf.stat().corr(col1, col2, method) + def cov(self, col1, col2): """ Calculate the sample covariance for the given columns, specified by their names, as a @@ -1359,6 +1380,11 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df + def corr(self, col1, col2, method=None): + return self.df.corr(col1, col2, method) + + corr.__doc__ = DataFrame.corr.__doc__ + def cov(self, col1, col2): return self.df.cov(col1, col2) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 613efc0ac029d..d652c302a54ba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -394,6 +394,12 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_corr(self): + import math + df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() + corr = df.stat.corr("a", "b") + self.assertTrue(abs(corr - 0.95734012) < 1e-6) + def test_cov(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() cov = df.stat.cov("a", "b") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index e8fa82947759b..903532105284e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -27,6 +27,32 @@ import org.apache.spark.sql.execution.stat._ @Experimental final class DataFrameStatFunctions private[sql](df: DataFrame) { + /** + * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * MLlib's Statistics. + * + * @param col1 the name of the column + * @param col2 the name of the column to calculate the correlation against + * @return The Pearson Correlation Coefficient as a Double. + */ + def corr(col1: String, col2: String, method: String): Double = { + require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") + StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) + } + + /** + * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. + * + * @param col1 the name of the column + * @param col2 the name of the column to calculate the correlation against + * @return The Pearson Correlation Coefficient as a Double. + */ + def corr(col1: String, col2: String): Double = { + corr(col1, col2, "pearson") + } + /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d4a94c24d9866..67b48e58b17ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -23,29 +23,43 @@ import org.apache.spark.sql.types.{DoubleType, NumericType} private[sql] object StatFunctions { + /** Calculate the Pearson Correlation Coefficient for the given columns */ + private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols) + counts.Ck / math.sqrt(counts.MkX * counts.MkY) + } + /** Helper class to simplify tracking and merging counts. */ private class CovarianceCounter extends Serializable { - var xAvg = 0.0 - var yAvg = 0.0 - var Ck = 0.0 - var count = 0L + var xAvg = 0.0 // the mean of all examples seen so far in col1 + var yAvg = 0.0 // the mean of all examples seen so far in col2 + var Ck = 0.0 // the co-moment after k examples + var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 + var MkY = 0.0 // sum of squares of differences from the (current) mean for col1 + var count = 0L // count of observed examples // add an example to the calculation def add(x: Double, y: Double): this.type = { - val oldX = xAvg + val deltaX = x - xAvg + val deltaY = y - yAvg count += 1 - xAvg += (x - xAvg) / count - yAvg += (y - yAvg) / count - Ck += (y - yAvg) * (x - oldX) + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) this } // merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance def merge(other: CovarianceCounter): this.type = { val totalCount = count + other.count - Ck += other.Ck + - (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count + val deltaX = xAvg - other.xAvg + val deltaY = yAvg - other.yAvg + Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count xAvg = (xAvg * count + other.xAvg * other.count) / totalCount yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count + MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count count = totalCount this } @@ -53,13 +67,7 @@ private[sql] object StatFunctions { def cov: Double = Ck / (count - 1) } - /** - * Calculate the covariance of two numerical columns of a DataFrame. - * @param df The DataFrame - * @param cols the column names - * @return the covariance of the two columns. - */ - private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { require(cols.length == 2, "Currently cov supports calculating the covariance " + "between two columns.") cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => @@ -68,13 +76,23 @@ private[sql] object StatFunctions { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, combOp = (baseCounter, other) => { baseCounter.merge(other) - }) + }) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df The DataFrame + * @param cols the column names + * @return the covariance of the two columns. + */ + private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols) counts.cov } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 96fe66d0b84a6..78e847239f405 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -187,6 +187,13 @@ public void testFrequentItems() { Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } + @Test + public void testCorrelation() { + DataFrame df = context.table("testData2"); + Double pearsonCorr = df.stat().corr("a", "b", "pearson"); + Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + } + @Test public void testCovariance() { DataFrame df = context.table("testData2"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 4f5a2ff696789..06764d2a122f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite { def toLetter(i: Int): String = (i + 97).toChar.toString test("Frequent Items") { - val rows = Array.tabulate(1000) { i => + val rows = Seq.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) } - val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") + val df = rows.toDF("numbers", "letters", "negDoubles") val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head @@ -43,19 +43,40 @@ class DataFrameStatSuite extends FunSuite { val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) + } + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.stat.corr("a", "b", "pearson") + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.stat.corr("a", "c", "pearson") + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.stat.corr("a", "b", "pearson") + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) } test("covariance") { - val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) - val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") + val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") val results = df.stat.cov("singles", "doubles") - assert(math.abs(results - 55.0 / 3) < 1e-6) + assert(math.abs(results - 55.0 / 3) < 1e-12) intercept[IllegalArgumentException] { df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes } val decimalRes = decimalData.stat.cov("a", "b") - assert(math.abs(decimalRes) < 1e-6) + assert(math.abs(decimalRes) < 1e-12) } }