-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19636][ML] Feature parity for correlation statistics in MLlib #17108
Conversation
Test build #73627 has finished for PR 17108 at commit
|
package org.apache.spark.ml.stat | ||
|
||
/** | ||
* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, sorry, removing this file.
* API for statistical functions in MLlib, compatible with Dataframes and Datasets. | ||
* | ||
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] | ||
* to MLlib's Vector types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor terminology comment: should this be ML instead of MLLib? I understand this is for the new ML vector types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will use spark.ml
which is the most correct terminology.
* Compute the correlation matrix for the input RDD of Vectors using the specified method. | ||
* Methods currently supported: `pearson` (default), `spearman`. | ||
* | ||
* @param dataset a dataset or a dataframe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very minor: "Sentence case" params, as in "A dataset...", "The name..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems there are inconsistencies in a lot of comments. I wish we had something like scalastyle for comments...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, thank you. I am correcting the other instances of course.
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { | ||
val rdd = dataset.select(column).rdd.map { | ||
case Row(v: Vector) => OldVectors.fromML(v) | ||
// case r: GenericRowWithSchema => OldVectors.fromML(r.getAs[Vector](0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented out code (?)
} | ||
val oldM = OldStatistics.corr(rdd, method) | ||
val name = s"$method($column)" | ||
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = true))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor comment: ideally shouldn't you check for collisions prior to creating the name - eg add a suffix such as "_2" or _i if the column name already exists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally this would be an infrastructure-level method that just finds a new column name and would be reusable in other code. I don't believe something like this exists.
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to | ||
* avoid recomputing the common lineage. | ||
*/ | ||
// TODO: how do we handle missing values? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is more a comment for the internal implementation of the pearson/spearman calculation, I don't think it should be at this level (maybe moved into the MLLib code?). I think they should just ignore the rows where one of the columns compared have a missing/nan value and log a warning (but only once) when they encounter this -- if all are missing, we just assign a 0 score.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I will remove the comment at this point, since this should be decided in JIRA instead of during the implementation.
|
||
object StatisticsSuite extends Logging { | ||
|
||
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are very nice methods! would it be possible to move them to a place where every test suite could use them? Specifically the matrixApproxEqual.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved
test("corr(X) default, pearson") { | ||
val defaultMat = Statistics.corr(X, "features") | ||
val pearsonMat = Statistics.corr(X, "features", "pearson") | ||
// scalastyle:off |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the error that the scalastyle gives? I wish there was some way to avoid turning it off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is the alignment of the values, which we realize by padding with 0
's.
The changes look good to me. I just had a few minor comments. I wish we could just natively implement the correlations in spark to avoid extra copying between the old and new implementations, but this seems like a move in the right direction. |
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] | ||
* to MLlib's Vector types. | ||
*/ | ||
@Since("2.2.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this have @experimental tag at the top? similar to:
https://github.com/apache/spark/pull/17110/files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, thanks
Given further thought, I'd prefer we stick to the API specified in the design doc, with a Correlations object instead of a generic Statistics object. In the future, we may want optional Params such as weightCol, in which case we may switch to a builder pattern for Correlations and ChiSquare and move away from a shared Statistics object. I'm going to proceed with #17110 using a separate ChiSquare object. |
I moved the code |
Test build #74626 has finished for PR 17108 at commit
|
Test build #74627 has finished for PR 17108 at commit
|
Taking a look now |
*/ | ||
@Since("2.2.0") | ||
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { | ||
val rdd = dataset.select(column).rdd.map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to the code, but does this generate a new rdd or just reference the data in the input dataset? Also, in performance testing, I noticed a lot of operations on rdds are more expensive than on dataframe and dataset (probably because optimizations from catalyst are not used), so it seems we should try to avoid using rdds when doing computations, is this true?
@Since("2.2.0") | ||
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { | ||
val rdd = dataset.select(column).rdd.map { | ||
case Row(v: Vector) => OldVectors.fromML(v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is not a Row of vector, should we throw a nice error message? Otherwise the map will fail.
import org.apache.spark.sql.{DataFrame, Row} | ||
|
||
|
||
class CorrelationsSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a negative test case where we pass a single column instead of a vector in a column?
} | ||
|
||
/** | ||
* Compute the correlation matrix for the input Dataset of Vectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this specify "pearson" correlation in the documentation to be precise?
/** | ||
* Utility test methods for linear algebra. | ||
*/ | ||
object LinalgUtils extends Logging { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is nice, thank you for refactoring the test code here!
the code looks good to me, I added some minor comments, thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done with review; just a few comments. Thanks!
import org.apache.spark.sql.types.{StructField, StructType} | ||
|
||
/** | ||
* API for statistical functions in MLlib, compatible with Dataframes and Datasets. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be limited to correlations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
import org.apache.spark.sql.types.{StructField, StructType} | ||
|
||
/** | ||
* API for statistical functions in MLlib, compatible with Dataframes and Datasets. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add :: Experimental ::
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
*/ | ||
@Since("2.2.0") | ||
@Experimental | ||
object Correlations { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about calling it "Correlation" (singular)? Especially if we add a builder pattern, then I feel like new Correlation().set...
seems more natural.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I do not know if there is a convention for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, but let's make one?
} | ||
val oldM = OldStatistics.corr(rdd, method) | ||
val name = s"$method($column)" | ||
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = true))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nullable = false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. It seems that Spark can be quite liberal with the nullability.
} | ||
|
||
/** | ||
* Compute the correlation matrix for the input Dataset of Vectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just say that this is a version of corr which defaults to "pearson" for the method. Don't document params or return value.
/** | ||
* Utility test methods for linear algebra. | ||
*/ | ||
object LinalgUtils extends Logging { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you use org.apache.spark.ml.util.TestingUtils from mllib-local?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, I had missed that file.
@@ -32,6 +32,10 @@ object TestingUtils { | |||
* the relative tolerance is meaningless, so the exception will be raised to warn users. | |||
*/ | |||
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | |||
// Special case for NaNs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jkbradley I do not think this change is going to be controversial, but I want to point out that from now on, matrix/vector checks will not always throw errors when comparing NaN
: the previous code would throw whenever a NaN was found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you that the update has the right semantics. SGTM
Test build #75060 has finished for PR 17108 at commit
|
} | ||
|
||
/** | ||
* Compute the correlation matrix for the input Dataset of Vectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say "pearson" here explicitly.
LGTM except for the one doc nit. |
LGTM will merge after tests |
Test build #75118 has finished for PR 17108 at commit
|
Merging with master |
What changes were proposed in this pull request?
This patch adds the Dataframes-based support for the correlation statistics found in the
org.apache.spark.mllib.stat.correlation.Statistics
, following the design doc discussed in the JIRA ticket.The current implementation is a simple wrapper around the
spark.mllib
implementation. Future optimizations can be implemented at a later stage.How was this patch tested?