From b2e1ce2e35a9f9f3b35785c21573cbcfd818f0df Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 19 Sep 2024 14:59:25 -0700 Subject: [PATCH 01/18] added test for struct/getfield --- src/gateway/tests/test_dataframe_api.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 7fe9fa8..ad98c55 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -62,7 +62,7 @@ trim, try_sum, ucase, - upper, + upper, struct, ) from pyspark.sql.types import DoubleType, IntegerType, StringType, StructField, StructType from pyspark.sql.window import Window @@ -2791,7 +2791,7 @@ def userage_dataframe(spark_session_for_setup): class TestDataFrameDecisionSupport: - """Tests data science methods of the dataframe side of SparkConnect.""" + """Tests decision support methods of the dataframe side of SparkConnect.""" def test_groupby(self, userage_dataframe): expected = [ @@ -2840,3 +2840,22 @@ def test_cube(self, userage_dataframe): "age").collect() assertDataFrameEqual(outcome, expected) + + +class TestDataFrameComplexDatastructures: + """Tests the use of complex datastructures in the dataframe side of SparkConnect.""" + + @pytest.mark.interesting + def test_struct(self, register_tpch_dataset, spark_session): + expected = [ + Row(result=1), + ] + + with utilizes_valid_plans(spark_session): + customer_df = spark_session.table("customer") + outcome = customer_df.select( + struct(col('c_custkey'), col('c_name')).alias('test_struct')).agg( + pyspark.sql.functions.min(col('test_struct').getField('c_custkey')).alias( + 'result')).collect() + + assertDataFrameEqual(outcome, expected) From f27d98d1396755bce122cf420f74c9c3d9775f42 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 19 Sep 2024 21:13:13 -0700 Subject: [PATCH 02/18] start implementing struct --- src/gateway/converter/spark_to_substrait.py | 13 +++++++++ src/gateway/tests/test_dataframe_api.py | 29 ++++++++++++++++----- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 720336c..e38d15e 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -495,6 +495,17 @@ def convert_ifnull_function( return if_then_else_operation(expr, arg1, arg0) + def convert_struct_function( + self, func: spark_exprs_pb2.Expression.UnresolvedFunction + ) -> algebra_pb2.Expression: + """Convert a Spark struct function into a Substrait nested expression.""" + nested = algebra_pb2.Expression.Nested() + nested.nullable = False + for spark_arg in func.arguments: + arg = self.convert_expression(spark_arg) + nested.struct.fields.append(arg) + return algebra_pb2.Expression(nested=nested) + def convert_unresolved_function( self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression | algebra_pb2.AggregateFunction: @@ -517,6 +528,8 @@ def convert_unresolved_function( return self.convert_nvl2_function(unresolved_function) if unresolved_function.function_name == "ifnull": return self.convert_ifnull_function(unresolved_function) + if unresolved_function.function_name == "struct": + return self.convert_struct_function(unresolved_function) func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) if ( diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index ad98c55..acf1f87 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -2846,16 +2846,31 @@ class TestDataFrameComplexDatastructures: """Tests the use of complex datastructures in the dataframe side of SparkConnect.""" @pytest.mark.interesting - def test_struct(self, register_tpch_dataset, spark_session): + def test_struct(self, register_tpch_dataset, spark_session, caplog): + expected = [ + Row(test_struct=Row(c_custkey=1, c_name='Customer#000000001')), + Row(test_struct=Row(c_custkey=2, c_name='Customer#000000002')), + Row(test_struct=Row(c_custkey=3, c_name='Customer#000000003')), + ] + + # TODO -- Validate once the validator supports nested expressions. + customer_df = spark_session.table("customer") + outcome = customer_df.select( + struct(col('c_custkey'), col('c_name')).alias('test_struct')).limit(3).collect() + + assertDataFrameEqual(outcome, expected) + + @pytest.mark.interesting + def test_struct_and_getfield(self, register_tpch_dataset, spark_session, caplog): expected = [ Row(result=1), ] - with utilizes_valid_plans(spark_session): - customer_df = spark_session.table("customer") - outcome = customer_df.select( - struct(col('c_custkey'), col('c_name')).alias('test_struct')).agg( - pyspark.sql.functions.min(col('test_struct').getField('c_custkey')).alias( - 'result')).collect() + # TODO -- Validate once the validator supports nested expressions. + customer_df = spark_session.table("customer") + outcome = customer_df.select( + struct(col('c_custkey'), col('c_name')).alias('test_struct')).agg( + pyspark.sql.functions.min(col('test_struct').getField('c_custkey')).alias( + 'result')).collect() assertDataFrameEqual(outcome, expected) From 2494a49723c19aa0de24ae31b1975b2e52f05dc1 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 19 Sep 2024 21:29:29 -0700 Subject: [PATCH 03/18] undo server changes --- src/gateway/server.py | 1 - src/gateway/tests/test_dataframe_api.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gateway/server.py b/src/gateway/server.py index 2419bcb..353ad56 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -386,7 +386,6 @@ def Config(self, request, context): for pair in request.operation.set.pairs: if pair.key == "spark-substrait-gateway.backend": # Set the server backend for all connections (including ongoing ones). - need_reset = False match pair.value: case "arrow": if (self._backend is not None and diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index acf1f87..140b6d5 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -57,12 +57,13 @@ rtrim, sqrt, startswith, + struct, substr, substring, trim, try_sum, ucase, - upper, struct, + upper, ) from pyspark.sql.types import DoubleType, IntegerType, StringType, StructField, StructType from pyspark.sql.window import Window From 025004c7ffac478a3c53b9552c5fe09466357b4e Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 20 Sep 2024 15:53:55 -0700 Subject: [PATCH 04/18] support numeric info --- src/gateway/converter/spark_to_substrait.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index e38d15e..6efc90d 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -764,8 +764,15 @@ def convert_extract_value( value=self.convert_unresolved_attribute(extract.child.unresolved_attribute) ) ) + # MEGAHACK -- Fix this to be relative to the struct value we found above. + field_ref = self.find_field_by_name(extract.extraction.literal.string) + if field_ref is None: + raise ValueError( + f"could not locate field named {extract.extraction.literal.string} in plan id " + f"{self._current_plan_id}" + ) func.arguments.append( - algebra_pb2.FunctionArgument(value=string_literal(extract.extraction.literal.string)) + algebra_pb2.FunctionArgument(value=integer_literal(field_ref+1)) ) return algebra_pb2.Expression(scalar_function=func) From 388b61dc3e1b7cae04846db4f7abc03b09e569ae Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 20 Sep 2024 23:47:09 -0700 Subject: [PATCH 05/18] investigating struct behavior --- src/backends/duckdb_backend.py | 19 ++++++++++++++++++- src/gateway/server.py | 1 + src/gateway/tests/test_dataframe_api.py | 4 ++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/backends/duckdb_backend.py b/src/backends/duckdb_backend.py index eead9b7..1ea8242 100644 --- a/src/backends/duckdb_backend.py +++ b/src/backends/duckdb_backend.py @@ -67,13 +67,30 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: # ruff: noqa: BLE001 def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against DuckDB.""" + if False: + plan.relations[0].root.names.append("custid") + plan.relations[0].root.names.append("custname") plan_data = plan.SerializeToString() try: query_result = self._connection.from_substrait(proto=plan_data) except Exception as err: raise ValueError(f"DuckDB Execution Error: {err}") from err - return query_result.arrow() + if False: + arrow = query_result.arrow() + new_struct_array = pa.StructArray.from_arrays(arrow, names=["custid", "custname"]) + new_schema = pa.schema( + [ + pa.field("test_struct", pa.struct([ + pa.field("custid", pa.int64()), + pa.field("custname", pa.string()), + ])), + ] + ) + new_table = pa.Table.from_arrays([new_struct_array], schema=new_schema) + return new_table + else: + return query_result.arrow() def register_table( self, diff --git a/src/gateway/server.py b/src/gateway/server.py index 353ad56..2419bcb 100644 --- a/src/gateway/server.py +++ b/src/gateway/server.py @@ -386,6 +386,7 @@ def Config(self, request, context): for pair in request.operation.set.pairs: if pair.key == "spark-substrait-gateway.backend": # Set the server backend for all connections (including ongoing ones). + need_reset = False match pair.value: case "arrow": if (self._backend is not None and diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 140b6d5..8731759 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -160,6 +160,10 @@ def mark_dataframe_tests_as_xfail(request): if source == "gateway-over-duckdb" and originalname == "test_cube": pytest.skip(reason="cube aggregation not yet implemented in DuckDB") + if source == "gateway-over-datafusion" and originalname in ["test_struct", + "test_struct_and_getfield"]: + pytest.skip(reason="nested expressions not supported") + # ruff: noqa: E712 class TestDataFrameAPI: From 74aa64585b7a55b1d677e2f16ee30049aa241990 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 23 Sep 2024 17:01:52 -0700 Subject: [PATCH 06/18] progress --- src/gateway/converter/symbol_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index aa862b8..704dd45 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -11,14 +11,14 @@ class PlanMetadata: plan_id: int type: str | None parent_plan_id: int | None - input_fields: list[str] # And maybe type + input_fields: list[str] # And maybe type with additional names generated_fields: list[str] output_fields: list[str] def __init__(self, plan_id: int): """Create the PlanMetadata structure.""" self.plan_id = plan_id - self.symbol_type = None + self.symbol_type = None # Useful when debugging. self.parent_plan_id = None self.input_fields = [] self.generated_fields = [] From e94b1071d6d8ffaed11473087c60bf490f8a180a Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 3 Oct 2024 00:01:12 -0700 Subject: [PATCH 07/18] update installation instructions --- README.md | 4 +++- environment.yml | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3c2a46e..237c5e3 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,13 @@ or Velox. ### Locally To run the gateway locally - you need to setup a Python (Conda) environment. +To run the Spark tests you will need Java installed. + Ensure you have [Miniconda](https://docs.anaconda.com/miniconda/miniconda-install/) and [Rust/Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html) installed. Once that is done - run these steps from a bash terminal: ```bash -git clone --recursive https://github.com//spark-substrait-gateway.git +git clone https://github.com//spark-substrait-gateway.git cd spark-substrait-gateway conda init bash . ~/.bashrc diff --git a/environment.yml b/environment.yml index a6ac78d..f85638d 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,7 @@ dependencies: - setuptools >= 61.0.0 - setuptools_scm >= 6.2.0 - mypy-protobuf - - types-protobuf >= 4.25.0, < 5.0.0 + - types-protobuf >= 5.0.0 - numpy < 2.0.0 - Faker - pip: @@ -27,7 +27,7 @@ dependencies: - substrait == 0.21.0 - substrait-validator - pytest-timeout - - protobuf >= 4.25.3, < 5.0.0 + - protobuf >= 5.0.0 - cryptography == 43.0.* - click == 8.1.* - pyjwt == 2.8.* From abbd6d8302a8efe4e68a6a5fb921c39c3f5776cf Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 3 Oct 2024 19:40:38 -0700 Subject: [PATCH 08/18] working on arrow tools package --- src/backends/arrow_tools.py | 33 ++++++++++++++++++++++++++ src/backends/duckdb_backend.py | 22 +++++------------ src/backends/tests/arrow_tools_test.py | 33 ++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 src/backends/arrow_tools.py create mode 100644 src/backends/tests/arrow_tools_test.py diff --git a/src/backends/arrow_tools.py b/src/backends/arrow_tools.py new file mode 100644 index 0000000..2c58fd8 --- /dev/null +++ b/src/backends/arrow_tools.py @@ -0,0 +1,33 @@ +from typing import List + +import pyarrow as pa + + +def _reapply_names_to_struct(struct: pa.StructArray, names: List[str]) -> pa.StructArray: + return struct + + +def reapply_names(table: pa.Table, names: List[str]) -> pa.Table: + new_arrays = [] + new_schema = [] + + for column in iter(table.columns): + # TODO: Rebuild the data. + # TODO: Save the schema. + pass + + new_schema = pa.schema( + [ + pa.field("test_struct", pa.struct([ + pa.field("custid", pa.int64()), + pa.field("custname", pa.string()), + ])), + ] + ) + custid_array = table.columns[0].chunks[0].field(0) + custname_array = table.columns[0].chunks[0].field(1) + + new_struct_array = pa.StructArray.from_arrays([custid_array, custname_array], names=['custid', 'custname']) + new_table = pa.Table.from_arrays([new_struct_array], schema=new_schema) + + return new_table diff --git a/src/backends/duckdb_backend.py b/src/backends/duckdb_backend.py index 1ea8242..42212dd 100644 --- a/src/backends/duckdb_backend.py +++ b/src/backends/duckdb_backend.py @@ -4,6 +4,7 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path +from typing import List import duckdb import pyarrow as pa @@ -13,6 +14,8 @@ from backends.backend import Backend from transforms.rename_functions import RenameFunctionsForDuckDB +from src.backends.arrow_tools import reapply_names + # pylint: disable=fixme class DuckDBBackend(Backend): @@ -67,7 +70,7 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: # ruff: noqa: BLE001 def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against DuckDB.""" - if False: + if True: plan.relations[0].root.names.append("custid") plan.relations[0].root.names.append("custname") plan_data = plan.SerializeToString() @@ -76,21 +79,8 @@ def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: query_result = self._connection.from_substrait(proto=plan_data) except Exception as err: raise ValueError(f"DuckDB Execution Error: {err}") from err - if False: - arrow = query_result.arrow() - new_struct_array = pa.StructArray.from_arrays(arrow, names=["custid", "custname"]) - new_schema = pa.schema( - [ - pa.field("test_struct", pa.struct([ - pa.field("custid", pa.int64()), - pa.field("custname", pa.string()), - ])), - ] - ) - new_table = pa.Table.from_arrays([new_struct_array], schema=new_schema) - return new_table - else: - return query_result.arrow() + arrow = query_result.arrow() + return reapply_names(arrow, plan.relations[0].root.names) def register_table( self, diff --git a/src/backends/tests/arrow_tools_test.py b/src/backends/tests/arrow_tools_test.py new file mode 100644 index 0000000..635f65f --- /dev/null +++ b/src/backends/tests/arrow_tools_test.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import List +import pytest +import pyarrow as pa + +from src.backends.arrow_tools import reapply_names + + +@dataclass +class TestCase: + name: str + input: pa.Table + names: List[str] + expected: pa.table + + +cases: List[TestCase] = [ + TestCase('empty table', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), + TestCase('normal columns', pa.Table.from_arrays([]), ['a', 'b', 'c'], pa.Table.from_arrays([])), + TestCase('struct column', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), + TestCase('nested structs', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), +] + + +class TestArrowTools: + """Tests the functionality of the arrow tools package.""" + + @pytest.mark.parametrize( + "case", cases, ids=lambda case: case.name + ) + def test_reapply_names(self, case): + result = reapply_names(case.input, case.names) + assert result == case.expected From c4defda987af718cc886b44eecd3a38d46b95a59 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 4 Oct 2024 16:55:04 -0700 Subject: [PATCH 09/18] Now properly renames structs. --- src/backends/arrow_tools.py | 53 +++++++++++++++++--------- src/backends/duckdb_backend.py | 1 + src/backends/tests/arrow_tools_test.py | 23 ++++++++++- 3 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/backends/arrow_tools.py b/src/backends/arrow_tools.py index 2c58fd8..9048a62 100644 --- a/src/backends/arrow_tools.py +++ b/src/backends/arrow_tools.py @@ -3,31 +3,46 @@ import pyarrow as pa -def _reapply_names_to_struct(struct: pa.StructArray, names: List[str]) -> pa.StructArray: - return struct +def _reapply_names_to_type(array: pa.ChunkedArray, names: List[str]) -> (pa.Array, List[str]): + new_arrays = [] + new_schema = [] + + remaining_names = names + if pa.types.is_list(array.type): + raise NotImplementedError('Reapplying names to lists not yet supported') + elif pa.types.is_map(array.type): + raise NotImplementedError('Reapplying names to maps not yet supported') + elif pa.types.is_struct(array.type): + field_num = 0 + while field_num < array.type.num_fields: + field = array.chunks[0].field(field_num) + this_name = remaining_names.pop(0) + + new_array, remaining_names = _reapply_names_to_type(field, remaining_names) + new_arrays.append(new_array) + + new_schema.append(pa.field(this_name, new_array.type)) + + field_num += 1 + + return pa.StructArray.from_arrays(new_arrays, fields=new_schema), remaining_names + if array.type.num_fields != 0: + raise ValueError(f'Unsupported complex type: {array.type}') + return array, remaining_names def reapply_names(table: pa.Table, names: List[str]) -> pa.Table: new_arrays = [] new_schema = [] + remaining_names = names for column in iter(table.columns): - # TODO: Rebuild the data. - # TODO: Save the schema. - pass - - new_schema = pa.schema( - [ - pa.field("test_struct", pa.struct([ - pa.field("custid", pa.int64()), - pa.field("custname", pa.string()), - ])), - ] - ) - custid_array = table.columns[0].chunks[0].field(0) - custname_array = table.columns[0].chunks[0].field(1) - - new_struct_array = pa.StructArray.from_arrays([custid_array, custname_array], names=['custid', 'custname']) - new_table = pa.Table.from_arrays([new_struct_array], schema=new_schema) + this_name = remaining_names.pop(0) + + new_array, remaining_names = _reapply_names_to_type(column, remaining_names) + new_arrays.append(new_array) + + new_schema.append(pa.field(this_name, new_array.type)) + new_table = pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema)) return new_table diff --git a/src/backends/duckdb_backend.py b/src/backends/duckdb_backend.py index 42212dd..f601a2b 100644 --- a/src/backends/duckdb_backend.py +++ b/src/backends/duckdb_backend.py @@ -71,6 +71,7 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against DuckDB.""" if True: + # MEGAHACK -- Modify the plan conversion to include these fields. plan.relations[0].root.names.append("custid") plan.relations[0].root.names.append("custname") plan_data = plan.SerializeToString() diff --git a/src/backends/tests/arrow_tools_test.py b/src/backends/tests/arrow_tools_test.py index 635f65f..203bbc5 100644 --- a/src/backends/tests/arrow_tools_test.py +++ b/src/backends/tests/arrow_tools_test.py @@ -16,9 +16,28 @@ class TestCase: cases: List[TestCase] = [ TestCase('empty table', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), - TestCase('normal columns', pa.Table.from_arrays([]), ['a', 'b', 'c'], pa.Table.from_arrays([])), - TestCase('struct column', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), + TestCase('normal columns', + pa.Table.from_pydict( + {"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]}, + schema=pa.schema({"name": pa.string(), "age": pa.int32()}) + ), + ['renamed_name', 'renamed_age'], + pa.Table.from_pydict( + {"renamed_name": [None, "Joe", "Sarah", None], "renamed_age": [99, None, 42, None]}, + schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()}) + )), + TestCase('struct column', + pa.Table.from_arrays( + [pa.array([{"": 1, "b": "b"}], type=pa.struct([("", pa.int64()), ("b", pa.string())]))], + names=["r"]), + ['r', 'a', 'b'], + pa.Table.from_arrays( + [pa.array([{"a": 1, "b": "b"}], type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"]) + ), TestCase('nested structs', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), + # TODO -- Test a list. + # TODO -- Test a map. + # TODO -- Test a mixture of complex and simple types. ] From 80e9124be9bde50c85943e6f8828b2ac5874bf1c Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 4 Oct 2024 20:19:16 -0700 Subject: [PATCH 10/18] remove the root names hack --- src/backends/arrow_tools.py | 3 +++ src/backends/duckdb_backend.py | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/backends/arrow_tools.py b/src/backends/arrow_tools.py index 9048a62..d9b3297 100644 --- a/src/backends/arrow_tools.py +++ b/src/backends/arrow_tools.py @@ -7,6 +7,9 @@ def _reapply_names_to_type(array: pa.ChunkedArray, names: List[str]) -> (pa.Arra new_arrays = [] new_schema = [] + if array.type.num_fields > len(names): + raise ValueError('Insufficient number of names provided to reapply names.') + remaining_names = names if pa.types.is_list(array.type): raise NotImplementedError('Reapplying names to lists not yet supported') diff --git a/src/backends/duckdb_backend.py b/src/backends/duckdb_backend.py index f601a2b..25aabc5 100644 --- a/src/backends/duckdb_backend.py +++ b/src/backends/duckdb_backend.py @@ -70,10 +70,6 @@ def adjust_plan(self, plan: plan_pb2.Plan) -> Iterator[plan_pb2.Plan]: # ruff: noqa: BLE001 def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table: """Execute the given Substrait plan against DuckDB.""" - if True: - # MEGAHACK -- Modify the plan conversion to include these fields. - plan.relations[0].root.names.append("custid") - plan.relations[0].root.names.append("custname") plan_data = plan.SerializeToString() try: From 593638ed632d82ae9df880996c337abec5531327 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 7 Oct 2024 00:23:48 -0700 Subject: [PATCH 11/18] massive type refactor in progress --- src/backends/arrow_tools.py | 16 +- src/backends/duckdb_backend.py | 4 +- src/backends/tests/arrow_tools_test.py | 18 +- src/gateway/converter/conversion_options.py | 2 +- src/gateway/converter/spark_to_substrait.py | 255 ++++++++++++------ src/gateway/converter/symbol_table.py | 30 ++- .../output_field_tracking_visitor.py | 11 +- 7 files changed, 224 insertions(+), 112 deletions(-) diff --git a/src/backends/arrow_tools.py b/src/backends/arrow_tools.py index d9b3297..0871d52 100644 --- a/src/backends/arrow_tools.py +++ b/src/backends/arrow_tools.py @@ -1,9 +1,9 @@ -from typing import List - +# SPDX-License-Identifier: Apache-2.0 +"""Routines to manipulate arrow tables.""" import pyarrow as pa -def _reapply_names_to_type(array: pa.ChunkedArray, names: List[str]) -> (pa.Array, List[str]): +def _reapply_names_to_type(array: pa.ChunkedArray, names: list[str]) -> (pa.Array, list[str]): new_arrays = [] new_schema = [] @@ -13,9 +13,9 @@ def _reapply_names_to_type(array: pa.ChunkedArray, names: List[str]) -> (pa.Arra remaining_names = names if pa.types.is_list(array.type): raise NotImplementedError('Reapplying names to lists not yet supported') - elif pa.types.is_map(array.type): + if pa.types.is_map(array.type): raise NotImplementedError('Reapplying names to maps not yet supported') - elif pa.types.is_struct(array.type): + if pa.types.is_struct(array.type): field_num = 0 while field_num < array.type.num_fields: field = array.chunks[0].field(field_num) @@ -34,7 +34,8 @@ def _reapply_names_to_type(array: pa.ChunkedArray, names: List[str]) -> (pa.Arra return array, remaining_names -def reapply_names(table: pa.Table, names: List[str]) -> pa.Table: +def reapply_names(table: pa.Table, names: list[str]) -> pa.Table: + """Apply the provided names to the given table recursively.""" new_arrays = [] new_schema = [] @@ -47,5 +48,4 @@ def reapply_names(table: pa.Table, names: List[str]) -> pa.Table: new_schema.append(pa.field(this_name, new_array.type)) - new_table = pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema)) - return new_table + return pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema)) diff --git a/src/backends/duckdb_backend.py b/src/backends/duckdb_backend.py index 25aabc5..576d0bd 100644 --- a/src/backends/duckdb_backend.py +++ b/src/backends/duckdb_backend.py @@ -4,7 +4,6 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path -from typing import List import duckdb import pyarrow as pa @@ -12,9 +11,8 @@ from substrait.gen.proto import plan_pb2 from backends.backend import Backend -from transforms.rename_functions import RenameFunctionsForDuckDB - from src.backends.arrow_tools import reapply_names +from transforms.rename_functions import RenameFunctionsForDuckDB # pylint: disable=fixme diff --git a/src/backends/tests/arrow_tools_test.py b/src/backends/tests/arrow_tools_test.py index 203bbc5..fb4a57c 100644 --- a/src/backends/tests/arrow_tools_test.py +++ b/src/backends/tests/arrow_tools_test.py @@ -1,7 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import List -import pytest + import pyarrow as pa +import pytest from src.backends.arrow_tools import reapply_names @@ -10,11 +11,11 @@ class TestCase: name: str input: pa.Table - names: List[str] + names: list[str] expected: pa.table -cases: List[TestCase] = [ +cases: list[TestCase] = [ TestCase('empty table', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), TestCase('normal columns', pa.Table.from_pydict( @@ -23,16 +24,19 @@ class TestCase: ), ['renamed_name', 'renamed_age'], pa.Table.from_pydict( - {"renamed_name": [None, "Joe", "Sarah", None], "renamed_age": [99, None, 42, None]}, + {"renamed_name": [None, "Joe", "Sarah", None], + "renamed_age": [99, None, 42, None]}, schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()}) )), TestCase('struct column', pa.Table.from_arrays( - [pa.array([{"": 1, "b": "b"}], type=pa.struct([("", pa.int64()), ("b", pa.string())]))], + [pa.array([{"": 1, "b": "b"}], + type=pa.struct([("", pa.int64()), ("b", pa.string())]))], names=["r"]), ['r', 'a', 'b'], pa.Table.from_arrays( - [pa.array([{"a": 1, "b": "b"}], type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"]) + [pa.array([{"a": 1, "b": "b"}], + type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"]) ), TestCase('nested structs', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), # TODO -- Test a list. diff --git a/src/gateway/converter/conversion_options.py b/src/gateway/converter/conversion_options.py index c4acffd..f42468d 100644 --- a/src/gateway/converter/conversion_options.py +++ b/src/gateway/converter/conversion_options.py @@ -11,7 +11,7 @@ class ConversionOptions: """Holds all the possible conversion options.""" - def __init__(self, backend: BackendOptions = None): + def __init__(self, backend: BackendOptions): """Initialize the conversion options.""" self.use_named_table_workaround = False self.needs_scheme_in_path_uris = False diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 6efc90d..d873c13 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -57,7 +57,11 @@ string_type, strlen, ) -from gateway.converter.symbol_table import SymbolTable +from gateway.converter.symbol_table import Field, SymbolTable + + +class InternalError(Exception): + pass class ExpressionProcessingMode(Enum): @@ -81,6 +85,13 @@ def _extract_decimal_parameters(type_name: str) -> tuple[int, int, int]: return int(match.group(1)), int(match.group(2)), int(match.group(3)) +def _find_reference_number(fields: list[Field], name: str) -> int | None: + for reference, field in enumerate(fields): + if field.name == name: + return reference + return None + + # ruff: noqa: RUF005 class SparkSubstraitConverter: """Converts SparkConnect plans to Substrait plans.""" @@ -101,7 +112,7 @@ def __init__(self, options: ConversionOptions): # These are used when processing expressions inside aggregate relations. self._expression_processing_mode = ExpressionProcessingMode.NORMAL self._top_level_projects: list[algebra_pb2.Rel] = [] - self._next_aggregation_reference_id = None + self._next_aggregation_reference_id = 0 self._aggregations: list[algebra_pb2.AggregateFunction] = [] self._next_under_aggregation_reference_id = 0 self._under_aggregation_projects: list[algebra_pb2.Rel] = [] @@ -114,8 +125,9 @@ def set_backends(self, backend, sql_backend) -> None: def lookup_function_by_name(self, name: str) -> ExtensionFunction: """Find the function reference for a given Spark function name.""" - if name in self._functions: - return self._functions.get(name) + function = self._functions.get(name) + if function is not None: + return function func = lookup_spark_function(name, self._conversion_options) if not func: raise LookupError( @@ -125,25 +137,33 @@ def lookup_function_by_name(self, name: str) -> ExtensionFunction: self._functions[name] = func if not self._function_uris.get(func.uri): self._function_uris[func.uri] = len(self._function_uris) + 1 - return self._functions.get(name) + function = self._functions.get(name) + assert function is not None + return function def update_field_references(self, plan_id: int) -> None: """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) + if not source_symbol: + raise InternalError(f'Could not find plan id {plan_id} that we constructed earlier.') current_symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not current_symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') original_output_fields = current_symbol.output_fields for symbol in source_symbol.output_fields: - new_name = symbol - while new_name in original_output_fields: - new_name = new_name + "_dup" - current_symbol.input_fields.append(new_name) - current_symbol.output_fields.append(new_name) + new_symbol = symbol + while _find_reference_number(original_output_fields, new_symbol.name): + new_symbol.name = new_symbol.name + "_dup" + current_symbol.input_fields.append(new_symbol) + current_symbol.output_fields.append(new_symbol) def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not current_symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') try: - return current_symbol.input_fields.index(field_name) + return _find_reference_number(current_symbol.input_fields, field_name) except ValueError: return None @@ -172,7 +192,7 @@ def convert_string_literal(self, s: str) -> algebra_pb2.Expression.Literal: return algebra_pb2.Expression.Literal(string=s) def convert_literal_expression( - self, literal: spark_exprs_pb2.Expression.Literal + self, literal: spark_exprs_pb2.Expression.Literal ) -> algebra_pb2.Expression: """Convert a Spark literal into a Substrait literal.""" match literal.WhichOneof("literal_type"): @@ -251,7 +271,7 @@ def convert_literal_expression( return algebra_pb2.Expression(literal=result) def convert_unresolved_attribute( - self, attr: spark_exprs_pb2.Expression.UnresolvedAttribute + self, attr: spark_exprs_pb2.Expression.UnresolvedAttribute ) -> algebra_pb2.Expression: """Convert a Spark unresolved attribute into a Substrait field reference.""" field_ref = self.find_field_by_name(attr.unparsed_identifier) @@ -308,7 +328,7 @@ def determine_type_of_expression(self, expr: algebra_pb2.Expression) -> type_pb2 ) def convert_when_function( - self, when: spark_exprs_pb2.Expression.UnresolvedFunction + self, when: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark when function into a Substrait if-then expression.""" ifthen = algebra_pb2.Expression.IfThen() @@ -324,6 +344,8 @@ def convert_when_function( else: nullable_literal = self.determine_type_of_expression(ifthen.ifs[-1].then) kind = nullable_literal.WhichOneof("kind") + if not kind: + raise ValueError('Missing kind one_of in when function.') getattr( nullable_literal, kind ).nullability = type_pb2.Type.Nullability.NULLABILITY_NULLABLE @@ -336,7 +358,7 @@ def convert_when_function( return algebra_pb2.Expression(if_then=ifthen) def convert_in_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark in function into a Substrait switch expression.""" @@ -378,7 +400,7 @@ def is_switch_expression_appropriate() -> bool: return algebra_pb2.Expression(if_then=ifthen) def convert_rlike_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark rlike function into a Substrait expression.""" if self._conversion_options.use_duckdb_regexp_matches_function: @@ -419,7 +441,7 @@ def convert_rlike_function( return greater_function(greater_func, regexp_expr, bigint_literal(0)) def convert_nanvl_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark nanvl function into a Substrait expression.""" isnan_func = self.lookup_function_by_name("isnan") @@ -438,7 +460,7 @@ def convert_nanvl_function( return if_then_else_operation(expr, arg1, arg0) def convert_nvl_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark nvl function into a Substrait expression.""" isnull_func = self.lookup_function_by_name("isnull") @@ -457,7 +479,7 @@ def convert_nvl_function( return if_then_else_operation(expr, arg1, arg0) def convert_nvl2_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark nvl2 function into a Substrait expression.""" isnotnull_func = self.lookup_function_by_name("isnotnull") @@ -477,7 +499,7 @@ def convert_nvl2_function( return if_then_else_operation(expr, arg1, arg2) def convert_ifnull_function( - self, in_: spark_exprs_pb2.Expression.UnresolvedFunction + self, in_: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark ifnull function into a Substrait expression.""" isnan_func = self.lookup_function_by_name("isnull") @@ -496,7 +518,7 @@ def convert_ifnull_function( return if_then_else_operation(expr, arg1, arg0) def convert_struct_function( - self, func: spark_exprs_pb2.Expression.UnresolvedFunction + self, func: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression: """Convert a Spark struct function into a Substrait nested expression.""" nested = algebra_pb2.Expression.Nested() @@ -507,7 +529,7 @@ def convert_struct_function( return algebra_pb2.Expression(nested=nested) def convert_unresolved_function( - self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction + self, unresolved_function: spark_exprs_pb2.Expression.UnresolvedFunction ) -> algebra_pb2.Expression | algebra_pb2.AggregateFunction: """Convert a Spark unresolved function into a Substrait scalar function.""" parent_processing_mode = self._expression_processing_mode @@ -533,8 +555,8 @@ def convert_unresolved_function( func = algebra_pb2.Expression.ScalarFunction() function_def = self.lookup_function_by_name(unresolved_function.function_name) if ( - parent_processing_mode == ExpressionProcessingMode.AGGR_NOT_TOP_LEVEL - and function_def.function_type == FunctionType.AGGREGATE + parent_processing_mode == ExpressionProcessingMode.AGGR_NOT_TOP_LEVEL + and function_def.function_type == FunctionType.AGGREGATE ): self._expression_processing_mode = ExpressionProcessingMode.AGGR_UNDER_AGGREGATE func.function_reference = function_def.anchor @@ -542,8 +564,8 @@ def convert_unresolved_function( if function_def.max_args is not None and idx >= function_def.max_args: break if ( - unresolved_function.function_name == "count" - and arg.WhichOneof("expr_type") == "unresolved_star" + unresolved_function.function_name == "count" + and arg.WhichOneof("expr_type") == "unresolved_star" ): # Ignore all the rest of the arguments. func.arguments.append(algebra_pb2.FunctionArgument(value=bigint_literal(1))) @@ -589,7 +611,7 @@ def convert_unresolved_function( self._expression_processing_mode = parent_processing_mode def convert_alias_expression( - self, alias: spark_exprs_pb2.Expression.Alias + self, alias: spark_exprs_pb2.Expression.Alias ) -> algebra_pb2.Expression: """Convert a Spark alias into a Substrait expression.""" # We do nothing here and let the magic happen in the calling project relation. @@ -620,7 +642,7 @@ def convert_type(self, spark_type: spark_types_pb2.DataType) -> type_pb2.Type: return self.convert_type_str(spark_type.WhichOneof("kind")) def convert_cast_expression( - self, cast: spark_exprs_pb2.Expression.Cast + self, cast: spark_exprs_pb2.Expression.Cast ) -> algebra_pb2.Expression: """Convert a Spark cast expression into a Substrait cast expression.""" cast_rel = algebra_pb2.Expression.Cast( @@ -752,7 +774,7 @@ def convert_window_expression( return algebra_pb2.Expression(window_function=func) def convert_extract_value( - self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue + self, extract: spark_exprs_pb2.Expression.UnresolvedExtractValue ) -> algebra_pb2.Expression: """Convert a Spark extract value expression into a Substrait extract value expression.""" # TODO -- Add support for list and map operations. @@ -764,6 +786,7 @@ def convert_extract_value( value=self.convert_unresolved_attribute(extract.child.unresolved_attribute) ) ) + # MEGAHACK -- here # MEGAHACK -- Fix this to be relative to the struct value we found above. field_ref = self.find_field_by_name(extract.extraction.literal.string) if field_ref is None: @@ -772,7 +795,7 @@ def convert_extract_value( f"{self._current_plan_id}" ) func.arguments.append( - algebra_pb2.FunctionArgument(value=integer_literal(field_ref+1)) + algebra_pb2.FunctionArgument(value=integer_literal(field_ref + 1)) ) return algebra_pb2.Expression(scalar_function=func) @@ -861,21 +884,25 @@ def get_primary_names(self, schema: type_pb2.NamedStruct) -> list[str]: return primary_names def convert_read_named_table_relation( - self, rel: spark_relations_pb2.Read.NamedTable + self, rel: spark_relations_pb2.Read.NamedTable ) -> algebra_pb2.Rel: """Convert a read named table relation to a Substrait relation.""" table_name = rel.unparsed_identifier + assert self._backend is not None arrow_schema = self._backend.describe_table(table_name) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') if self._conversion_options.use_duckdb_struct_name_behavior: for field_name in self.get_primary_names(schema): - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) else: - symbol.output_fields.extend(schema.names) + for field_name in schema.names: + symbol.output_fields.append(Field(field_name)) return algebra_pb2.Rel( read=algebra_pb2.ReadRel( @@ -912,7 +939,7 @@ def convert_type_name(self, type_name: str) -> type_pb2.Type: case _: raise NotImplementedError(f"Unexpected type name: {type_name}") - def convert_field(self, field: types_pb2.DataType) -> (type_pb2.Type, list[str]): + def convert_field(self, field: types_pb2.DataType) -> tuple[type_pb2.Type, list[str]]: """Convert a Spark field into a Substrait field.""" if field.get("nullable"): nullability = type_pb2.Type.NULLABILITY_NULLABLE @@ -960,6 +987,8 @@ def convert_field(self, field: types_pb2.DataType) -> (type_pb2.Type, list[str]) elif ft.get("type") == "struct": field_type = type_pb2.Type(struct=type_pb2.Type.Struct(nullability=nullability)) sub_type = self.convert_schema_dict(ft) + if not sub_type: + raise ValueError(f'Could not parse type struct type {ft}') more_names.extend(sub_type.names) field_type.struct.types.extend(sub_type.struct.types) else: @@ -990,8 +1019,8 @@ def convert_schema(self, schema_str: str) -> type_pb2.NamedStruct | None: return self.convert_schema_dict(schema_data) def convert_arrow_datatype( - self, arrow_type: pa.DataType, nullable: bool = False - ) -> (type_pb2.Type, list[str]): + self, arrow_type: pa.DataType, nullable: bool = False + ) -> tuple[type_pb2.Type, list[str]]: """Convert an Arrow datatype into a Substrait type.""" if nullable: nullability = type_pb2.Type.NULLABILITY_NULLABLE @@ -1032,7 +1061,7 @@ def convert_arrow_datatype( sub_type = self.convert_arrow_schema(y.type.schema) more_names.extend(y.name) field_type.struct.types.extend(sub_type.struct.types) - return field_type + return field_type, more_names raise NotImplementedError(f"Unexpected arrow datatype: {arrow_type}") return field_type, more_names @@ -1125,8 +1154,10 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al arrow_schema = self._backend.describe_files(paths) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') for field_name in schema.names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) if self._conversion_options.use_named_table_workaround: return algebra_pb2.Rel( read=algebra_pb2.ReadRel( @@ -1195,6 +1226,8 @@ def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: if not self._conversion_options.use_emits_instead_of_direct: return algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct()) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') emit = algebra_pb2.RelCommon.Emit() if emit_overrides: for field_number in emit_overrides: @@ -1277,6 +1310,8 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge self.update_field_references(rel.input.common.plan_id) aggregate.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') # Start tracking the parts of the expressions we are interested in. self._top_level_projects = [] @@ -1352,7 +1387,7 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate, result = self.convert_expression(expr) if result: self._top_level_projects.append(result) - symbol.generated_fields.append(self.determine_expression_name(expr)) + symbol.generated_fields.append(Field(self.determine_expression_name(expr))) symbol.output_fields.clear() symbol.output_fields.extend(symbol.generated_fields) if len(rel.grouping_expressions) > 1: @@ -1458,6 +1493,8 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a # Now that we've processed the input, do the bookkeeping. self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') # Find the length of each column in every row. project1 = project_relation( @@ -1486,18 +1523,18 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a least_function( greater_func, field_reference(column_number), bigint_literal(rel.truncate) ), - strlen(strlen_func, string_literal(symbol.input_fields[column_number])), + strlen(strlen_func, string_literal(symbol.input_fields[column_number].name)), ) for column_number in range(len(symbol.input_fields)) ], ) - def field_header_fragment(field_number: int) -> list[algebra_pb2.Expression]: + def field_header_fragment(fields: list[Field], field_number: int) -> list[algebra_pb2.Expression]: return [ string_literal("|"), lpad_function( lpad_func, - string_literal(symbol.input_fields[field_number]), + string_literal(fields[field_number].name), field_reference(field_number), ), ] @@ -1508,7 +1545,7 @@ def field_line_fragment(field_number: int) -> list[algebra_pb2.Expression]: repeat_function(repeat_func, "-", field_reference(field_number)), ] - def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: + def field_body_fragment(fields: list[Field], field_number: int) -> list[algebra_pb2.Expression]: return [ string_literal("|"), if_then_else_operation( @@ -1518,7 +1555,7 @@ def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: strlen_func, cast_operation(field_reference(field_number), string_type()), ), - field_reference(field_number + len(symbol.input_fields)), + field_reference(field_number + len(fields)), ), concat( concat_func, @@ -1528,7 +1565,7 @@ def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: field_reference(field_number), minus_function( minus_func, - field_reference(field_number + len(symbol.input_fields)), + field_reference(field_number + len(fields)), bigint_literal(3), ), ), @@ -1538,17 +1575,17 @@ def field_body_fragment(field_number: int) -> list[algebra_pb2.Expression]: lpad_function( lpad_func, field_reference(field_number), - field_reference(field_number + len(symbol.input_fields)), + field_reference(field_number + len(fields)), ), ), ] - def header_line(fields: list[str]) -> list[algebra_pb2.Expression]: + def header_line(fields: list[Field]) -> list[algebra_pb2.Expression]: return [ concat( concat_func, flatten( - [field_header_fragment(field_number) for field_number in range(len(fields))] + [field_header_fragment(fields, field_number) for field_number in range(len(fields))] ) + [ string_literal("|\n"), @@ -1556,7 +1593,7 @@ def header_line(fields: list[str]) -> list[algebra_pb2.Expression]: ) ] - def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: + def full_line(fields: list[Field]) -> list[algebra_pb2.Expression]: return [ concat( concat_func, @@ -1594,7 +1631,7 @@ def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: concat_func, flatten( [ - field_body_fragment(field_number) + field_body_fragment(symbol.input_fields, field_number) for field_number in range(len(symbol.input_fields)) ] ) @@ -1614,8 +1651,10 @@ def full_line(fields: list[str]) -> list[algebra_pb2.Expression]: join2 = join_relation(project3, aggregate2) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') symbol.output_fields.clear() - symbol.output_fields.append("show_string") + symbol.output_fields.append(Field("show_string")) def compute_row_count_footer(num_rows: int) -> str: if num_rows == 1: @@ -1647,26 +1686,29 @@ def compute_row_count_footer(num_rows: int) -> str: return project5 def convert_with_columns_relation( - self, rel: spark_relations_pb2.WithColumns + self, rel: spark_relations_pb2.WithColumns ) -> algebra_pb2.Rel: """Convert a with columns relation into a Substrait project relation.""" input_rel = self.convert_relation(rel.input) project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') proposed_expressions = [field_reference(i) for i in range(len(symbol.input_fields))] for alias in rel.aliases: if len(alias.name) != 1: raise ValueError("Only one name part is supported in an alias.") name = alias.name[0] - if name in symbol.input_fields: - proposed_expressions[symbol.input_fields.index(name)] = self.convert_expression( - alias.expr - ) + field_num = _find_reference_number(symbol.input_fields, name) + if field_num is not None: + # Overwrite our the old expression with our own. + proposed_expressions[field_num] = self.convert_expression(alias.expr) else: + # This is a new column. proposed_expressions.append(self.convert_expression(alias.expr)) - symbol.generated_fields.append(name) - symbol.output_fields.append(name) + symbol.generated_fields.append(Field(name)) + symbol.output_fields.append(Field(name)) project.common.CopyFrom(self.create_common_relation()) project.expressions.extend(proposed_expressions) for i in range(len(proposed_expressions)): @@ -1674,22 +1716,24 @@ def convert_with_columns_relation( return algebra_pb2.Rel(project=project) def convert_with_columns_renamed_relation( - self, rel: spark_relations_pb2.WithColumnsRenamed + self, rel: spark_relations_pb2.WithColumnsRenamed ) -> algebra_pb2.Rel: """Update the columns names based on the Spark with columns renamed relation.""" input_rel = self.convert_relation(rel.input) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') self.update_field_references(rel.input.common.plan_id) symbol.output_fields.clear() if hasattr(rel, "renames"): aliases = {r.col_name: r.new_col_name for r in rel.renames} else: aliases = rel.rename_columns_map - for field_name in symbol.input_fields: - if field_name in aliases: - symbol.output_fields.append(aliases[field_name]) + for field in symbol.input_fields: + if field.name in aliases: + symbol.output_fields.append(field.alias(aliases[field.name])) else: - symbol.output_fields.append(field_name) + symbol.output_fields.append(field) return input_rel def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Rel: @@ -1698,14 +1742,16 @@ def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Re project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') if rel.columns: column_names = [c.unresolved_attribute.unparsed_identifier for c in rel.columns] else: column_names = rel.column_names symbol.output_fields.clear() - for field_number, field_name in enumerate(symbol.input_fields): - if field_name not in column_names: - symbol.output_fields.append(field_name) + for field_number, field in enumerate(symbol.input_fields): + if field.name not in column_names: + symbol.output_fields.append(field) if self._conversion_options.drop_emit_workaround: project.common.emit.output_mapping.append(len(project.expressions)) project.expressions.append(field_reference(field_number)) @@ -1720,6 +1766,8 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R input_rel = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') if len(rel.column_names) != len(symbol.input_fields): raise ValueError( "column_names does not match the number of input fields at " @@ -1727,7 +1775,7 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R ) symbol.output_fields.clear() for field_name in rel.column_names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) return input_rel def convert_arrow_to_literal(self, val: pa.Scalar) -> algebra_pb2.Expression.Literal: @@ -1793,9 +1841,13 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge """Convert a Spark local relation into a virtual table.""" read = algebra_pb2.ReadRel(virtual_table=self.convert_arrow_data_to_virtual_table(rel.data)) schema = self.convert_schema(rel.schema) + if not schema: + raise ValueError(f'Received an empty schema in plan id {self._current_plan_id}') symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') for field_name in schema.names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) read.base_schema.CopyFrom(schema) read.common.CopyFrom(self.create_common_relation()) return algebra_pb2.Rel(read=read) @@ -1805,6 +1857,8 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: # TODO -- Handle multithreading in the case with a persistent backend. plan = self._sql_backend.convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') for field_name in plan.relations[0].root.names: symbol.output_fields.append(field_name) # TODO -- Correctly capture all the used functions and extensions. @@ -1815,7 +1869,7 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: return plan.relations[0].root.input def convert_spark_join_type( - self, join_type: spark_relations_pb2.Join.JoinType + self, join_type: spark_relations_pb2.Join.JoinType ) -> algebra_pb2.JoinRel.JoinType: """Convert a Spark join type into a Substrait join type.""" match join_type: @@ -1833,7 +1887,7 @@ def convert_spark_join_type( return algebra_pb2.JoinRel.JOIN_TYPE_ANTI case spark_relations_pb2.Join.JOIN_TYPE_LEFT_SEMI: return algebra_pb2.JoinRel.JOIN_TYPE_SEMI - case spark_relations_pb2.Join.CROSS: + case spark_relations_pb2.Join.JOIN_TYPE_CROSS: raise RuntimeError("Internal error: cross joins should be handled elsewhere") case _: raise ValueError(f"Unexpected join type: {join_type}") @@ -1868,12 +1922,20 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re else: equal_func = self.lookup_function_by_name("==") if rel.using_columns: - column = rel.using_columns[0] - left_fields = self._symbol_table.get_symbol(rel.left.common.plan_id).output_fields + column_name = rel.using_columns[0] + left_symbol = self._symbol_table.get_symbol(rel.left.common.plan_id) + if not left_symbol: + raise InternalError( + f'Could not find plan id {rel.left.common.plan_id} that we constructed earlier.') + left_fields = left_symbol.output_fields left_column_count = len(left_fields) - left_column_reference = left_fields.index(column) - right_fields = self._symbol_table.get_symbol(rel.right.common.plan_id).output_fields - right_column_reference = right_fields.index(column) + left_column_reference = _find_reference_number(left_fields, column_name) + right_symbol = self._symbol_table.get_symbol(rel.right.common.plan_id) + if not right_symbol: + raise InternalError( + f'Could not find plan id {rel.right.common.plan_id} that we constructed earlier.') + right_fields = right_symbol.output_fields + right_column_reference = _find_reference_number(right_fields, column_name) join.expression.CopyFrom( equal_function( equal_func, @@ -1884,6 +1946,8 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re # Avoid emitting the join column twice. symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') if self._conversion_options.join_not_honoring_emit_workaround: project = algebra_pb2.ProjectRel(input=algebra_pb2.Rel(join=join)) for column_number in range(len(symbol.output_fields)): @@ -1898,21 +1962,36 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re del symbol.output_fields[right_column_reference + left_column_count] return algebra_pb2.Rel(join=join) + def find_additional_names(self, expr: spark_exprs_pb2.Expression) -> list[str]: + if expr.WhichOneof("expr_type") == "alias": + current_expr = expr.alias.expr + else: + current_expr = expr + if current_expr.unresolved_function.function_name == "struct": + new_names = [] + for arg in current_expr.unresolved_function.arguments: + new_names.append(arg.unresolved_attribute.unparsed_identifier) + # TODO -- Handle nested structures. + return new_names + return [] + def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_pb2.Rel: """Convert a Spark project relation into a Substrait project relation.""" input_rel = self.convert_relation(rel.input) project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') for field_number, expr in enumerate(rel.expressions): if expr.WhichOneof("expr_type") == "unresolved_regex": regex = expr.unresolved_regex.col_name.replace("`", "") matcher = re.compile(regex) found = False for column in symbol.input_fields: - if matcher.match(column): + if matcher.match(column.name): project.expressions.append( - field_reference(symbol.input_fields.index(column)) + field_reference(_find_reference_number(symbol.input_fields, column.name)) ) symbol.generated_fields.append(column) symbol.output_fields.append(column) @@ -1929,8 +2008,10 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ name = expr.unresolved_attribute.unparsed_identifier else: name = f"generated_field_{field_number}" - symbol.generated_fields.append(name) - symbol.output_fields.append(name) + projected_symbol = Field(name) + projected_symbol.child_names.extend(self.find_additional_names(expr)) + symbol.generated_fields.append(projected_symbol) + symbol.output_fields.append(projected_symbol) project.common.CopyFrom(self.create_common_relation()) symbol.output_fields = symbol.generated_fields for field_number in range(len(symbol.output_fields)): @@ -1938,7 +2019,7 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ return algebra_pb2.Rel(project=project) def convert_subquery_alias_relation( - self, rel: spark_relations_pb2.SubqueryAlias + self, rel: spark_relations_pb2.SubqueryAlias ) -> algebra_pb2.Rel: """Convert a Spark subquery alias relation into a Substrait relation.""" raise NotImplementedError("Subquery alias relations are not yet implemented") @@ -1954,6 +2035,8 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> self.update_field_references(rel.input.common.plan_id) aggregate.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') grouping = aggregate.groupings.add() for idx, field in enumerate(symbol.input_fields): grouping.grouping_expressions.append(field_reference(idx)) @@ -1981,7 +2064,7 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> return project def convert_set_operation_relation( - self, rel: spark_relations_pb2.SetOperation + self, rel: spark_relations_pb2.SetOperation ) -> algebra_pb2.Rel: """Convert a Spark set operation relation into a Substrait set operation relation.""" left = self.convert_relation(rel.left_input) @@ -2035,8 +2118,10 @@ def convert_dropna_relation(self, rel: spark_relations_pb2.NADrop) -> algebra_pb self.update_field_references(rel.input.common.plan_id) filter_rel.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') if rel.cols: - cols = [symbol.input_fields.index(col) for col in rel.cols] + cols = [_find_reference_number(symbol.input_fields, col) for col in rel.cols] else: cols = range(len(symbol.input_fields)) min_non_nulls = rel.min_non_nulls if rel.min_non_nulls else len(cols) @@ -2130,8 +2215,10 @@ def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan: if plan.HasField("root"): rel_root = algebra_pb2.RelRoot(input=self.convert_relation(plan.root)) symbol = self._symbol_table.get_symbol(plan.root.common.plan_id) - for name in symbol.output_fields: - rel_root.names.append(name) + if not symbol: + raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + for field in symbol.output_fields: + rel_root.names.extend(field.output_names()) result.relations.append(plan_pb2.PlanRel(root=rel_root)) for uri in sorted(self._function_uris.items(), key=operator.itemgetter(1)): result.extension_uris.append( diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index 704dd45..7db567a 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -4,16 +4,38 @@ import dataclasses +@dataclasses.dataclass +class Field: + """Tracks the names used by a field used as the input or output of a relation.""" + + name: str + # TODO -- Also track the field's type. + child_names: list[str] + + def __init__(self, name: str, child_names=None): + """Create the Field structure.""" + self.name = name + self.child_names = child_names or [] + + def alias(self, name: str): + new_field = Field(name) + new_field.child_names = self.child_names + return new_field + + def output_names(self) -> list[str]: + return [self.name] + self.child_names + + @dataclasses.dataclass class PlanMetadata: """Tracks various information about a specific plan id.""" plan_id: int - type: str | None + symbol_type: str | None parent_plan_id: int | None - input_fields: list[str] # And maybe type with additional names - generated_fields: list[str] - output_fields: list[str] + input_fields: list[Field] + generated_fields: list[Field] + output_fields: list[Field] def __init__(self, plan_id: int): """Create the PlanMetadata structure.""" diff --git a/src/transforms/output_field_tracking_visitor.py b/src/transforms/output_field_tracking_visitor.py index 8d78e0e..093bf77 100644 --- a/src/transforms/output_field_tracking_visitor.py +++ b/src/transforms/output_field_tracking_visitor.py @@ -6,6 +6,7 @@ from substrait.gen.proto import algebra_pb2, plan_pb2 from gateway.converter.symbol_table import SymbolTable +from src.gateway.converter.symbol_table import Field from substrait_visitors.substrait_plan_visitor import SubstraitPlanVisitor from transforms.label_relations import get_common_section @@ -48,8 +49,8 @@ def visit_read_relation(self, rel: algebra_pb2.ReadRel) -> Any: super().visit_read_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) # TODO -- Validate this logic where complicated data structures are used. - for field in rel.base_schema.names: - symbol.output_fields.append(field) + for field_name in rel.base_schema.names: + symbol.output_fields.append(Field(field_name)) def visit_filter_relation(self, rel: algebra_pb2.FilterRel) -> Any: """Collect the field references from the filter relation.""" @@ -66,9 +67,9 @@ def visit_aggregate_relation(self, rel: algebra_pb2.AggregateRel) -> Any: super().visit_aggregate_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) for _ in rel.groupings: - symbol.generated_fields.append("grouping") + symbol.generated_fields.append(Field("grouping")) for _ in rel.measures: - symbol.generated_fields.append("measure") + symbol.generated_fields.append(Field("measure")) self.update_field_references(get_plan_id(rel.input)) def visit_sort_relation(self, rel: algebra_pb2.SortRel) -> Any: @@ -81,7 +82,7 @@ def visit_project_relation(self, rel: algebra_pb2.ProjectRel) -> Any: super().visit_project_relation(rel) symbol = self._symbol_table.get_symbol(self._current_plan_id) for _ in rel.expressions: - symbol.generated_fields.append("intermediate") + symbol.generated_fields.append(Field("intermediate")) self.update_field_references(get_plan_id(rel.input)) def visit_extension_single_relation(self, rel: algebra_pb2.ExtensionSingleRel) -> Any: From 790c677d2310b314b3bdd5496694a6636f79941b Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 7 Oct 2024 22:44:52 -0700 Subject: [PATCH 12/18] another bug fix --- src/gateway/converter/spark_to_substrait.py | 2 +- src/gateway/tests/test_dataframe_api.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index d873c13..5ebb3cf 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -1359,7 +1359,7 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate, rel_grouping_expressions = rel.grouping_expressions for idx, grouping in enumerate(rel_grouping_expressions): grouping_expression_list.append(self.convert_expression(grouping)) - symbol.generated_fields.append(self.determine_name_for_grouping(grouping)) + symbol.generated_fields.append(Field(self.determine_name_for_grouping(grouping))) self._top_level_projects.append(field_reference(idx)) match rel.group_type: diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index 8731759..a7602b3 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -2850,7 +2850,6 @@ def test_cube(self, userage_dataframe): class TestDataFrameComplexDatastructures: """Tests the use of complex datastructures in the dataframe side of SparkConnect.""" - @pytest.mark.interesting def test_struct(self, register_tpch_dataset, spark_session, caplog): expected = [ Row(test_struct=Row(c_custkey=1, c_name='Customer#000000001')), @@ -2865,7 +2864,6 @@ def test_struct(self, register_tpch_dataset, spark_session, caplog): assertDataFrameEqual(outcome, expected) - @pytest.mark.interesting def test_struct_and_getfield(self, register_tpch_dataset, spark_session, caplog): expected = [ Row(result=1), From f86642237162f11dd9d691ff2a3ca9aa741cd587 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 7 Oct 2024 23:24:46 -0700 Subject: [PATCH 13/18] fixed inline SQL output fields --- src/gateway/converter/spark_to_substrait.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 5ebb3cf..0f42789 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -1860,7 +1860,7 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: if not symbol: raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') for field_name in plan.relations[0].root.names: - symbol.output_fields.append(field_name) + symbol.output_fields.append(Field(field_name)) # TODO -- Correctly capture all the used functions and extensions. self._saved_extension_uris = plan.extension_uris self._saved_extensions = plan.extensions From c915d839c410f35a762fc1b67abf4b3f679e6527 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 7 Oct 2024 23:36:25 -0700 Subject: [PATCH 14/18] disabled getfield tests as we need to track names with types --- src/gateway/converter/spark_to_substrait.py | 15 +++++++-------- src/gateway/tests/test_dataframe_api.py | 3 +++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 0f42789..5d0293b 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -786,17 +786,16 @@ def convert_extract_value( value=self.convert_unresolved_attribute(extract.child.unresolved_attribute) ) ) - # MEGAHACK -- here - # MEGAHACK -- Fix this to be relative to the struct value we found above. field_ref = self.find_field_by_name(extract.extraction.literal.string) if field_ref is None: - raise ValueError( - f"could not locate field named {extract.extraction.literal.string} in plan id " - f"{self._current_plan_id}" + func.arguments.append( + algebra_pb2.FunctionArgument(value=string_literal(extract.extraction.literal.string)) + ) + else: + # TODO -- Fix this to be relative to the struct value we found above. + func.arguments.append( + algebra_pb2.FunctionArgument(value=integer_literal(field_ref + 1)) ) - func.arguments.append( - algebra_pb2.FunctionArgument(value=integer_literal(field_ref + 1)) - ) return algebra_pb2.Expression(scalar_function=func) def convert_expression(self, expr: spark_exprs_pb2.Expression) -> algebra_pb2.Expression | None: diff --git a/src/gateway/tests/test_dataframe_api.py b/src/gateway/tests/test_dataframe_api.py index a7602b3..25b381a 100644 --- a/src/gateway/tests/test_dataframe_api.py +++ b/src/gateway/tests/test_dataframe_api.py @@ -160,6 +160,9 @@ def mark_dataframe_tests_as_xfail(request): if source == "gateway-over-duckdb" and originalname == "test_cube": pytest.skip(reason="cube aggregation not yet implemented in DuckDB") + if source == "gateway-over-duckdb" and originalname in ["test_column_getfield", + "test_struct_and_getfield"]: + pytest.skip(reason="fully named structs not yet tracked in gateway") if source == "gateway-over-datafusion" and originalname in ["test_struct", "test_struct_and_getfield"]: pytest.skip(reason="nested expressions not supported") From 9932c06fc24fd0c0c55e8bcd654d987b50b1c30b Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 7 Oct 2024 23:54:18 -0700 Subject: [PATCH 15/18] ruff --- src/gateway/converter/spark_to_substrait.py | 85 +++++++++++++-------- src/gateway/converter/symbol_table.py | 7 +- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index 5d0293b..f248e66 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -61,13 +61,13 @@ class InternalError(Exception): - pass + """Raised when there is an internal gateway issue (should not occur).""" class ExpressionProcessingMode(Enum): """The mode of processing expressions.""" - # Processing of an expression outside of an aggregate relation. + # Processing of an expression outside an aggregate relation. NORMAL = 0 # Processing of a measure at depth 0. AGGR_TOP_LEVEL = 1 @@ -145,10 +145,11 @@ def update_field_references(self, plan_id: int) -> None: """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) if not source_symbol: - raise InternalError(f'Could not find plan id {plan_id} that we constructed earlier.') + raise InternalError(f'Could not find plan id {plan_id} constructed earlier.') current_symbol = self._symbol_table.get_symbol(self._current_plan_id) if not current_symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') original_output_fields = current_symbol.output_fields for symbol in source_symbol.output_fields: new_symbol = symbol @@ -161,7 +162,8 @@ def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) if not current_symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') try: return _find_reference_number(current_symbol.input_fields, field_name) except ValueError: @@ -895,7 +897,8 @@ def convert_read_named_table_relation( symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') if self._conversion_options.use_duckdb_struct_name_behavior: for field_name in self.get_primary_names(schema): symbol.output_fields.append(Field(field_name)) @@ -1154,7 +1157,8 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_name in schema.names: symbol.output_fields.append(Field(field_name)) if self._conversion_options.use_named_table_workaround: @@ -1226,7 +1230,8 @@ def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: return algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct()) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') emit = algebra_pb2.RelCommon.Emit() if emit_overrides: for field_number in emit_overrides: @@ -1310,7 +1315,8 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge aggregate.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') # Start tracking the parts of the expressions we are interested in. self._top_level_projects = [] @@ -1493,7 +1499,8 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') # Find the length of each column in every row. project1 = project_relation( @@ -1528,7 +1535,8 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a ], ) - def field_header_fragment(fields: list[Field], field_number: int) -> list[algebra_pb2.Expression]: + def field_header_fragment(fields: list[Field], + field_number: int) -> list[algebra_pb2.Expression]: return [ string_literal("|"), lpad_function( @@ -1544,7 +1552,8 @@ def field_line_fragment(field_number: int) -> list[algebra_pb2.Expression]: repeat_function(repeat_func, "-", field_reference(field_number)), ] - def field_body_fragment(fields: list[Field], field_number: int) -> list[algebra_pb2.Expression]: + def field_body_fragment(fields: list[Field], + field_number: int) -> list[algebra_pb2.Expression]: return [ string_literal("|"), if_then_else_operation( @@ -1584,7 +1593,8 @@ def header_line(fields: list[Field]) -> list[algebra_pb2.Expression]: concat( concat_func, flatten( - [field_header_fragment(fields, field_number) for field_number in range(len(fields))] + [field_header_fragment(fields, field_number) + for field_number in range(len(fields))] ) + [ string_literal("|\n"), @@ -1651,7 +1661,8 @@ def full_line(fields: list[Field]) -> list[algebra_pb2.Expression]: symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') symbol.output_fields.clear() symbol.output_fields.append(Field("show_string")) @@ -1693,7 +1704,8 @@ def convert_with_columns_relation( self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') proposed_expressions = [field_reference(i) for i in range(len(symbol.input_fields))] for alias in rel.aliases: if len(alias.name) != 1: @@ -1721,7 +1733,8 @@ def convert_with_columns_renamed_relation( input_rel = self.convert_relation(rel.input) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') self.update_field_references(rel.input.common.plan_id) symbol.output_fields.clear() if hasattr(rel, "renames"): @@ -1742,7 +1755,8 @@ def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Re self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') if rel.columns: column_names = [c.unresolved_attribute.unparsed_identifier for c in rel.columns] else: @@ -1766,7 +1780,8 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') if len(rel.column_names) != len(symbol.input_fields): raise ValueError( "column_names does not match the number of input fields at " @@ -1844,7 +1859,8 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge raise ValueError(f'Received an empty schema in plan id {self._current_plan_id}') symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_name in schema.names: symbol.output_fields.append(Field(field_name)) read.base_schema.CopyFrom(schema) @@ -1857,7 +1873,8 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: plan = self._sql_backend.convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_name in plan.relations[0].root.names: symbol.output_fields.append(Field(field_name)) # TODO -- Correctly capture all the used functions and extensions. @@ -1925,14 +1942,14 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re left_symbol = self._symbol_table.get_symbol(rel.left.common.plan_id) if not left_symbol: raise InternalError( - f'Could not find plan id {rel.left.common.plan_id} that we constructed earlier.') + f'Could not find plan id {rel.left.common.plan_id} constructed earlier.') left_fields = left_symbol.output_fields left_column_count = len(left_fields) left_column_reference = _find_reference_number(left_fields, column_name) right_symbol = self._symbol_table.get_symbol(rel.right.common.plan_id) if not right_symbol: raise InternalError( - f'Could not find plan id {rel.right.common.plan_id} that we constructed earlier.') + f'Could not find plan id {rel.right.common.plan_id} constructed earlier.') right_fields = right_symbol.output_fields right_column_reference = _find_reference_number(right_fields, column_name) join.expression.CopyFrom( @@ -1946,7 +1963,8 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re # Avoid emitting the join column twice. symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') if self._conversion_options.join_not_honoring_emit_workaround: project = algebra_pb2.ProjectRel(input=algebra_pb2.Rel(join=join)) for column_number in range(len(symbol.output_fields)): @@ -1962,10 +1980,8 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re return algebra_pb2.Rel(join=join) def find_additional_names(self, expr: spark_exprs_pb2.Expression) -> list[str]: - if expr.WhichOneof("expr_type") == "alias": - current_expr = expr.alias.expr - else: - current_expr = expr + """Find any extra names used by this expression.""" + current_expr = expr.alias.expr if expr.WhichOneof("expr_type") == "alias" else expr if current_expr.unresolved_function.function_name == "struct": new_names = [] for arg in current_expr.unresolved_function.arguments: @@ -1981,7 +1997,8 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_number, expr in enumerate(rel.expressions): if expr.WhichOneof("expr_type") == "unresolved_regex": regex = expr.unresolved_regex.col_name.replace("`", "") @@ -1990,7 +2007,8 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ for column in symbol.input_fields: if matcher.match(column.name): project.expressions.append( - field_reference(_find_reference_number(symbol.input_fields, column.name)) + field_reference( + _find_reference_number(symbol.input_fields, column.name)) ) symbol.generated_fields.append(column) symbol.output_fields.append(column) @@ -2035,7 +2053,8 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> aggregate.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') grouping = aggregate.groupings.add() for idx, field in enumerate(symbol.input_fields): grouping.grouping_expressions.append(field_reference(idx)) @@ -2118,7 +2137,8 @@ def convert_dropna_relation(self, rel: spark_relations_pb2.NADrop) -> algebra_pb filter_rel.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') if rel.cols: cols = [_find_reference_number(symbol.input_fields, col) for col in rel.cols] else: @@ -2215,7 +2235,8 @@ def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan: rel_root = algebra_pb2.RelRoot(input=self.convert_relation(plan.root)) symbol = self._symbol_table.get_symbol(plan.root.common.plan_id) if not symbol: - raise InternalError(f'Could not find plan id {self._current_plan_id} that we constructed earlier.') + raise InternalError( + f'Could not find plan id {self._current_plan_id} constructed earlier.') for field in symbol.output_fields: rel_root.names.extend(field.output_names()) result.relations.append(plan_pb2.PlanRel(root=rel_root)) diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index 7db567a..ed3293c 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -2,6 +2,7 @@ """Routines to convert SparkConnect plans to Substrait plans.""" import dataclasses +from typing import Self @dataclasses.dataclass @@ -17,13 +18,15 @@ def __init__(self, name: str, child_names=None): self.name = name self.child_names = child_names or [] - def alias(self, name: str): + def alias(self, name: str) -> Self: + """Create a copy with an alternate name.""" new_field = Field(name) new_field.child_names = self.child_names return new_field def output_names(self) -> list[str]: - return [self.name] + self.child_names + """Return all of the names used by this field (including subtypes).""" + return [self.name, *self.child_names] @dataclasses.dataclass From ce186ee36aea1536ce339581b7dc81c5bb71022c Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 8 Oct 2024 00:00:33 -0700 Subject: [PATCH 16/18] remove self typing --- src/gateway/converter/symbol_table.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index ed3293c..6b4e373 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -2,7 +2,6 @@ """Routines to convert SparkConnect plans to Substrait plans.""" import dataclasses -from typing import Self @dataclasses.dataclass @@ -18,7 +17,7 @@ def __init__(self, name: str, child_names=None): self.name = name self.child_names = child_names or [] - def alias(self, name: str) -> Self: + def alias(self, name: str): """Create a copy with an alternate name.""" new_field = Field(name) new_field.child_names = self.child_names From bb49c1979548a046aca345be9609023c2fb1f3f0 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 8 Oct 2024 21:07:52 -0700 Subject: [PATCH 17/18] changes from review --- src/backends/arrow_tools.py | 6 +++++ src/backends/tests/arrow_tools_test.py | 33 ++++++++++++++++++++++---- src/gateway/converter/symbol_table.py | 2 +- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/backends/arrow_tools.py b/src/backends/arrow_tools.py index 0871d52..39dc93a 100644 --- a/src/backends/arrow_tools.py +++ b/src/backends/arrow_tools.py @@ -41,6 +41,9 @@ def reapply_names(table: pa.Table, names: list[str]) -> pa.Table: remaining_names = names for column in iter(table.columns): + if not remaining_names: + raise ValueError('Insufficient number of names provided to reapply names.') + this_name = remaining_names.pop(0) new_array, remaining_names = _reapply_names_to_type(column, remaining_names) @@ -48,4 +51,7 @@ def reapply_names(table: pa.Table, names: list[str]) -> pa.Table: new_schema.append(pa.field(this_name, new_array.type)) + if remaining_names: + raise ValueError('Too many names provided to reapply names.') + return pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema)) diff --git a/src/backends/tests/arrow_tools_test.py b/src/backends/tests/arrow_tools_test.py index fb4a57c..e65cc1a 100644 --- a/src/backends/tests/arrow_tools_test.py +++ b/src/backends/tests/arrow_tools_test.py @@ -13,10 +13,13 @@ class TestCase: input: pa.Table names: list[str] expected: pa.table + fail: bool = False cases: list[TestCase] = [ - TestCase('empty table', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), + TestCase('empty table', pa.Table.from_arrays([]), [], pa.Table.from_arrays([])), + TestCase('too many names', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([]), + fail=True), TestCase('normal columns', pa.Table.from_pydict( {"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]}, @@ -28,6 +31,18 @@ class TestCase: "renamed_age": [99, None, 42, None]}, schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()}) )), + TestCase('too few names', + pa.Table.from_pydict( + {"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]}, + schema=pa.schema({"name": pa.string(), "age": pa.int32()}) + ), + ['renamed_name'], + pa.Table.from_pydict( + {"renamed_name": [None, "Joe", "Sarah", None], + "renamed_age": [99, None, 42, None]}, + schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()}) + ), + fail=True), TestCase('struct column', pa.Table.from_arrays( [pa.array([{"": 1, "b": "b"}], @@ -38,13 +53,14 @@ class TestCase: [pa.array([{"a": 1, "b": "b"}], type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"]) ), - TestCase('nested structs', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])), + # TODO -- Test nested structs. # TODO -- Test a list. # TODO -- Test a map. # TODO -- Test a mixture of complex and simple types. ] +@pytest.mark.interesting class TestArrowTools: """Tests the functionality of the arrow tools package.""" @@ -52,5 +68,14 @@ class TestArrowTools: "case", cases, ids=lambda case: case.name ) def test_reapply_names(self, case): - result = reapply_names(case.input, case.names) - assert result == case.expected + failed = False + try: + result = reapply_names(case.input, case.names) + except ValueError as _: + result = None + failed = True + if case.fail: + assert failed + else: + assert result == case.expected + diff --git a/src/gateway/converter/symbol_table.py b/src/gateway/converter/symbol_table.py index 6b4e373..184fb1b 100644 --- a/src/gateway/converter/symbol_table.py +++ b/src/gateway/converter/symbol_table.py @@ -24,7 +24,7 @@ def alias(self, name: str): return new_field def output_names(self) -> list[str]: - """Return all of the names used by this field (including subtypes).""" + """Return all the names used by this field (including subtypes).""" return [self.name, *self.child_names] From 92652f7f6554b35b8859c407fa678c42d7bacfbb Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 8 Oct 2024 21:14:08 -0700 Subject: [PATCH 18/18] removed checks that made pyright less unhappy --- src/backends/tests/arrow_tools_test.py | 1 - src/gateway/converter/spark_to_substrait.py | 69 --------------------- 2 files changed, 70 deletions(-) diff --git a/src/backends/tests/arrow_tools_test.py b/src/backends/tests/arrow_tools_test.py index e65cc1a..596822a 100644 --- a/src/backends/tests/arrow_tools_test.py +++ b/src/backends/tests/arrow_tools_test.py @@ -60,7 +60,6 @@ class TestCase: ] -@pytest.mark.interesting class TestArrowTools: """Tests the functionality of the arrow tools package.""" diff --git a/src/gateway/converter/spark_to_substrait.py b/src/gateway/converter/spark_to_substrait.py index f248e66..f067827 100644 --- a/src/gateway/converter/spark_to_substrait.py +++ b/src/gateway/converter/spark_to_substrait.py @@ -60,10 +60,6 @@ from gateway.converter.symbol_table import Field, SymbolTable -class InternalError(Exception): - """Raised when there is an internal gateway issue (should not occur).""" - - class ExpressionProcessingMode(Enum): """The mode of processing expressions.""" @@ -144,12 +140,7 @@ def lookup_function_by_name(self, name: str) -> ExtensionFunction: def update_field_references(self, plan_id: int) -> None: """Use the field references using the specified portion of the plan.""" source_symbol = self._symbol_table.get_symbol(plan_id) - if not source_symbol: - raise InternalError(f'Could not find plan id {plan_id} constructed earlier.') current_symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not current_symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') original_output_fields = current_symbol.output_fields for symbol in source_symbol.output_fields: new_symbol = symbol @@ -161,9 +152,6 @@ def update_field_references(self, plan_id: int) -> None: def find_field_by_name(self, field_name: str) -> int | None: """Look up the field name in the current set of field references.""" current_symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not current_symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') try: return _find_reference_number(current_symbol.input_fields, field_name) except ValueError: @@ -896,9 +884,6 @@ def convert_read_named_table_relation( schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') if self._conversion_options.use_duckdb_struct_name_behavior: for field_name in self.get_primary_names(schema): symbol.output_fields.append(Field(field_name)) @@ -1156,9 +1141,6 @@ def convert_read_data_source_relation(self, rel: spark_relations_pb2.Read) -> al arrow_schema = self._backend.describe_files(paths) schema = self.convert_arrow_schema(arrow_schema) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_name in schema.names: symbol.output_fields.append(Field(field_name)) if self._conversion_options.use_named_table_workaround: @@ -1229,9 +1211,6 @@ def create_common_relation(self, emit_overrides=None) -> algebra_pb2.RelCommon: if not self._conversion_options.use_emits_instead_of_direct: return algebra_pb2.RelCommon(direct=algebra_pb2.RelCommon.Direct()) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') emit = algebra_pb2.RelCommon.Emit() if emit_overrides: for field_number in emit_overrides: @@ -1314,9 +1293,6 @@ def convert_aggregate_relation(self, rel: spark_relations_pb2.Aggregate) -> alge self.update_field_references(rel.input.common.plan_id) aggregate.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') # Start tracking the parts of the expressions we are interested in. self._top_level_projects = [] @@ -1498,9 +1474,6 @@ def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> a # Now that we've processed the input, do the bookkeeping. self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') # Find the length of each column in every row. project1 = project_relation( @@ -1660,9 +1633,6 @@ def full_line(fields: list[Field]) -> list[algebra_pb2.Expression]: join2 = join_relation(project3, aggregate2) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') symbol.output_fields.clear() symbol.output_fields.append(Field("show_string")) @@ -1703,9 +1673,6 @@ def convert_with_columns_relation( project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') proposed_expressions = [field_reference(i) for i in range(len(symbol.input_fields))] for alias in rel.aliases: if len(alias.name) != 1: @@ -1732,9 +1699,6 @@ def convert_with_columns_renamed_relation( """Update the columns names based on the Spark with columns renamed relation.""" input_rel = self.convert_relation(rel.input) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') self.update_field_references(rel.input.common.plan_id) symbol.output_fields.clear() if hasattr(rel, "renames"): @@ -1754,9 +1718,6 @@ def convert_drop_relation(self, rel: spark_relations_pb2.Drop) -> algebra_pb2.Re project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') if rel.columns: column_names = [c.unresolved_attribute.unparsed_identifier for c in rel.columns] else: @@ -1779,9 +1740,6 @@ def convert_to_df_relation(self, rel: spark_relations_pb2.ToDF) -> algebra_pb2.R input_rel = self.convert_relation(rel.input) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') if len(rel.column_names) != len(symbol.input_fields): raise ValueError( "column_names does not match the number of input fields at " @@ -1858,9 +1816,6 @@ def convert_local_relation(self, rel: spark_relations_pb2.LocalRelation) -> alge if not schema: raise ValueError(f'Received an empty schema in plan id {self._current_plan_id}') symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_name in schema.names: symbol.output_fields.append(Field(field_name)) read.base_schema.CopyFrom(schema) @@ -1872,9 +1827,6 @@ def convert_sql_relation(self, rel: spark_relations_pb2.SQL) -> algebra_pb2.Rel: # TODO -- Handle multithreading in the case with a persistent backend. plan = self._sql_backend.convert_sql(rel.query) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_name in plan.relations[0].root.names: symbol.output_fields.append(Field(field_name)) # TODO -- Correctly capture all the used functions and extensions. @@ -1940,16 +1892,10 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re if rel.using_columns: column_name = rel.using_columns[0] left_symbol = self._symbol_table.get_symbol(rel.left.common.plan_id) - if not left_symbol: - raise InternalError( - f'Could not find plan id {rel.left.common.plan_id} constructed earlier.') left_fields = left_symbol.output_fields left_column_count = len(left_fields) left_column_reference = _find_reference_number(left_fields, column_name) right_symbol = self._symbol_table.get_symbol(rel.right.common.plan_id) - if not right_symbol: - raise InternalError( - f'Could not find plan id {rel.right.common.plan_id} constructed earlier.') right_fields = right_symbol.output_fields right_column_reference = _find_reference_number(right_fields, column_name) join.expression.CopyFrom( @@ -1962,9 +1908,6 @@ def convert_join_relation(self, rel: spark_relations_pb2.Join) -> algebra_pb2.Re # Avoid emitting the join column twice. symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') if self._conversion_options.join_not_honoring_emit_workaround: project = algebra_pb2.ProjectRel(input=algebra_pb2.Rel(join=join)) for column_number in range(len(symbol.output_fields)): @@ -1996,9 +1939,6 @@ def convert_project_relation(self, rel: spark_relations_pb2.Project) -> algebra_ project = algebra_pb2.ProjectRel(input=input_rel) self.update_field_references(rel.input.common.plan_id) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') for field_number, expr in enumerate(rel.expressions): if expr.WhichOneof("expr_type") == "unresolved_regex": regex = expr.unresolved_regex.col_name.replace("`", "") @@ -2052,9 +1992,6 @@ def convert_deduplicate_relation(self, rel: spark_relations_pb2.Deduplicate) -> self.update_field_references(rel.input.common.plan_id) aggregate.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') grouping = aggregate.groupings.add() for idx, field in enumerate(symbol.input_fields): grouping.grouping_expressions.append(field_reference(idx)) @@ -2136,9 +2073,6 @@ def convert_dropna_relation(self, rel: spark_relations_pb2.NADrop) -> algebra_pb self.update_field_references(rel.input.common.plan_id) filter_rel.common.CopyFrom(self.create_common_relation()) symbol = self._symbol_table.get_symbol(self._current_plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') if rel.cols: cols = [_find_reference_number(symbol.input_fields, col) for col in rel.cols] else: @@ -2234,9 +2168,6 @@ def convert_plan(self, plan: spark_pb2.Plan) -> plan_pb2.Plan: if plan.HasField("root"): rel_root = algebra_pb2.RelRoot(input=self.convert_relation(plan.root)) symbol = self._symbol_table.get_symbol(plan.root.common.plan_id) - if not symbol: - raise InternalError( - f'Could not find plan id {self._current_plan_id} constructed earlier.') for field in symbol.output_fields: rel_root.names.extend(field.output_names()) result.relations.append(plan_pb2.PlanRel(root=rel_root))