Skip to content

Commit

Permalink
[SPARK-19636][ML] Feature parity for correlation statistics in MLlib
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This patch adds the Dataframes-based support for the correlation statistics found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following the design doc discussed in the JIRA ticket.

The current implementation is a simple wrapper around the `spark.mllib` implementation. Future optimizations can be implemented at a later stage.

## How was this patch tested?

```
build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite"
```

Author: Timothy Hunter <[email protected]>

Closes #17108 from thunterdb/19636.
  • Loading branch information
thunterdb authored and jkbradley committed Mar 24, 2017
1 parent 93581fb commit d27daa5
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ object TestingUtils {
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
}
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
Expand All @@ -49,6 +53,10 @@ object TestingUtils {
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
}
math.abs(x - y) < eps
}

Expand Down
86 changes: 86 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

import scala.collection.JavaConverters._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{StructField, StructType}

/**
* API for correlation functions in MLlib, compatible with Dataframes and Datasets.
*
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]]
* to spark.ml's Vector types.
*/
@Since("2.2.0")
@Experimental
object Correlation {

/**
* :: Experimental ::
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
*
* @param dataset A dataset or a dataframe
* @param column The name of the column of vectors for which the correlation coefficient needs
* to be computed. This must be a column of the dataset, and it must contain
* Vector objects.
* @param method String specifying the method to use for computing correlation.
* Supported: `pearson` (default), `spearman`
* @return A dataframe that contains the correlation matrix of the column of vectors. This
* dataframe contains a single row and a single column of name
* '$METHODNAME($COLUMN)'.
* @throws IllegalArgumentException if the column is not a valid column in the dataset, or if
* the content of this column is not of type Vector.
*
* Here is how to access the correlation coefficient:
* {{{
* val data: Dataset[Vector] = ...
* val Row(coeff: Matrix) = Statistics.corr(data, "value").head
* // coeff now contains the Pearson correlation matrix.
* }}}
*
* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
* avoid recomputing the common lineage.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
val rdd = dataset.select(column).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val oldM = OldStatistics.corr(rdd, method)
val name = s"$method($column)"
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
}

/**
* Compute the Pearson correlation matrix for the input Dataset of Vectors.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String): DataFrame = {
corr(dataset, column, "pearson")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}


class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {

val xData = Array(1.0, 0.0, -2.0)
val yData = Array(4.0, 5.0, 3.0)
val zeros = new Array[Double](3)
val data = Seq(
Vectors.dense(1.0, 0.0, 0.0, -2.0),
Vectors.dense(4.0, 5.0, 0.0, 3.0),
Vectors.dense(6.0, 7.0, 0.0, 8.0),
Vectors.dense(9.0, 0.0, 0.0, 1.0)
)

private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")

private def extract(df: DataFrame): BDM[Double] = {
val Array(Row(mat: Matrix)) = df.collect()
mat.asBreeze.toDenseMatrix
}


test("corr(X) default, pearson") {
val defaultMat = Correlation.corr(X, "features")
val pearsonMat = Correlation.corr(X, "features", "pearson")
// scalastyle:off
val expected = Matrices.fromBreeze(BDM(
(1.00000000, 0.05564149, Double.NaN, 0.4004714),
(0.05564149, 1.00000000, Double.NaN, 0.9135959),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.40047142, 0.91359586, Double.NaN, 1.0000000)))
// scalastyle:on

assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4)
assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4)
}

test("corr(X) spearman") {
val spearmanMat = Correlation.corr(X, "features", "spearman")
// scalastyle:off
val expected = Matrices.fromBreeze(BDM(
(1.0000000, 0.1054093, Double.NaN, 0.4000000),
(0.1054093, 1.0000000, Double.NaN, 0.9486833),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.4000000, 0.9486833, Double.NaN, 1.0000000)))
// scalastyle:on
assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4)
}

}

0 comments on commit d27daa5

Please sign in to comment.