Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get all algo function #12

Merged
merged 1 commit into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ngdi.`NebulaDataFrameObject` is a Spark DataFrame or Pandas DataFrame, which can

### Functions

- `ngdi.NebulaGraphObject.algo.get_all_algo()` returns all algorithms supported by the engine.
- `ngdi.NebulaGraphObject.algo.pagerank()` runs the PageRank algorithm on the NetworkX Graph. not yet implemented.

## NebulaAlgorithm
Expand Down
61 changes: 52 additions & 9 deletions ngdi/nebula_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from ngdi.nebula_data import NebulaDataFrameObject as NebulaDataFrameObjectImpl


def algo(func):
func.is_algo = True
return func


class NebulaAlgorithm:
def __init__(self, obj: NebulaGraphObjectImpl or NebulaDataFrameObjectImpl):
if isinstance(obj, NebulaGraphObjectImpl):
Expand All @@ -33,6 +38,17 @@ class NebulaDataFrameAlgorithm:

def __init__(self, ndf_obj: NebulaDataFrameObjectImpl):
self.ndf_obj = ndf_obj
self.algorithms = []

def register_algo(self, func):
self.algorithms.append(func.__name__)

def get_all_algo(self):
if not self.algorithms:
for name, func in NebulaDataFrameAlgorithm.__dict__.items():
if hasattr(func, "is_algo"):
self.register_algo(func)
return self.algorithms

def check_engine(self):
"""
Expand Down Expand Up @@ -73,6 +89,7 @@ def get_spark_dataframe(self):
)
return df

@algo
def pagerank(
self, reset_prob: float = 0.15, max_iter: int = 10, weighted: bool = False
):
Expand All @@ -85,6 +102,7 @@ def pagerank(

return result

@algo
def connected_components(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"CcConfig", "ConnectedComponentsAlgo"
Expand All @@ -97,6 +115,7 @@ def connected_components(self, max_iter: int = 10, weighted: bool = False):

return result

@algo
def label_propagation(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"LPAConfig", "LabelPropagationAlgo"
Expand All @@ -110,6 +129,7 @@ def label_propagation(self, max_iter: int = 10, weighted: bool = False):

return result

@algo
def louvain(self, max_iter: int = 10, internalIter: int = 10, tol: float = 0.0001):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"LouvainConfig", "LouvainAlgo"
Expand All @@ -121,6 +141,7 @@ def louvain(self, max_iter: int = 10, internalIter: int = 10, tol: float = 0.000

return result

@algo
def k_core(self, max_iter: int = 10, degree: int = 2):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"KCoreConfig", "KCoreAlgo"
Expand All @@ -145,6 +166,7 @@ def k_core(self, max_iter: int = 10, degree: int = 2):

# return result

@algo
def degree_statics(self):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"DegreeStaticConfig", "DegreeStaticAlgo"
Expand All @@ -156,6 +178,7 @@ def degree_statics(self):

return result

@algo
def betweenness_centrality(
self, max_iter: int = 10, degree: int = 2, weighted: bool = False
):
Expand All @@ -171,6 +194,7 @@ def betweenness_centrality(

return result

@algo
def coefficient_centrality(self, type: str = "local"):
# type could be either "local" or "global"
assert type.lower() in ["local", "global"], (
Expand All @@ -187,6 +211,7 @@ def coefficient_centrality(self, type: str = "local"):

return result

@algo
def bfs(self, max_depth: int = 10, root: int = 1):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"BfsConfig", "BfsAlgo"
Expand All @@ -199,18 +224,19 @@ def bfs(self, max_depth: int = 10, root: int = 1):
return result

# dfs is not yet supported, need to revisit upstream nebula-algorithm
#
# def dfs(self, max_depth: int = 10, root: int = 1):
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
# "DfsConfig", "DfsAlgo"
# )
# df = self.get_spark_dataframe()
@algo
def dfs(self, max_depth: int = 10, root: int = 1):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"DfsConfig", "DfsAlgo"
)
df = self.get_spark_dataframe()

# config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id)
# result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config)
config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id)
result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config)

# return result
return result

@algo
def hanp(
self,
hop_attenuation: float = 0.5,
Expand All @@ -233,6 +259,7 @@ def hanp(

return result

# @algo
# def node2vec(
# self,
# max_iter: int = 10,
Expand Down Expand Up @@ -277,6 +304,7 @@ def hanp(

# return result

@algo
def jaccard(self, tol: float = 1.0):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"JaccardConfig", "JaccardAlgo"
Expand All @@ -288,6 +316,7 @@ def jaccard(self, tol: float = 1.0):

return result

@algo
def strong_connected_components(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"CcConfig", "StronglyConnectedComponentsAlgo"
Expand All @@ -300,6 +329,7 @@ def strong_connected_components(self, max_iter: int = 10, weighted: bool = False

return result

@algo
def triangle_count(self):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
"TriangleConfig", "TriangleCountAlgo"
Expand All @@ -310,6 +340,7 @@ def triangle_count(self):

return result

# @algo
# def closeness(self, weighted: bool = False):
# # TBD: ClosenessAlgo is not yet encodeID compatible
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
Expand All @@ -329,6 +360,17 @@ class NebulaGraphAlgorithm:

def __init__(self, graph):
self.graph = graph
self.algorithms = []

def register_algo(self, func):
self.algorithms.append(func.__name__)

def get_all_algo(self):
if not self.algorithms:
for name, func in NebulaGraphAlgorithm.__dict__.items():
if hasattr(func, "is_algo"):
self.register_algo(func)
return self.algorithms

def check_engine(self):
"""
Expand All @@ -343,6 +385,7 @@ def check_engine(self):
"For example: df = nebula_graph.to_df; df.algo.pagerank()",
)

@algo
def pagerank(self, reset_prob=0.15, max_iter=10):
self.check_engine()
pass