Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-4156 [MLLIB] EM algorithm for GMMs #3022

Closed
wants to merge 31 commits into from
Closed

Conversation

tgaloppo
Copy link
Contributor

Implementation of Expectation-Maximization for Gaussian Mixture Models.

This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license.

@AmplabJenkins
Copy link

Can one of the admins verify this patch?

@rxin
Copy link
Contributor

rxin commented Nov 1, 2014

Jenkins, test this please.

@manishamde
Copy link
Contributor

@tgaloppo Thanks for the PR and congratulations on the first contribution. Apologies for the lack of feedback thus far -- I guess everyone is busy with the 1.2 release deadline on Nov 1. I will take a look at the PR in the next few days.

Please make sure you get the JIRA assigned to yourself next time before working. It's the only way to avoid duplicate work.

cc: @jkbradley, @mengxr

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/22688/
Test FAILed.

@tgaloppo
Copy link
Contributor Author

tgaloppo commented Nov 8, 2014

This test appeared to fail due to some form of timeout during the pull; is there any action I need to take?

@SparkQA
Copy link

SparkQA commented Nov 9, 2014

Test build #514 has started for PR 3022 at commit c15405c.

  • This patch does not merge cleanly.

@SparkQA
Copy link

SparkQA commented Nov 9, 2014

Test build #514 has finished for PR 3022 at commit c15405c.

  • This patch passes all tests.
  • This patch does not merge cleanly.
  • This patch adds the following public classes (experimental):
    • class GaussianMixtureModel(val w: Array[Double], val mu: Array[Vector], val sigma: Array[Matrix])

