-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed #3610
Changes from 4 commits
35c1884
616d111
5d087b4
9156a57
277d367
f8d5928
7668124
a2ebbd3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,14 +43,20 @@ class KMeans private ( | |
private var runs: Int, | ||
private var initializationMode: String, | ||
private var initializationSteps: Int, | ||
private var epsilon: Double) extends Serializable with Logging { | ||
private var epsilon: Double, | ||
private var seed: Long = System.nanoTime()) extends Serializable with Logging { | ||
|
||
/** | ||
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, | ||
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}. | ||
*/ | ||
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) | ||
|
||
def setSeed(seed: Long): this.type = { | ||
this.seed = seed | ||
this | ||
} | ||
|
||
/** Set the number of clusters to create (k). Default: 2. */ | ||
def setK(k: Int): this.type = { | ||
this.k = k | ||
|
@@ -255,7 +261,7 @@ class KMeans private ( | |
private def initRandom(data: RDD[VectorWithNorm]) | ||
: Array[Array[VectorWithNorm]] = { | ||
// Sample all the cluster centers in one pass to avoid repeated scans | ||
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq | ||
val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq | ||
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => | ||
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) | ||
}.toArray) | ||
|
@@ -273,7 +279,7 @@ class KMeans private ( | |
private def initKMeansParallel(data: RDD[VectorWithNorm]) | ||
: Array[Array[VectorWithNorm]] = { | ||
// Initialize each run's center to a random point | ||
val seed = new XORShiftRandom().nextInt() | ||
val seed = new XORShiftRandom(this.seed).nextInt() | ||
val sample = data.takeSample(true, runs, seed).toSeq | ||
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) | ||
|
||
|
@@ -352,6 +358,31 @@ object KMeans { | |
.run(data) | ||
} | ||
|
||
/** | ||
* Trains a k-means model using the given set of parameters. | ||
* | ||
* @param data training points stored as `RDD[Array[Double]]` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please update this outdated doc here and in the other part of KMeans (now that we use Vector instead of Array[Double])? Thanks! |
||
* @param k number of clusters | ||
* @param maxIterations max number of iterations | ||
* @param runs number of parallel runs, defaults to 1. The best model is returned. | ||
* @param initializationMode initialization model, either "random" or "k-means||" (default). | ||
* @param seed seed value for cluster initialization | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In doc: Maybe say "random seed value" instead of "seed value" since I could imagine people mistaking "seed" to mean "initial cluster center" at first glance. |
||
*/ | ||
def train( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you move this to the beginning and make the one without seed call this? |
||
data: RDD[Vector], | ||
k: Int, | ||
maxIterations: Int, | ||
runs: Int, | ||
initializationMode: String, | ||
seed: Long): KMeansModel = { | ||
new KMeans().setK(k) | ||
.setMaxIterations(maxIterations) | ||
.setRuns(runs) | ||
.setInitializationMode(initializationMode) | ||
.setSeed(seed) | ||
.run(data) | ||
} | ||
|
||
/** | ||
* Trains a k-means model using specified parameters and the default values for unspecified. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { | |
assert(model.clusterCenters.size === 3) | ||
} | ||
|
||
test("deterministic initilization") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: "initilization" --> "initialization" |
||
// Create a large-ish set of point to cluster | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: "point" --> "points" |
||
val points = List.tabulate(1000)(n => Vectors.dense(n,n)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scala style: space after comma |
||
val rdd = sc.parallelize(points, 3) | ||
|
||
for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { | ||
// Create three deterministic models and compare cluster means | ||
val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line too wide. wrap at 100 chars. |
||
val centers1 = model1.clusterCenters | ||
|
||
val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42) | ||
val centers2 = model2.clusterCenters | ||
|
||
val model3 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, initializationMode = initMode, seed = 42) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two should be sufficient. |
||
val centers3 = model3.clusterCenters | ||
|
||
assert(centers1.deep == centers2.deep) | ||
assert(centers1.deep == centers3.deep) | ||
} | ||
} | ||
|
||
test("single cluster with big dataset") { | ||
val smallData = Array( | ||
Vectors.dense(1.0, 2.0, 6.0), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,6 +129,23 @@ def test_clustering(self): | |
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) | ||
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) | ||
|
||
def test_clustering_deterministic(self): | ||
from pyspark.mllib.clustering import KMeans | ||
X = range(0, 100, 10) | ||
Y = range(0, 100, 10) | ||
data = [[x, y] for x, y in zip(X, Y)] | ||
clusters1 = KMeans.train(self.sc.parallelize(data), | ||
3, initializationMode="k-means||", seed=42) | ||
clusters2 = KMeans.train(self.sc.parallelize(data), | ||
3, initializationMode="k-means||", seed=42) | ||
clusters3 = KMeans.train(self.sc.parallelize(data), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two should be sufficient. |
||
3, initializationMode="k-means||", seed=42) | ||
centers1 = array(clusters1.centers).flatten().tolist() | ||
centers2 = array(clusters2.centers).flatten().tolist() | ||
centers3 = array(clusters3.centers).flatten().tolist() | ||
self.assertListEqual(centers1, centers2) | ||
self.assertListEqual(centers1, centers3) | ||
|
||
def test_classification(self): | ||
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes | ||
from pyspark.mllib.tree import DecisionTree | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you set the default in the one public constructor instead since that's where other defaults are set?