Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat][Spark] Update PySpark bindings following GraphAr Spark #374

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions pyspark/graphar_pyspark/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
name: Optional[str],
data_type: Optional[GarType],
is_primary: Optional[bool],
is_nullable: Optional[bool],
jvm_obj: Optional[JavaObject] = None,
) -> None:
"""One should not use this constructor directly, please use `from_scala` or `from_python`."""
Expand All @@ -52,6 +53,7 @@ def __init__(
property_pyobj.setName(name)
property_pyobj.setData_type(data_type.value)
property_pyobj.setIs_primary(is_primary)
property_pyobj.setIs_nullable(is_nullable)

self._jvm_property_obj = property_pyobj

Expand Down Expand Up @@ -97,6 +99,20 @@ def set_is_primary(self, is_primary: bool) -> None:
"""
self._jvm_property_obj.setIs_primary(is_primary)

def set_is_nullable(self, is_nullable: bool) -> None:
"""Mutate corresponding JVM object.

:param is_nullable: is nullable
"""
self._jvm_property_obj.setIs_nullable(is_nullable)

def get_is_nullable(self) -> bool:
"""Get is nullable flag from corresponding JVM object.

:returns: is nullable
"""
return self._jvm_property_obj.getIs_nullable()

def to_scala(self) -> JavaObject:
"""Transform object to JVM representation.

Expand All @@ -111,23 +127,25 @@ def from_scala(cls: type[PropertyType], jvm_obj: JavaObject) -> PropertyType:
:param jvm_obj: scala object in JVM.
:returns: instance of Python Class.
"""
return cls(None, None, None, jvm_obj)
return cls(None, None, None, None, jvm_obj)

@classmethod
def from_python(
cls: type[PropertyType],
name: str,
data_type: GarType,
is_primary: bool,
is_nullable: Optional[bool] = None,
) -> PropertyType:
"""Create an instance of the Class from Python arguments.

:param name: property name
:param data_type: property data type
:param is_primary: flag that property is primary
:param is_nullable: flag that property is nullable (optional, default is None)
:returns: instance of Python Class.
"""
return cls(name, data_type, is_primary, None)
return cls(name, data_type, is_primary, is_nullable, None)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Property):
Expand All @@ -137,6 +155,7 @@ def __eq__(self, other: object) -> bool:
(self.get_name() == other.get_name())
and (self.get_data_type() == other.get_data_type())
and (self.get_is_primary() == other.get_is_primary())
and (self.get_is_nullable() == other.get_is_nullable())
)


Expand Down Expand Up @@ -427,6 +446,14 @@ def get_primary_key(self) -> str:
"""
return self._jvm_vertex_info_obj.getPrimaryKey()

def is_nullable_key(self, property_name: str) -> bool:
"""Check if the property is nullable key.

:param property_name: name of the property to check.
:returns: true if the property if a nullable key of vertex info, otherwise return false
"""
return self._jvm_vertex_info_obj.isNullableKey(property_name)

def is_validated(self) -> bool:
"""Check if the vertex info is validated.

Expand Down Expand Up @@ -1047,6 +1074,15 @@ def is_primary_key(self, property_name: str) -> bool:
"""
return self._jvm_edge_info_obj.isPrimaryKey(property_name)

def is_nullable_key(self, property_name: str) -> bool:
"""Check the property is nullable key of edge info.

:param property_name: name of the property.
:returns: true if the property is the nullable key of edge info, false if not. If
edge info not contains the property, raise an IllegalArgumentException error.
"""
return self._jvm_edge_info_obj.isNullableKey(property_name)

def get_primary_key(self) -> str:
"""Get Primary key of edge info.

Expand Down
58 changes: 27 additions & 31 deletions pyspark/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,33 @@

import pytest
import yaml

from graphar_pyspark import initialize
from graphar_pyspark.enums import AdjListType, FileType, GarType
from graphar_pyspark.info import (
AdjList,
EdgeInfo,
GraphInfo,
Property,
PropertyGroup,
VertexInfo,
)
from graphar_pyspark.info import (AdjList, EdgeInfo, GraphInfo, Property,
PropertyGroup, VertexInfo)
from pyspark.sql.utils import IllegalArgumentException

GRAPHAR_TESTS_EXAMPLES = Path(__file__).parent.parent.parent.joinpath("testing")


def test_property(spark):
initialize(spark)
property_from_py = Property.from_python("name", GarType.BOOL, False)
property_from_py = Property.from_python("name", GarType.BOOL, False, False)

assert property_from_py == Property.from_scala(property_from_py.to_scala())
assert property_from_py != 0
assert property_from_py == Property.from_python("name", GarType.BOOL, False)
assert property_from_py == Property.from_python("name", GarType.BOOL, False, False)

property_from_py.set_name("new_name")
property_from_py.set_data_type(GarType.INT32)
property_from_py.set_is_primary(True)
property_from_py.set_is_nullable(True)

assert property_from_py.get_name() == "new_name"
assert property_from_py.get_data_type() == GarType.INT32
assert property_from_py.get_is_primary() == True
assert property_from_py.get_is_nullable() == True


def test_property_group(spark):
Expand All @@ -54,8 +51,8 @@ def test_property_group(spark):
"prefix",
FileType.CSV,
[
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
],
)

Expand All @@ -66,7 +63,7 @@ def test_property_group(spark):
p_group_from_py.set_file_type(FileType.ORC)
p_group_from_py.set_properties(
p_group_from_py.get_properties()
+ [Property("another_one", GarType.LIST, False)]
+ [Property("another_one", GarType.LIST, False, False)]
)

assert p_group_from_py.get_prefix() == "new_prefix"
Expand All @@ -76,9 +73,9 @@ def test_property_group(spark):
for p_left, p_right in zip(
p_group_from_py.get_properties(),
[
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property("another_one", GarType.LIST, False),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
Property("another_one", GarType.LIST, False, False),
],
)
)
Expand Down Expand Up @@ -115,14 +112,14 @@ def test_vertex_info(spark):
initialize(spark)

props_list_1 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
]

props_list_2 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property("another_one", GarType.LIST, False),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
Property("another_one", GarType.LIST, False, False),
]

vertex_info_from_py = VertexInfo.from_python(
Expand All @@ -136,6 +133,8 @@ def test_vertex_info(spark):
"1",
)

assert vertex_info_from_py.is_nullable_key("non_primary") == False

assert vertex_info_from_py.contain_property_group(
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1)
)
Expand Down Expand Up @@ -289,8 +288,8 @@ def test_edge_info(spark):
assert py_edge_info.get_version() == "v2"

props_list_1 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
]
py_edge_info.set_adj_lists(
[
Expand All @@ -304,12 +303,12 @@ def test_edge_info(spark):
)
py_edge_info.set_property_groups(
[
PropertyGroup.from_python(
"prefix1", FileType.PARQUET, props_list_1
),
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1),
],
)


assert py_edge_info.is_nullable_key("non_primary") == False

assert len(py_edge_info.get_adj_lists()) == 1
assert len(py_edge_info.get_property_groups()) == 1

Expand Down Expand Up @@ -354,10 +353,7 @@ def test_edge_info(spark):
person_knows_person_info.get_adj_list_file_type(AdjListType.ORDERED_BY_SOURCE)
!= 0
)
assert (
len(person_knows_person_info.get_property_groups())
== 1
)
assert len(person_knows_person_info.get_property_groups()) == 1
assert person_knows_person_info.contain_property_group(
person_knows_person_info.get_property_groups()[0],
)
Expand Down
Loading