Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19636][ML] Feature parity for correlation statistics in MLlib #17108

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ object TestingUtils {
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley I do not think this change is going to be controversial, but I want to point out that from now on, matrix/vector checks will not always throw errors when comparing NaN: the previous code would throw whenever a NaN was found.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that the update has the right semantics. SGTM

if (x.isNaN && y.isNaN) {
return true
}
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
Expand All @@ -49,6 +53,10 @@ object TestingUtils {
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
}
math.abs(x - y) < eps
}

Expand Down
86 changes: 86 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

import scala.collection.JavaConverters._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{StructField, StructType}

/**
* API for correlation functions in MLlib, compatible with Dataframes and Datasets.
*
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]]
* to spark.ml's Vector types.
*/
@Since("2.2.0")
@Experimental
object Correlation {

/**
* :: Experimental ::
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
*
* @param dataset A dataset or a dataframe
* @param column The name of the column of vectors for which the correlation coefficient needs
* to be computed. This must be a column of the dataset, and it must contain
* Vector objects.
* @param method String specifying the method to use for computing correlation.
* Supported: `pearson` (default), `spearman`
* @return A dataframe that contains the correlation matrix of the column of vectors. This
* dataframe contains a single row and a single column of name
* '$METHODNAME($COLUMN)'.
* @throws IllegalArgumentException if the column is not a valid column in the dataset, or if
* the content of this column is not of type Vector.
*
* Here is how to access the correlation coefficient:
* {{{
* val data: Dataset[Vector] = ...
* val Row(coeff: Matrix) = Statistics.corr(data, "value").head
* // coeff now contains the Pearson correlation matrix.
* }}}
*
* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
* avoid recomputing the common lineage.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
val rdd = dataset.select(column).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val oldM = OldStatistics.corr(rdd, method)
val name = s"$method($column)"
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
}

/**
* Compute the correlation matrix for the input Dataset of Vectors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say "pearson" here explicitly.

*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String): DataFrame = {
corr(dataset, column, "pearson")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}


class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {

val xData = Array(1.0, 0.0, -2.0)
val yData = Array(4.0, 5.0, 3.0)
val zeros = new Array[Double](3)
val data = Seq(
Vectors.dense(1.0, 0.0, 0.0, -2.0),
Vectors.dense(4.0, 5.0, 0.0, 3.0),
Vectors.dense(6.0, 7.0, 0.0, 8.0),
Vectors.dense(9.0, 0.0, 0.0, 1.0)
)

private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")

private def extract(df: DataFrame): BDM[Double] = {
val Array(Row(mat: Matrix)) = df.collect()
mat.asBreeze.toDenseMatrix
}


test("corr(X) default, pearson") {
val defaultMat = Correlation.corr(X, "features")
val pearsonMat = Correlation.corr(X, "features", "pearson")
// scalastyle:off
val expected = Matrices.fromBreeze(BDM(
(1.00000000, 0.05564149, Double.NaN, 0.4004714),
(0.05564149, 1.00000000, Double.NaN, 0.9135959),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.40047142, 0.91359586, Double.NaN, 1.0000000)))
// scalastyle:on

assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4)
assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4)
}

test("corr(X) spearman") {
val spearmanMat = Correlation.corr(X, "features", "spearman")
// scalastyle:off
val expected = Matrices.fromBreeze(BDM(
(1.0000000, 0.1054093, Double.NaN, 0.4000000),
(0.1054093, 1.0000000, Double.NaN, 0.9486833),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.4000000, 0.9486833, Double.NaN, 1.0000000)))
// scalastyle:on
assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4)
}

}