/** Sum the values in array of doubles */
private def sum(x : Array[Double]) : Double = {
var s : Double = 0.0
x.foreach(u => s += u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might not care about this at all, but calling foreach on an Array is actually notably slower than using a while loop over the indices. foreach over a Range is actually pretty close to while loop (ie. (0 until x.length).foreach{idx => s += x(idx)}. Or if you don't care about runtimes, then you can always just call array.sum (it actually comes from an implicit conversion to WrappedArray):

scala> ((0 to 100).map{_ / 100.0}.toArray).sum
res2: Double = 50.5

@tgaloppo
Copy link
Contributor Author

Please advise how to resolve merge issues.

@tgaloppo
Copy link
Contributor Author

Thanks, @squito ... while I expect the array to only have a few elements, I have made changes according to your advice.

@tgaloppo
Copy link
Contributor Author

Merged with the latest master branch to hopefully fix any merge issues.
Updated scala test suite to use new MLlibSparkTestContext
Improved cluster initialization strategy to average several samples per cluster.

… and tolerance parameters.

Modified cluster initialization strategy to use an initial covariance matrix derived from the sample points used to initialize the mean.
package org.apache.spark.examples.mllib

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.GaussianMixtureModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this import

val mu = vectorMean(x)
val ss = BreezeVector.zeros[Double](x(0).length)
val cov = BreezeMatrix.eye[Double](ss.length)
x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breeze has squaredDistance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squaredDistance returns a scalar... I want the squared entry values.

Changed ExpectationSum to a private class
@tgaloppo
Copy link
Contributor Author

I've performed most of the requested changes. I do not see the BLAS function mentioned (dsyr), so I left this as a TODO. Also, I could not find EPSILON in MLUtils.

I left predictMembership public and changed predict to predictLabels, providing soft and hard label assignments, respectively. I know there are some other thoughts around improving these, but I am not clear on what I should do.

cc: @mengxr @jkbradley

@FlytxtRnD
Copy link
Contributor

Sorry for late reply.predictLabels() and predictMembership() looks fine.But what about moving the computeSoftAssignments() to GaussianMixtureModelEM class(in KMeans, findClosest() is defined in KMeans rather than in KMeansModel)

It will be good if the name of the class GaussianMixtureModelEM is changed as @mengxr suggested.

}
}

private def run(inputFile: String, k: Int, convergenceTol: Double) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we take maxIterations as an optional input parameter?


/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
def this() = this(2, 0.01, 100)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra newlines

@jkbradley
Copy link
Member

@tgaloppo MLUtils.EPSILON is actually private[util]. I think it would be fine to change it to be private[mllib]. CC: @mengxr

@tgaloppo I strongly recommend predict() instead of predictLabels() to be consistent with KMeansModel.

@FlytxtRnD computeSoftAssignments() is a function of the model, not the learning algorithm, so I think it belongs in the model. IMO, findClosest() should be in KMeansModel instead of KMeans, but that should be fixed in another PR. (It is not too important though since it is a private[mllib] API.)

GaussianMixtureEM: Renamed from GaussianMixtureModelEM; corrected formatting issues

GaussianMixtureModel: Renamed predictLabels() to predict()

Others: Modifications based on rename of GaussianMixtureEM
@tgaloppo
Copy link
Contributor Author

Ok. I changed the privacy of EPSILON and am now using it in this code.
I changed the name from GaussianMixtureModelEM to GaussianMixtureEM.
I've changed predictLabels() back to predict().

@SparkQA
Copy link

SparkQA commented Dec 29, 2014

Test build #555 has started for PR 3022 at commit aaa8f25.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Dec 29, 2014

Test build #555 has finished for PR 3022 at commit aaa8f25.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class GaussianMixtureModel(

@jkbradley
Copy link
Member

@tgaloppo Thanks for the updates, and thanks for all of your work in getting this ready!

LGTM

CC: @mengxr

After this is merged, I'll make some JIRAs for the various item we've discussed along the way + a few more. Let me know if I've missed anything here:

  • Add parameters: seed, maxIterations
  • Use sparse vectors more efficiently
  • If numFeatures or k are large, distribute matrix inverses for Gaussian initialization.
  • Breeze pinv fails when the matrix is singular: [https://github.com/MatrixSingularException when column is 0 all at pinv scalanlp/breeze#304] Do SVD instead.
  • Make MultivariateGaussian public, and update GMM API
  • Check for NaNs:
    • in computeSoftAssignments (if all pdfs = 0)
    • in values when constructing a GMM

@tgaloppo
Copy link
Contributor Author

@jkbradley Thank you for your help and feedback along the way. Please assign some (or all) of those tickets to me and I will continue to improve the implementation. In particular, you mentioned that there are a number of PR's with code for common distributions... I would be happy to help formalize a common interface and make these a public part of the library.

@asfgit asfgit closed this in 6cf6fdf Dec 29, 2014
@mengxr
Copy link
Contributor

mengxr commented Dec 29, 2014

@tgaloppo I've merged this into master. Thanks for contributing GMM!

@FlytxtRnD
Copy link
Contributor

@tgaloppo Good Work
@mengxr Thanks for giving us a chance to be a part of this contribution

@jkbradley
Copy link
Member

@tgaloppo @FlytxtRnD I made some JIRAs for the to-do items above.

I'd say the most important are:

It would be great to do:

Some less critical ones are:

I removed the NAN JIRAs, but we should investigate numerical stability at some point.

Please let me know if you'd like any assigned to you, and thanks in advance for your work on this! If I'm able to work on one of the JIRAs, I'll make a note on the JIRA page.

@tgaloppo
Copy link
Contributor Author

@jkbradley Please assign 5017, 5018, 5019, and 5020 to me. Regarding 5018, can you refer me to other PR's that are bringing in common distributions? I can work toward formalizing an API to make all of them public.

I also indicated that I would be happy to provide the Python wrappers for the algorithm (ticket 5012); @FlytxtRnD had provided an initial Python implementation of the algorithm... if they would like to provide the wrappers instead, that would be cool (but I am still definitely happy to do it if not).

CC: @mengxr

@jkbradley
Copy link
Member

@tgaloppo It's ideal if we assign & fix one JIRA at a time (as separate PRs). Can I start by assigning one of your choosing?

For 5018, there is only one other such PR I know of, and it uses a Dirichlet distribution. But for API examples, I would recommend checking out popular libraries, such as R, Matlab, numpy, etc.

@tgaloppo
Copy link
Contributor Author

@jkbradley No problem. Let's start with 5020, and I'll move on from there.

@tgaloppo
Copy link
Contributor Author

@jkbradley Please assign me SPARK-5017, and I will take care of this in preparation for 5018 and 5019.

@mengxr
Copy link
Contributor

mengxr commented Jan 1, 2015

Done :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants