Skip to content

Commit

Permalink
[SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single ve…
Browse files Browse the repository at this point in the history
…ctor
  • Loading branch information
dkobylarz committed Jul 16, 2015
1 parent a5f602e commit cef1f0a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class GaussianMixtureModel(
responsibilityMatrix.map(r => r.indexOf(r.max))
}

/** Maps given point to its cluster index. */
def predict(point: Vector): Int = {
val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
r.indexOf(r.max)
}

/** Java-friendly version of [[predict()]] */
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
Expand All @@ -83,6 +89,13 @@ class GaussianMixtureModel(
}
}

/**
* Given the input vector, return the membership values to all mixture components.
*/
def predictSoft(point: Vector): Array[Double] = {
computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
}

/**
* Compute the partial assignments for each vector
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("model prediction, parallel and local") {
val data = sc.parallelize(GaussianTestData.data)
val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)

val batchPredictions = gmm.predict(data)
batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
assert(batchPred === gmm.predict(datum))
}
}

object GaussianTestData {

val data = Array(
Expand Down

0 comments on commit cef1f0a

Please sign in to comment.