diff --git a/README.md b/README.md index 3c2a46e..237c5e3 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,13 @@ or Velox. ### Locally To run the gateway locally - you need to setup a Python (Conda) environment. +To run the Spark tests you will need Java installed. + Ensure you have [Miniconda](https://docs.anaconda.com/miniconda/miniconda-install/) and [Rust/Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html) installed. Once that is done - run these steps from a bash terminal: ```bash -git clone --recursive https://github.com//spark-substrait-gateway.git +git clone https://github.com//spark-substrait-gateway.git cd spark-substrait-gateway conda init bash . ~/.bashrc diff --git a/environment.yml b/environment.yml index a6ac78d..f85638d 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,7 @@ dependencies: - setuptools >= 61.0.0 - setuptools_scm >= 6.2.0 - mypy-protobuf - - types-protobuf >= 4.25.0, < 5.0.0 + - types-protobuf >= 5.0.0 - numpy < 2.0.0 - Faker - pip: @@ -27,7 +27,7 @@ dependencies: - substrait == 0.21.0 - substrait-validator - pytest-timeout - - protobuf >= 4.25.3, < 5.0.0 + - protobuf >= 5.0.0 - cryptography == 43.0.* - click == 8.1.* - pyjwt == 2.8.* diff --git a/src/backends/arrow_tools.py b/src/backends/arrow_tools.py new file mode 100644 index 0000000..39dc93a --- /dev/null +++ b/src/backends/arrow_tools.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Routines to manipulate arrow tables.""" +import pyarrow as pa + + +def _reapply_names_to_type(array: pa.ChunkedArray, names: list[str]) -> (pa.Array, list[str]): + new_arrays = [] + new_schema = [] + + if array.type.num_fields > len(names): + raise ValueError('Insufficient number of names provided to reapply names.') + + remaining_names = names + if pa.types.is_list(array.type): + raise NotImplementedError('Reapplying names to lists not yet supported') + if pa.types.is_map(array.type): + raise NotImplementedError('Reapplying names to maps not yet supported') + if pa.types.is_struct(array.type): + field_num = 0 + while field_num < array.type.num_fields: + field = array.chunks[0].field(field_num) + this_name = remaining_names.pop(0) + + new_array, remaining_names = _reapply_names_to_type(field, remaining_names) + new_arrays.append(new_array) + + new_schema.append(pa.field(this_name, new_array.type)) + + field_num += 1 + + return pa.StructArray.from_arrays(new_arrays, fields=new_schema), remaining_names + if array.type.num_fields != 0: + raise ValueError(f'Unsupported complex type: {array.type}') + return array, remaining_names + + +def reapply_names(table: pa.Table, names: list[str]) -> pa.Table: + """Apply the provided names to the given table recursively.""" + new_arrays = [] + new_schema = [] + + remaining_names = names + for column in iter(table.columns): + if not remaining_names: + raise ValueError('Insufficient number of names provided to reapply names.') + + this_name = remaining_names.pop(0) + + new_array, remaining_names = _reapply_names_to_type(column, remaining_names) + new_arrays.append(new_array) + + new_schema.append(pa.field(this_name, new_array.type)) + + if remaining_names: + raise ValueError('Too many names provided to reapply names.') + + return pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema)) diff --git a/src/backends/duckdb_backend.py b/src/backends/duckdb_backend.py index eead9b7..576d0bd 100644 --- a/src/backends/duckdb_backend.py +++ b/src/backends/duckdb_backend.py @@ -11,6 +11,7 @@ from substrait.gen.proto import plan_pb2 from backends.backend import Backend +from src.backends.arrow_tools import reapply_names from transforms.rename_functions import RenameFunctionsForDuckDB @@ -73,7 +74,8 @@ def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: query_result = self._connection.from_substrait(proto=plan_data) except Exception as err: raise ValueError(f"DuckDB Execution Error: {err}") from err - return query_result.arrow() + arrow = query_result.arrow() + return reapply_names(arrow, plan.relations[0].root.names) def register_table( self, diff --git a/src/backends/tests/arrow_tools_test.py b/src/backends/tests/arrow_tools_test.py new file mode 100644 index 0000000..596822a --- /dev/null +++ b/src/backends/tests/arrow_tools_test.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import pyarrow as pa +import pytest + +from src.backends.arrow_tools import reapply_names + + +@dataclass +class TestCase: + name: str + input: pa.Table + names: list[str] + expected: pa.table + fail: bool = False + + +cases: list[TestCase] = [ + TestCase('empty table', pa.Table.from_arrays([]), [], pa.Table.from_arrays([])), + TestCase('too many names', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([]), + fail=True), + TestCase('normal columns', + pa.Table.from_pydict( + {"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]}, + schema=pa.schema({"name": pa.string(), "age": pa.int32()}) + ), + ['renamed_name', 'renamed_age'], + pa.Table.from_pydict( + {"renamed_name": [None, "Joe", "Sarah", None], + "renamed_age": [99, None, 42, None]}, + schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()}) + )), + TestCase('too few names', + pa.Table.from_pydict( + {"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]}, + schema=pa.schema({"name": pa.string(), "age": pa.int32()}) + ), + ['renamed_name'], + pa.Table.from_pydict( + {"renamed_name": [None, "Joe", "Sarah", None], + "renamed_age": [99, None, 42, None]}, + schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()}) + ), + fail=True), + TestCase('struct column', + pa.Table.from_arrays( + [pa.array([{"": 1, "b": "b"}], + type=pa.struct([("", pa.int64()), ("b", pa.string())]))], + names=["r"]), + ['r', 'a', 'b'], + pa.Table.from_arrays( + [pa.array([{"a": 1, "b": "b"}], + type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"]) + ), + # TODO -- Test nested structs. + # TODO -- Test a list. + # TODO -- Test a map. + # TODO -- Test a mixture of complex and simple types. +] + + +class TestArrowTools: + """Tests the functionality of the arrow tools package.""" + + @pytest.mark.parametrize( + "case", cases, ids=lambda case: case.name + ) + def test_reapply_names(self, case): + failed = False + try: + result = reapply_names(case.input, case.names) + except ValueError as _: + result = None + failed = True + if case.fail: + assert failed + else: + assert result == case.expected + diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index c4acffd..f42468d 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -11,7 +11,7 @@ class ConversionOptions: """Holds all the possible conversion options.""" - def __init__(self, backend: BackendOptions = None): + def __init__(self, backend: BackendOptions): """Initialize the conversion options.""" self.use_named_table_workaround = False self.needs_scheme_in_path_uris = False diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 720336c..f067827 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -57,13 +57,13 @@ string_type, strlen, ) -from gateway.converter.symbol_table import SymbolTable +from gateway.converter.symbol_table import Field, SymbolTable class ExpressionProcessingMode(Enum): """The mode of processing expressions.""" - # Processing of an expression outside of an aggregate relation. + # Processing of an expression outside an aggregate relation. NORMAL = 0 # Processing of a measure at depth 0. AGGR_TOP_LEVEL = 1 @@ -81,6 +81,13 @@ def _extract_decimal_parameters(type_name: str) -> tuple[int, int, int]: return int(match.group(1)), int(match.group(2)), int(match.group(3)) +def _find_reference_number(fields: list[Field], name: str) -> int | None: + for reference, field in enumerate(fields): + if field.name == name: + return reference + return None + + # ruff: noqa: RUF005 class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" @@ -101,7 +108,7 @@ def __init__(self, options: ConversionOptions): # These are used when processing expressions inside aggregate relations. self._expression_processing_mode = ExpressionProcessingMode.NORMAL self._top_level_projects: list[algebra_pb2.Rel] = [] - self._next_aggregation_reference_id = None + self._next_aggregation_reference_id = 0 self._aggregations: list[algebra_pb2.AggregateFunction] = [] self._next_under_aggregation_reference_id = 0 self._under_aggregation_projects: list[algebra_pb2.Rel] = [] @@ -114,8 +121,9 @@ def set_backends(self, backend, sql_backend) -> None: def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Find the function reference for a given Spark function name.""" - if name in self._functions: - return self._functions.get(name) + function = self._functions.get(name) + if function is not None: + return function func = lookup_spark_function(name, self._conversion_options) if not func: raise LookupError( @@ -125,7 +133,9 @@ def lookup_function_by_name(self, name: str) -> ExtensionFunction: self._functions[name] = func if not self._function_uris.get(func.uri): self._function_uris[func.uri] = len(self._function_uris) + 1 - return self._functions.get(name) + function = self._functions.get(name) + assert function is not None + return function def update_field_references(self, plan_id: int) -> None: """Use the field references using the specified portion of the plan.""" @@ -133,17 +143,17 @@ def update_field_references(self, plan_id: int) -> None: current_symbol = self._symbol_table.get_symbol(self._current_plan_id) original_output_fields = current_symbol.output_fields for symbol in source_symbol.output_fields: - new_name = symbol - while new_name in original_output_fields: - new_name = new_name + "_dup" - current_symbol.input_fields.append(new_name) - current_symbol.output_fields.append(new_name) + new_symbol = symbol + while _find_reference_number(original_output_fields, new_symbol.name): + new_symbol.name = new_symbol.name + "_dup" + current_symbol.input_fields.append(new_symbol) + current_symbol.output_fields.append(new_symbol) def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) try: - return current_symbol.input_fields.index(field_name) + return _find_reference_number(current_symbol.input_fields, field_name) except ValueError: return None @@ -172,7 +182,7 @@ def convert_string_literal(self, s: str) -> algebra_pb2.Expression.Literal: return algebra_pb2.Expression.Literal(string=s) def convert_literal_expression( - self, literal: spark_exprs_pb2.Expression.Literal + self, literal: spark_exprs_pb2.Expression.Literal ) -> algebra_pb2.Expression: """Convert a Spark literal into a Substrait literal.""" match literal.WhichOneof("literal_type"): @@ -251,7 +261,7 @@ def convert_literal_expression( return algebra_pb2.Expression(literal=result) def convert_unresolved_attribute( - self, attr: spark_exprs_pb2.Expression.UnresolvedAttribute + self, attr: spark_exprs_pb2.Expression.UnresolvedAttribute ) -> algebra_pb2.Expression: """Convert a Spark unresolved attribute into a Substrait field reference.""" field_ref = self.find_field_by_name(attr.unparsed_identifier) @@ -308,7 +318,7 @@ def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2 ) def convert_when_function( - self, when: spark_exprs_pb2.Expression.UnresolvedFunction + self, when: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark when function into a Substrait if-then expression.""" ifthen = algebra_pb2.Expression.IfThen() @@ -324,6 +334,8 @@ def convert_when_function( else: nullable_literal = self.determine_type_of_expression(ifthen.ifs[-1].then) kind = nullable_literal.WhichOneof("kind") + if not kind: + raise ValueError('Missing kind one_of in when function.') getattr( nullable_literal, kind ).nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE @@ -336,7 +348,7 @@ def convert_when_function( return algebra_pb2.Expression(if_then=ifthen) def convert_in_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark in function into a Substrait switch expression.""" @@ -378,7 +390,7 @@ def is_switch_expression_appropriate() -> bool: return algebra_pb2.Expression(if_then=ifthen) def convert_rlike_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark rlike function into a Substrait expression.""" if self._conversion_options.use_duckdb_regexp_matches_function: @@ -419,7 +431,7 @@ def convert_rlike_function( return greater_function(greater_func, regexp_expr, bigint_literal(0)) def convert_nanvl_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark nanvl function into a Substrait expression.""" isnan_func = self.lookup_function_by_name("isnan") @@ -438,7 +450,7 @@ def convert_nanvl_function( return if_then_else_operation(expr, arg1, arg0) def convert_nvl_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark nvl function into a Substrait expression.""" isnull_func = self.lookup_function_by_name("isnull") @@ -457,7 +469,7 @@ def convert_nvl_function( return if_then_else_operation(expr, arg1, arg0) def convert_nvl2_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark nvl2 function into a Substrait expression.""" isnotnull_func = self.lookup_function_by_name("isnotnull") @@ -477,7 +489,7 @@ def convert_nvl2_function( return if_then_else_operation(expr, arg1, arg2) def convert_ifnull_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark ifnull function into a Substrait expression.""" isnan_func = self.lookup_function_by_name("isnull") @@ -495,8 +507,19 @@ def convert_ifnull_function( return if_then_else_operation(expr, arg1, arg0) + def convert_struct_function( + self, func: spark_exprs_pb2.Expression.UnresolvedFunction + ) -> algebra_pb2.Expression: + """Convert a Spark struct function into a Substrait nested expression.""" + nested = algebra_pb2.Expression.Nested() + nested.nullable = False + for spark_arg in func.arguments: + arg = self.convert_expression(spark_arg) + nested.struct.fields.append(arg) + return algebra_pb2.Expression(nested=nested) + def convert_unresolved_function( - self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction + self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression | algebra_pb2.AggregateFunction: """Convert a Spark unresolved function into a Substrait scalar function.""" parent_processing_mode = self._expression_processing_mode @@ -517,11 +540,13 @@ def convert_unresolved_function( return self.convert_nvl2_function(unresolved_function) if unresolved_function.function_name == "ifnull": return self.convert_ifnull_function(unresolved_function) + if unresolved_function.function_name == "struct": + return self.convert_struct_function(unresolved_function) func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) if ( - parent_processing_mode == ExpressionProcessingMode.AGGR_NOT_TOP_LEVEL - and function_def.function_type == FunctionType.AGGREGATE + parent_processing_mode == ExpressionProcessingMode.AGGR_NOT_TOP_LEVEL + and function_def.function_type == FunctionType.AGGREGATE ): self._expression_processing_mode = ExpressionProcessingMode.AGGR_UNDER_AGGREGATE func.function_reference = function_def.anchor @@ -529,8 +554,8 @@ def convert_unresolved_function( if function_def.max_args is not None and idx >= function_def.max_args: break if ( - unresolved_function.function_name == "count" - and arg.WhichOneof("expr_type") == "unresolved_star" + unresolved_function.function_name == "count" + and arg.WhichOneof("expr_type") == "unresolved_star" ): # Ignore all the rest of the arguments. func.arguments.append(algebra_pb2.FunctionArgument(value=bigint_literal(1))) @@ -576,7 +601,7 @@ def convert_unresolved_function( self._expression_processing_mode = parent_processing_mode def convert_alias_expression( - self, alias: spark_exprs_pb2.Expression.Alias + self, alias: spark_exprs_pb2.Expression.Alias ) -> algebra_pb2.Expression: """Convert a Spark alias into a Substrait expression.""" # We do nothing here and let the magic happen in the calling project relation. @@ -607,7 +632,7 @@ def convert_type(self, spark_type: spark_types_pb2.DataType) -> type_pb2.Type: return self.convert_type_str(spark_type.WhichOneof("kind")) def convert_cast_expression( - self, cast: spark_exprs_pb2.Expression.Cast + self, cast: spark_exprs_pb2.Expression.Cast ) -> algebra_pb2.Expression: """Convert a Spark cast expression into a Substrait cast expression.""" cast_rel = algebra_pb2.Expression.Cast( @@ -739,7 +764,7 @@ def convert_window_expression( return algebra_pb2.Expression(window_function=func) def convert_extract_value( - self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue + self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue ) -> algebra_pb2.Expression: """Convert a Spark extract value expression into a Substrait extract value expression.""" # TODO -- Add support for list and map operations. @@ -751,9 +776,16 @@ def convert_extract_value( value=self.convert_unresolved_attribute(extract.child.unresolved_attribute) ) ) - func.arguments.append( - algebra_pb2.FunctionArgument(value=string_literal(extract.extraction.literal.string)) - ) + field_ref = self.find_field_by_name(extract.extraction.literal.string) + if field_ref is None: + func.arguments.append( + algebra_pb2.FunctionArgument(value=string_literal(extract.extraction.literal.string)) + ) + else: + # TODO -- Fix this to be relative to the struct value we found above. + func.arguments.append( + algebra_pb2.FunctionArgument(value=integer_literal(field_ref + 1)) + ) return algebra_pb2.Expression(scalar_function=func) def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Expression | None: @@ -841,11 +873,12 @@ def get_primary_names(self, schema: type_pb2.NamedStruct) -> list[str]: return primary_names def convert_read_named_table_relation( - self, rel: spark_relations_pb2.Read.NamedTable + self, rel: spark_relations_pb2.Read.NamedTable ) -> algebra_pb2.Rel: """Convert a read named table relation to a Substrait relation.""" table_name = rel.unparsed_identifier + assert self._backend is not None arrow_schema = self._backend.describe_table(table_name) schema = self.convert_arrow_schema(arrow_schema) @@ -853,9 +886,10 @@ def convert_read_named_table_relation( symbol = self._symbol_table.get_symbol(self._current_plan_id) if self._conversion_options.use_duckdb_struct_name_behavior: for field_name in self.get_primary_names(schema): - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) else: - symbol.output_fields.extend(schema.names) + for field_name in schema.names: + symbol.output_fields.append(Field(field_name)) return algebra_pb2.Rel( read=algebra_pb2.ReadRel( @@ -892,7 +926,7 @@ def convert_type_name(self, type_name: str) -> type_pb2.Type: case _: raise NotImplementedError(f"Unexpected type name: {type_name}") - def convert_field(self, field: types_pb2.DataType) -> (type_pb2.Type, list[str]): + def convert_field(self, field: types_pb2.DataType) -> tuple[type_pb2.Type, list[str]]: """Convert a Spark field into a Substrait field.""" if field.get("nullable"): nullability = type_pb2.Type.NULLABILITY_NULLABLE @@ -940,6 +974,8 @@ def convert_field(self, field: types_pb2.DataType) -> (type_pb2.Type, list[str]) elif ft.get("type") == "struct": field_type = type_pb2.Type(struct=type_pb2.Type.Struct(nullability=nullability)) sub_type = self.convert_schema_dict(ft) + if not sub_type: + raise ValueError(f'Could not parse type struct type {ft}') more_names.extend(sub_type.names) field_type.struct.types.extend(sub_type.struct.types) else: @@ -970,8 +1006,8 @@ def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: return self.convert_schema_dict(schema_data) def convert_arrow_datatype( - self, arrow_type: pa.DataType, nullable: bool = False - ) -> (type_pb2.Type, list[str]): + self, arrow_type: pa.DataType, nullable: bool = False + ) -> tuple[type_pb2.Type, list[str]]: """Convert an Arrow datatype into a Substrait type.""" if nullable: nullability = type_pb2.Type.NULLABILITY_NULLABLE @@ -1012,7 +1048,7 @@ def convert_arrow_datatype( sub_type = self.convert_arrow_schema(y.type.schema) more_names.extend(y.name) field_type.struct.types.extend(sub_type.struct.types) - return field_type + return field_type, more_names raise NotImplementedError(f"Unexpected arrow datatype: {arrow_type}") return field_type, more_names @@ -1106,7 +1142,7 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) if self._conversion_options.use_named_table_workaround: return algebra_pb2.Rel( read=algebra_pb2.ReadRel( @@ -1304,7 +1340,7 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate, rel_grouping_expressions = rel.grouping_expressions for idx, grouping in enumerate(rel_grouping_expressions): grouping_expression_list.append(self.convert_expression(grouping)) - symbol.generated_fields.append(self.determine_name_for_grouping(grouping)) + symbol.generated_fields.append(Field(self.determine_name_for_grouping(grouping))) self._top_level_projects.append(field_reference(idx)) match rel.group_type: @@ -1332,7 +1368,7 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate, result = self.convert_expression(expr) if result: self._top_level_projects.append(result) - symbol.generated_fields.append(self.determine_expression_name(expr)) + symbol.generated_fields.append(Field(self.determine_expression_name(expr))) symbol.output_fields.clear() symbol.output_fields.extend(symbol.generated_fields) if len(rel.grouping_expressions) > 1: @@ -1466,18 +1502,19 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a least_function( greater_func, field_reference(column_number), bigint_literal(rel.truncate) ), - strlen(strlen_func, string_literal(symbol.input_fields[column_number])), + strlen(strlen_func, string_literal(symbol.input_fields[column_number].name)), ) for column_number in range(len(symbol.input_fields)) ], ) - def field_header_fragment(field_number: int) -> list[algebra_pb2.Expression]: + def field_header_fragment(fields: list[Field], + field_number: int) -> list[algebra_pb2.Expression]: return [ string_literal("|"), lpad_function( lpad_func, - string_literal(symbol.input_fields[field_number]), + string_literal(fields[field_number].name), field_reference(field_number), ), ] @@ -1488,7 +1525,8 @@ def field_line_fragment(field_number: int) -> list[algebra_pb2.Expression]: repeat_function(repeat_func, "-", field_reference(field_number)), ] - def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: + def field_body_fragment(fields: list[Field], + field_number: int) -> list[algebra_pb2.Expression]: return [ string_literal("|"), if_then_else_operation( @@ -1498,7 +1536,7 @@ def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: strlen_func, cast_operation(field_reference(field_number), string_type()), ), - field_reference(field_number + len(symbol.input_fields)), + field_reference(field_number + len(fields)), ), concat( concat_func, @@ -1508,7 +1546,7 @@ def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: field_reference(field_number), minus_function( minus_func, - field_reference(field_number + len(symbol.input_fields)), + field_reference(field_number + len(fields)), bigint_literal(3), ), ), @@ -1518,17 +1556,18 @@ def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: lpad_function( lpad_func, field_reference(field_number), - field_reference(field_number + len(symbol.input_fields)), + field_reference(field_number + len(fields)), ), ), ] - def header_line(fields: list[str]) -> list[algebra_pb2.Expression]: + def header_line(fields: list[Field]) -> list[algebra_pb2.Expression]: return [ concat( concat_func, flatten( - [field_header_fragment(field_number) for field_number in range(len(fields))] + [field_header_fragment(fields, field_number) + for field_number in range(len(fields))] ) + [ string_literal("|\n"), @@ -1536,7 +1575,7 @@ def header_line(fields: list[str]) -> list[algebra_pb2.Expression]: ) ] - def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: + def full_line(fields: list[Field]) -> list[algebra_pb2.Expression]: return [ concat( concat_func, @@ -1574,7 +1613,7 @@ def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: concat_func, flatten( [ - field_body_fragment(field_number) + field_body_fragment(symbol.input_fields, field_number) for field_number in range(len(symbol.input_fields)) ] ) @@ -1595,7 +1634,7 @@ def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: symbol = self._symbol_table.get_symbol(self._current_plan_id) symbol.output_fields.clear() - symbol.output_fields.append("show_string") + symbol.output_fields.append(Field("show_string")) def compute_row_count_footer(num_rows: int) -> str: if num_rows == 1: @@ -1627,7 +1666,7 @@ def compute_row_count_footer(num_rows: int) -> str: return project5 def convert_with_columns_relation( - self, rel: spark_relations_pb2.WithColumns + self, rel: spark_relations_pb2.WithColumns ) -> algebra_pb2.Rel: """Convert a with columns relation into a Substrait project relation.""" input_rel = self.convert_relation(rel.input) @@ -1639,14 +1678,15 @@ def convert_with_columns_relation( if len(alias.name) != 1: raise ValueError("Only one name part is supported in an alias.") name = alias.name[0] - if name in symbol.input_fields: - proposed_expressions[symbol.input_fields.index(name)] = self.convert_expression( - alias.expr - ) + field_num = _find_reference_number(symbol.input_fields, name) + if field_num is not None: + # Overwrite our the old expression with our own. + proposed_expressions[field_num] = self.convert_expression(alias.expr) else: + # This is a new column. proposed_expressions.append(self.convert_expression(alias.expr)) - symbol.generated_fields.append(name) - symbol.output_fields.append(name) + symbol.generated_fields.append(Field(name)) + symbol.output_fields.append(Field(name)) project.common.CopyFrom(self.create_common_relation()) project.expressions.extend(proposed_expressions) for i in range(len(proposed_expressions)): @@ -1654,7 +1694,7 @@ def convert_with_columns_relation( return algebra_pb2.Rel(project=project) def convert_with_columns_renamed_relation( - self, rel: spark_relations_pb2.WithColumnsRenamed + self, rel: spark_relations_pb2.WithColumnsRenamed ) -> algebra_pb2.Rel: """Update the columns names based on the Spark with columns renamed relation.""" input_rel = self.convert_relation(rel.input) @@ -1665,11 +1705,11 @@ def convert_with_columns_renamed_relation( aliases = {r.col_name: r.new_col_name for r in rel.renames} else: aliases = rel.rename_columns_map - for field_name in symbol.input_fields: - if field_name in aliases: - symbol.output_fields.append(aliases[field_name]) + for field in symbol.input_fields: + if field.name in aliases: + symbol.output_fields.append(field.alias(aliases[field.name])) else: - symbol.output_fields.append(field_name) + symbol.output_fields.append(field) return input_rel def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Rel: @@ -1683,9 +1723,9 @@ def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Re else: column_names = rel.column_names symbol.output_fields.clear() - for field_number, field_name in enumerate(symbol.input_fields): - if field_name not in column_names: - symbol.output_fields.append(field_name) + for field_number, field in enumerate(symbol.input_fields): + if field.name not in column_names: + symbol.output_fields.append(field) if self._conversion_options.drop_emit_workaround: project.common.emit.output_mapping.append(len(project.expressions)) project.expressions.append(field_reference(field_number)) @@ -1707,7 +1747,7 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R ) symbol.output_fields.clear() for field_name in rel.column_names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) return input_rel def convert_arrow_to_literal(self, val: pa.Scalar) -> algebra_pb2.Expression.Literal: @@ -1773,9 +1813,11 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge """Convert a Spark local relation into a virtual table.""" read = algebra_pb2.ReadRel(virtual_table=self.convert_arrow_data_to_virtual_table(rel.data)) schema = self.convert_schema(rel.schema) + if not schema: + raise ValueError(f'Received an empty schema in plan id {self._current_plan_id}') symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in schema.names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) read.base_schema.CopyFrom(schema) read.common.CopyFrom(self.create_common_relation()) return algebra_pb2.Rel(read=read) @@ -1786,7 +1828,7 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: plan = self._sql_backend.convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) for field_name in plan.relations[0].root.names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) # TODO -- Correctly capture all the used functions and extensions. self._saved_extension_uris = plan.extension_uris self._saved_extensions = plan.extensions @@ -1795,7 +1837,7 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: return plan.relations[0].root.input def convert_spark_join_type( - self, join_type: spark_relations_pb2.Join.JoinType + self, join_type: spark_relations_pb2.Join.JoinType ) -> algebra_pb2.JoinRel.JoinType: """Convert a Spark join type into a Substrait join type.""" match join_type: @@ -1813,7 +1855,7 @@ def convert_spark_join_type( return algebra_pb2.JoinRel.JOIN_TYPE_ANTI case spark_relations_pb2.Join.JOIN_TYPE_LEFT_SEMI: return algebra_pb2.JoinRel.JOIN_TYPE_SEMI - case spark_relations_pb2.Join.CROSS: + case spark_relations_pb2.Join.JOIN_TYPE_CROSS: raise RuntimeError("Internal error: cross joins should be handled elsewhere") case _: raise ValueError(f"Unexpected join type: {join_type}") @@ -1848,12 +1890,14 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re else: equal_func = self.lookup_function_by_name("==") if rel.using_columns: - column = rel.using_columns[0] - left_fields = self._symbol_table.get_symbol(rel.left.common.plan_id).output_fields + column_name = rel.using_columns[0] + left_symbol = self._symbol_table.get_symbol(rel.left.common.plan_id) + left_fields = left_symbol.output_fields left_column_count = len(left_fields) - left_column_reference = left_fields.index(column) - right_fields = self._symbol_table.get_symbol(rel.right.common.plan_id).output_fields - right_column_reference = right_fields.index(column) + left_column_reference = _find_reference_number(left_fields, column_name) + right_symbol = self._symbol_table.get_symbol(rel.right.common.plan_id) + right_fields = right_symbol.output_fields + right_column_reference = _find_reference_number(right_fields, column_name) join.expression.CopyFrom( equal_function( equal_func, @@ -1878,6 +1922,17 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re del symbol.output_fields[right_column_reference + left_column_count] return algebra_pb2.Rel(join=join) + def find_additional_names(self, expr: spark_exprs_pb2.Expression) -> list[str]: + """Find any extra names used by this expression.""" + current_expr = expr.alias.expr if expr.WhichOneof("expr_type") == "alias" else expr + if current_expr.unresolved_function.function_name == "struct": + new_names = [] + for arg in current_expr.unresolved_function.arguments: + new_names.append(arg.unresolved_attribute.unparsed_identifier) + # TODO -- Handle nested structures. + return new_names + return [] + def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_pb2.Rel: """Convert a Spark project relation into a Substrait project relation.""" input_rel = self.convert_relation(rel.input) @@ -1890,9 +1945,10 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ matcher = re.compile(regex) found = False for column in symbol.input_fields: - if matcher.match(column): + if matcher.match(column.name): project.expressions.append( - field_reference(symbol.input_fields.index(column)) + field_reference( + _find_reference_number(symbol.input_fields, column.name)) ) symbol.generated_fields.append(column) symbol.output_fields.append(column) @@ -1909,8 +1965,10 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ name = expr.unresolved_attribute.unparsed_identifier else: name = f"generated_field_{field_number}" - symbol.generated_fields.append(name) - symbol.output_fields.append(name) + projected_symbol = Field(name) + projected_symbol.child_names.extend(self.find_additional_names(expr)) + symbol.generated_fields.append(projected_symbol) + symbol.output_fields.append(projected_symbol) project.common.CopyFrom(self.create_common_relation()) symbol.output_fields = symbol.generated_fields for field_number in range(len(symbol.output_fields)): @@ -1918,7 +1976,7 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ return algebra_pb2.Rel(project=project) def convert_subquery_alias_relation( - self, rel: spark_relations_pb2.SubqueryAlias + self, rel: spark_relations_pb2.SubqueryAlias ) -> algebra_pb2.Rel: """Convert a Spark subquery alias relation into a Substrait relation.""" raise NotImplementedError("Subquery alias relations are not yet implemented") @@ -1961,7 +2019,7 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> return project def convert_set_operation_relation( - self, rel: spark_relations_pb2.SetOperation + self, rel: spark_relations_pb2.SetOperation ) -> algebra_pb2.Rel: """Convert a Spark set operation relation into a Substrait set operation relation.""" left = self.convert_relation(rel.left_input) @@ -2016,7 +2074,7 @@ def convert_dropna_relation(self, rel: spark_relations_pb2.NADrop) -> algebra_pb filter_rel.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) if rel.cols: - cols = [symbol.input_fields.index(col) for col in rel.cols] + cols = [_find_reference_number(symbol.input_fields, col) for col in rel.cols] else: cols = range(len(symbol.input_fields)) min_non_nulls = rel.min_non_nulls if rel.min_non_nulls else len(cols) @@ -2110,8 +2168,8 @@ def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan: if plan.HasField("root"): rel_root = algebra_pb2.RelRoot(input=self.convert_relation(plan.root)) symbol = self._symbol_table.get_symbol(plan.root.common.plan_id) - for name in symbol.output_fields: - rel_root.names.append(name) + for field in symbol.output_fields: + rel_root.names.extend(field.output_names()) result.relations.append(plan_pb2.PlanRel(root=rel_root)) for uri in sorted(self._function_uris.items(), key=operator.itemgetter(1)): result.extension_uris.append( diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index aa862b8..184fb1b 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -4,21 +4,45 @@ import dataclasses +@dataclasses.dataclass +class Field: + """Tracks the names used by a field used as the input or output of a relation.""" + + name: str + # TODO -- Also track the field's type. + child_names: list[str] + + def __init__(self, name: str, child_names=None): + """Create the Field structure.""" + self.name = name + self.child_names = child_names or [] + + def alias(self, name: str): + """Create a copy with an alternate name.""" + new_field = Field(name) + new_field.child_names = self.child_names + return new_field + + def output_names(self) -> list[str]: + """Return all the names used by this field (including subtypes).""" + return [self.name, *self.child_names] + + @dataclasses.dataclass class PlanMetadata: """Tracks various information about a specific plan id.""" plan_id: int - type: str | None + symbol_type: str | None parent_plan_id: int | None - input_fields: list[str] # And maybe type - generated_fields: list[str] - output_fields: list[str] + input_fields: list[Field] + generated_fields: list[Field] + output_fields: list[Field] def __init__(self, plan_id: int): """Create the PlanMetadata structure.""" self.plan_id = plan_id - self.symbol_type = None + self.symbol_type = None # Useful when debugging. self.parent_plan_id = None self.input_fields = [] self.generated_fields = [] diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 7fe9fa8..25b381a 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -57,6 +57,7 @@ rtrim, sqrt, startswith, + struct, substr, substring, trim, @@ -159,6 +160,13 @@ def mark_dataframe_tests_as_xfail(request): if source == "gateway-over-duckdb" and originalname == "test_cube": pytest.skip(reason="cube aggregation not yet implemented in DuckDB") + if source == "gateway-over-duckdb" and originalname in ["test_column_getfield", + "test_struct_and_getfield"]: + pytest.skip(reason="fully named structs not yet tracked in gateway") + if source == "gateway-over-datafusion" and originalname in ["test_struct", + "test_struct_and_getfield"]: + pytest.skip(reason="nested expressions not supported") + # ruff: noqa: E712 class TestDataFrameAPI: @@ -2791,7 +2799,7 @@ def userage_dataframe(spark_session_for_setup): class TestDataFrameDecisionSupport: - """Tests data science methods of the dataframe side of SparkConnect.""" + """Tests decision support methods of the dataframe side of SparkConnect.""" def test_groupby(self, userage_dataframe): expected = [ @@ -2840,3 +2848,35 @@ def test_cube(self, userage_dataframe): "age").collect() assertDataFrameEqual(outcome, expected) + + +class TestDataFrameComplexDatastructures: + """Tests the use of complex datastructures in the dataframe side of SparkConnect.""" + + def test_struct(self, register_tpch_dataset, spark_session, caplog): + expected = [ + Row(test_struct=Row(c_custkey=1, c_name='Customer#000000001')), + Row(test_struct=Row(c_custkey=2, c_name='Customer#000000002')), + Row(test_struct=Row(c_custkey=3, c_name='Customer#000000003')), + ] + + # TODO -- Validate once the validator supports nested expressions. + customer_df = spark_session.table("customer") + outcome = customer_df.select( + struct(col('c_custkey'), col('c_name')).alias('test_struct')).limit(3).collect() + + assertDataFrameEqual(outcome, expected) + + def test_struct_and_getfield(self, register_tpch_dataset, spark_session, caplog): + expected = [ + Row(result=1), + ] + + # TODO -- Validate once the validator supports nested expressions. + customer_df = spark_session.table("customer") + outcome = customer_df.select( + struct(col('c_custkey'), col('c_name')).alias('test_struct')).agg( + pyspark.sql.functions.min(col('test_struct').getField('c_custkey')).alias( + 'result')).collect() + + assertDataFrameEqual(outcome, expected) diff --git a/src/transforms/output_field_tracking_visitor.py b/src/transforms/output_field_tracking_visitor.py index 8d78e0e..093bf77 100644 --- a/src/transforms/output_field_tracking_visitor.py +++ b/src/transforms/output_field_tracking_visitor.py @@ -6,6 +6,7 @@ from substrait.gen.proto import algebra_pb2, plan_pb2 from gateway.converter.symbol_table import SymbolTable +from src.gateway.converter.symbol_table import Field from substrait_visitors.substrait_plan_visitor import SubstraitPlanVisitor from transforms.label_relations import get_common_section @@ -48,8 +49,8 @@ def visit_read_relation(self, rel: algebra_pb2.ReadRel) -> Any: super().visit_read_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) # TODO -- Validate this logic where complicated data structures are used. - for field in rel.base_schema.names: - symbol.output_fields.append(field) + for field_name in rel.base_schema.names: + symbol.output_fields.append(Field(field_name)) def visit_filter_relation(self, rel: algebra_pb2.FilterRel) -> Any: """Collect the field references from the filter relation.""" @@ -66,9 +67,9 @@ def visit_aggregate_relation(self, rel: algebra_pb2.AggregateRel) -> Any: super().visit_aggregate_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) for _ in rel.groupings: - symbol.generated_fields.append("grouping") + symbol.generated_fields.append(Field("grouping")) for _ in rel.measures: - symbol.generated_fields.append("measure") + symbol.generated_fields.append(Field("measure")) self.update_field_references(get_plan_id(rel.input)) def visit_sort_relation(self, rel: algebra_pb2.SortRel) -> Any: @@ -81,7 +82,7 @@ def visit_project_relation(self, rel: algebra_pb2.ProjectRel) -> Any: super().visit_project_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) for _ in rel.expressions: - symbol.generated_fields.append("intermediate") + symbol.generated_fields.append(Field("intermediate")) self.update_field_references(get_plan_id(rel.input)) def visit_extension_single_relation(self, rel: algebra_pb2.ExtensionSingleRel) -> Any: