Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19636][ML] Feature parity for correlation statistics in MLlib #17108

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Correlations.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

/**
*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry, removing this file.

*/
object Correlations {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about calling it "Correlation" (singular)? Especially if we add a builder pattern, then I feel like new Correlation().set... seems more natural.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I do not know if there is a convention for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, but let's make one?


}
89 changes: 89 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Statistics.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{StructField, StructType}

/**
* API for statistical functions in MLlib, compatible with Dataframes and Datasets.
*
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]]
* to MLlib's Vector types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor terminology comment: should this be ML instead of MLLib? I understand this is for the new ML vector types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use spark.ml which is the most correct terminology.

*/
@Since("2.2.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this have @experimental tag at the top? similar to:
https://github.com/apache/spark/pull/17110/files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thanks

object Statistics {

/**
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
*
* @param dataset a dataset or a dataframe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor: "Sentence case" params, as in "A dataset...", "The name..."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems there are inconsistencies in a lot of comments. I wish we had something like scalastyle for comments...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, thank you. I am correcting the other instances of course.

* @param column the name of the column of vectors for which the correlation coefficient needs
* to be computed. This must be a column of the dataset, and it must contain
* Vector objects.
* @param method String specifying the method to use for computing correlation.
* Supported: `pearson` (default), `spearman`
* @return A dataframe that contains the correlation matrix of the column of vectors. This
* dataframe contains a single row and a single column of name
* '$METHODNAME($COLUMN)'.
* @throws IllegalArgumentException if the column is not a valid column in the dataset, or if
* the content of this column is not of type Vector.
*
* Here is how to access the correlation coefficient:
* {{{
* val data: Dataset[Vector] = ...
* val Row(coeff: Matrix) = Statistics.corr(data, "value").head
* // coeff now contains the Pearson correlation matrix.
* }}}
*
* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
* avoid recomputing the common lineage.
*/
// TODO: how do we handle missing values?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is more a comment for the internal implementation of the pearson/spearman calculation, I don't think it should be at this level (maybe moved into the MLLib code?). I think they should just ignore the rows where one of the columns compared have a missing/nan value and log a warning (but only once) when they encounter this -- if all are missing, we just assign a 0 score.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I will remove the comment at this point, since this should be decided in JIRA instead of during the implementation.

@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
val rdd = dataset.select(column).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
// case r: GenericRowWithSchema => OldVectors.fromML(r.getAs[Vector](0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented out code (?)

}
val oldM = OldStatistics.corr(rdd, method)
val name = s"$method($column)"
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = true)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment: ideally shouldn't you check for collisions prior to creating the name - eg add a suffix such as "_2" or _i if the column name already exists

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally this would be an infrastructure-level method that just finds a new column name and would be reusable in other code. I don't believe something like this exists.

dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
}

/**
* Compute the correlation matrix for the input Dataset of Vectors.
* @param dataset a dataset or dataframe
* @param column a column of this dataset
* @return
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String): DataFrame = {
corr(dataset, column, "pearson")
}
}
102 changes: 102 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/stat/StatisticsSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat

import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.Matrix
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}


class StatisticsSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {

import StatisticsSuite._

val xData = Array(1.0, 0.0, -2.0)
val yData = Array(4.0, 5.0, 3.0)
val zeros = new Array[Double](3)
val data = Seq(
Vectors.dense(1.0, 0.0, 0.0, -2.0),
Vectors.dense(4.0, 5.0, 0.0, 3.0),
Vectors.dense(6.0, 7.0, 0.0, 8.0),
Vectors.dense(9.0, 0.0, 0.0, 1.0)
)

private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")

private def extract(df: DataFrame): BDM[Double] = {
val Array(Row(mat: Matrix)) = df.collect()
mat.asBreeze.toDenseMatrix
}


test("corr(X) default, pearson") {
val defaultMat = Statistics.corr(X, "features")
val pearsonMat = Statistics.corr(X, "features", "pearson")
// scalastyle:off
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error that the scalastyle gives? I wish there was some way to avoid turning it off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is the alignment of the values, which we realize by padding with 0's.

val expected = BDM(
(1.00000000, 0.05564149, Double.NaN, 0.4004714),
(0.05564149, 1.00000000, Double.NaN, 0.9135959),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.40047142, 0.91359586, Double.NaN, 1.0000000))
// scalastyle:on

assert(matrixApproxEqual(extract(defaultMat), expected))
assert(matrixApproxEqual(extract(pearsonMat), expected))
}

test("corr(X) spearman") {
val spearmanMat = Statistics.corr(X, "features", "spearman")
// scalastyle:off
val expected = BDM(
(1.0000000, 0.1054093, Double.NaN, 0.4000000),
(0.1054093, 1.0000000, Double.NaN, 0.9486833),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.4000000, 0.9486833, Double.NaN, 1.0000000))
// scalastyle:on
assert(matrixApproxEqual(extract(spearmanMat), expected))
}

}


object StatisticsSuite extends Logging {

def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are very nice methods! would it be possible to move them to a place where every test suite could use them? Specifically the matrixApproxEqual.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

if (v1.isNaN) {
v2.isNaN
} else {
math.abs(v1 - v2) <= threshold
}
}

def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
for (i <- 0 until A.rows; j <- 0 until A.cols) {
if (!approxEqual(A(i, j), B(i, j), threshold)) {
logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
return false
}
}
true
}

}