Skip to content

Commit

Permalink
[Feat][Spark] Align info implementation of spark with c++ (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
acezen authored Jan 12, 2024
1 parent 2faccd8 commit 000e27b
Show file tree
Hide file tree
Showing 18 changed files with 177 additions and 339 deletions.
105 changes: 36 additions & 69 deletions pyspark/graphar_pyspark/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ def __init__(
aligned_by: Optional[str],
prefix: Optional[str],
file_type: Optional[FileType],
property_groups: Optional[Sequence[PropertyGroup]],
jvm_obj: Optional[JavaObject],
) -> None:
"""One should not use this constructor directly, please use `from_scala` or `from_python`."""
Expand All @@ -552,9 +551,6 @@ def __init__(
jvm_adj_list.setAligned_by(aligned_by)
jvm_adj_list.setPrefix(prefix)
jvm_adj_list.setFile_type(file_type.value)
jvm_adj_list.setProperty_groups(
[py_property_group.to_scala() for py_property_group in property_groups],
)
self._jvm_adj_list_obj = jvm_adj_list

def get_ordered(self) -> bool:
Expand Down Expand Up @@ -615,25 +611,6 @@ def set_file_type(self, file_type: FileType) -> None:
"""
self._jvm_adj_list_obj.setFile_type(file_type.value)

def get_property_groups(self) -> Sequence[PropertyGroup]:
"""Get property groups from the corresponding JVM object.
:returns: property groups
"""
return [
PropertyGroup.from_scala(jvm_property_group)
for jvm_property_group in self._jvm_adj_list_obj.getProperty_groups()
]

def set_property_groups(self, property_groups: Sequence[PropertyGroup]) -> None:
"""Mutate the corresponding JVM object.
:param property_groups: new property groups
"""
self._jvm_adj_list_obj.setProperty_groups(
[p_group.to_scala() for p_group in property_groups],
)

def get_adj_list_type(self) -> AdjListType:
"""Get adj list type.
Expand All @@ -658,7 +635,7 @@ def from_scala(
:param jvm_obj: scala object in JVM.
:returns: instance of Python Class.
"""
return AdjList(None, None, None, None, None, jvm_obj)
return AdjList(None, None, None, None, jvm_obj)

@classmethod
def from_python(
Expand All @@ -667,19 +644,17 @@ def from_python(
aligned_by: str,
prefix: str,
file_type: FileType,
property_groups: Sequence[PropertyGroup],
) -> AdjListClassType:
"""Create an instance of the class from python arguments.
:param ordered: ordered flag
:param aligned_by: recommended values are "src" or "dst"
:param prefix: path prefix
:param file_type: file type
:param property_groups: sequence of PropertyGroup objects
"""
if not prefix.endswith(os.sep):
prefix += os.sep
return AdjList(ordered, aligned_by, prefix, file_type, property_groups, None)
return AdjList(ordered, aligned_by, prefix, file_type, None)

def __eq__(self, other: object) -> bool:
if not isinstance(other, AdjList):
Expand All @@ -690,14 +665,6 @@ def __eq__(self, other: object) -> bool:
and (self.get_aligned_by() == other.get_aligned_by())
and (self.get_prefix() == other.get_prefix())
and (self.get_file_type() == other.get_file_type())
and (len(self.get_property_groups()) == len(other.get_property_groups()))
and all(
left_pg == right_pg
for left_pg, right_pg in zip(
self.get_property_groups(),
other.get_property_groups(),
)
)
)


Expand All @@ -719,6 +686,7 @@ def __init__(
directed: Optional[bool],
prefix: Optional[str],
adj_lists: Sequence[AdjList],
property_groups: Optional[Sequence[PropertyGroup]],
version: Optional[str],
jvm_edge_info_obj: JavaObject,
) -> None:
Expand All @@ -739,6 +707,9 @@ def __init__(
edge_info.setAdj_lists(
[py_adj_list.to_scala() for py_adj_list in adj_lists],
)
edge_info.setProperty_groups(
[py_property_group.to_scala() for py_property_group in property_groups],
)
edge_info.setVersion(version)
self._jvm_edge_info_obj = edge_info

Expand Down Expand Up @@ -873,6 +844,27 @@ def set_adj_lists(self, adj_lists: Sequence[AdjList]) -> None:
[py_adj_list.to_scala() for py_adj_list in adj_lists],
)

def get_property_groups(self) -> Sequence[PropertyGroup]:
"""Get the property groups of adj list type.
WARNING! Exceptions from the JVM are not checked inside, it is just a proxy-method!
:returns: property groups of edge info.
"""
return [
PropertyGroup.from_scala(jvm_property_group)
for jvm_property_group in self._jvm_edge_info_obj.getProperty_groups()
]

def set_property_groups(self, property_groups: Sequence[PropertyGroup]) -> None:
"""Mutate the corresponding JVM object.
:param property_groups: the new property groups, sequence of PropertyGroup
"""
self._jvm_edge_info_obj.setProperty_groups(
[py_property_group.to_scala() for py_property_group in property_groups],
)

def get_version(self) -> str:
"""Get GAR version from the corresponding JVM object.
Expand Down Expand Up @@ -912,6 +904,7 @@ def from_scala(cls: type[EdgeInfoType], jvm_obj: JavaObject) -> EdgeInfoType:
None,
None,
None,
None,
jvm_obj,
)

Expand All @@ -927,6 +920,7 @@ def from_python(
directed: bool,
prefix: str,
adj_lists: Sequence[AdjList],
property_groups: Sequence[PropertyGroup],
version: str,
) -> EdgeInfoType:
"""Create an instance of the class from python arguments.
Expand All @@ -940,6 +934,7 @@ def from_python(
:param directed: directed graph flag
:param prefix: path prefix
:param adj_lists: sequence of AdjList objects
:property_groups: sequence of of PropertyGroup objects
:param version: version of GAR format
"""
if not prefix.endswith(os.sep):
Expand All @@ -955,6 +950,7 @@ def from_python(
directed,
prefix,
adj_lists,
property_groups,
version,
None,
)
Expand Down Expand Up @@ -990,41 +986,18 @@ def get_adj_list_file_type(self, adj_list_type: AdjListType) -> FileType:
self._jvm_edge_info_obj.getAdjListFileType(adj_list_type.to_scala()),
)

def get_property_groups(
self,
adj_list_type: AdjListType,
) -> Sequence[PropertyGroup]:
"""Get the property groups of adj list type.
WARNING! Exceptions from the JVM are not checked inside, it is just a proxy-method!
:param adj_list_type: the input adj list type.
:returns: property group of the input adj list type, if edge info not support the adj list type,
raise an IllegalArgumentException error.
"""
return [
PropertyGroup.from_scala(property_group)
for property_group in self._jvm_edge_info_obj.getPropertyGroups(
adj_list_type.to_scala(),
)
]

def contain_property_group(
self,
property_group: PropertyGroup,
adj_list_type: AdjListType,
) -> bool:
"""Check if the edge info contains the property group in certain adj list structure.
"""Check if the edge info contains the property group.
:param property_group: the property group to check.
:param adj_list_type: the type of adj list structure.
:returns: true if the edge info contains the property group in certain adj list
structure. If edge info not support the given adj list type or not
contains the property group in the adj list structure, return false.
structure.
"""
return self._jvm_edge_info_obj.containPropertyGroup(
property_group.to_scala(),
adj_list_type.to_scala(),
)

def contain_property(self, property_name: str) -> bool:
Expand All @@ -1038,23 +1011,17 @@ def contain_property(self, property_name: str) -> bool:
def get_property_group(
self,
property_name: str,
adj_list_type: AdjListType,
) -> PropertyGroup:
"""Get property group that contains property with adj list type.
WARNING! Exceptions from the JVM are not checked inside, it is just a proxy-method!
:param property_name: name of the property.
:param adj_list_type: the type of adj list structure.
:returns: property group that contains the property. If edge info not support the
adj list type, or not find the property group that contains the property,
return false.
:returns: property group that contains the property. If edge info not find the property group that contains the property,
raise error.
"""
return PropertyGroup.from_scala(
self._jvm_edge_info_obj.getPropertyGroup(
property_name,
adj_list_type.to_scala(),
),
self._jvm_edge_info_obj.getPropertyGroup(property_name),
)

def get_property_type(self, property_name: str) -> GarType:
Expand Down
64 changes: 16 additions & 48 deletions pyspark/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,11 @@ def test_property_group(spark):
def test_adj_list(spark):
initialize(spark)

props_list_1 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
]

props_list_2 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property("another_one", GarType.LIST, False),
]

adj_list_from_py = AdjList.from_python(
True,
"dest",
"prefix",
FileType.PARQUET,
[
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1),
PropertyGroup.from_python("prefix2", FileType.ORC, props_list_2),
],
)

assert adj_list_from_py == AdjList.from_scala(adj_list_from_py.to_scala())
Expand All @@ -125,28 +110,6 @@ def test_adj_list(spark):
adj_list_from_py.set_file_type(FileType.CSV)
assert adj_list_from_py.get_file_type() == FileType.CSV

adj_list_from_py.set_property_groups(
adj_list_from_py.get_property_groups()
+ [
PropertyGroup.from_python(
"prefix3", FileType.CSV, props_list_1 + props_list_2
)
]
)
assert all(
pg_left == pg_right
for pg_left, pg_right in zip(
adj_list_from_py.get_property_groups(),
[
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1),
PropertyGroup.from_python("prefix2", FileType.ORC, props_list_2),
PropertyGroup.from_python(
"prefix3", FileType.CSV, props_list_1 + props_list_2
),
],
)
)


def test_vertex_info(spark):
initialize(spark)
Expand Down Expand Up @@ -293,6 +256,7 @@ def test_edge_info(spark):
directed=True,
prefix="prefix",
adj_lists=[],
property_groups=[],
version="v1",
)

Expand Down Expand Up @@ -335,15 +299,19 @@ def test_edge_info(spark):
"dest",
"prefix",
FileType.PARQUET,
[
PropertyGroup.from_python(
"prefix1", FileType.PARQUET, props_list_1
),
],
)
]
)
py_edge_info.set_property_groups(
[
PropertyGroup.from_python(
"prefix1", FileType.PARQUET, props_list_1
),
],
)

assert len(py_edge_info.get_adj_lists()) == 1
assert len(py_edge_info.get_property_groups()) == 1

# Load from YAML
person_knows_person_info = EdgeInfo.load_edge_info(
Expand Down Expand Up @@ -387,12 +355,11 @@ def test_edge_info(spark):
!= 0
)
assert (
len(person_knows_person_info.get_property_groups(AdjListType.ORDERED_BY_SOURCE))
len(person_knows_person_info.get_property_groups())
== 1
)
assert person_knows_person_info.contain_property_group(
person_knows_person_info.get_property_groups(AdjListType.UNORDERED_BY_DEST)[0],
AdjListType.UNORDERED_BY_DEST,
person_knows_person_info.get_property_groups()[0],
)
assert person_knows_person_info.get_property_type("creationDate") == GarType.STRING
assert person_knows_person_info.is_primary_key("creationDate") == False
Expand Down Expand Up @@ -443,7 +410,7 @@ def test_edge_info(spark):
assert (
person_knows_person_info.get_property_file_path(
person_knows_person_info.get_property_group(
"creationDate", AdjListType.ORDERED_BY_SOURCE
"creationDate",
),
AdjListType.ORDERED_BY_SOURCE,
0,
Expand All @@ -454,7 +421,7 @@ def test_edge_info(spark):
assert (
person_knows_person_info.get_property_group_path_prefix(
person_knows_person_info.get_property_group(
"creationDate", AdjListType.ORDERED_BY_SOURCE
"creationDate",
),
AdjListType.ORDERED_BY_SOURCE,
0,
Expand All @@ -463,7 +430,7 @@ def test_edge_info(spark):
assert (
person_knows_person_info.get_property_group_path_prefix(
person_knows_person_info.get_property_group(
"creationDate", AdjListType.ORDERED_BY_SOURCE
"creationDate",
),
AdjListType.ORDERED_BY_SOURCE,
None,
Expand Down Expand Up @@ -536,6 +503,7 @@ def test_graph_info(spark):
True,
"prefix",
[],
[],
"v1",
)
)
Expand Down
4 changes: 2 additions & 2 deletions pyspark/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def test_edge_reader(spark):
assert (
"_graphArEdgeIndex"
in edge_reader.read_edge_property_group(
edge_info.get_property_group("weight", AdjListType.ORDERED_BY_SOURCE)
edge_info.get_property_group("weight")
).columns
)
assert (
edge_reader.read_edge_property_group(
edge_info.get_property_group("weight", AdjListType.ORDERED_BY_SOURCE)
edge_info.get_property_group("weight")
).count()
> 0
)
Expand Down
Loading

0 comments on commit 000e27b

Please sign in to comment.