Skip to content

Commit

Permalink
feat: get all algo function
Browse files Browse the repository at this point in the history
  • Loading branch information
wey-gu committed Mar 1, 2023
1 parent fc721f9 commit aa2ea04
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ngdi.`NebulaDataFrameObject` is a Spark DataFrame or Pandas DataFrame, which can

### Functions

- `ngdi.NebulaGraphObject.algo.get_all_algo()` returns all algorithms supported by the engine.
- `ngdi.NebulaGraphObject.algo.pagerank()` runs the PageRank algorithm on the NetworkX Graph. not yet implemented.

## NebulaAlgorithm
Expand Down
61 changes: 52 additions & 9 deletions ngdi/nebula_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from ngdi.nebula_data import NebulaDataFrameObject as NebulaDataFrameObjectImpl


def algo(func):
func.is_algo = True
return func


class NebulaAlgorithm:
def __init__(self, obj: NebulaGraphObjectImpl or NebulaDataFrameObjectImpl):
if isinstance(obj, NebulaGraphObjectImpl):
Expand All @@ -33,6 +38,17 @@ class NebulaDataFrameAlgorithm:

def __init__(self, ndf_obj: NebulaDataFrameObjectImpl):
self.ndf_obj = ndf_obj
self.algorithms = []

def register_algo(self, func):
self.algorithms.append(func.__name__)

def get_all_algo(self):
if not self.algorithms:
for name, func in NebulaDataFrameAlgorithm.__dict__.items():
if hasattr(func, "is_algo"):
self.register_algo(func)
return self.algorithms

def check_engine(self):
"""
Expand Down Expand Up @@ -73,6 +89,7 @@ def get_spark_dataframe(self):
)
return df

@algo
def pagerank(
self, reset_prob: float = 0.15, max_iter: int = 10, weighted: bool = False
):
Expand All @@ -85,6 +102,7 @@ def pagerank(

return result

@algo
def connected_components(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"CcConfig", "ConnectedComponentsAlgo"
Expand All @@ -97,6 +115,7 @@ def connected_components(self, max_iter: int = 10, weighted: bool = False):

return result

@algo
def label_propagation(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"LPAConfig", "LabelPropagationAlgo"
Expand All @@ -110,6 +129,7 @@ def label_propagation(self, max_iter: int = 10, weighted: bool = False):

return result

@algo
def louvain(self, max_iter: int = 10, internalIter: int = 10, tol: float = 0.0001):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"LouvainConfig", "LouvainAlgo"
Expand All @@ -121,6 +141,7 @@ def louvain(self, max_iter: int = 10, internalIter: int = 10, tol: float = 0.000

return result

@algo
def k_core(self, max_iter: int = 10, degree: int = 2):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"KCoreConfig", "KCoreAlgo"
Expand All @@ -145,6 +166,7 @@ def k_core(self, max_iter: int = 10, degree: int = 2):

# return result

@algo
def degree_statics(self):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"DegreeStaticConfig", "DegreeStaticAlgo"
Expand All @@ -156,6 +178,7 @@ def degree_statics(self):

return result

@algo
def betweenness_centrality(
self, max_iter: int = 10, degree: int = 2, weighted: bool = False
):
Expand All @@ -171,6 +194,7 @@ def betweenness_centrality(

return result

@algo
def coefficient_centrality(self, type: str = "local"):
# type could be either "local" or "global"
assert type.lower() in ["local", "global"], (
Expand All @@ -187,6 +211,7 @@ def coefficient_centrality(self, type: str = "local"):

return result

@algo
def bfs(self, max_depth: int = 10, root: int = 1):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"BfsConfig", "BfsAlgo"
Expand All @@ -199,18 +224,19 @@ def bfs(self, max_depth: int = 10, root: int = 1):
return result

# dfs is not yet supported, need to revisit upstream nebula-algorithm
#
# def dfs(self, max_depth: int = 10, root: int = 1):
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
# "DfsConfig", "DfsAlgo"
# )
# df = self.get_spark_dataframe()
@algo
def dfs(self, max_depth: int = 10, root: int = 1):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"DfsConfig", "DfsAlgo"
)
df = self.get_spark_dataframe()

# config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id)
# result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config)
config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id)
result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config)

# return result
return result

@algo
def hanp(
self,
hop_attenuation: float = 0.5,
Expand All @@ -233,6 +259,7 @@ def hanp(

return result

# @algo
# def node2vec(
# self,
# max_iter: int = 10,
Expand Down Expand Up @@ -277,6 +304,7 @@ def hanp(

# return result

@algo
def jaccard(self, tol: float = 1.0):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"JaccardConfig", "JaccardAlgo"
Expand All @@ -288,6 +316,7 @@ def jaccard(self, tol: float = 1.0):

return result

@algo
def strong_connected_components(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"CcConfig", "StronglyConnectedComponentsAlgo"
Expand All @@ -300,6 +329,7 @@ def strong_connected_components(self, max_iter: int = 10, weighted: bool = False

return result

@algo
def triangle_count(self):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"TriangleConfig", "TriangleCountAlgo"
Expand All @@ -310,6 +340,7 @@ def triangle_count(self):

return result

# @algo
# def closeness(self, weighted: bool = False):
# # TBD: ClosenessAlgo is not yet encodeID compatible
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
Expand All @@ -329,6 +360,17 @@ class NebulaGraphAlgorithm:

def __init__(self, graph):
self.graph = graph
self.algorithms = []

def register_algo(self, func):
self.algorithms.append(func.__name__)

def get_all_algo(self):
if not self.algorithms:
for name, func in NebulaGraphAlgorithm.__dict__.items():
if hasattr(func, "is_algo"):
self.register_algo(func)
return self.algorithms

def check_engine(self):
"""
Expand All @@ -343,6 +385,7 @@ def check_engine(self):
"For example: df = nebula_graph.to_df; df.algo.pagerank()",
)

@algo
def pagerank(self, reset_prob=0.15, max_iter=10):
self.check_engine()
pass

0 comments on commit aa2ea04

Please sign in to comment.