diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala index 9811185..cba8814 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala @@ -200,16 +200,16 @@ object CoefficientConfig { /** * bfs */ -case class BfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false) +case class BfsConfig(maxIter: Int, root: String, encodeId: Boolean = false) object BfsConfig { var maxIter: Int = _ - var root: Long = _ + var root: String = _ var encodeId: Boolean = false def getBfsConfig(configs: Configs): BfsConfig = { val bfsConfig = configs.algorithmConfig.map maxIter = bfsConfig("algorithm.bfs.maxIter").toInt - root = bfsConfig("algorithm.bfs.root").toLong + root = bfsConfig("algorithm.bfs.root").toString encodeId = ConfigUtil.getOrElseBoolean(bfsConfig, "algorithm.bfs.encodeId", false) BfsConfig(maxIter, root, encodeId) } @@ -218,16 +218,16 @@ object BfsConfig { /** * dfs */ -case class DfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false) +case class DfsConfig(maxIter: Int, root: String, encodeId: Boolean = false) object DfsConfig { var maxIter: Int = _ - var root: Long = _ + var root: String = _ var encodeId: Boolean = false def getDfsConfig(configs: Configs): DfsConfig = { val dfsConfig = configs.algorithmConfig.map maxIter = dfsConfig("algorithm.dfs.maxIter").toInt - root = dfsConfig("algorithm.dfs.root").toLong + root = dfsConfig("algorithm.dfs.root").toString encodeId = ConfigUtil.getOrElseBoolean(dfsConfig, "algorithm.dfs.encodeId", false) DfsConfig(maxIter, root, encodeId) } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala index abb94af..93f7311 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala @@ -386,6 +386,7 @@ object AlgoConstants { val HANP_RESULT_COL: String = "hanp" val NODE2VEC_RESULT_COL: String = "node2vec" val BFS_RESULT_COL: String = "bfs" + val DFS_RESULT_COL: String = "dfs" val ENCODE_ID_COL: String = "encodedId" val ORIGIN_ID_COL: String = "id" } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala index 8765f93..96a3fe7 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala @@ -26,15 +26,18 @@ object BfsAlgo { */ def apply(spark: SparkSession, dataset: Dataset[Row], bfsConfig: BfsConfig): DataFrame = { var encodeIdDf: DataFrame = null + var finalRoot: Long = 0 val graph: Graph[None.type, Double] = if (bfsConfig.encodeId) { val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false) encodeIdDf = encodeId + finalRoot = encodeIdDf.filter(row => row.get(0).toString == bfsConfig.root).first().getLong(1) NebulaUtil.loadInitGraph(data, false) } else { + finalRoot = bfsConfig.root.toLong NebulaUtil.loadInitGraph(dataset, false) } - val bfsGraph = execute(graph, bfsConfig.maxIter, bfsConfig.root) + val bfsGraph = execute(graph, bfsConfig.maxIter, finalRoot) // filter out the not traversal vertices val visitedVertices = bfsGraph.vertices.filter(v => v._2 != Double.PositiveInfinity) diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DfsAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DfsAlgo.scala index 1789ae6..1ee5d4f 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DfsAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DfsAlgo.scala @@ -5,9 +5,16 @@ package com.vesoft.nebula.algorithm.lib +import com.vesoft.nebula.algorithm.config.AlgoConstants.{ + ALGO_ID_COL, + DFS_RESULT_COL, + ENCODE_ID_COL, + ORIGIN_ID_COL +} import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig, DfsConfig} import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil} -import org.apache.spark.graphx.{EdgeDirection, Graph, VertexId} +import org.apache.spark.graphx.{EdgeDirection, EdgeTriplet, Graph, Pregel, VertexId} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructField, StructType} @@ -18,21 +25,28 @@ object DfsAlgo { def apply(spark: SparkSession, dataset: Dataset[Row], dfsConfig: DfsConfig): DataFrame = { var encodeIdDf: DataFrame = null + var finalRoot: Long = 0 val graph: Graph[None.type, Double] = if (dfsConfig.encodeId) { val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false) encodeIdDf = encodeId + finalRoot = encodeIdDf.filter(row => row.get(0).toString == dfsConfig.root).first().getLong(1) NebulaUtil.loadInitGraph(data, false) } else { + finalRoot = dfsConfig.root.toLong NebulaUtil.loadInitGraph(dataset, false) } - val bfsVertices = dfs(graph, dfsConfig.root, mutable.Seq.empty[VertexId])(dfsConfig.maxIter) + val bfsVertices = + dfs(graph, finalRoot, mutable.Seq.empty[VertexId])(dfsConfig.maxIter).vertices.filter(v => + v._2 != Double.PositiveInfinity) - val schema = StructType(List(StructField("dfs", LongType, nullable = false))) + val schema = StructType( + List(StructField(ALGO_ID_COL, LongType, nullable = false), + StructField(DFS_RESULT_COL, DoubleType, nullable = true))) - val rdd = spark.sparkContext.parallelize(bfsVertices.toSeq, 1).map(row => Row(row)) - val algoResult = spark.sqlContext - .createDataFrame(rdd, schema) + val resultRDD = bfsVertices.map(v => Row(v._1, v._2)) + val algoResult = + spark.sqlContext.createDataFrame(resultRDD, schema).orderBy(col(DFS_RESULT_COL)) if (dfsConfig.encodeId) { DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf).coalesce(1) @@ -42,18 +56,35 @@ object DfsAlgo { } def dfs(g: Graph[None.type, Double], vertexId: VertexId, visited: mutable.Seq[VertexId])( - maxIter: Int): mutable.Seq[VertexId] = { - if (visited.contains(vertexId)) { - visited - } else { - if (iterNums > maxIter) { - return visited + maxIter: Int): Graph[Double, Double] = { + + val initialGraph = + g.mapVertices((id, _) => if (id == vertexId) 0.0 else Double.PositiveInfinity) + + def vertexProgram(id: VertexId, attr: Double, msg: Double): Double = { + math.min(attr, msg) + } + + def sendMessage(edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] = { + val sourceVertex = edge.srcAttr + val targetVertex = edge.dstAttr + if (sourceVertex + 1 < targetVertex && sourceVertex < maxIter) { + Iterator((edge.dstId, sourceVertex + 1)) + } else { + Iterator.empty } - val newVisited = visited :+ vertexId - val neighbors = g.collectNeighbors(EdgeDirection.Out).lookup(vertexId).flatten - iterNums = iterNums + 1 - neighbors.foldLeft(newVisited)((visited, neighbor) => dfs(g, neighbor._1, visited)(maxIter)) } + + def mergeMessage(a: Double, b: Double): Double = { + math.min(a, b) + } + + //开始迭代 + val resultGraph = + Pregel(initialGraph, Double.PositiveInfinity)(vertexProgram, sendMessage, mergeMessage) + + resultGraph + } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/BfsAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/BfsAlgoSuite.scala index 73c75df..a4c4f12 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/BfsAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/BfsAlgoSuite.scala @@ -14,7 +14,7 @@ class BfsAlgoSuite { def bfsAlgoSuite(): Unit = { val spark = SparkSession.builder().master("local").getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") - val bfsAlgoConfig = new BfsConfig(5, 1) + val bfsAlgoConfig = new BfsConfig(5, "1") val result = BfsAlgo.apply(spark, data, bfsAlgoConfig) result.show() assert(result.count() == 4) diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DfsAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DfsAlgoSuite.scala index e710d58..6d0f3c5 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DfsAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DfsAlgoSuite.scala @@ -13,11 +13,21 @@ import org.junit.Test class DfsAlgoSuite { @Test def bfsAlgoSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = SparkSession + .builder() + .master("local") + .config("spark.sql.shuffle.partitions", 5) + .getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") - val dfsAlgoConfig = new DfsConfig(5, 3) - val result = DfsAlgo.apply(spark, data, dfsAlgoConfig) - result.show() - assert(result.count() == 4) + val dfsAlgoConfig = new DfsConfig(5, "3") +// val result = DfsAlgo.apply(spark, data, dfsAlgoConfig) +// result.show() +// assert(result.count() == 4) + + val encodeDfsConfig = new DfsConfig(5, "3", true) + val encodeResult = DfsAlgo.apply(spark, data, encodeDfsConfig) + + encodeResult.show() + assert(encodeResult.count() == 4) } }