-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPARK-1462: Examples of ML algorithms are using deprecated APIs #416
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,14 @@ | |
package org.apache.spark.examples | ||
|
||
import java.util.Random | ||
import org.apache.spark.util.Vector | ||
import org.apache.spark.SparkContext._ | ||
|
||
import scala.collection.mutable.HashMap | ||
import scala.collection.mutable.HashSet | ||
|
||
import breeze.linalg.{Vector, DenseVector, squaredDistance} | ||
|
||
import org.apache.spark.SparkContext._ | ||
|
||
/** | ||
* K-means clustering. | ||
*/ | ||
|
@@ -36,19 +39,19 @@ object LocalKMeans { | |
|
||
def generateData = { | ||
def generatePoint(i: Int) = { | ||
Vector(D, _ => rand.nextDouble * R) | ||
DenseVector.fill(D){rand.nextDouble * R} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. |
||
} | ||
Array.tabulate(N)(generatePoint) | ||
} | ||
|
||
def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { | ||
def closestPoint(p: Vector[Double], centers: HashMap[Int, Vector[Double]]): Int = { | ||
var index = 0 | ||
var bestIndex = 0 | ||
var closest = Double.PositiveInfinity | ||
|
||
for (i <- 1 to centers.size) { | ||
val vCurr = centers.get(i).get | ||
val tempDist = p.squaredDist(vCurr) | ||
val tempDist = squaredDistance(p, vCurr) | ||
if (tempDist < closest) { | ||
closest = tempDist | ||
bestIndex = i | ||
|
@@ -60,8 +63,8 @@ object LocalKMeans { | |
|
||
def main(args: Array[String]) { | ||
val data = generateData | ||
var points = new HashSet[Vector] | ||
var kPoints = new HashMap[Int, Vector] | ||
var points = new HashSet[Vector[Double]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better if you import |
||
var kPoints = new HashMap[Int, Vector[Double]] | ||
var tempDist = 1.0 | ||
|
||
while (points.size < K) { | ||
|
@@ -81,16 +84,17 @@ object LocalKMeans { | |
var mappings = closest.groupBy[Int] (x => x._1) | ||
|
||
var pointStats = mappings.map { pair => | ||
pair._2.reduceLeft [(Int, (Vector, Int))] { | ||
pair._2.reduceLeft [(Int, (Vector[Double], Int))] { | ||
case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1 + y2)) | ||
} | ||
} | ||
|
||
var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)} | ||
var newPoints = pointStats.map {mapping => | ||
(mapping._1, mapping._2._1 * (1.0 / mapping._2._2))} | ||
|
||
tempDist = 0.0 | ||
for (mapping <- newPoints) { | ||
tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) | ||
tempDist += squaredDistance(kPoints.get(mapping._1).get, mapping._2) | ||
} | ||
|
||
for (newP <- newPoints) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use
rand.nextDouble()
instead ofrand.nextDouble
because it changes the internal state ofrand
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not related to this PR, but it would be great if you update it in the next pass.