-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19636][ML] Feature parity for correlation statistics in MLlib #17108
Changes from 3 commits
42c26bd
d9f6a6c
7d4ccfe
a2d7e2d
a85a889
2aeb6ee
903e6d0
6040e4c
2151e8a
7c540e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.stat | ||
|
||
/** | ||
* | ||
*/ | ||
object Correlations { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about calling it "Correlation" (singular)? Especially if we add a builder pattern, then I feel like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, I do not know if there is a convention for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really, but let's make one? |
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.stat | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.annotation.Since | ||
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector} | ||
import org.apache.spark.mllib.linalg.{Vectors => OldVectors} | ||
import org.apache.spark.mllib.stat.{Statistics => OldStatistics} | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
import org.apache.spark.sql.types.{StructField, StructType} | ||
|
||
/** | ||
* API for statistical functions in MLlib, compatible with Dataframes and Datasets. | ||
* | ||
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] | ||
* to MLlib's Vector types. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor terminology comment: should this be ML instead of MLLib? I understand this is for the new ML vector types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will use |
||
*/ | ||
@Since("2.2.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this have @experimental tag at the top? similar to: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, thanks |
||
object Statistics { | ||
|
||
/** | ||
* Compute the correlation matrix for the input RDD of Vectors using the specified method. | ||
* Methods currently supported: `pearson` (default), `spearman`. | ||
* | ||
* @param dataset a dataset or a dataframe | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very minor: "Sentence case" params, as in "A dataset...", "The name..." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems there are inconsistencies in a lot of comments. I wish we had something like scalastyle for comments... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yes, thank you. I am correcting the other instances of course. |
||
* @param column the name of the column of vectors for which the correlation coefficient needs | ||
* to be computed. This must be a column of the dataset, and it must contain | ||
* Vector objects. | ||
* @param method String specifying the method to use for computing correlation. | ||
* Supported: `pearson` (default), `spearman` | ||
* @return A dataframe that contains the correlation matrix of the column of vectors. This | ||
* dataframe contains a single row and a single column of name | ||
* '$METHODNAME($COLUMN)'. | ||
* @throws IllegalArgumentException if the column is not a valid column in the dataset, or if | ||
* the content of this column is not of type Vector. | ||
* | ||
* Here is how to access the correlation coefficient: | ||
* {{{ | ||
* val data: Dataset[Vector] = ... | ||
* val Row(coeff: Matrix) = Statistics.corr(data, "value").head | ||
* // coeff now contains the Pearson correlation matrix. | ||
* }}} | ||
* | ||
* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column | ||
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], | ||
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to | ||
* avoid recomputing the common lineage. | ||
*/ | ||
// TODO: how do we handle missing values? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is more a comment for the internal implementation of the pearson/spearman calculation, I don't think it should be at this level (maybe moved into the MLLib code?). I think they should just ignore the rows where one of the columns compared have a missing/nan value and log a warning (but only once) when they encounter this -- if all are missing, we just assign a 0 score. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I will remove the comment at this point, since this should be decided in JIRA instead of during the implementation. |
||
@Since("2.2.0") | ||
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { | ||
val rdd = dataset.select(column).rdd.map { | ||
case Row(v: Vector) => OldVectors.fromML(v) | ||
// case r: GenericRowWithSchema => OldVectors.fromML(r.getAs[Vector](0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented out code (?) |
||
} | ||
val oldM = OldStatistics.corr(rdd, method) | ||
val name = s"$method($column)" | ||
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = true))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor comment: ideally shouldn't you check for collisions prior to creating the name - eg add a suffix such as "_2" or _i if the column name already exists There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ideally this would be an infrastructure-level method that just finds a new column name and would be reusable in other code. I don't believe something like this exists. |
||
dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema) | ||
} | ||
|
||
/** | ||
* Compute the correlation matrix for the input Dataset of Vectors. | ||
* @param dataset a dataset or dataframe | ||
* @param column a column of this dataset | ||
* @return | ||
*/ | ||
@Since("2.2.0") | ||
def corr(dataset: Dataset[_], column: String): DataFrame = { | ||
corr(dataset, column, "pearson") | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.stat | ||
|
||
import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.ml.linalg.Matrix | ||
import org.apache.spark.ml.linalg.Vectors | ||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.sql.{DataFrame, Row} | ||
|
||
|
||
class StatisticsSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { | ||
|
||
import StatisticsSuite._ | ||
|
||
val xData = Array(1.0, 0.0, -2.0) | ||
val yData = Array(4.0, 5.0, 3.0) | ||
val zeros = new Array[Double](3) | ||
val data = Seq( | ||
Vectors.dense(1.0, 0.0, 0.0, -2.0), | ||
Vectors.dense(4.0, 5.0, 0.0, 3.0), | ||
Vectors.dense(6.0, 7.0, 0.0, 8.0), | ||
Vectors.dense(9.0, 0.0, 0.0, 1.0) | ||
) | ||
|
||
private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") | ||
|
||
private def extract(df: DataFrame): BDM[Double] = { | ||
val Array(Row(mat: Matrix)) = df.collect() | ||
mat.asBreeze.toDenseMatrix | ||
} | ||
|
||
|
||
test("corr(X) default, pearson") { | ||
val defaultMat = Statistics.corr(X, "features") | ||
val pearsonMat = Statistics.corr(X, "features", "pearson") | ||
// scalastyle:off | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the error that the scalastyle gives? I wish there was some way to avoid turning it off. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is the alignment of the values, which we realize by padding with |
||
val expected = BDM( | ||
(1.00000000, 0.05564149, Double.NaN, 0.4004714), | ||
(0.05564149, 1.00000000, Double.NaN, 0.9135959), | ||
(Double.NaN, Double.NaN, 1.00000000, Double.NaN), | ||
(0.40047142, 0.91359586, Double.NaN, 1.0000000)) | ||
// scalastyle:on | ||
|
||
assert(matrixApproxEqual(extract(defaultMat), expected)) | ||
assert(matrixApproxEqual(extract(pearsonMat), expected)) | ||
} | ||
|
||
test("corr(X) spearman") { | ||
val spearmanMat = Statistics.corr(X, "features", "spearman") | ||
// scalastyle:off | ||
val expected = BDM( | ||
(1.0000000, 0.1054093, Double.NaN, 0.4000000), | ||
(0.1054093, 1.0000000, Double.NaN, 0.9486833), | ||
(Double.NaN, Double.NaN, 1.00000000, Double.NaN), | ||
(0.4000000, 0.9486833, Double.NaN, 1.0000000)) | ||
// scalastyle:on | ||
assert(matrixApproxEqual(extract(spearmanMat), expected)) | ||
} | ||
|
||
} | ||
|
||
|
||
object StatisticsSuite extends Logging { | ||
|
||
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are very nice methods! would it be possible to move them to a place where every test suite could use them? Specifically the matrixApproxEqual. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
if (v1.isNaN) { | ||
v2.isNaN | ||
} else { | ||
math.abs(v1 - v2) <= threshold | ||
} | ||
} | ||
|
||
def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = { | ||
for (i <- 0 until A.rows; j <- 0 until A.cols) { | ||
if (!approxEqual(A(i, j), B(i, j), threshold)) { | ||
logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) | ||
return false | ||
} | ||
} | ||
true | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, sorry, removing this file.