Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Jun 5, 2018
1 parent de91f5d commit 15c087f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.",
(value: String) => value.nonEmpty)

setDefault(srcCol, "src")

/** @group getParam */
@Since("2.4.0")
def getSrcCol: String = getOrDefault(srcCol)
Expand All @@ -89,11 +87,11 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
"Name of the input column for destination vertex IDs.",
(value: String) => value.nonEmpty)

setDefault(dstCol, "dst")

/** @group getParam */
@Since("2.4.0")
def getDstCol: String = $(dstCol)

setDefault(srcCol -> "src", dstCol -> "dst")
}

/**
Expand Down Expand Up @@ -148,6 +146,8 @@ class PowerIterationClustering private[clustering] (
def setWeightCol(value: String): this.type = set(weightCol, value)

/**
* Run the PIC algorithm and returns a cluster assignment for each input vertex.
*
* @param dataset A dataset with columns src, dst, weight representing the affinity matrix,
* which is the matrix A in the PIC paper. Suppose the src column value is i,
* the dst column value is j, the weight column value is similarity s,,ij,,
Expand Down Expand Up @@ -183,14 +183,8 @@ class PowerIterationClustering private[clustering] (
.setMaxIterations($(maxIter))
val model = algorithm.run(rdd)

val assignmentsRDD: RDD[Row] = model.assignments.map { assignment =>
Row(assignment.id, assignment.cluster)
}
val assignmentsSchema = StructType(Seq(
StructField("id", LongType, nullable = false),
StructField("cluster", IntegerType, nullable = false)))

dataset.sparkSession.createDataFrame(assignmentsRDD, assignmentsSchema)
import dataset.sparkSession.implicits._
model.assignments.toDF
}

@Since("2.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.ml.clustering
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -71,22 +71,28 @@ class PowerIterationClusteringSuite extends SparkFunSuite
test("power iteration clustering") {
val n = n1 + n2

val result = new PowerIterationClustering()
val assignments = new PowerIterationClustering()
.setK(2)
.setMaxIter(40)
.setWeightCol("weight")
.assignClusters(data).as[(Long, Int)].collect().toSet
.assignClusters(data)
val localAssignments = assignments
.select('id, 'cluster)
.as[(Long, Int)].collect().toSet
val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
(n1 until n).map(x => (x, 0)).toSet
assert(result === expectedResult)
assert(localAssignments === expectedResult)

val result2 = new PowerIterationClustering()
val assignments2 = new PowerIterationClustering()
.setK(2)
.setMaxIter(10)
.setInitMode("degree")
.setWeightCol("weight")
.assignClusters(data).as[(Long, Int)].collect().toSet
assert(result2 === expectedResult)
.assignClusters(data)
val localAssignments2 = assignments2
.select('id, 'cluster)
.as[(Long, Int)].collect().toSet
assert(localAssignments2 === expectedResult)
}

test("supported input types") {
Expand Down Expand Up @@ -129,6 +135,30 @@ class PowerIterationClusteringSuite extends SparkFunSuite
assert(msg.contains("Similarity must be nonnegative"))
}

test("test default weight") {
val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)

val assignments = new PowerIterationClustering()
.setK(2)
.setMaxIter(40)
.assignClusters(dataWithoutWeight)
val localAssignments = assignments
.select('id, 'cluster)
.as[(Long, Int)].collect().toSet

val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0))

val assignments2 = new PowerIterationClustering()
.setK(2)
.setMaxIter(40)
.assignClusters(dataWithWeightOne)
val localAssignments2 = assignments2
.select('id, 'cluster)
.as[(Long, Int)].collect().toSet

assert(localAssignments === localAssignments2)
}

test("read/write") {
val t = new PowerIterationClustering()
.setK(4)
Expand Down

0 comments on commit 15c087f

Please sign in to comment.