Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement basic struct handling #91

Merged
merged 18 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ or Velox.
### Locally
To run the gateway locally - you need to setup a Python (Conda) environment.

To run the Spark tests you will need Java installed.

Ensure you have [Miniconda](https://docs.anaconda.com/miniconda/miniconda-install/) and [Rust/Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html) installed.

Once that is done - run these steps from a bash terminal:
```bash
git clone --recursive https://github.com/<your-fork>/spark-substrait-gateway.git
git clone https://github.com/<your-fork>/spark-substrait-gateway.git
cd spark-substrait-gateway
conda init bash
. ~/.bashrc
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- setuptools >= 61.0.0
- setuptools_scm >= 6.2.0
- mypy-protobuf
- types-protobuf >= 4.25.0, < 5.0.0
- types-protobuf >= 5.0.0
- numpy < 2.0.0
- Faker
- pip:
Expand All @@ -27,7 +27,7 @@ dependencies:
- substrait == 0.21.0
- substrait-validator
- pytest-timeout
- protobuf >= 4.25.3, < 5.0.0
- protobuf >= 5.0.0
- cryptography == 43.0.*
- click == 8.1.*
- pyjwt == 2.8.*
Expand Down
51 changes: 51 additions & 0 deletions src/backends/arrow_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
"""Routines to manipulate arrow tables."""
import pyarrow as pa


def _reapply_names_to_type(array: pa.ChunkedArray, names: list[str]) -> (pa.Array, list[str]):
new_arrays = []
new_schema = []

if array.type.num_fields > len(names):
raise ValueError('Insufficient number of names provided to reapply names.')

remaining_names = names
if pa.types.is_list(array.type):
raise NotImplementedError('Reapplying names to lists not yet supported')
if pa.types.is_map(array.type):
raise NotImplementedError('Reapplying names to maps not yet supported')
if pa.types.is_struct(array.type):
field_num = 0
while field_num < array.type.num_fields:
field = array.chunks[0].field(field_num)
this_name = remaining_names.pop(0)

new_array, remaining_names = _reapply_names_to_type(field, remaining_names)
new_arrays.append(new_array)

new_schema.append(pa.field(this_name, new_array.type))

field_num += 1

return pa.StructArray.from_arrays(new_arrays, fields=new_schema), remaining_names
if array.type.num_fields != 0:
raise ValueError(f'Unsupported complex type: {array.type}')
return array, remaining_names


def reapply_names(table: pa.Table, names: list[str]) -> pa.Table:
"""Apply the provided names to the given table recursively."""
new_arrays = []
new_schema = []

remaining_names = names
for column in iter(table.columns):
this_name = remaining_names.pop(0)
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved

new_array, remaining_names = _reapply_names_to_type(column, remaining_names)
new_arrays.append(new_array)

new_schema.append(pa.field(this_name, new_array.type))

return pa.Table.from_arrays(new_arrays, schema=pa.schema(new_schema))
4 changes: 3 additions & 1 deletion src/backends/duckdb_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from substrait.gen.proto import plan_pb2

from backends.backend import Backend
from src.backends.arrow_tools import reapply_names
from transforms.rename_functions import RenameFunctionsForDuckDB


Expand Down Expand Up @@ -73,7 +74,8 @@ def _execute_plan(self, plan: plan_pb2.Plan) -> pa.lib.Table:
query_result = self._connection.from_substrait(proto=plan_data)
except Exception as err:
raise ValueError(f"DuckDB Execution Error: {err}") from err
return query_result.arrow()
arrow = query_result.arrow()
return reapply_names(arrow, plan.relations[0].root.names)

def register_table(
self,
Expand Down
56 changes: 56 additions & 0 deletions src/backends/tests/arrow_tools_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

import pyarrow as pa
import pytest

from src.backends.arrow_tools import reapply_names


@dataclass
class TestCase:
name: str
input: pa.Table
names: list[str]
expected: pa.table


cases: list[TestCase] = [
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
TestCase('empty table', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])),
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
TestCase('normal columns',
pa.Table.from_pydict(
{"name": [None, "Joe", "Sarah", None], "age": [99, None, 42, None]},
schema=pa.schema({"name": pa.string(), "age": pa.int32()})
),
['renamed_name', 'renamed_age'],
pa.Table.from_pydict(
{"renamed_name": [None, "Joe", "Sarah", None],
"renamed_age": [99, None, 42, None]},
schema=pa.schema({"renamed_name": pa.string(), "renamed_age": pa.int32()})
)),
TestCase('struct column',
pa.Table.from_arrays(
[pa.array([{"": 1, "b": "b"}],
type=pa.struct([("", pa.int64()), ("b", pa.string())]))],
names=["r"]),
['r', 'a', 'b'],
pa.Table.from_arrays(
[pa.array([{"a": 1, "b": "b"}],
type=pa.struct([("a", pa.int64()), ("b", pa.string())]))], names=["r"])
),
TestCase('nested structs', pa.Table.from_arrays([]), ['a', 'b'], pa.Table.from_arrays([])),
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
# TODO -- Test a list.
# TODO -- Test a map.
# TODO -- Test a mixture of complex and simple types.
]


class TestArrowTools:
"""Tests the functionality of the arrow tools package."""

@pytest.mark.parametrize(
"case", cases, ids=lambda case: case.name
)
def test_reapply_names(self, case):
result = reapply_names(case.input, case.names)
assert result == case.expected
2 changes: 1 addition & 1 deletion src/gateway/converter/conversion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ConversionOptions:
"""Holds all the possible conversion options."""

def __init__(self, backend: BackendOptions = None):
def __init__(self, backend: BackendOptions):
"""Initialize the conversion options."""
self.use_named_table_workaround = False
self.needs_scheme_in_path_uris = False
Expand Down
Loading
Loading