Skip to content

Commit

Permalink
encode root for path algo (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicole00 authored Jun 16, 2023
1 parent 2b364a2 commit 397a8a3
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,16 @@ object CoefficientConfig {
/**
* bfs
*/
case class BfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false)
case class BfsConfig(maxIter: Int, root: String, encodeId: Boolean = false)
object BfsConfig {
var maxIter: Int = _
var root: Long = _
var root: String = _
var encodeId: Boolean = false

def getBfsConfig(configs: Configs): BfsConfig = {
val bfsConfig = configs.algorithmConfig.map
maxIter = bfsConfig("algorithm.bfs.maxIter").toInt
root = bfsConfig("algorithm.bfs.root").toLong
root = bfsConfig("algorithm.bfs.root").toString
encodeId = ConfigUtil.getOrElseBoolean(bfsConfig, "algorithm.bfs.encodeId", false)
BfsConfig(maxIter, root, encodeId)
}
Expand All @@ -218,16 +218,16 @@ object BfsConfig {
/**
* dfs
*/
case class DfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false)
case class DfsConfig(maxIter: Int, root: String, encodeId: Boolean = false)
object DfsConfig {
var maxIter: Int = _
var root: Long = _
var root: String = _
var encodeId: Boolean = false

def getDfsConfig(configs: Configs): DfsConfig = {
val dfsConfig = configs.algorithmConfig.map
maxIter = dfsConfig("algorithm.dfs.maxIter").toInt
root = dfsConfig("algorithm.dfs.root").toLong
root = dfsConfig("algorithm.dfs.root").toString
encodeId = ConfigUtil.getOrElseBoolean(dfsConfig, "algorithm.dfs.encodeId", false)
DfsConfig(maxIter, root, encodeId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ object AlgoConstants {
val HANP_RESULT_COL: String = "hanp"
val NODE2VEC_RESULT_COL: String = "node2vec"
val BFS_RESULT_COL: String = "bfs"
val DFS_RESULT_COL: String = "dfs"
val ENCODE_ID_COL: String = "encodedId"
val ORIGIN_ID_COL: String = "id"
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ object BfsAlgo {
*/
def apply(spark: SparkSession, dataset: Dataset[Row], bfsConfig: BfsConfig): DataFrame = {
var encodeIdDf: DataFrame = null
var finalRoot: Long = 0

val graph: Graph[None.type, Double] = if (bfsConfig.encodeId) {
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
encodeIdDf = encodeId
finalRoot = encodeIdDf.filter(row => row.get(0).toString == bfsConfig.root).first().getLong(1)
NebulaUtil.loadInitGraph(data, false)
} else {
finalRoot = bfsConfig.root.toLong
NebulaUtil.loadInitGraph(dataset, false)
}
val bfsGraph = execute(graph, bfsConfig.maxIter, bfsConfig.root)
val bfsGraph = execute(graph, bfsConfig.maxIter, finalRoot)

// filter out the not traversal vertices
val visitedVertices = bfsGraph.vertices.filter(v => v._2 != Double.PositiveInfinity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

package com.vesoft.nebula.algorithm.lib

import com.vesoft.nebula.algorithm.config.AlgoConstants.{
ALGO_ID_COL,
DFS_RESULT_COL,
ENCODE_ID_COL,
ORIGIN_ID_COL
}
import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig, DfsConfig}
import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil}
import org.apache.spark.graphx.{EdgeDirection, Graph, VertexId}
import org.apache.spark.graphx.{EdgeDirection, EdgeTriplet, Graph, Pregel, VertexId}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructField, StructType}

Expand All @@ -18,21 +25,28 @@ object DfsAlgo {

def apply(spark: SparkSession, dataset: Dataset[Row], dfsConfig: DfsConfig): DataFrame = {
var encodeIdDf: DataFrame = null
var finalRoot: Long = 0

val graph: Graph[None.type, Double] = if (dfsConfig.encodeId) {
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
encodeIdDf = encodeId
finalRoot = encodeIdDf.filter(row => row.get(0).toString == dfsConfig.root).first().getLong(1)
NebulaUtil.loadInitGraph(data, false)
} else {
finalRoot = dfsConfig.root.toLong
NebulaUtil.loadInitGraph(dataset, false)
}
val bfsVertices = dfs(graph, dfsConfig.root, mutable.Seq.empty[VertexId])(dfsConfig.maxIter)
val bfsVertices =
dfs(graph, finalRoot, mutable.Seq.empty[VertexId])(dfsConfig.maxIter).vertices.filter(v =>
v._2 != Double.PositiveInfinity)

val schema = StructType(List(StructField("dfs", LongType, nullable = false)))
val schema = StructType(
List(StructField(ALGO_ID_COL, LongType, nullable = false),
StructField(DFS_RESULT_COL, DoubleType, nullable = true)))

val rdd = spark.sparkContext.parallelize(bfsVertices.toSeq, 1).map(row => Row(row))
val algoResult = spark.sqlContext
.createDataFrame(rdd, schema)
val resultRDD = bfsVertices.map(v => Row(v._1, v._2))
val algoResult =
spark.sqlContext.createDataFrame(resultRDD, schema).orderBy(col(DFS_RESULT_COL))

if (dfsConfig.encodeId) {
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf).coalesce(1)
Expand All @@ -42,18 +56,35 @@ object DfsAlgo {
}

def dfs(g: Graph[None.type, Double], vertexId: VertexId, visited: mutable.Seq[VertexId])(
maxIter: Int): mutable.Seq[VertexId] = {
if (visited.contains(vertexId)) {
visited
} else {
if (iterNums > maxIter) {
return visited
maxIter: Int): Graph[Double, Double] = {

val initialGraph =
g.mapVertices((id, _) => if (id == vertexId) 0.0 else Double.PositiveInfinity)

def vertexProgram(id: VertexId, attr: Double, msg: Double): Double = {
math.min(attr, msg)
}

def sendMessage(edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] = {
val sourceVertex = edge.srcAttr
val targetVertex = edge.dstAttr
if (sourceVertex + 1 < targetVertex && sourceVertex < maxIter) {
Iterator((edge.dstId, sourceVertex + 1))
} else {
Iterator.empty
}
val newVisited = visited :+ vertexId
val neighbors = g.collectNeighbors(EdgeDirection.Out).lookup(vertexId).flatten
iterNums = iterNums + 1
neighbors.foldLeft(newVisited)((visited, neighbor) => dfs(g, neighbor._1, visited)(maxIter))
}

def mergeMessage(a: Double, b: Double): Double = {
math.min(a, b)
}

//开始迭代
val resultGraph =
Pregel(initialGraph, Double.PositiveInfinity)(vertexProgram, sendMessage, mergeMessage)

resultGraph

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BfsAlgoSuite {
def bfsAlgoSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
val bfsAlgoConfig = new BfsConfig(5, 1)
val bfsAlgoConfig = new BfsConfig(5, "1")
val result = BfsAlgo.apply(spark, data, bfsAlgoConfig)
result.show()
assert(result.count() == 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ import org.junit.Test
class DfsAlgoSuite {
@Test
def bfsAlgoSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val spark = SparkSession
.builder()
.master("local")
.config("spark.sql.shuffle.partitions", 5)
.getOrCreate()
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
val dfsAlgoConfig = new DfsConfig(5, 3)
val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
result.show()
assert(result.count() == 4)
val dfsAlgoConfig = new DfsConfig(5, "3")
// val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
// result.show()
// assert(result.count() == 4)

val encodeDfsConfig = new DfsConfig(5, "3", true)
val encodeResult = DfsAlgo.apply(spark, data, encodeDfsConfig)

encodeResult.show()
assert(encodeResult.count() == 4)
}
}

0 comments on commit 397a8a3

Please sign in to comment.