Skip to content

Commit

Permalink
add Dfs algorithm (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicole00 authored Dec 27, 2022
1 parent d7037a3 commit 78713c1
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 0 deletions.
6 changes: 6 additions & 0 deletions nebula-algorithm/src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@
root:"10"
}

# DFS parameter
dfs:{
maxIter:5
root:"10"
}

# HanpAlgo parameter
hanp:{
hopAttenuation:0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.vesoft.nebula.algorithm.config.{
CcConfig,
CoefficientConfig,
Configs,
DfsConfig,
HanpConfig,
JaccardConfig,
KCoreConfig,
Expand All @@ -30,6 +31,7 @@ import com.vesoft.nebula.algorithm.lib.{
ClusteringCoefficientAlgo,
ConnectedComponentsAlgo,
DegreeStaticAlgo,
DfsAlgo,
GraphTriangleCountAlgo,
HanpAlgo,
JaccardAlgo,
Expand Down Expand Up @@ -204,6 +206,10 @@ object Main {
val bfsConfig = BfsConfig.getBfsConfig(configs)
BfsAlgo(spark, dataSet, bfsConfig)
}
case "dfs" => {
val dfsConfig = DfsConfig.getDfsConfig(configs)
DfsAlgo(spark, dataSet, dfsConfig)
}
case "jaccard" => {
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
JaccardAlgo(spark, dataSet, jaccardConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,22 @@ object BfsConfig {
}
}

/**
* dfs
*/
case class DfsConfig(maxIter: Int, root: Long)
object DfsConfig {
var maxIter: Int = _
var root: Long = _

def getDfsConfig(configs: Configs): DfsConfig = {
val dfsConfig = configs.algorithmConfig.map
maxIter = dfsConfig("algorithm.dfs.maxIter").toInt
root = dfsConfig("algorithm.dfs.root").toLong
DfsConfig(maxIter, root)
}
}

/**
* Hanp
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.vesoft.nebula.algorithm.lib

import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig, DfsConfig}
import com.vesoft.nebula.algorithm.utils.NebulaUtil
import org.apache.spark.graphx.{EdgeDirection, Graph, VertexId}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructField, StructType}

import scala.collection.mutable

object DfsAlgo {
var iterNums = 0

def apply(spark: SparkSession, dataset: Dataset[Row], dfsConfig: DfsConfig): DataFrame = {
val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false)
val bfsVertices = dfs(graph, dfsConfig.root, mutable.Seq.empty[VertexId])(dfsConfig.maxIter)

val schema = StructType(List(StructField("dfs", LongType, nullable = false)))

val rdd = spark.sparkContext.parallelize(bfsVertices.toSeq, 1).map(row => Row(row))
val algoResult = spark.sqlContext
.createDataFrame(rdd, schema)

algoResult.repartition(1)
}

def dfs(g: Graph[None.type, Double], vertexId: VertexId, visited: mutable.Seq[VertexId])(
maxIter: Int): mutable.Seq[VertexId] = {
if (visited.contains(vertexId)) {
visited
} else {
if (iterNums > maxIter) {
return visited
}
val newVisited = visited :+ vertexId
val neighbors = g.collectNeighbors(EdgeDirection.Out).lookup(vertexId).flatten
iterNums = iterNums + 1
neighbors.foldLeft(newVisited)((visited, neighbor) => dfs(g, neighbor._1, visited)(maxIter))
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package scala.com.vesoft.nebula.algorithm.lib

import com.vesoft.nebula.algorithm.config.{BfsConfig, DfsConfig}
import com.vesoft.nebula.algorithm.lib.{BfsAlgo, DfsAlgo}
import org.apache.spark.sql.SparkSession
import org.junit.Test

class DfsAlgoSuite {
@Test
def bfsAlgoSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
val dfsAlgoConfig = new DfsConfig(5, 3)
val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
result.show()
assert(result.count() == 4)
}
}

0 comments on commit 78713c1

Please sign in to comment.