Skip to content

Commit

Permalink
[FEAT][PySpark] Update PySpark bindings following GraphAr Spark (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
SemyonSinchenko authored Feb 23, 2024
1 parent a0fdabe commit bb331db
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 33 deletions.
40 changes: 38 additions & 2 deletions pyspark/graphar_pyspark/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
name: Optional[str],
data_type: Optional[GarType],
is_primary: Optional[bool],
is_nullable: Optional[bool],
jvm_obj: Optional[JavaObject] = None,
) -> None:
"""One should not use this constructor directly, please use `from_scala` or `from_python`."""
Expand All @@ -52,6 +53,7 @@ def __init__(
property_pyobj.setName(name)
property_pyobj.setData_type(data_type.value)
property_pyobj.setIs_primary(is_primary)
property_pyobj.setIs_nullable(is_nullable)

self._jvm_property_obj = property_pyobj

Expand Down Expand Up @@ -97,6 +99,20 @@ def set_is_primary(self, is_primary: bool) -> None:
"""
self._jvm_property_obj.setIs_primary(is_primary)

def set_is_nullable(self, is_nullable: bool) -> None:
"""Mutate corresponding JVM object.
:param is_nullable: is nullable
"""
self._jvm_property_obj.setIs_nullable(is_nullable)

def get_is_nullable(self) -> bool:
"""Get is nullable flag from corresponding JVM object.
:returns: is nullable
"""
return self._jvm_property_obj.getIs_nullable()

def to_scala(self) -> JavaObject:
"""Transform object to JVM representation.
Expand All @@ -111,23 +127,25 @@ def from_scala(cls: type[PropertyType], jvm_obj: JavaObject) -> PropertyType:
:param jvm_obj: scala object in JVM.
:returns: instance of Python Class.
"""
return cls(None, None, None, jvm_obj)
return cls(None, None, None, None, jvm_obj)

@classmethod
def from_python(
cls: type[PropertyType],
name: str,
data_type: GarType,
is_primary: bool,
is_nullable: Optional[bool] = None,
) -> PropertyType:
"""Create an instance of the Class from Python arguments.
:param name: property name
:param data_type: property data type
:param is_primary: flag that property is primary
:param is_nullable: flag that property is nullable (optional, default is None)
:returns: instance of Python Class.
"""
return cls(name, data_type, is_primary, None)
return cls(name, data_type, is_primary, is_nullable, None)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Property):
Expand All @@ -137,6 +155,7 @@ def __eq__(self, other: object) -> bool:
(self.get_name() == other.get_name())
and (self.get_data_type() == other.get_data_type())
and (self.get_is_primary() == other.get_is_primary())
and (self.get_is_nullable() == other.get_is_nullable())
)


Expand Down Expand Up @@ -427,6 +446,14 @@ def get_primary_key(self) -> str:
"""
return self._jvm_vertex_info_obj.getPrimaryKey()

def is_nullable_key(self, property_name: str) -> bool:
"""Check if the property is nullable key.
:param property_name: name of the property to check.
:returns: true if the property if a nullable key of vertex info, otherwise return false
"""
return self._jvm_vertex_info_obj.isNullableKey(property_name)

def is_validated(self) -> bool:
"""Check if the vertex info is validated.
Expand Down Expand Up @@ -1047,6 +1074,15 @@ def is_primary_key(self, property_name: str) -> bool:
"""
return self._jvm_edge_info_obj.isPrimaryKey(property_name)

def is_nullable_key(self, property_name: str) -> bool:
"""Check the property is nullable key of edge info.
:param property_name: name of the property.
:returns: true if the property is the nullable key of edge info, false if not. If
edge info not contains the property, raise an IllegalArgumentException error.
"""
return self._jvm_edge_info_obj.isNullableKey(property_name)

def get_primary_key(self) -> str:
"""Get Primary key of edge info.
Expand Down
58 changes: 27 additions & 31 deletions pyspark/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,33 @@

import pytest
import yaml

from graphar_pyspark import initialize
from graphar_pyspark.enums import AdjListType, FileType, GarType
from graphar_pyspark.info import (
AdjList,
EdgeInfo,
GraphInfo,
Property,
PropertyGroup,
VertexInfo,
)
from graphar_pyspark.info import (AdjList, EdgeInfo, GraphInfo, Property,
PropertyGroup, VertexInfo)
from pyspark.sql.utils import IllegalArgumentException

GRAPHAR_TESTS_EXAMPLES = Path(__file__).parent.parent.parent.joinpath("testing")


def test_property(spark):
initialize(spark)
property_from_py = Property.from_python("name", GarType.BOOL, False)
property_from_py = Property.from_python("name", GarType.BOOL, False, False)

assert property_from_py == Property.from_scala(property_from_py.to_scala())
assert property_from_py != 0
assert property_from_py == Property.from_python("name", GarType.BOOL, False)
assert property_from_py == Property.from_python("name", GarType.BOOL, False, False)

property_from_py.set_name("new_name")
property_from_py.set_data_type(GarType.INT32)
property_from_py.set_is_primary(True)
property_from_py.set_is_nullable(True)

assert property_from_py.get_name() == "new_name"
assert property_from_py.get_data_type() == GarType.INT32
assert property_from_py.get_is_primary() == True
assert property_from_py.get_is_nullable() == True


def test_property_group(spark):
Expand All @@ -54,8 +51,8 @@ def test_property_group(spark):
"prefix",
FileType.CSV,
[
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
],
)

Expand All @@ -66,7 +63,7 @@ def test_property_group(spark):
p_group_from_py.set_file_type(FileType.ORC)
p_group_from_py.set_properties(
p_group_from_py.get_properties()
+ [Property("another_one", GarType.LIST, False)]
+ [Property("another_one", GarType.LIST, False, False)]
)

assert p_group_from_py.get_prefix() == "new_prefix"
Expand All @@ -76,9 +73,9 @@ def test_property_group(spark):
for p_left, p_right in zip(
p_group_from_py.get_properties(),
[
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property("another_one", GarType.LIST, False),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
Property("another_one", GarType.LIST, False, False),
],
)
)
Expand Down Expand Up @@ -115,14 +112,14 @@ def test_vertex_info(spark):
initialize(spark)

props_list_1 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
]

props_list_2 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property("another_one", GarType.LIST, False),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
Property("another_one", GarType.LIST, False, False),
]

vertex_info_from_py = VertexInfo.from_python(
Expand All @@ -136,6 +133,8 @@ def test_vertex_info(spark):
"1",
)

assert vertex_info_from_py.is_nullable_key("non_primary") == False

assert vertex_info_from_py.contain_property_group(
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1)
)
Expand Down Expand Up @@ -289,8 +288,8 @@ def test_edge_info(spark):
assert py_edge_info.get_version() == "v2"

props_list_1 = [
Property.from_python("non_primary", GarType.DOUBLE, False),
Property.from_python("primary", GarType.INT64, True),
Property.from_python("non_primary", GarType.DOUBLE, False, False),
Property.from_python("primary", GarType.INT64, True, False),
]
py_edge_info.set_adj_lists(
[
Expand All @@ -304,12 +303,12 @@ def test_edge_info(spark):
)
py_edge_info.set_property_groups(
[
PropertyGroup.from_python(
"prefix1", FileType.PARQUET, props_list_1
),
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1),
],
)


assert py_edge_info.is_nullable_key("non_primary") == False

assert len(py_edge_info.get_adj_lists()) == 1
assert len(py_edge_info.get_property_groups()) == 1

Expand Down Expand Up @@ -354,10 +353,7 @@ def test_edge_info(spark):
person_knows_person_info.get_adj_list_file_type(AdjListType.ORDERED_BY_SOURCE)
!= 0
)
assert (
len(person_knows_person_info.get_property_groups())
== 1
)
assert len(person_knows_person_info.get_property_groups()) == 1
assert person_knows_person_info.contain_property_group(
person_knows_person_info.get_property_groups()[0],
)
Expand Down

0 comments on commit bb331db

Please sign in to comment.