Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2062][GraphX] VertexRDD.apply does not use the mergeFunc #1903

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ object VertexRDD {
*/
def apply[VD: ClassTag](
vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = {
VertexRDD(vertices, edges, defaultVal, (a, b) => b)
VertexRDD(vertices, edges, defaultVal, (a, b) => a)
}

/**
Expand All @@ -419,7 +419,7 @@ object VertexRDD {
(vertexIter, routingTableIter) =>
val routingTable =
if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal))
Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc))
}
new VertexRDD(vertexPartitions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,36 @@ private[graphx]
object ShippableVertexPartition {
/** Construct a `ShippableVertexPartition` from the given vertices without any routing table. */
def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]): ShippableVertexPartition[VD] =
apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD])
apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD], (a, b) => a)

/**
* Construct a `ShippableVertexPartition` from the given vertices with the specified routing
* table, filling in missing vertices mentioned in the routing table using `defaultVal`.
*/
def apply[VD: ClassTag](
iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD)
: ShippableVertexPartition[VD] = {
val fullIter = iter ++ routingTable.iterator.map(vid => (vid, defaultVal))
val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, (a: VD, b: VD) => a)
new ShippableVertexPartition(index, values, mask, routingTable)
: ShippableVertexPartition[VD] =
apply(iter, routingTable, defaultVal, (a, b) => a)

/**
* Construct a `ShippableVertexPartition` from the given vertices with the specified routing
* table, filling in missing vertices mentioned in the routing table using `defaultVal`,
* and merging duplicate vertex atrribute with mergeFunc.
*/
def apply[VD: ClassTag](
iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD,
mergeFunc: (VD, VD) => VD): ShippableVertexPartition[VD] = {
val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD]
// Merge the given vertices using mergeFunc
iter.foreach { pair =>
map.setMerge(pair._1, pair._2, mergeFunc)
}
// Fill in missing vertices mentioned in the routing table
routingTable.iterator.foreach { vid =>
map.changeValue(vid, defaultVal, identity)
}

new ShippableVertexPartition(map.keySet, map._values, map.keySet.getBitSet, routingTable)
}

import scala.language.implicitConversions
Expand Down
11 changes: 11 additions & 0 deletions graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,15 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
}
}

test("mergeFunc") {
// test to see if the mergeFunc is working correctly
withSpark { sc =>
val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3)))
val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b)
// test merge function
assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9)))
}
}

}