Skip to content

Commit

Permalink
Addressing PR issues
Browse files Browse the repository at this point in the history
  • Loading branch information
str-janus committed Dec 10, 2014
1 parent 277d367 commit f8d5928
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
23 changes: 12 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,13 @@ class KMeans private (
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double,
private var seed: Long = System.nanoTime()) extends Serializable with Logging {
private var seed: Long) extends Serializable with Logging {

/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, System.nanoTime()}.
*/
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)

def setSeed(seed: Long): this.type = {
this.seed = seed
this
}
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, System.nanoTime())

/** Set the number of clusters to create (k). Default: 2. */
def setK(k: Int): this.type = {
Expand Down Expand Up @@ -118,6 +113,12 @@ class KMeans private (
this
}

/** Set the random seed for cluster initialization. */
def setSeed(seed: Long): this.type = {
this.seed = seed
this
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
Expand Down Expand Up @@ -339,7 +340,7 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
* @param data training points stored as `RDD[Array[Double]]`
* @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param runs number of parallel runs, defaults to 1. The best model is returned.
Expand All @@ -361,12 +362,12 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
* @param data training points stored as `RDD[Array[Double]]`
* @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param runs number of parallel runs, defaults to 1. The best model is returned.
* @param initializationMode initialization model, either "random" or "k-means||" (default).
* @param seed seed value for cluster initialization
* @param seed random seed value for cluster initialization
*/
def train(
data: RDD[Vector],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(model.clusterCenters.size === 3)
}

test("deterministic initilization") {
// Create a large-ish set of point to cluster
val points = List.tabulate(1000)(n => Vectors.dense(n,n))
test("deterministic initialization") {
// Create a large-ish set of points for clustering
val points = List.tabulate(1000)(n => Vectors.dense(n, n))
val rdd = sc.parallelize(points, 3)

for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
Expand Down

0 comments on commit f8d5928

Please sign in to comment.