From bb331dbb96e83089cac941f1eb0f51aedc72d89c Mon Sep 17 00:00:00 2001 From: Semyon Date: Sat, 24 Feb 2024 00:17:27 +0100 Subject: [PATCH] [FEAT][PySpark] Update PySpark bindings following GraphAr Spark (#374) --- pyspark/graphar_pyspark/info.py | 40 +++++++++++++++++++++-- pyspark/tests/test_info.py | 58 +++++++++++++++------------------ 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/pyspark/graphar_pyspark/info.py b/pyspark/graphar_pyspark/info.py index d20e5a4da..9ea156b9e 100644 --- a/pyspark/graphar_pyspark/info.py +++ b/pyspark/graphar_pyspark/info.py @@ -41,6 +41,7 @@ def __init__( name: Optional[str], data_type: Optional[GarType], is_primary: Optional[bool], + is_nullable: Optional[bool], jvm_obj: Optional[JavaObject] = None, ) -> None: """One should not use this constructor directly, please use `from_scala` or `from_python`.""" @@ -52,6 +53,7 @@ def __init__( property_pyobj.setName(name) property_pyobj.setData_type(data_type.value) property_pyobj.setIs_primary(is_primary) + property_pyobj.setIs_nullable(is_nullable) self._jvm_property_obj = property_pyobj @@ -97,6 +99,20 @@ def set_is_primary(self, is_primary: bool) -> None: """ self._jvm_property_obj.setIs_primary(is_primary) + def set_is_nullable(self, is_nullable: bool) -> None: + """Mutate corresponding JVM object. + + :param is_nullable: is nullable + """ + self._jvm_property_obj.setIs_nullable(is_nullable) + + def get_is_nullable(self) -> bool: + """Get is nullable flag from corresponding JVM object. + + :returns: is nullable + """ + return self._jvm_property_obj.getIs_nullable() + def to_scala(self) -> JavaObject: """Transform object to JVM representation. @@ -111,7 +127,7 @@ def from_scala(cls: type[PropertyType], jvm_obj: JavaObject) -> PropertyType: :param jvm_obj: scala object in JVM. :returns: instance of Python Class. """ - return cls(None, None, None, jvm_obj) + return cls(None, None, None, None, jvm_obj) @classmethod def from_python( @@ -119,15 +135,17 @@ def from_python( name: str, data_type: GarType, is_primary: bool, + is_nullable: Optional[bool] = None, ) -> PropertyType: """Create an instance of the Class from Python arguments. :param name: property name :param data_type: property data type :param is_primary: flag that property is primary + :param is_nullable: flag that property is nullable (optional, default is None) :returns: instance of Python Class. """ - return cls(name, data_type, is_primary, None) + return cls(name, data_type, is_primary, is_nullable, None) def __eq__(self, other: object) -> bool: if not isinstance(other, Property): @@ -137,6 +155,7 @@ def __eq__(self, other: object) -> bool: (self.get_name() == other.get_name()) and (self.get_data_type() == other.get_data_type()) and (self.get_is_primary() == other.get_is_primary()) + and (self.get_is_nullable() == other.get_is_nullable()) ) @@ -427,6 +446,14 @@ def get_primary_key(self) -> str: """ return self._jvm_vertex_info_obj.getPrimaryKey() + def is_nullable_key(self, property_name: str) -> bool: + """Check if the property is nullable key. + + :param property_name: name of the property to check. + :returns: true if the property if a nullable key of vertex info, otherwise return false + """ + return self._jvm_vertex_info_obj.isNullableKey(property_name) + def is_validated(self) -> bool: """Check if the vertex info is validated. @@ -1047,6 +1074,15 @@ def is_primary_key(self, property_name: str) -> bool: """ return self._jvm_edge_info_obj.isPrimaryKey(property_name) + def is_nullable_key(self, property_name: str) -> bool: + """Check the property is nullable key of edge info. + + :param property_name: name of the property. + :returns: true if the property is the nullable key of edge info, false if not. If + edge info not contains the property, raise an IllegalArgumentException error. + """ + return self._jvm_edge_info_obj.isNullableKey(property_name) + def get_primary_key(self) -> str: """Get Primary key of edge info. diff --git a/pyspark/tests/test_info.py b/pyspark/tests/test_info.py index c62136e48..1ec25f53b 100644 --- a/pyspark/tests/test_info.py +++ b/pyspark/tests/test_info.py @@ -16,16 +16,11 @@ import pytest import yaml + from graphar_pyspark import initialize from graphar_pyspark.enums import AdjListType, FileType, GarType -from graphar_pyspark.info import ( - AdjList, - EdgeInfo, - GraphInfo, - Property, - PropertyGroup, - VertexInfo, -) +from graphar_pyspark.info import (AdjList, EdgeInfo, GraphInfo, Property, + PropertyGroup, VertexInfo) from pyspark.sql.utils import IllegalArgumentException GRAPHAR_TESTS_EXAMPLES = Path(__file__).parent.parent.parent.joinpath("testing") @@ -33,19 +28,21 @@ def test_property(spark): initialize(spark) - property_from_py = Property.from_python("name", GarType.BOOL, False) + property_from_py = Property.from_python("name", GarType.BOOL, False, False) assert property_from_py == Property.from_scala(property_from_py.to_scala()) assert property_from_py != 0 - assert property_from_py == Property.from_python("name", GarType.BOOL, False) + assert property_from_py == Property.from_python("name", GarType.BOOL, False, False) property_from_py.set_name("new_name") property_from_py.set_data_type(GarType.INT32) property_from_py.set_is_primary(True) + property_from_py.set_is_nullable(True) assert property_from_py.get_name() == "new_name" assert property_from_py.get_data_type() == GarType.INT32 assert property_from_py.get_is_primary() == True + assert property_from_py.get_is_nullable() == True def test_property_group(spark): @@ -54,8 +51,8 @@ def test_property_group(spark): "prefix", FileType.CSV, [ - Property.from_python("non_primary", GarType.DOUBLE, False), - Property.from_python("primary", GarType.INT64, True), + Property.from_python("non_primary", GarType.DOUBLE, False, False), + Property.from_python("primary", GarType.INT64, True, False), ], ) @@ -66,7 +63,7 @@ def test_property_group(spark): p_group_from_py.set_file_type(FileType.ORC) p_group_from_py.set_properties( p_group_from_py.get_properties() - + [Property("another_one", GarType.LIST, False)] + + [Property("another_one", GarType.LIST, False, False)] ) assert p_group_from_py.get_prefix() == "new_prefix" @@ -76,9 +73,9 @@ def test_property_group(spark): for p_left, p_right in zip( p_group_from_py.get_properties(), [ - Property.from_python("non_primary", GarType.DOUBLE, False), - Property.from_python("primary", GarType.INT64, True), - Property("another_one", GarType.LIST, False), + Property.from_python("non_primary", GarType.DOUBLE, False, False), + Property.from_python("primary", GarType.INT64, True, False), + Property("another_one", GarType.LIST, False, False), ], ) ) @@ -115,14 +112,14 @@ def test_vertex_info(spark): initialize(spark) props_list_1 = [ - Property.from_python("non_primary", GarType.DOUBLE, False), - Property.from_python("primary", GarType.INT64, True), + Property.from_python("non_primary", GarType.DOUBLE, False, False), + Property.from_python("primary", GarType.INT64, True, False), ] props_list_2 = [ - Property.from_python("non_primary", GarType.DOUBLE, False), - Property.from_python("primary", GarType.INT64, True), - Property("another_one", GarType.LIST, False), + Property.from_python("non_primary", GarType.DOUBLE, False, False), + Property.from_python("primary", GarType.INT64, True, False), + Property("another_one", GarType.LIST, False, False), ] vertex_info_from_py = VertexInfo.from_python( @@ -136,6 +133,8 @@ def test_vertex_info(spark): "1", ) + assert vertex_info_from_py.is_nullable_key("non_primary") == False + assert vertex_info_from_py.contain_property_group( PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1) ) @@ -289,8 +288,8 @@ def test_edge_info(spark): assert py_edge_info.get_version() == "v2" props_list_1 = [ - Property.from_python("non_primary", GarType.DOUBLE, False), - Property.from_python("primary", GarType.INT64, True), + Property.from_python("non_primary", GarType.DOUBLE, False, False), + Property.from_python("primary", GarType.INT64, True, False), ] py_edge_info.set_adj_lists( [ @@ -304,12 +303,12 @@ def test_edge_info(spark): ) py_edge_info.set_property_groups( [ - PropertyGroup.from_python( - "prefix1", FileType.PARQUET, props_list_1 - ), + PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1), ], ) - + + assert py_edge_info.is_nullable_key("non_primary") == False + assert len(py_edge_info.get_adj_lists()) == 1 assert len(py_edge_info.get_property_groups()) == 1 @@ -354,10 +353,7 @@ def test_edge_info(spark): person_knows_person_info.get_adj_list_file_type(AdjListType.ORDERED_BY_SOURCE) != 0 ) - assert ( - len(person_knows_person_info.get_property_groups()) - == 1 - ) + assert len(person_knows_person_info.get_property_groups()) == 1 assert person_knows_person_info.contain_property_group( person_knows_person_info.get_property_groups()[0], )