Skip to content

Commit

Permalink
feat: add tests
Browse files Browse the repository at this point in the history
pip3 install pdm
pdm install
pdm run test
pdm run lint
pdm run format
  • Loading branch information
wey-gu committed Mar 21, 2023
1 parent 7e11f93 commit df961fa
Show file tree
Hide file tree
Showing 16 changed files with 843 additions and 227 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,8 @@ dmypy.json
.pdm.toml

# requirments.txt is only for local development
requirements.txt
requirements.txt

# ide

.vscode
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,21 @@ ng_ai is an unified abstraction layer for different engines, the current impleme
- [NebulaGraph Python Client 3.4+](https://github.com/vesoft-inc/nebula-python)
- [NetworkX](https://networkx.org/)


## Contributing

```bash
pip3 install pdm
# build and install ng_ai
pdm install
# run tests
pdm run test
# lint
pdm run lint
# format
pdm run format
```

## License

This project is licensed under the terms of the Apache License 2.0.
7 changes: 3 additions & 4 deletions ng_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from pkgutil import extend_path


__path__ = extend_path(__path__, __name__) # type: ignore

from ng_ai.nebula_reader import NebulaReader
from ng_ai.nebula_writer import NebulaWriter
from ng_ai.config import NebulaGraphConfig
from ng_ai.nebula_algo import NebulaAlgorithm
from ng_ai.nebula_gnn import NebulaGNN
from ng_ai.config import NebulaGraphConfig
from ng_ai.nebula_reader import NebulaReader
from ng_ai.nebula_writer import NebulaWriter
from ng_ai.ng_ai_api.app import app as ng_ai_api_app

# export
Expand Down
9 changes: 5 additions & 4 deletions ng_ai/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Copyright 2023 The NebulaGraph Authors. All rights reserved.
from __future__ import annotations

from ng_ai.config import NebulaGraphConfig

# SPARK DEFAULTS
DEFAULT_SHUFFLE_PARTITIONS = 5
DEFAULT_EXECUTOR_MEMORY = "8g"
Expand Down Expand Up @@ -58,7 +56,8 @@ def __init__(self, config):
self.prepare()
self.nebula_spark_ds = NEBULA_SPARK_CONNECTOR_DATASOURCE

# TBD: ping NebulaGraph Meta and Storage Server, fail and guide user to check config
# TBD: ping NebulaGraph Meta and Storage Server
# fail and guide user to check config

def __str__(self):
return f"SparkEngine: {self.spark}"
Expand Down Expand Up @@ -91,7 +90,9 @@ def _get_java_import(self, force=False):

# scala:
# import "com.vesoft.nebula.algorithm.config.SparkConfig"
java_import(self.spark._jvm, "com.vesoft.nebula.algorithm.config.SparkConfig")
java_import(
self.spark._jvm, "com.vesoft.nebula.algorithm.config.SparkConfig"
)
return java_import

def import_scala_class(self, class_name):
Expand Down
92 changes: 49 additions & 43 deletions ng_ai/nebula_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from __future__ import annotations

from ng_ai.nebula_data import NebulaGraphObject as NebulaGraphObjectImpl
from ng_ai.nebula_data import NebulaDataFrameObject as NebulaDataFrameObjectImpl
from ng_ai.nebula_data import NebulaGraphObject as NebulaGraphObjectImpl


def algo(func):
Expand Down Expand Up @@ -53,13 +53,14 @@ def get_all_algo(self):
def check_engine(self):
"""
Check if the engine is supported.
For netowrkx, we need to convert the NebulaDataFrameObject to NebulaGraphObject
For netowrkx, we need to convert the NebulaDataFrameObject
to NebulaGraphObject
For spark, we can directly use the NebulaDataFrameObject
"""
if self.ndf_obj.engine.type == "networkx":
raise Exception(
"For NebulaDataFrameObject in networkx engine,"
"Please transform it to NebulaGraphObject to run algorithm",
"Plz transform it to NebulaGraphObject to run algorithm",
"For example: g = nebula_df.to_graph; g.algo.pagerank()",
)

Expand All @@ -73,7 +74,7 @@ def get_spark_engine_context(self, config_class: str, lib_class: str):
jspark = engine.jspark
engine.import_algo_config_class(config_class)
engine.import_algo_lib_class(lib_class)
return engine, spark, jspark, engine.encode_vertex_id
return engine, spark, jspark, engine.encode_vid

def get_spark_dataframe(self):
"""
Expand All @@ -93,22 +94,22 @@ def get_spark_dataframe(self):
def pagerank(
self, reset_prob: float = 0.15, max_iter: int = 10, weighted: bool = False
):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"PRConfig", "PageRankAlgo"
)
df = self.get_spark_dataframe()
config = spark._jvm.PRConfig(max_iter, reset_prob, encode_vertex_id)
config = spark._jvm.PRConfig(max_iter, reset_prob, encode_vid)
result = spark._jvm.PageRankAlgo.apply(jspark, df._jdf, config, weighted)

return result

@algo
def connected_components(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"CcConfig", "ConnectedComponentsAlgo"
)
df = self.get_spark_dataframe()
config = spark._jvm.CcConfig(max_iter, encode_vertex_id)
config = spark._jvm.CcConfig(max_iter, encode_vid)
result = spark._jvm.ConnectedComponentsAlgo.apply(
jspark, df._jdf, config, weighted
)
Expand All @@ -117,12 +118,12 @@ def connected_components(self, max_iter: int = 10, weighted: bool = False):

@algo
def label_propagation(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"LPAConfig", "LabelPropagationAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.LPAConfig(max_iter, encode_vertex_id)
config = spark._jvm.LPAConfig(max_iter, encode_vid)
result = spark._jvm.LabelPropagationAlgo.apply(
jspark, df._jdf, config, weighted
)
Expand All @@ -131,49 +132,50 @@ def label_propagation(self, max_iter: int = 10, weighted: bool = False):

@algo
def louvain(self, max_iter: int = 20, internalIter: int = 10, tol: float = 0.5):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"LouvainConfig", "LouvainAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.LouvainConfig(max_iter, internalIter, tol, encode_vertex_id)
config = spark._jvm.LouvainConfig(max_iter, internalIter, tol, encode_vid)
result = spark._jvm.LouvainAlgo.apply(jspark, df._jdf, config, False)

return result

@algo
def k_core(self, max_iter: int = 10, degree: int = 2):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"KCoreConfig", "KCoreAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.KCoreConfig(max_iter, degree, encode_vertex_id)
config = spark._jvm.KCoreConfig(max_iter, degree, encode_vid)

result = spark._jvm.KCoreAlgo.apply(jspark, df._jdf, config)

return result

# def shortest_path(self, landmarks: list, weighted: bool = False):
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
# engine, spark, jspark, encode_vid = self.get_spark_engine_context(
# "ShortestPathConfig", "ShortestPathAlgo"
# )
# # TBD: ShortestPathAlgo is not yet encodeID compatible
# df = self.get_spark_dataframe()

# config = spark._jvm.ShortestPathConfig(landmarks, encode_vertex_id)
# result = spark._jvm.ShortestPathAlgo.apply(jspark, df._jdf, config, weighted)
# config = spark._jvm.ShortestPathConfig(landmarks, encode_vid)
# result = spark._jvm.ShortestPathAlgo.apply(
# jspark, df._jdf, config, weighted)

# return result

@algo
def degree_statics(self):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"DegreeStaticConfig", "DegreeStaticAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.DegreeStaticConfig(encode_vertex_id)
config = spark._jvm.DegreeStaticConfig(encode_vid)
result = spark._jvm.DegreeStaticAlgo.apply(jspark, df._jdf, config)

return result
Expand All @@ -182,12 +184,12 @@ def degree_statics(self):
def betweenness_centrality(
self, max_iter: int = 10, degree: int = 2, weighted: bool = False
):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"BetweennessConfig", "BetweennessCentralityAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.BetweennessConfig(max_iter, encode_vertex_id)
config = spark._jvm.BetweennessConfig(max_iter, encode_vid)
result = spark._jvm.BetweennessCentralityAlgo.apply(
jspark, df._jdf, config, weighted
)
Expand All @@ -201,37 +203,37 @@ def coefficient_centrality(self, type: str = "local"):
"type should be either local or global"
f"in coefficient_centrality algo. Got type: {type}"
)
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"CoefficientConfig", "ClusteringCoefficientAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.CoefficientConfig(type, encode_vertex_id)
config = spark._jvm.CoefficientConfig(type, encode_vid)
result = spark._jvm.ClusteringCoefficientAlgo.apply(jspark, df._jdf, config)

return result

@algo
def bfs(self, max_depth: int = 10, root: int = 1):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"BfsConfig", "BfsAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.BfsConfig(max_depth, root, encode_vertex_id)
config = spark._jvm.BfsConfig(max_depth, root, encode_vid)
result = spark._jvm.BfsAlgo.apply(jspark, df._jdf, config)

return result

# dfs is not yet supported, need to revisit upstream nebula-algorithm
@algo
def dfs(self, max_depth: int = 10, root: int = 1):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"DfsConfig", "DfsAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id)
config = spark._jvm.DfsConfig(max_depth, root, encode_vid)
result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config)

return result
Expand All @@ -245,13 +247,13 @@ def hanp(
weighted: bool = False,
preferences=None,
):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"HanpConfig", "HanpAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.HanpConfig(
hop_attenuation, max_iter, preference, encode_vertex_id
hop_attenuation, max_iter, preference, encode_vid
)
result = spark._jvm.HanpAlgo.apply(
jspark, df._jdf, config, weighted, preferences
Expand All @@ -278,7 +280,7 @@ def hanp(
# model_path: str = "hdfs://127.0.0.1:9000/model",
# weighted: bool = False,
# ):
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
# engine, spark, jspark, encode_vid = self.get_spark_engine_context(
# "Node2vecConfig", "Node2VecAlgo"
# )
# # TBD: Node2VecAlgo is not yet encodeID compatible
Expand All @@ -298,31 +300,33 @@ def hanp(
# degree,
# emb_separator,
# model_path,
# encode_vertex_id,
# encode_vid,
# )
# result = spark._jvm.Node2VecAlgo.apply(jspark, df._jdf, config, weighted)

# return result

@algo
def jaccard(self, tol: float = 1.0):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"JaccardConfig", "JaccardAlgo"
)
df = self.get_spark_dataframe()

config = spark._jvm.JaccardConfig(tol, encode_vertex_id)
config = spark._jvm.JaccardConfig(tol, encode_vid)
result = spark._jvm.JaccardAlgo.apply(jspark, df._jdf, config)

return result

@algo
def strong_connected_components(self, max_iter: int = 10, weighted: bool = False):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
def strong_connected_components(
self, max_iter: int = 10, weighted: bool = False
):
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"CcConfig", "StronglyConnectedComponentsAlgo"
)
df = self.get_spark_dataframe()
config = spark._jvm.CcConfig(max_iter, encode_vertex_id)
config = spark._jvm.CcConfig(max_iter, encode_vid)
result = spark._jvm.StronglyConnectedComponentsAlgo.apply(
jspark, df._jdf, config, weighted
)
Expand All @@ -331,24 +335,25 @@ def strong_connected_components(self, max_iter: int = 10, weighted: bool = False

@algo
def triangle_count(self):
engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
engine, spark, jspark, encode_vid = self.get_spark_engine_context(
"TriangleConfig", "TriangleCountAlgo"
)
df = self.get_spark_dataframe()
config = spark._jvm.TriangleConfig(encode_vertex_id)
config = spark._jvm.TriangleConfig(encode_vid)
result = spark._jvm.TriangleCountAlgo.apply(jspark, df._jdf, config)

return result

# @algo
# def closeness(self, weighted: bool = False):
# # TBD: ClosenessAlgo is not yet encodeID compatible
# engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context(
# engine, spark, jspark, encode_vid = self.get_spark_engine_context(
# "ClosenessConfig", "ClosenessAlgo"
# )
# df = self.get_spark_dataframe()
# config = spark._jvm.ClosenessConfig(weighted, encode_vertex_id)
# result = spark._jvm.ClosenessAlgo.apply(jspark, df._jdf, config, False)
# config = spark._jvm.ClosenessConfig(weighted, encode_vid)
# result = spark._jvm.ClosenessAlgo.apply(
# jspark, df._jdf, config, False)

# return result

Expand Down Expand Up @@ -376,12 +381,13 @@ def check_engine(self):
"""
Check if the engine is supported.
For netowrkx, we can directly call .algo.pagerank()
For spark, we need to convert the NebulaGraphObject to NebulaDataFrameObject
For spark, we need to convert the NebulaGraphObject
to NebulaDataFrameObject
"""
if self.graph.engine.type == "spark":
raise Exception(
"For NebulaGraphObject in spark engine,"
"Please transform it to NebulaDataFrameObject to run algorithm",
"Plz transform it to NebulaDataFrameObject to run algorithm",
"For example: df = nebula_graph.to_df; df.algo.pagerank()",
)

Expand Down
2 changes: 1 addition & 1 deletion ng_ai/nebula_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def to_networkx(self):

def to_graphx(self):
if self.engine.type == "spark":
df = self.data
df = self.data # noqa: F841
# convert the df to a graphx graph, not implemented now
raise NotImplementedError
else:
Expand Down
Loading

0 comments on commit df961fa

Please sign in to comment.