diff --git a/.gitattributes b/.gitattributes index 7ff0bbb6d959..bcdeffc09a11 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ +.github/ export-ignore datafusion/proto/src/generated/prost.rs linguist-generated datafusion/proto/src/generated/pbjson.rs linguist-generated diff --git a/.github/actions/setup-windows-builder/action.yaml b/.github/actions/setup-windows-builder/action.yaml index 9ab5c4a8b1bb..a26a34a3db93 100644 --- a/.github/actions/setup-windows-builder/action.yaml +++ b/.github/actions/setup-windows-builder/action.yaml @@ -38,8 +38,8 @@ runs: - name: Setup Rust toolchain shell: bash run: | - rustup update stable - rustup toolchain install stable + # Avoid self update to avoid CI failures: https://github.com/apache/arrow-datafusion/issues/9653 + rustup toolchain install stable --no-self-update rustup default stable rustup component add rustfmt - name: Configure rust runtime env diff --git a/.github/workflows/pr_benchmarks.yml b/.github/workflows/pr_benchmarks.yml index b7b85c9fcf14..5827c42e85ae 100644 --- a/.github/workflows/pr_benchmarks.yml +++ b/.github/workflows/pr_benchmarks.yml @@ -28,9 +28,10 @@ jobs: cd benchmarks mkdir data - # Setup the TPC-H data set with a scale factor of 10 + # Setup the TPC-H data sets for scale factors 1 and 10 ./bench.sh data tpch - + ./bench.sh data tpch10 + - name: Generate unique result names run: | echo "HEAD_LONG_SHA=$(git log -1 --format='%H')" >> "$GITHUB_ENV" @@ -44,6 +45,8 @@ jobs: cd benchmarks ./bench.sh run tpch + ./bench.sh run tpch_mem + ./bench.sh run tpch10 # For some reason this step doesn't seem to propagate the env var down into the script if [ -d "results/HEAD" ]; then @@ -64,6 +67,8 @@ jobs: cd benchmarks ./bench.sh run tpch + ./bench.sh run tpch_mem + ./bench.sh run tpch10 echo ${{ github.event.issue.number }} > pr diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 07c46351e9ac..6f6179fa52a2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -65,42 +65,73 @@ jobs: # this key equals the ones on `linux-build-lib` for re-use key: cargo-cache-benchmark-${{ hashFiles('datafusion/**/Cargo.toml', 'benchmarks/Cargo.toml', 'datafusion-cli/Cargo.toml') }} - - name: Check workspace without default features + - name: Check datafusion without default features + # Some of the test binaries require the parquet feature still + #run: cargo check --all-targets --no-default-features -p datafusion run: cargo check --no-default-features -p datafusion - name: Check datafusion-common without default features - run: cargo check --tests --no-default-features -p datafusion-common + run: cargo check --all-targets --no-default-features -p datafusion-common + + - name: Check datafusion-functions + run: cargo check --all-targets --no-default-features -p datafusion-functions - name: Check workspace in debug mode run: cargo check - - name: Check workspace with all features + - name: Check workspace with avro,json features run: cargo check --workspace --benches --features avro,json + - name: Check Cargo.lock for datafusion-cli + run: | + # If this test fails, try running `cargo update` in the `datafusion-cli` directory + # and check in the updated Cargo.lock file. + cargo check --manifest-path datafusion-cli/Cargo.toml --locked + # Ensure that the datafusion crate can be built with only a subset of the function # packages enabled. - - name: Check function packages (array_expressions) + - name: Check datafusion (array_expressions) run: cargo check --no-default-features --features=array_expressions -p datafusion - - name: Check function packages (datetime_expressions) + - name: Check datafusion (crypto) + run: cargo check --no-default-features --features=crypto_expressions -p datafusion + + - name: Check datafusion (datetime_expressions) run: cargo check --no-default-features --features=datetime_expressions -p datafusion - - name: Check function packages (encoding_expressions) + - name: Check datafusion (encoding_expressions) run: cargo check --no-default-features --features=encoding_expressions -p datafusion - - name: Check function packages (math_expressions) + - name: Check datafusion (math_expressions) run: cargo check --no-default-features --features=math_expressions -p datafusion - - name: Check function packages (regex_expressions) + - name: Check datafusion (regex_expressions) run: cargo check --no-default-features --features=regex_expressions -p datafusion - - name: Check Cargo.lock for datafusion-cli - run: | - # If this test fails, try running `cargo update` in the `datafusion-cli` directory - # and check in the updated Cargo.lock file. - cargo check --manifest-path datafusion-cli/Cargo.toml --locked + - name: Check datafusion (string_expressions) + run: cargo check --no-default-features --features=string_expressions -p datafusion + + # Ensure that the datafusion-functions crate can be built with only a subset of the function + # packages enabled. + - name: Check datafusion-functions (crypto) + run: cargo check --all-targets --no-default-features --features=crypto_expressions -p datafusion-functions + + - name: Check datafusion-functions (datetime_expressions) + run: cargo check --all-targets --no-default-features --features=datetime_expressions -p datafusion-functions + + - name: Check datafusion-functions (encoding_expressions) + run: cargo check --all-targets --no-default-features --features=encoding_expressions -p datafusion-functions - # test the crate + - name: Check datafusion-functions (math_expressions) + run: cargo check --all-targets --no-default-features --features=math_expressions -p datafusion-functions + + - name: Check datafusion-functions (regex_expressions) + run: cargo check --all-targets --no-default-features --features=regex_expressions -p datafusion-functions + + - name: Check datafusion-functions (string_expressions) + run: cargo check --all-targets --no-default-features --features=string_expressions -p datafusion-functions + + # Run tests linux-test: name: cargo test (amd64) needs: [ linux-build-lib ] @@ -164,6 +195,25 @@ jobs: - name: Verify Working Directory Clean run: git diff --exit-code + depcheck: + name: circular dependency check + needs: [ linux-build-lib ] + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Check dependencies + run: | + cd dev/depcheck + cargo run + # Run `cargo test doc` (test documentation examples) linux-test-doc: name: cargo test doc (amd64) diff --git a/Cargo.toml b/Cargo.toml index 48e555bd5527..9df489724d46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,10 +16,10 @@ # under the License. [workspace] -exclude = ["datafusion-cli"] +exclude = ["datafusion-cli", "dev/depcheck"] members = [ "datafusion/common", - "datafusion/common_runtime", + "datafusion/common-runtime", "datafusion/core", "datafusion/expr", "datafusion/execution", @@ -49,7 +49,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.72" -version = "36.0.0" +version = "37.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -57,34 +57,34 @@ version = "36.0.0" # for the inherited dependency but cannot do the reverse (override from true to false). # # See for more detaiils: https://github.com/rust-lang/cargo/issues/11329 -arrow = { version = "50.0.0", features = ["prettyprint"] } -arrow-array = { version = "50.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "50.0.0", default-features = false } -arrow-flight = { version = "50.0.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "50.0.0", default-features = false, features = ["lz4"] } -arrow-ord = { version = "50.0.0", default-features = false } -arrow-schema = { version = "50.0.0", default-features = false } -arrow-string = { version = "50.0.0", default-features = false } +arrow = { version = "51.0.0", features = ["prettyprint"] } +arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "51.0.0", default-features = false } +arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "51.0.0", default-features = false } +arrow-schema = { version = "51.0.0", default-features = false } +arrow-string = { version = "51.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" chrono = { version = "0.4.34", default-features = false } ctor = "0.2.0" dashmap = "5.4.0" -datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } -datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common_runtime", version = "36.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } -datafusion-functions-array = { path = "datafusion/functions-array", version = "36.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "36.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "36.0.0", default-features = false } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "36.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "36.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "36.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "36.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "36.0.0" } +datafusion = { path = "datafusion/core", version = "37.0.0", default-features = false } +datafusion-common = { path = "datafusion/common", version = "37.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "37.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "37.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "37.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "37.0.0" } +datafusion-functions-array = { path = "datafusion/functions-array", version = "37.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "37.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "37.0.0", default-features = false } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "37.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "37.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "37.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "37.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "37.0.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" @@ -93,9 +93,9 @@ indexmap = "2.0.0" itertools = "0.12" log = "^0.4" num_cpus = "1.13.0" -object_store = { version = "0.9.0", default-features = false } +object_store = { version = "0.9.1", default-features = false } parking_lot = "0.12" -parquet = { version = "50.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" diff --git a/README.md b/README.md index abd727672aca..c3d7c6792990 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,11 @@ Optional features: [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ -## Rust Version Compatibility +## Rust Version Compatibility Policy -Datafusion crate is tested with the [minimum required stable Rust version](https://github.com/search?q=repo%3Aapache%2Farrow-datafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support +each stable Rust version for 6 months after it is +[released](https://github.com/rust-lang/rust/blob/master/RELEASES.md). This +generally translates to support for the most recent 3 to 4 stable Rust versions. + +We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Farrow-datafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 039f4790acb0..a72400892752 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -314,7 +314,7 @@ run_tpch() { fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - RESULTS_FILE="${RESULTS_DIR}/tpch.json" + RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --format parquet -o ${RESULTS_FILE} @@ -329,7 +329,7 @@ run_tpch_mem() { fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - RESULTS_FILE="${RESULTS_DIR}/tpch_mem.json" + RESULTS_FILE="${RESULTS_DIR}/tpch_mem_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." # -m means in memory diff --git a/conbench/.flake8 b/conbench/.flake8 deleted file mode 100644 index e44b81084185..000000000000 --- a/conbench/.flake8 +++ /dev/null @@ -1,2 +0,0 @@ -[flake8] -ignore = E501 diff --git a/conbench/.gitignore b/conbench/.gitignore deleted file mode 100755 index aa44ee2adbd4..000000000000 --- a/conbench/.gitignore +++ /dev/null @@ -1,130 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - diff --git a/conbench/.isort.cfg b/conbench/.isort.cfg deleted file mode 100644 index f238bf7ea137..000000000000 --- a/conbench/.isort.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[settings] -profile = black diff --git a/conbench/README.md b/conbench/README.md deleted file mode 100644 index f655ac8bd297..000000000000 --- a/conbench/README.md +++ /dev/null @@ -1,252 +0,0 @@ - - -# DataFusion + Conbench Integration - - -## Quick start - -``` -$ cd ~/arrow-datafusion/conbench/ -$ conda create -y -n conbench python=3.9 -$ conda activate conbench -(conbench) $ pip install -r requirements.txt -(conbench) $ conbench datafusion -``` - -## Example output - -``` -{ - "batch_id": "3c82f9d23fce49328b78ba9fd963b254", - "context": { - "benchmark_language": "Rust" - }, - "github": { - "commit": "e8c198b9fac6cd8822b950b9f71898e47965488d", - "repository": "https://github.com/dianaclarke/arrow-datafusion" - }, - "info": {}, - "machine_info": { - "architecture_name": "x86_64", - "cpu_core_count": "8", - "cpu_frequency_max_hz": "2400000000", - "cpu_l1d_cache_bytes": "65536", - "cpu_l1i_cache_bytes": "131072", - "cpu_l2_cache_bytes": "4194304", - "cpu_l3_cache_bytes": "0", - "cpu_model_name": "Apple M1", - "cpu_thread_count": "8", - "gpu_count": "0", - "gpu_product_names": [], - "kernel_name": "20.6.0", - "memory_bytes": "17179869184", - "name": "diana", - "os_name": "macOS", - "os_version": "10.16" - }, - "run_id": "ec2a50b9380c470b96d7eb7d63ab5b77", - "stats": { - "data": [ - "0.001532", - "0.001394", - "0.001333", - "0.001356", - "0.001379", - "0.001361", - "0.001307", - "0.001348", - "0.001436", - "0.001397", - "0.001339", - "0.001523", - "0.001593", - "0.001415", - "0.001344", - "0.001312", - "0.001402", - "0.001362", - "0.001329", - "0.001330", - "0.001447", - "0.001413", - "0.001536", - "0.001330", - "0.001333", - "0.001338", - "0.001333", - "0.001331", - "0.001426", - "0.001575", - "0.001362", - "0.001343", - "0.001334", - "0.001383", - "0.001476", - "0.001356", - "0.001362", - "0.001334", - "0.001390", - "0.001497", - "0.001330", - "0.001347", - "0.001331", - "0.001468", - "0.001377", - "0.001351", - "0.001328", - "0.001509", - "0.001338", - "0.001355", - "0.001332", - "0.001485", - "0.001370", - "0.001366", - "0.001507", - "0.001358", - "0.001331", - "0.001463", - "0.001362", - "0.001336", - "0.001428", - "0.001343", - "0.001359", - "0.001905", - "0.001726", - "0.001411", - "0.001433", - "0.001391", - "0.001453", - "0.001346", - "0.001339", - "0.001420", - "0.001330", - "0.001422", - "0.001683", - "0.001426", - "0.001349", - "0.001342", - "0.001430", - "0.001330", - "0.001436", - "0.001331", - "0.001415", - "0.001332", - "0.001408", - "0.001343", - "0.001392", - "0.001371", - "0.001655", - "0.001354", - "0.001438", - "0.001347", - "0.001341", - "0.001374", - "0.001453", - "0.001352", - "0.001358", - "0.001398", - "0.001362", - "0.001454" - ], - "iqr": "0.000088", - "iterations": 100, - "max": "0.001905", - "mean": "0.001401", - "median": "0.001362", - "min": "0.001307", - "q1": "0.001340", - "q3": "0.001428", - "stdev": "0.000095", - "time_unit": "s", - "times": [], - "unit": "s" - }, - "tags": { - "name": "aggregate_query_group_by", - "suite": "aggregate_query_group_by" - }, - "timestamp": "2022-02-09T01:32:55.769468+00:00" -} -``` - -## Debug with test benchmark - -``` -(conbench) $ cd ~/arrow-datafusion/conbench/ -(conbench) $ conbench test --iterations=3 - -Benchmark result: -{ - "batch_id": "41a144761bc24d82b94efa70d6e460b3", - "context": { - "benchmark_language": "Python" - }, - "github": { - "commit": "e8c198b9fac6cd8822b950b9f71898e47965488d", - "repository": "https://github.com/dianaclarke/arrow-datafusion" - }, - "info": { - "benchmark_language_version": "Python 3.9.7" - }, - "machine_info": { - "architecture_name": "x86_64", - "cpu_core_count": "8", - "cpu_frequency_max_hz": "2400000000", - "cpu_l1d_cache_bytes": "65536", - "cpu_l1i_cache_bytes": "131072", - "cpu_l2_cache_bytes": "4194304", - "cpu_l3_cache_bytes": "0", - "cpu_model_name": "Apple M1", - "cpu_thread_count": "8", - "gpu_count": "0", - "gpu_product_names": [], - "kernel_name": "20.6.0", - "memory_bytes": "17179869184", - "name": "diana", - "os_name": "macOS", - "os_version": "10.16" - }, - "run_id": "71f46362db8844afacea82cba119cefc", - "stats": { - "data": [ - "0.000001", - "0.000001", - "0.000000" - ], - "iqr": "0.000000", - "iterations": 3, - "max": "0.000001", - "mean": "0.000001", - "median": "0.000001", - "min": "0.000000", - "q1": "0.000000", - "q3": "0.000001", - "stdev": "0.000001", - "time_unit": "s", - "times": [], - "unit": "s" - }, - "tags": { - "name": "test" - }, - "timestamp": "2022-02-09T01:36:45.823615+00:00" -} -``` - diff --git a/conbench/_criterion.py b/conbench/_criterion.py deleted file mode 100644 index 168a1b9b6cb1..000000000000 --- a/conbench/_criterion.py +++ /dev/null @@ -1,98 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import collections -import csv -import os -import pathlib -import subprocess - -import conbench.runner -from conbench.machine_info import github_info - - -def _result_in_seconds(row): - # sample_measured_value - The value of the measurement for this sample. - # Note that this is the measured value for the whole sample, not the - # time-per-iteration To calculate the time-per-iteration, use - # sample_measured_value/iteration_count - # -- https://bheisler.github.io/criterion.rs/book/user_guide/csv_output.html - count = int(row["iteration_count"]) - sample = float(row["sample_measured_value"]) - return sample / count / 10**9 - - -def _parse_benchmark_group(row): - parts = row["group"].split(",") - if len(parts) > 1: - suite, name = parts[0], ",".join(parts[1:]) - else: - suite, name = row["group"], row["group"] - return suite.strip(), name.strip() - - -def _read_results(src_dir): - results = collections.defaultdict(lambda: collections.defaultdict(list)) - path = pathlib.Path(os.path.join(src_dir, "target", "criterion")) - for path in list(path.glob("**/new/raw.csv")): - with open(path) as csv_file: - reader = csv.DictReader(csv_file) - for row in reader: - suite, name = _parse_benchmark_group(row) - results[suite][name].append(_result_in_seconds(row)) - return results - - -def _execute_command(command): - try: - print(command) - result = subprocess.run(command, capture_output=True, check=True) - except subprocess.CalledProcessError as e: - print(e.stderr.decode("utf-8")) - raise e - return result.stdout.decode("utf-8"), result.stderr.decode("utf-8") - - -class CriterionBenchmark(conbench.runner.Benchmark): - external = True - - def run(self, **kwargs): - src_dir = os.path.join(os.getcwd(), "..") - self._cargo_bench(src_dir) - results = _read_results(src_dir) - for suite in results: - self.conbench.mark_new_batch() - for name, data in results[suite].items(): - yield self._record_result(suite, name, data, kwargs) - - def _cargo_bench(self, src_dir): - os.chdir(src_dir) - _execute_command(["cargo", "bench"]) - - def _record_result(self, suite, name, data, options): - tags = {"suite": suite} - result = {"data": data, "unit": "s"} - context = {"benchmark_language": "Rust"} - github = github_info() - return self.conbench.record( - result, - name, - tags=tags, - context=context, - github=github, - options=options, - ) diff --git a/conbench/benchmarks.json b/conbench/benchmarks.json deleted file mode 100644 index bb7033547722..000000000000 --- a/conbench/benchmarks.json +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "command": "datafusion", - "flags": { - "language": "Rust" - } - } -] diff --git a/conbench/requirements-test.txt b/conbench/requirements-test.txt deleted file mode 100644 index 5e5647acd2d6..000000000000 --- a/conbench/requirements-test.txt +++ /dev/null @@ -1,3 +0,0 @@ -black -flake8 -isort diff --git a/conbench/requirements.txt b/conbench/requirements.txt deleted file mode 100644 index a877c7b44e9b..000000000000 --- a/conbench/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -conbench diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index deda497d9dd3..3be92221d3ee 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -39,9 +39,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa285343fba4d829d49985bdc541e3789cf6000ed0e84be7c039438df4a4e78c" +checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" dependencies = [ "arrow-arith", "arrow-array", @@ -151,9 +151,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "753abd0a5290c1bcade7c6623a556f7d1659c5f4148b140b5b63ce7bd1a45705" +checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" dependencies = [ "arrow-array", "arrow-buffer", @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d390feeb7f21b78ec997a4081a025baef1e2e0d6069e181939b61864c9779609" +checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" dependencies = [ "ahash", "arrow-buffer", @@ -183,9 +183,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69615b061701bcdffbc62756bc7e85c827d5290b472b580c972ebbbf690f5aa4" +checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" dependencies = [ "bytes", "half", @@ -194,28 +194,30 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e448e5dd2f4113bf5b74a1f26531708f5edcacc77335b7066f9398f4bcf4cdef" +checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "base64 0.21.7", + "atoi", + "base64 0.22.0", "chrono", "comfy-table", "half", "lexical-core", "num", + "ryu", ] [[package]] name = "arrow-csv" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46af72211f0712612f5b18325530b9ad1bfbdc87290d5fbfd32a7da128983781" +checksum = "95cbcba196b862270bf2a5edb75927380a7f3a163622c61d40cbba416a6305f2" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,9 +234,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67d644b91a162f3ad3135ce1184d0a31c28b816a581e08f29e8e9277a574c64e" +checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -244,9 +246,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03dea5e79b48de6c2e04f03f62b0afea7105be7b77d134f6c5414868feefb80d" +checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" dependencies = [ "arrow-array", "arrow-buffer", @@ -259,9 +261,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8950719280397a47d37ac01492e3506a8a724b3fb81001900b866637a829ee0f" +checksum = "eaafb5714d4e59feae964714d724f880511500e3569cc2a94d02456b403a2a49" dependencies = [ "arrow-array", "arrow-buffer", @@ -270,7 +272,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.5", + "indexmap 2.2.6", "lexical-core", "num", "serde", @@ -279,9 +281,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed9630979034077982d8e74a942b7ac228f33dd93a93b615b4d02ad60c260be" +checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" dependencies = [ "arrow-array", "arrow-buffer", @@ -294,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "007035e17ae09c4e8993e4cb8b5b96edf0afb927cd38e2dff27189b274d83dcf" +checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" dependencies = [ "ahash", "arrow-array", @@ -309,15 +311,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff3e9c01f7cd169379d269f926892d0e622a704960350d09d331be3ec9e0029" +checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" [[package]] name = "arrow-select" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce20973c1912de6514348e064829e50947e35977bb9d7fb637dc99ea9ffd78c" +checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" dependencies = [ "ahash", "arrow-array", @@ -329,15 +331,16 @@ dependencies = [ [[package]] name = "arrow-string" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f3b37f2aeece31a2636d1b037dabb69ef590e03bdc7eb68519b51ec86932a7" +checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "memchr", "num", "regex", "regex-syntax", @@ -378,13 +381,22 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", ] [[package]] @@ -400,9 +412,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] name = "aws-config" @@ -696,9 +708,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -739,9 +751,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "blake2" @@ -776,9 +788,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -820,9 +832,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "bytes-utils" @@ -873,9 +885,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.35" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" +checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1080,7 +1092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad291aa74992b9b7a7e88c38acbbf6ad7e107f1d90ee8775b7bc1fc3394f485c" dependencies = [ "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -1104,7 +1116,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "apache-avro", @@ -1133,7 +1145,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.3", - "indexmap 2.2.5", + "indexmap 2.2.6", "itertools", "log", "num-traits", @@ -1155,7 +1167,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1183,7 +1195,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "apache-avro", @@ -1203,14 +1215,14 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "36.0.0" +version = "37.0.0" dependencies = [ "tokio", ] [[package]] name = "datafusion-execution" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "chrono", @@ -1229,7 +1241,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "arrow", @@ -1237,6 +1249,7 @@ dependencies = [ "chrono", "datafusion-common", "paste", + "serde_json", "sqlparser", "strum 0.26.2", "strum_macros 0.26.2", @@ -1244,7 +1257,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "base64 0.22.0", @@ -1255,21 +1268,25 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "hashbrown 0.14.3", "hex", "itertools", "log", "md-5", "regex", "sha2", + "unicode-segmentation", + "uuid", ] [[package]] name = "datafusion-functions-array" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "datafusion-common", "datafusion-execution", @@ -1282,7 +1299,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "async-trait", @@ -1298,7 +1315,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "arrow", @@ -1317,7 +1334,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "hex", - "indexmap 2.2.5", + "indexmap 2.2.6", "itertools", "log", "md-5", @@ -1326,13 +1343,11 @@ dependencies = [ "rand", "regex", "sha2", - "unicode-segmentation", - "uuid", ] [[package]] name = "datafusion-physical-plan" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "arrow", @@ -1349,7 +1364,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.3", - "indexmap 2.2.5", + "indexmap 2.2.6", "itertools", "log", "once_cell", @@ -1357,12 +1372,11 @@ dependencies = [ "pin-project-lite", "rand", "tokio", - "uuid", ] [[package]] name = "datafusion-sql" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "arrow-array", @@ -1518,9 +1532,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.1" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" [[package]] name = "fd-lock" @@ -1639,7 +1653,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -1713,9 +1727,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb" dependencies = [ "bytes", "fnv", @@ -1723,7 +1737,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.2.5", + "indexmap 2.2.6", "slab", "tokio", "tokio-util", @@ -1940,9 +1954,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.5" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -1983,9 +1997,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" @@ -2127,7 +2141,7 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "libc", "redox_syscall", ] @@ -2441,9 +2455,9 @@ dependencies = [ [[package]] name = "parquet" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "547b92ebf0c1177e3892f44c8f79757ee62e678d564a9834189725f2c5b7a750" +checksum = "096795d4f47f65fd3ee1ec5a98b77ab26d602f2cc785b0e4be5443add17ecc32" dependencies = [ "ahash", "arrow-array", @@ -2453,7 +2467,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "arrow-select", - "base64 0.21.7", + "base64 0.22.0", "brotli", "bytes", "chrono", @@ -2502,7 +2516,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.2.5", + "indexmap 2.2.6", ] [[package]] @@ -2560,7 +2574,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -2743,9 +2757,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -2772,15 +2786,15 @@ checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "reqwest" -version = "0.11.26" +version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ "base64 0.21.7", "bytes", @@ -2898,11 +2912,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", @@ -2966,9 +2980,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8" +checksum = "868e20fada228fefaf6b652e00cc73623d54f8171e7352c18bb281571f2d92da" [[package]] name = "rustls-webpki" @@ -3101,14 +3115,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" dependencies = [ "itoa", "ryu", @@ -3164,9 +3178,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snafu" @@ -3236,7 +3250,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3282,7 +3296,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3295,7 +3309,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3317,9 +3331,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" dependencies = [ "proc-macro2", "quote", @@ -3360,7 +3374,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" dependencies = [ "cfg-if", - "fastrand 2.0.1", + "fastrand 2.0.2", "rustix", "windows-sys 0.52.0", ] @@ -3403,7 +3417,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3498,7 +3512,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3595,7 +3609,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3640,7 +3654,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3719,9 +3733,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", "serde", @@ -3794,7 +3808,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "wasm-bindgen-shared", ] @@ -3828,7 +3842,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4086,7 +4100,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index ad506762f0d0..18e14357314e 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "36.0.0" +version = "37.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -30,12 +30,12 @@ rust-version = "1.72" readme = "README.md" [dependencies] -arrow = "50.0.0" +arrow = "51.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "36.0.0", features = [ +datafusion = { path = "../datafusion/core", version = "37.0.0", features = [ "avro", "crypto_expressions", "datetime_expressions", @@ -52,7 +52,7 @@ futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.9.0", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "50.0.0", default-features = false } +parquet = { version = "51.0.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index a8ecb98637cb..0fbb7a5908b5 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -177,7 +177,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { // Register the store for this URL. Here we don't have access // to any command options so the only choice is to use an empty collection match scheme { - "s3" | "oss" => { + "s3" | "oss" | "cos" => { state = state.add_table_options_extension(AwsOptions::default()); } "gs" | "gcs" => { @@ -189,7 +189,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { &state, table_url.scheme(), url, - state.default_table_options(), + &state.default_table_options(), ) .await?; state.runtime_env().register_object_store(url, store); diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index b11f1c202284..53375ab4104f 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -22,6 +22,7 @@ use std::fs::File; use std::io::prelude::*; use std::io::BufReader; +use crate::helper::split_from_semicolon; use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, @@ -40,6 +41,7 @@ use datafusion::prelude::SessionContext; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion_common::FileType; use rustyline::error::ReadlineError; use rustyline::Editor; use tokio::signal; @@ -163,21 +165,24 @@ pub async fn exec_from_repl( } } Ok(line) => { - rl.add_history_entry(line.trim_end())?; - tokio::select! { - res = exec_and_print(ctx, print_options, line) => match res { - Ok(_) => {} - Err(err) => eprintln!("{err}"), - }, - _ = signal::ctrl_c() => { - println!("^C"); - continue - }, + let lines = split_from_semicolon(line); + for line in lines { + rl.add_history_entry(line.trim_end())?; + tokio::select! { + res = exec_and_print(ctx, print_options, line) => match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + _ = signal::ctrl_c() => { + println!("^C"); + continue + }, + } + // dialect might have changed + rl.helper_mut().unwrap().set_dialect( + &ctx.task_ctx().session_config().options().sql_parser.dialect, + ); } - // dialect might have changed - rl.helper_mut().unwrap().set_dialect( - &ctx.task_ctx().session_config().options().sql_parser.dialect, - ); } Err(ReadlineError::Interrupted) => { println!("^C"); @@ -257,15 +262,23 @@ async fn create_plan( // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { - register_object_store_and_config_extensions(ctx, &cmd.location, &cmd.options) - .await?; + register_object_store_and_config_extensions( + ctx, + &cmd.location, + &cmd.options, + None, + ) + .await?; } if let LogicalPlan::Copy(copy_to) = &mut plan { + let format: FileType = (©_to.format_options).into(); + register_object_store_and_config_extensions( ctx, ©_to.output_url, ©_to.options, + Some(format), ) .await?; } @@ -303,6 +316,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx: &SessionContext, location: &String, options: &HashMap, + format: Option, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -318,6 +332,9 @@ pub(crate) async fn register_object_store_and_config_extensions( // Clone and modify the default table options based on the provided options let mut table_options = ctx.state().default_table_options().clone(); + if let Some(format) = format { + table_options.set_file_format(format); + } table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options @@ -347,6 +364,7 @@ mod tests { &ctx, &cmd.location, &cmd.options, + None, ) .await?; } else { @@ -367,10 +385,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Copy(cmd) = &plan { + let format: FileType = (&cmd.format_options).into(); register_object_store_and_config_extensions( &ctx, &cmd.output_url, &cmd.options, + Some(format), ) .await?; } else { @@ -399,6 +419,7 @@ mod tests { let locations = vec![ "s3://bucket/path/file.parquet", "oss://bucket/path/file.parquet", + "cos://bucket/path/file.parquet", "gcs://bucket/path/file.parquet", ]; let mut ctx = SessionContext::new(); @@ -412,7 +433,7 @@ mod tests { ) })?; for location in locations { - let sql = format!("copy (values (1,2)) to '{}';", location); + let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail @@ -438,8 +459,8 @@ mod tests { let location = "s3://bucket/path/file.parquet"; // Missing region, use object_store defaults - let sql = format!("COPY (values (1,2)) TO '{location}' - (format parquet, 'aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')"); + let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET + OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')"); copy_to_table_test(location, &sql).await?; Ok(()) @@ -481,6 +502,21 @@ mod tests { Ok(()) } + #[tokio::test] + async fn create_object_store_table_cos() -> Result<()> { + let access_key_id = "fake_access_key_id"; + let secret_access_key = "fake_secret_access_key"; + let endpoint = "fake_endpoint"; + let location = "cos://bucket/path/file.parquet"; + + // Should be OK + let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET + OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'"); + create_external_table_test(location, &sql).await?; + + Ok(()) + } + #[tokio::test] async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index a8e149b4c5c6..8b196484ee2c 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -86,16 +86,23 @@ impl CliHelper { )))) } }; - - match DFParser::parse_sql_with_dialect(&sql, dialect.as_ref()) { - Ok(statements) if statements.is_empty() => Ok(ValidationResult::Invalid( - Some(" 🤔 You entered an empty statement".to_string()), - )), - Ok(_statements) => Ok(ValidationResult::Valid(None)), - Err(err) => Ok(ValidationResult::Invalid(Some(format!( - " 🤔 Invalid statement: {err}", - )))), + let lines = split_from_semicolon(sql); + for line in lines { + match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) { + Ok(statements) if statements.is_empty() => { + return Ok(ValidationResult::Invalid(Some( + " 🤔 You entered an empty statement".to_string(), + ))); + } + Ok(_statements) => {} + Err(err) => { + return Ok(ValidationResult::Invalid(Some(format!( + " 🤔 Invalid statement: {err}", + )))); + } + } } + Ok(ValidationResult::Valid(None)) } else if input.starts_with('\\') { // command Ok(ValidationResult::Valid(None)) @@ -197,6 +204,37 @@ pub fn unescape_input(input: &str) -> datafusion::error::Result { Ok(result) } +/// Splits a string which consists of multiple queries. +pub(crate) fn split_from_semicolon(sql: String) -> Vec { + let mut commands = Vec::new(); + let mut current_command = String::new(); + let mut in_single_quote = false; + let mut in_double_quote = false; + + for c in sql.chars() { + if c == '\'' && !in_double_quote { + in_single_quote = !in_single_quote; + } else if c == '"' && !in_single_quote { + in_double_quote = !in_double_quote; + } + + if c == ';' && !in_single_quote && !in_double_quote { + if !current_command.trim().is_empty() { + commands.push(format!("{};", current_command.trim())); + current_command.clear(); + } + } else { + current_command.push(c); + } + } + + if !current_command.trim().is_empty() { + commands.push(format!("{};", current_command.trim())); + } + + commands +} + #[cfg(test)] mod tests { use std::io::{BufRead, Cursor}; @@ -292,4 +330,39 @@ mod tests { Ok(()) } + + #[test] + fn test_split_from_semicolon() { + let sql = "SELECT 1; SELECT 2;"; + let expected = vec!["SELECT 1;", "SELECT 2;"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = r#"SELECT ";";"#; + let expected = vec![r#"SELECT ";";"#]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = "SELECT ';';"; + let expected = vec!["SELECT ';';"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#; + let expected = vec![ + "SELECT 1;", + "SELECT 'value;value';", + r#"SELECT 1 as "text;text";"#, + ]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = ""; + let expected: Vec = Vec::new(); + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = "SELECT 1"; + let expected = vec!["SELECT 1;"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = "SELECT 1; "; + let expected = vec!["SELECT 1;"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 033c8f839ab2..94560cb9d8da 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::{Debug, Display}; use std::sync::Arc; -use datafusion::common::{config_namespace, exec_datafusion_err, exec_err, internal_err}; +use datafusion::common::{exec_datafusion_err, exec_err, internal_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; @@ -106,12 +106,27 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, aws_options: &AwsOptions, +) -> Result { + get_object_store_builder(url, aws_options, true) +} + +pub fn get_cos_object_store_builder( + url: &Url, + aws_options: &AwsOptions, +) -> Result { + get_object_store_builder(url, aws_options, false) +} + +fn get_object_store_builder( + url: &Url, + aws_options: &AwsOptions, + virtual_hosted_style_request: bool, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() - .with_virtual_hosted_style_request(true) + .with_virtual_hosted_style_request(virtual_hosted_style_request) .with_bucket_name(bucket_name) - // oss don't care about the "region" field + // oss/cos don't care about the "region" field .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = @@ -122,7 +137,7 @@ pub fn get_oss_object_store_builder( .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = &aws_options.oss.endpoint { + if let Some(endpoint) = &aws_options.endpoint { builder = builder.with_endpoint(endpoint); } @@ -171,14 +186,8 @@ pub struct AwsOptions { pub session_token: Option, /// AWS Region pub region: Option, - /// Object Storage Service options - pub oss: OssOptions, -} - -config_namespace! { - pub struct OssOptions { - pub endpoint: Option, default = None - } + /// OSS or COS Endpoint + pub endpoint: Option, } impl ExtensionOptions for AwsOptions { @@ -210,8 +219,8 @@ impl ExtensionOptions for AwsOptions { "region" => { self.region.set(rem, value)?; } - "oss" => { - self.oss.set(rem, value)?; + "oss" | "cos" => { + self.endpoint.set(rem, value)?; } _ => { return internal_err!("Config value \"{}\" not found on AwsOptions", rem); @@ -252,7 +261,7 @@ impl ExtensionOptions for AwsOptions { .visit(&mut v, "secret_access_key", ""); self.session_token.visit(&mut v, "session_token", ""); self.region.visit(&mut v, "region", ""); - self.oss.visit(&mut v, "oss", ""); + self.endpoint.visit(&mut v, "endpoint", ""); v.0 } } @@ -376,7 +385,7 @@ pub(crate) fn register_options(ctx: &SessionContext, scheme: &str) { // Match the provided scheme against supported cloud storage schemes: match scheme { // For Amazon S3 or Alibaba Cloud OSS - "s3" | "oss" => { + "s3" | "oss" | "cos" => { // Register AWS specific table options in the session context: ctx.register_table_options_extension(AwsOptions::default()) } @@ -415,6 +424,15 @@ pub(crate) async fn get_object_store( let builder = get_oss_object_store_builder(url, options)?; Arc::new(builder.build()?) } + "cos" => { + let Some(options) = table_options.extensions.get::() else { + return exec_err!( + "Given table options incompatible with the 'cos' scheme" + ); + }; + let builder = get_cos_object_store_builder(url, options)?; + Arc::new(builder.build()?) + } "gs" | "gcs" => { let Some(options) = table_options.extensions.get::() else { return exec_err!( diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 02cb0fb9c63e..93630c8d48f8 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -73,21 +73,22 @@ pub struct PrintOptions { pub color: bool, } -fn get_timing_info_str( +// Returns the query execution details formatted +fn get_execution_details_formatted( row_count: usize, maxrows: MaxRows, query_start_time: Instant, ) -> String { - let row_word = if row_count == 1 { "row" } else { "rows" }; let nrows_shown_msg = match maxrows { - MaxRows::Limited(nrows) if nrows < row_count => format!(" ({} shown)", nrows), + MaxRows::Limited(nrows) if nrows < row_count => { + format!("(First {nrows} displayed. Use --maxrows to adjust)") + } _ => String::new(), }; format!( - "{} {} in set{}. Query took {:.3} seconds.\n", + "{} row(s) fetched. {}\nElapsed {:.3} seconds.\n", row_count, - row_word, nrows_shown_msg, query_start_time.elapsed().as_secs_f64() ) @@ -107,7 +108,7 @@ impl PrintOptions { .print_batches(&mut writer, batches, self.maxrows, true)?; let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - let timing_info = get_timing_info_str( + let formatted_exec_details = get_execution_details_formatted( row_count, if self.format == PrintFormat::Table { self.maxrows @@ -118,7 +119,7 @@ impl PrintOptions { ); if !self.quiet { - writeln!(writer, "{timing_info}")?; + writeln!(writer, "{formatted_exec_details}")?; } Ok(()) @@ -154,11 +155,14 @@ impl PrintOptions { with_header = false; } - let timing_info = - get_timing_info_str(row_count, MaxRows::Unlimited, query_start_time); + let formatted_exec_details = get_execution_details_formatted( + row_count, + MaxRows::Unlimited, + query_start_time, + ); if !self.quiet { - writeln!(writer, "{timing_info}")?; + writeln!(writer, "{formatted_exec_details}")?; } Ok(()) diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index ad2a49fb352e..4966143782ba 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -74,7 +74,6 @@ serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -# 0.10 and 0.11 are incompatible. Need to upgrade tonic to 0.11 when upgrading to arrow 51 -tonic = "0.10" +tonic = "0.11" url = { workspace = true } -uuid = "1.2" +uuid = "1.7" diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index dbc8050555b9..7ca90463cf8c 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -42,36 +42,37 @@ cargo run --example csv_sql ## Single Process +- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file +- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file -- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame +- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory +- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and analyze `Expr`s - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es -- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files +- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass -- [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function -- [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) -- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) -- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) +- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` +- [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function +- [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions ## Distributed diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 0b7e3d4c6442..ba0d2f3b30f8 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -226,6 +226,10 @@ impl DisplayAs for CustomExec { } impl ExecutionPlan for CustomExec { + fn name(&self) -> &'static str { + "CustomExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs index e999fc4dac3e..985cab703a5c 100644 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ b/datafusion-examples/examples/deserialize_to_struct.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::AsArray; +use arrow::datatypes::{Float64Type, Int32Type}; use datafusion::error::Result; use datafusion::prelude::*; -use serde::Deserialize; +use futures::StreamExt; /// This example shows that it is possible to convert query results into Rust structs . -/// It will collect the query results into RecordBatch, then convert it to serde_json::Value. -/// Then, serde_json::Value is turned into Rust's struct. -/// Any datatype with `Deserialize` implemeneted works. #[tokio::main] async fn main() -> Result<()> { let data_list = Data::new().await?; @@ -30,10 +29,10 @@ async fn main() -> Result<()> { Ok(()) } -#[derive(Deserialize, Debug)] +#[derive(Debug)] struct Data { #[allow(dead_code)] - int_col: i64, + int_col: i32, #[allow(dead_code)] double_col: f64, } @@ -41,35 +40,36 @@ struct Data { impl Data { pub async fn new() -> Result> { // this group is almost the same as the one you find it in parquet_sql.rs - let batches = { - let ctx = SessionContext::new(); + let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; - let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") - .await?; + let df = ctx + .sql("SELECT int_col, double_col FROM alltypes_plain") + .await?; - df.clone().show().await?; + df.clone().show().await?; - df.collect().await? - }; - let batches: Vec<_> = batches.iter().collect(); + let mut stream = df.execute_stream().await?; + let mut list = vec![]; + while let Some(b) = stream.next().await.transpose()? { + let int_col = b.column(0).as_primitive::(); + let float_col = b.column(1).as_primitive::(); - // converts it to serde_json type and then convert that into Rust type - let list = arrow::json::writer::record_batches_to_json_rows(&batches[..])? - .into_iter() - .map(|val| serde_json::from_value(serde_json::Value::Object(val))) - .take_while(|val| val.is_ok()) - .map(|val| val.unwrap()) - .collect(); + for (i, f) in int_col.values().iter().zip(float_col.values()) { + list.push(Data { + int_col: *i, + double_col: *f, + }) + } + } Ok(list) } diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index cb7b7c28d909..f9d1b8029f04 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -18,7 +18,7 @@ use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; use std::sync::Arc; -use arrow_flight::SchemaAsIpc; +use arrow_flight::{PollInfo, SchemaAsIpc}; use datafusion::arrow::error::ArrowError; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; @@ -177,6 +177,13 @@ impl FlightService for FlightServiceImpl { ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } + + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } } fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 35d475623062..ed9457643b7d 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -307,6 +307,8 @@ impl FlightSqlService for FlightSqlServiceImpl { let endpoint = FlightEndpoint { ticket: Some(ticket), location: vec![], + expiration_time: None, + app_metadata: Default::default(), }; let endpoints = vec![endpoint]; @@ -329,6 +331,7 @@ impl FlightSqlService for FlightSqlServiceImpl { total_records: -1_i64, total_bytes: -1_i64, ordered: false, + app_metadata: Default::default(), }; let resp = Response::new(info); Ok(resp) diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 6c033e6c8eef..a7c8558c6da8 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -16,8 +16,9 @@ // under the License. use datafusion::error::Result; -use datafusion::execution::config::SessionConfig; -use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionContext}; +use datafusion::execution::context::{ + FunctionFactory, RegisterFunction, SessionContext, SessionState, +}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{exec_err, internal_err, DataFusionError}; use datafusion_expr::simplify::ExprSimplifyResult; @@ -91,7 +92,7 @@ impl FunctionFactory for CustomFunctionFactory { /// the function instance. async fn create( &self, - _state: &SessionConfig, + _state: &SessionState, statement: CreateFunction, ) -> Result { let f: ScalarFunctionWrapper = statement.try_into()?; diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 1d84fc2d1e0a..3fa35049a8da 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -163,6 +163,11 @@ impl PruningStatistics for MyCatalog { None } + fn row_counts(&self, _column: &Column) -> Option { + // In this example, we know nothing about the number of rows in each file + None + } + fn contained( &self, _column: &Column, diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs new file mode 100644 index 000000000000..259f38216b80 --- /dev/null +++ b/datafusion-examples/examples/sql_dialect.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Display; + +use datafusion::error::Result; +use datafusion_sql::{ + parser::{CopyToSource, CopyToStatement, DFParser, Statement}, + sqlparser::{keywords::Keyword, parser::ParserError, tokenizer::Token}, +}; + +/// This example demonstrates how to use the DFParser to parse a statement in a custom way +/// +/// This technique can be used to implement a custom SQL dialect, for example. +#[tokio::main] +async fn main() -> Result<()> { + let mut my_parser = + MyParser::new("COPY source_table TO 'file.fasta' STORED AS FASTA")?; + + let my_statement = my_parser.parse_statement()?; + + match my_statement { + MyStatement::DFStatement(s) => println!("df: {}", s), + MyStatement::MyCopyTo(s) => println!("my_copy: {}", s), + } + + Ok(()) +} + +/// Here we define a Parser for our new SQL dialect that wraps the existing `DFParser` +struct MyParser<'a> { + df_parser: DFParser<'a>, +} + +impl MyParser<'_> { + fn new(sql: &str) -> Result { + let df_parser = DFParser::new(sql)?; + Ok(Self { df_parser }) + } + + /// Returns true if the next token is `COPY` keyword, false otherwise + fn is_copy(&self) -> bool { + matches!( + self.df_parser.parser.peek_token().token, + Token::Word(w) if w.keyword == Keyword::COPY + ) + } + + /// This is the entry point to our parser -- it handles `COPY` statements specially + /// but otherwise delegates to the existing DataFusion parser. + pub fn parse_statement(&mut self) -> Result { + if self.is_copy() { + self.df_parser.parser.next_token(); // COPY + let df_statement = self.df_parser.parse_copy()?; + + if let Statement::CopyTo(s) = df_statement { + Ok(MyStatement::from(s)) + } else { + Ok(MyStatement::DFStatement(Box::from(df_statement))) + } + } else { + let df_statement = self.df_parser.parse_statement()?; + Ok(MyStatement::from(df_statement)) + } + } +} + +enum MyStatement { + DFStatement(Box), + MyCopyTo(MyCopyToStatement), +} + +impl Display for MyStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MyStatement::DFStatement(s) => write!(f, "{}", s), + MyStatement::MyCopyTo(s) => write!(f, "{}", s), + } + } +} + +impl From for MyStatement { + fn from(s: Statement) -> Self { + Self::DFStatement(Box::from(s)) + } +} + +impl From for MyStatement { + fn from(s: CopyToStatement) -> Self { + if s.stored_as == Some("FASTA".to_string()) { + Self::MyCopyTo(MyCopyToStatement::from(s)) + } else { + Self::DFStatement(Box::from(Statement::CopyTo(s))) + } + } +} + +struct MyCopyToStatement { + pub source: CopyToSource, + pub target: String, +} + +impl From for MyCopyToStatement { + fn from(s: CopyToStatement) -> Self { + Self { + source: s.source, + target: s.target, + } + } +} + +impl Display for MyCopyToStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "COPY {} TO '{}' STORED AS FASTA", + self.source, self.target + ) + } +} diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index e99f69fbcd55..f8ed68b46f19 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -125,14 +125,14 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------+", - "| t.values |", - "+------------+", - "| 2020-09-01 |", - "| 2020-09-02 |", - "| 2020-09-03 |", - "| 2020-09-04 |", - "+------------+", + "+-----------------------------------+", + "| arrow_cast(t.values,Utf8(\"Utf8\")) |", + "+-----------------------------------+", + "| 2020-09-01 |", + "| 2020-09-02 |", + "| 2020-09-03 |", + "| 2020-09-04 |", + "+-----------------------------------+", ], &result ); @@ -146,11 +146,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+-----------------------------------------------------------------+", - "| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", - "+-----------------------------------------------------------------+", - "| 03-08-2023 14:38:50 |", - "+-----------------------------------------------------------------+", + "+-------------------------------------------------------------------------------------------------------------+", + "| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", + "+-------------------------------------------------------------------------------------------------------------+", + "| 03-08-2023 14:38:50 |", + "+-------------------------------------------------------------------------------------------------------------+", ], &result ); @@ -165,11 +165,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------+", - "| to_char(Int64(123456),Utf8(\"pretty\")) |", - "+---------------------------------------+", - "| 1 days 10 hours 17 mins 36 secs |", - "+---------------------------------------+", + "+----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |", + "+----------------------------------------------------------------------------+", + "| 1 days 10 hours 17 mins 36 secs |", + "+----------------------------------------------------------------------------+", ], &result ); @@ -184,11 +184,30 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------+", - "| to_char(Int64(123456),Utf8(\"iso8601\")) |", - "+----------------------------------------+", - "| PT123456S |", - "+----------------------------------------+", + "+-----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |", + "+-----------------------------------------------------------------------------+", + "| PT123456S |", + "+-----------------------------------------------------------------------------+", + ], + &result + ); + + // output format is null + + let result = ctx + .sql("SELECT to_char(arrow_cast(123456, 'Duration(Second)'), null) as result") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+--------+", + "| result |", + "+--------+", + "| |", + "+--------+", ], &result ); diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index 2d09782a3982..c111375e3058 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,7 @@ # Changelog +- [37.0.0](../dev/changelog/37.0.0.md) - [36.0.0](../dev/changelog/36.0.0.md) - [35.0.0](../dev/changelog/35.0.0.md) - [34.0.0](../dev/changelog/34.0.0.md) diff --git a/datafusion/common_runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml similarity index 100% rename from datafusion/common_runtime/Cargo.toml rename to datafusion/common-runtime/Cargo.toml diff --git a/datafusion/common_runtime/README.md b/datafusion/common-runtime/README.md similarity index 100% rename from datafusion/common_runtime/README.md rename to datafusion/common-runtime/README.md diff --git a/datafusion/common_runtime/src/common.rs b/datafusion/common-runtime/src/common.rs similarity index 100% rename from datafusion/common_runtime/src/common.rs rename to datafusion/common-runtime/src/common.rs diff --git a/datafusion/common_runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs similarity index 100% rename from datafusion/common_runtime/src/lib.rs rename to datafusion/common-runtime/src/lib.rs diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 088f03e002ed..0dc0532bbb6f 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -24,17 +24,18 @@ use crate::{downcast_value, DataFusionError, Result}; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, - DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, - Float64Array, GenericBinaryArray, GenericListArray, GenericStringArray, - Int32Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeListArray, ListArray, MapArray, NullArray, - OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt32Array, UInt64Array, UInt8Array, UnionArray, + Decimal256Array, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, + Float32Array, Float64Array, GenericBinaryArray, GenericListArray, + GenericStringArray, Int32Array, Int64Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray, ListArray, + MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt32Array, UInt64Array, + UInt8Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; -use arrow_array::Decimal256Array; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -154,6 +155,26 @@ pub fn as_union_array(array: &dyn Array) -> Result<&UnionArray> { Ok(downcast_value!(array, UnionArray)) } +// Downcast ArrayRef to Time32SecondArray +pub fn as_time32_second_array(array: &dyn Array) -> Result<&Time32SecondArray> { + Ok(downcast_value!(array, Time32SecondArray)) +} + +// Downcast ArrayRef to Time32MillisecondArray +pub fn as_time32_millisecond_array(array: &dyn Array) -> Result<&Time32MillisecondArray> { + Ok(downcast_value!(array, Time32MillisecondArray)) +} + +// Downcast ArrayRef to Time64MicrosecondArray +pub fn as_time64_microsecond_array(array: &dyn Array) -> Result<&Time64MicrosecondArray> { + Ok(downcast_value!(array, Time64MicrosecondArray)) +} + +// Downcast ArrayRef to Time64NanosecondArray +pub fn as_time64_nanosecond_array(array: &dyn Array) -> Result<&Time64NanosecondArray> { + Ok(downcast_value!(array, Time64NanosecondArray)) +} + // Downcast ArrayRef to TimestampNanosecondArray pub fn as_timestamp_nanosecond_array( array: &dyn Array, diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 68b9ec9dab94..968d8215ca4d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1109,58 +1109,163 @@ macro_rules! extensions_options { } } +/// Represents the configuration options available for handling different table formats within a data processing application. +/// This struct encompasses options for various file formats including CSV, Parquet, and JSON, allowing for flexible configuration +/// of parsing and writing behaviors specific to each format. Additionally, it supports extending functionality through custom extensions. #[derive(Debug, Clone, Default)] pub struct TableOptions { + /// Configuration options for CSV file handling. This includes settings like the delimiter, + /// quote character, and whether the first row is considered as headers. pub csv: CsvOptions, + + /// Configuration options for Parquet file handling. This includes settings for compression, + /// encoding, and other Parquet-specific file characteristics. pub parquet: TableParquetOptions, + + /// Configuration options for JSON file handling. pub json: JsonOptions, + + /// The current file format that the table operations should assume. This option allows + /// for dynamic switching between the supported file types (e.g., CSV, Parquet, JSON). pub current_format: Option, - /// Optional extensions registered using [`Extensions::insert`] + + /// Optional extensions that can be used to extend or customize the behavior of the table + /// options. Extensions can be registered using `Extensions::insert` and might include + /// custom file handling logic, additional configuration parameters, or other enhancements. pub extensions: Extensions, } impl ConfigField for TableOptions { + /// Visits configuration settings for the current file format, or all formats if none is selected. + /// + /// This method adapts the behavior based on whether a file format is currently selected in `current_format`. + /// If a format is selected, it visits only the settings relevant to that format. Otherwise, + /// it visits all available format settings. fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { - self.csv.visit(v, "csv", ""); - self.parquet.visit(v, "parquet", ""); - self.json.visit(v, "json", ""); + if let Some(file_type) = &self.current_format { + match file_type { + #[cfg(feature = "parquet")] + FileType::PARQUET => self.parquet.visit(v, "format", ""), + FileType::CSV => self.csv.visit(v, "format", ""), + FileType::JSON => self.json.visit(v, "format", ""), + _ => {} + } + } else { + self.csv.visit(v, "csv", ""); + self.parquet.visit(v, "parquet", ""); + self.json.visit(v, "json", ""); + } } + /// Sets a configuration value for a specific key within `TableOptions`. + /// + /// This method delegates setting configuration values to the specific file format configurations, + /// based on the current format selected. If no format is selected, it returns an error. + /// + /// # Parameters + /// + /// * `key`: The configuration key specifying which setting to adjust, prefixed with the format (e.g., "format.delimiter") + /// for CSV format. + /// * `value`: The value to set for the specified configuration key. + /// + /// # Returns + /// + /// A result indicating success or an error if the key is not recognized, if a format is not specified, + /// or if setting the configuration value fails for the specific format. fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; match key { - "csv" => self.csv.set(rem, value), - "parquet" => self.parquet.set(rem, value), - "json" => self.json.set(rem, value), + "format" => match format { + #[cfg(feature = "parquet")] + FileType::PARQUET => self.parquet.set(rem, value), + FileType::CSV => self.csv.set(rem, value), + FileType::JSON => self.json.set(rem, value), + _ => { + _config_err!("Config value \"{key}\" is not supported on {}", format) + } + }, _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } } impl TableOptions { - /// Creates a new [`ConfigOptions`] with default values + /// Constructs a new instance of `TableOptions` with default settings. + /// + /// # Returns + /// + /// A new `TableOptions` instance with default configuration values. pub fn new() -> Self { Self::default() } + /// Sets the file format for the table. + /// + /// # Parameters + /// + /// * `format`: The file format to use (e.g., CSV, Parquet). pub fn set_file_format(&mut self, format: FileType) { self.current_format = Some(format); } + /// Creates a new `TableOptions` instance initialized with settings from a given session config. + /// + /// # Parameters + /// + /// * `config`: A reference to the session `ConfigOptions` from which to derive initial settings. + /// + /// # Returns + /// + /// A new `TableOptions` instance with settings applied from the session config. pub fn default_from_session_config(config: &ConfigOptions) -> Self { - let mut initial = TableOptions::default(); - initial.parquet.global = config.execution.parquet.clone(); + let initial = TableOptions::default(); + initial.combine_with_session_config(config); initial } - /// Set extensions to provided value + /// Updates the current `TableOptions` with settings from a given session config. + /// + /// # Parameters + /// + /// * `config`: A reference to the session `ConfigOptions` whose settings are to be applied. + /// + /// # Returns + /// + /// A new `TableOptions` instance with updated settings from the session config. + pub fn combine_with_session_config(&self, config: &ConfigOptions) -> Self { + let mut clone = self.clone(); + clone.parquet.global = config.execution.parquet.clone(); + clone + } + + /// Sets the extensions for this `TableOptions` instance. + /// + /// # Parameters + /// + /// * `extensions`: The `Extensions` instance to set. + /// + /// # Returns + /// + /// A new `TableOptions` instance with the specified extensions applied. pub fn with_extensions(mut self, extensions: Extensions) -> Self { self.extensions = extensions; self } - /// Set a configuration option + /// Sets a specific configuration option. + /// + /// # Parameters + /// + /// * `key`: The configuration key (e.g., "format.delimiter"). + /// * `value`: The value to set for the specified key. + /// + /// # Returns + /// + /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { let (prefix, _) = key.split_once('.').ok_or_else(|| { DataFusionError::Configuration(format!( @@ -1168,28 +1273,7 @@ impl TableOptions { )) })?; - if prefix == "csv" || prefix == "json" || prefix == "parquet" { - if let Some(format) = &self.current_format { - match format { - FileType::CSV if prefix != "csv" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for CSV format" - ))) - } - #[cfg(feature = "parquet")] - FileType::PARQUET if prefix != "parquet" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for PARQUET format" - ))) - } - FileType::JSON if prefix != "json" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for JSON format" - ))) - } - _ => {} - } - } + if prefix == "format" { return ConfigField::set(self, key, value); } @@ -1202,6 +1286,15 @@ impl TableOptions { e.0.set(key, value) } + /// Initializes a new `TableOptions` from a hash map of string settings. + /// + /// # Parameters + /// + /// * `settings`: A hash map where each key-value pair represents a configuration setting. + /// + /// # Returns + /// + /// A result containing the new `TableOptions` instance or an error if any setting could not be applied. pub fn from_string_hash_map(settings: &HashMap) -> Result { let mut ret = Self::default(); for (k, v) in settings { @@ -1211,6 +1304,15 @@ impl TableOptions { Ok(ret) } + /// Modifies the current `TableOptions` instance with settings from a hash map. + /// + /// # Parameters + /// + /// * `settings`: A hash map where each key-value pair represents a configuration setting. + /// + /// # Returns + /// + /// A result indicating success or failure in applying the settings. pub fn alter_with_string_hash_map( &mut self, settings: &HashMap, @@ -1221,7 +1323,11 @@ impl TableOptions { Ok(()) } - /// Returns the [`ConfigEntry`] stored within this [`ConfigOptions`] + /// Retrieves all configuration entries from this `TableOptions`. + /// + /// # Returns + /// + /// A vector of `ConfigEntry` instances, representing all the configuration options within this `TableOptions`. pub fn entries(&self) -> Vec { struct Visitor(Vec); @@ -1249,9 +1355,7 @@ impl TableOptions { } let mut v = Visitor(vec![]); - self.visit(&mut v, "csv", ""); - self.visit(&mut v, "json", ""); - self.visit(&mut v, "parquet", ""); + self.visit(&mut v, "format", ""); v.0.extend(self.extensions.0.values().flat_map(|e| e.0.entries())); v.0 @@ -1556,6 +1660,7 @@ mod tests { use crate::config::{ ConfigEntry, ConfigExtension, ExtensionOptions, Extensions, TableOptions, }; + use crate::FileType; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -1609,12 +1714,13 @@ mod tests { } #[test] - fn alter_kafka_config() { + fn alter_test_extension_config() { let mut extension = Extensions::new(); extension.insert(TestExtensionConfig::default()); let mut table_config = TableOptions::new().with_extensions(extension); - table_config.set("parquet.write_batch_size", "10").unwrap(); - assert_eq!(table_config.parquet.global.write_batch_size, 10); + table_config.set_file_format(FileType::CSV); + table_config.set("format.delimiter", ";").unwrap(); + assert_eq!(table_config.csv.delimiter, b';'); table_config.set("test.bootstrap.servers", "asd").unwrap(); let kafka_config = table_config .extensions @@ -1626,11 +1732,25 @@ mod tests { ); } + #[test] + fn csv_u8_table_options() { + let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::CSV); + table_config.set("format.delimiter", ";").unwrap(); + assert_eq!(table_config.csv.delimiter as char, ';'); + table_config.set("format.escape", "\"").unwrap(); + assert_eq!(table_config.csv.escape.unwrap() as char, '"'); + table_config.set("format.escape", "\'").unwrap(); + assert_eq!(table_config.csv.escape.unwrap() as char, '\''); + } + + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config - .set("parquet.bloom_filter_enabled::col1", "true") + .set("format.bloom_filter_enabled::col1", "true") .unwrap(); assert_eq!( table_config.parquet.column_specific_options["col1"].bloom_filter_enabled, @@ -1638,26 +1758,17 @@ mod tests { ); } - #[test] - fn csv_u8_table_options() { - let mut table_config = TableOptions::new(); - table_config.set("csv.delimiter", ";").unwrap(); - assert_eq!(table_config.csv.delimiter as char, ';'); - table_config.set("csv.escape", "\"").unwrap(); - assert_eq!(table_config.csv.escape.unwrap() as char, '"'); - table_config.set("csv.escape", "\'").unwrap(); - assert_eq!(table_config.csv.escape.unwrap() as char, '\''); - } - + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config - .set("parquet.bloom_filter_enabled::col1", "true") + .set("format.bloom_filter_enabled::col1", "true") .unwrap(); let entries = table_config.entries(); assert!(entries .iter() - .any(|item| item.key == "parquet.bloom_filter_enabled::col1")) + .any(|item| item.key == "format.bloom_filter_enabled::col1")) } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 2642032c9a04..90fb0b035d35 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -97,10 +97,11 @@ pub type DFSchemaRef = Arc; /// ```rust /// use datafusion_common::{DFSchema, DFField}; /// use arrow_schema::Schema; +/// use std::collections::HashMap; /// -/// let df_schema = DFSchema::new(vec![ +/// let df_schema = DFSchema::new_with_metadata(vec![ /// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), -/// ]).unwrap(); +/// ], HashMap::new()).unwrap(); /// let schema = Schema::from(df_schema); /// assert_eq!(schema.fields().len(), 1); /// ``` @@ -124,12 +125,6 @@ impl DFSchema { } } - #[deprecated(since = "7.0.0", note = "please use `new_with_metadata` instead")] - /// Create a new `DFSchema` - pub fn new(fields: Vec) -> Result { - Self::new_with_metadata(fields, HashMap::new()) - } - /// Create a new `DFSchema` pub fn new_with_metadata( fields: Vec, @@ -251,32 +246,6 @@ impl DFSchema { &self.fields[i] } - #[deprecated(since = "8.0.0", note = "please use `index_of_column_by_name` instead")] - /// Find the index of the column with the given unqualified name - pub fn index_of(&self, name: &str) -> Result { - for i in 0..self.fields.len() { - if self.fields[i].name() == name { - return Ok(i); - } else { - // Now that `index_of` is deprecated an error is thrown if - // a fully qualified field name is provided. - match &self.fields[i].qualifier { - Some(qualifier) => { - if (qualifier.to_string() + "." + self.fields[i].name()) == name { - return _plan_err!( - "Fully qualified field name '{name}' was supplied to `index_of` \ - which is deprecated. Please use `index_of_column_by_name` instead" - ); - } - } - None => (), - } - } - } - - Err(unqualified_field_not_found(name, self)) - } - pub fn index_of_column_by_name( &self, qualifier: Option<&TableReference>, @@ -771,6 +740,9 @@ pub trait ExprSchema: std::fmt::Debug { /// Returns the column's optional metadata. fn metadata(&self, col: &Column) -> Result<&HashMap>; + + /// Return the coulmn's datatype and nullability + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; } // Implement `ExprSchema` for `Arc` @@ -786,6 +758,10 @@ impl + std::fmt::Debug> ExprSchema for P { fn metadata(&self, col: &Column) -> Result<&HashMap> { ExprSchema::metadata(self.as_ref(), col) } + + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + self.as_ref().data_type_and_nullable(col) + } } impl ExprSchema for DFSchema { @@ -800,6 +776,11 @@ impl ExprSchema for DFSchema { fn metadata(&self, col: &Column) -> Result<&HashMap> { Ok(self.field_from_column(col)?.metadata()) } + + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.field_from_column(col)?; + Ok((field.data_type(), field.is_nullable())) + } } /// DFField wraps an Arrow field and adds an optional qualifier @@ -1146,13 +1127,10 @@ mod tests { Ok(()) } - #[allow(deprecated)] #[test] fn helpful_error_messages() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let expected_help = "Valid fields are t1.c0, t1.c1."; - // Pertinent message parts - let expected_err_msg = "Fully qualified field name 't1.c0'"; assert_contains!( schema .field_with_qualified_name(&TableReference::bare("x"), "y") @@ -1167,11 +1145,12 @@ mod tests { .to_string(), expected_help ); - assert_contains!(schema.index_of("y").unwrap_err().to_string(), expected_help); - assert_contains!( - schema.index_of("t1.c0").unwrap_err().to_string(), - expected_err_msg - ); + assert!(schema.index_of_column_by_name(None, "y").unwrap().is_none()); + assert!(schema + .index_of_column_by_name(None, "t1.c0") + .unwrap() + .is_none()); + Ok(()) } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 1ecd5b62bee8..cafab6d334b3 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -63,7 +63,7 @@ pub enum DataFusionError { IoError(io::Error), /// Error when SQL is syntactically incorrect. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace SQL(ParserError, Option), /// Error when a feature is not yet implemented. /// @@ -101,7 +101,7 @@ pub enum DataFusionError { /// This error can be returned in cases such as when schema inference is not /// possible and when column names are not unique. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace /// Boxing the optional backtrace to prevent SchemaError(SchemaError, Box>), /// Error during execution of the query. @@ -601,6 +601,7 @@ pub use config_err as _config_err; pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; +pub use plan_datafusion_err as _plan_datafusion_err; pub use plan_err as _plan_err; pub use schema_err as _schema_err; diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index 812cb02a5f77..fc0bb7445645 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -20,6 +20,7 @@ use std::fmt::{self, Display}; use std::str::FromStr; +use crate::config::FormatOptions; use crate::error::{DataFusionError, Result}; /// The default file extension of arrow files @@ -55,6 +56,19 @@ pub enum FileType { JSON, } +impl From<&FormatOptions> for FileType { + fn from(value: &FormatOptions) -> Self { + match value { + FormatOptions::CSV(_) => FileType::CSV, + FormatOptions::JSON(_) => FileType::JSON, + #[cfg(feature = "parquet")] + FormatOptions::PARQUET(_) => FileType::PARQUET, + FormatOptions::AVRO => FileType::AVRO, + FormatOptions::ARROW => FileType::ARROW, + } + } +} + impl GetExt for FileType { fn get_ext(&self) -> String { match self { diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index a72b812adc8d..eb1ce1b364fd 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -35,7 +35,7 @@ mod tests { config::TableOptions, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - Result, + FileType, Result, }; use parquet::{ @@ -47,35 +47,36 @@ mod tests { #[test] fn test_writeroptions_parquet_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("parquet.max_row_group_size".to_owned(), "123".to_owned()); - option_map.insert("parquet.data_pagesize_limit".to_owned(), "123".to_owned()); - option_map.insert("parquet.write_batch_size".to_owned(), "123".to_owned()); - option_map.insert("parquet.writer_version".to_owned(), "2.0".to_owned()); + option_map.insert("format.max_row_group_size".to_owned(), "123".to_owned()); + option_map.insert("format.data_pagesize_limit".to_owned(), "123".to_owned()); + option_map.insert("format.write_batch_size".to_owned(), "123".to_owned()); + option_map.insert("format.writer_version".to_owned(), "2.0".to_owned()); option_map.insert( - "parquet.dictionary_page_size_limit".to_owned(), + "format.dictionary_page_size_limit".to_owned(), "123".to_owned(), ); option_map.insert( - "parquet.created_by".to_owned(), + "format.created_by".to_owned(), "df write unit test".to_owned(), ); option_map.insert( - "parquet.column_index_truncate_length".to_owned(), + "format.column_index_truncate_length".to_owned(), "123".to_owned(), ); option_map.insert( - "parquet.data_page_row_count_limit".to_owned(), + "format.data_page_row_count_limit".to_owned(), "123".to_owned(), ); - option_map.insert("parquet.bloom_filter_enabled".to_owned(), "true".to_owned()); - option_map.insert("parquet.encoding".to_owned(), "plain".to_owned()); - option_map.insert("parquet.dictionary_enabled".to_owned(), "true".to_owned()); - option_map.insert("parquet.compression".to_owned(), "zstd(4)".to_owned()); - option_map.insert("parquet.statistics_enabled".to_owned(), "page".to_owned()); - option_map.insert("parquet.bloom_filter_fpp".to_owned(), "0.123".to_owned()); - option_map.insert("parquet.bloom_filter_ndv".to_owned(), "123".to_owned()); + option_map.insert("format.bloom_filter_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.encoding".to_owned(), "plain".to_owned()); + option_map.insert("format.dictionary_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.compression".to_owned(), "zstd(4)".to_owned()); + option_map.insert("format.statistics_enabled".to_owned(), "page".to_owned()); + option_map.insert("format.bloom_filter_fpp".to_owned(), "0.123".to_owned()); + option_map.insert("format.bloom_filter_ndv".to_owned(), "123".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -131,54 +132,52 @@ mod tests { let mut option_map: HashMap = HashMap::new(); option_map.insert( - "parquet.bloom_filter_enabled::col1".to_owned(), + "format.bloom_filter_enabled::col1".to_owned(), "true".to_owned(), ); option_map.insert( - "parquet.bloom_filter_enabled::col2.nested".to_owned(), + "format.bloom_filter_enabled::col2.nested".to_owned(), "true".to_owned(), ); - option_map.insert("parquet.encoding::col1".to_owned(), "plain".to_owned()); - option_map.insert("parquet.encoding::col2.nested".to_owned(), "rle".to_owned()); + option_map.insert("format.encoding::col1".to_owned(), "plain".to_owned()); + option_map.insert("format.encoding::col2.nested".to_owned(), "rle".to_owned()); option_map.insert( - "parquet.dictionary_enabled::col1".to_owned(), + "format.dictionary_enabled::col1".to_owned(), "true".to_owned(), ); option_map.insert( - "parquet.dictionary_enabled::col2.nested".to_owned(), + "format.dictionary_enabled::col2.nested".to_owned(), "true".to_owned(), ); - option_map.insert("parquet.compression::col1".to_owned(), "zstd(4)".to_owned()); + option_map.insert("format.compression::col1".to_owned(), "zstd(4)".to_owned()); option_map.insert( - "parquet.compression::col2.nested".to_owned(), + "format.compression::col2.nested".to_owned(), "zstd(10)".to_owned(), ); option_map.insert( - "parquet.statistics_enabled::col1".to_owned(), + "format.statistics_enabled::col1".to_owned(), "page".to_owned(), ); option_map.insert( - "parquet.statistics_enabled::col2.nested".to_owned(), + "format.statistics_enabled::col2.nested".to_owned(), "none".to_owned(), ); option_map.insert( - "parquet.bloom_filter_fpp::col1".to_owned(), + "format.bloom_filter_fpp::col1".to_owned(), "0.123".to_owned(), ); option_map.insert( - "parquet.bloom_filter_fpp::col2.nested".to_owned(), + "format.bloom_filter_fpp::col2.nested".to_owned(), "0.456".to_owned(), ); + option_map.insert("format.bloom_filter_ndv::col1".to_owned(), "123".to_owned()); option_map.insert( - "parquet.bloom_filter_ndv::col1".to_owned(), - "123".to_owned(), - ); - option_map.insert( - "parquet.bloom_filter_ndv::col2.nested".to_owned(), + "format.bloom_filter_ndv::col2.nested".to_owned(), "456".to_owned(), ); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -271,16 +270,17 @@ mod tests { // for StatementOptions fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("csv.has_header".to_owned(), "true".to_owned()); - option_map.insert("csv.date_format".to_owned(), "123".to_owned()); - option_map.insert("csv.datetime_format".to_owned(), "123".to_owned()); - option_map.insert("csv.timestamp_format".to_owned(), "2.0".to_owned()); - option_map.insert("csv.time_format".to_owned(), "123".to_owned()); - option_map.insert("csv.null_value".to_owned(), "123".to_owned()); - option_map.insert("csv.compression".to_owned(), "gzip".to_owned()); - option_map.insert("csv.delimiter".to_owned(), ";".to_owned()); + option_map.insert("format.has_header".to_owned(), "true".to_owned()); + option_map.insert("format.date_format".to_owned(), "123".to_owned()); + option_map.insert("format.datetime_format".to_owned(), "123".to_owned()); + option_map.insert("format.timestamp_format".to_owned(), "2.0".to_owned()); + option_map.insert("format.time_format".to_owned(), "123".to_owned()); + option_map.insert("format.null_value".to_owned(), "123".to_owned()); + option_map.insert("format.compression".to_owned(), "gzip".to_owned()); + option_map.insert("format.delimiter".to_owned(), ";".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::CSV); table_config.alter_with_string_hash_map(&option_map)?; let csv_options = CsvWriterOptions::try_from(&table_config.csv)?; @@ -299,9 +299,10 @@ mod tests { // for StatementOptions fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("json.compression".to_owned(), "gzip".to_owned()); + option_map.insert("format.compression".to_owned(), "gzip".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::JSON); table_config.alter_with_string_hash_map(&option_map)?; let json_options = JsonWriterOptions::try_from(&table_config.json)?; diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index e8a350e8d389..28e73ba48f53 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -156,6 +156,7 @@ pub(crate) fn parse_encoding_string( "plain" => Ok(parquet::basic::Encoding::PLAIN), "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), "rle" => Ok(parquet::basic::Encoding::RLE), + #[allow(deprecated)] "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), "delta_length_byte_array" => { diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index c614098713d6..8d61bad97b9f 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::error::_plan_err; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::error::{_plan_datafusion_err, _plan_err}; +use crate::{Result, ScalarValue}; use arrow_schema::DataType; use std::collections::HashMap; @@ -75,16 +75,12 @@ impl ParamValues { let idx = id[1..] .parse::() .map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) + _plan_datafusion_err!("Failed to parse placeholder id: {e}") })? .checked_sub(1); // value at the idx-th position in param_values should be the value for the placeholder let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) + _plan_datafusion_err!("No value found for placeholder with id {id}") })?; Ok(value.clone()) } @@ -93,9 +89,7 @@ impl ParamValues { let name = &id[1..]; // value at the name position in param_values should be the value for the placeholder let value = map.get(name).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with name {id}" - )) + _plan_datafusion_err!("No value found for placeholder with name {id}") })?; Ok(value.clone()) } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5ace44f24b69..88d40a35585d 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -53,6 +53,8 @@ use arrow::{ }, }; use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_buffer::Buffer; +use arrow_schema::{UnionFields, UnionMode}; pub use struct_builder::ScalarStructBuilder; @@ -275,6 +277,11 @@ pub enum ScalarValue { DurationMicrosecond(Option), /// Duration in nanoseconds DurationNanosecond(Option), + /// A nested datatype that can represent slots of differing types. Components: + /// `.0`: a tuple of union `type_id` and the single value held by this Scalar + /// `.1`: the list of fields, zero-to-one of which will by set in `.0` + /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came + Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), } @@ -375,6 +382,10 @@ impl PartialEq for ScalarValue { (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), (IntervalMonthDayNano(_), _) => false, + (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => { + val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) + } + (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, (Null, Null) => true, @@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue { (DurationMicrosecond(_), _) => None, (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), (DurationNanosecond(_), _) => None, + (Union(v1, t1, m1), Union(v2, t2, m2)) => { + if t1.eq(t2) && m1.eq(m2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) if k1 == k2 { @@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue { IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), + Union(v, t, m) => { + v.hash(state); + t.hash(state); + m.hash(state); + } Dictionary(k, v) => { k.hash(state); v.hash(state); @@ -1093,6 +1117,7 @@ impl ScalarValue { ScalarValue::DurationNanosecond(_) => { DataType::Duration(TimeUnit::Nanosecond) } + ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -1292,6 +1317,7 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), + ScalarValue::Union(v, _, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -1650,7 +1676,11 @@ impl ScalarValue { | DataType::Duration(_) | DataType::Union(_, _) | DataType::Map(_, _) - | DataType::RunEndEncoded(_, _) => { + | DataType::RunEndEncoded(_, _) + | DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { return _internal_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, @@ -1746,7 +1776,7 @@ impl ScalarValue { } /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a [`ListArray`]. + /// `data_type`, to a single element [`ListArray`]. /// /// Example /// ``` @@ -2083,6 +2113,39 @@ impl ScalarValue { e, size ), + ScalarValue::Union(value, fields, _mode) => match value { + Some((v_id, value)) => { + let mut field_type_ids = Vec::::with_capacity(fields.len()); + let mut child_arrays = + Vec::<(Field, ArrayRef)>::with_capacity(fields.len()); + for (f_id, field) in fields.iter() { + let ar = if f_id == *v_id { + value.to_array_of_size(size)? + } else { + let dt = field.data_type(); + new_null_array(dt, size) + }; + let field = (**field).clone(); + child_arrays.push((field, ar)); + field_type_ids.push(f_id); + } + let type_ids = repeat(*v_id).take(size).collect::>(); + let type_ids = Buffer::from_slice_ref(type_ids); + let value_offsets: Option = None; + let ar = UnionArray::try_new( + field_type_ids.as_slice(), + type_ids, + value_offsets, + child_arrays, + ) + .map_err(|e| DataFusionError::ArrowError(e, None))?; + Arc::new(ar) + } + None => { + let dt = self.data_type(); + new_null_array(&dt, size) + } + }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { @@ -2622,6 +2685,9 @@ impl ScalarValue { ScalarValue::DurationNanosecond(val) => { eq_array_primitive!(array, index, DurationNanosecondArray, val)? } + ScalarValue::Union(_, _, _) => { + return _not_impl_err!("Union is not supported yet") + } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { DataType::Int8 => get_dict_value::(array, index)?, @@ -2699,6 +2765,15 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Union(vals, fields, _mode) => { + vals.as_ref() + .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .unwrap_or_default() + // `fields` is boxed, so it is NOT already included in `self` + + std::mem::size_of_val(fields) + + (std::mem::size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() @@ -3044,6 +3119,9 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( @@ -3160,6 +3238,10 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "{}:{}", id, val)?, + None => write!(f, "NULL")?, + }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; @@ -3275,6 +3357,10 @@ impl fmt::Debug for ScalarValue { ScalarValue::DurationNanosecond(_) => { write!(f, "DurationNanosecond(\"{self}\")") } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "Union {}:{}", id, val), + None => write!(f, "Union(NULL)"), + }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } @@ -4442,18 +4528,19 @@ mod tests { assert_eq!(expected, data_type.try_into().unwrap()) } - // this test fails on aarch, so don't run it there - #[cfg(not(target_arch = "aarch64"))] #[test] fn size_of_scalar() { // Since ScalarValues are used in a non trivial number of places, // making it larger means significant more memory consumption // per distinct value. // + // Thus this test ensures that no code change makes ScalarValue larger + // // The alignment requirements differ across architectures and // thus the size of the enum appears to as well - assert_eq!(std::mem::size_of::(), 48); + // The value may also change depending on rust version + assert_eq!(std::mem::size_of::(), 64); } #[test] @@ -5769,7 +5856,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); #[rustfmt::skip] - let expected = [ + let expected = [ "+---+", "| s |", "+---+", @@ -5803,7 +5890,7 @@ mod tests { &DataType::List(Arc::new(Field::new( "item", DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), - true + true, ))) ); } diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index a10e05a55c64..6cefef8d0eb5 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -221,7 +221,7 @@ pub struct Statistics { /// Total bytes of the table rows. pub total_byte_size: Precision, /// Statistics on a column level. It contains a [`ColumnStatistics`] for - /// each field in the schema of the the table to which the [`Statistics`] refer. + /// each field in the schema of the table to which the [`Statistics`] refer. pub column_statistics: Vec, } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index a3570834fdb7..610784f91dec 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -50,6 +50,7 @@ default = [ "datetime_expressions", "encoding_expressions", "regex_expressions", + "string_expressions", "unicode_expressions", "compression", "parquet", @@ -66,10 +67,10 @@ regex_expressions = [ "datafusion-functions/regex_expressions", ] serde = ["arrow-schema/serde"] +string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ - "datafusion-physical-expr/unicode_expressions", - "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions", + "datafusion-functions/unicode_expressions", ] [dependencies] @@ -122,20 +123,20 @@ tempfile = { workspace = true } tokio = { workspace = true } tokio-util = { version = "0.7.4", features = ["io"], optional = true } url = { workspace = true } -uuid = { version = "1.0", features = ["v4"] } +uuid = { version = "1.7", features = ["v4"] } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] async-trait = { workspace = true } bigdecimal = { workspace = true } -cargo = "0.77.0" criterion = { version = "0.5", features = ["async_tokio"] } csv = "1.1.6" ctor = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } half = { workspace = true, default-features = true } +paste = "^1.0" postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } rand = { workspace = true, features = ["small_rng"] } @@ -187,6 +188,7 @@ name = "physical_plan" [[bench]] harness = false name = "parquet_query_sql" +required-features = ["parquet"] [[bench]] harness = false @@ -203,7 +205,3 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" - -[[bench]] -harness = false -name = "array_expression" diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/core/src/catalog/mod.rs index 8aeeaf9f72d8..d39fad8a5643 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/core/src/catalog/mod.rs @@ -177,7 +177,7 @@ impl CatalogProviderList for MemoryCatalogProviderList { /// /// [`datafusion-cli`]: https://arrow.apache.org/datafusion/user-guide/cli.html /// [`DynamicFileCatalogProvider`]: https://github.com/apache/arrow-datafusion/blob/31b9b48b08592b7d293f46e75707aad7dadd7cbc/datafusion-cli/src/catalog.rs#L75 -/// [`catalog.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/external_dependency/catalog.rs +/// [`catalog.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/catalog.rs /// [delta-rs]: https://github.com/delta-io/delta-rs /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5f192b83fdd9..eea5fc1127ce 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1001,16 +1001,6 @@ impl DataFrame { Arc::new(DataFrameTableProvider { plan: self.plan }) } - /// Return the optimized logical plan represented by this DataFrame. - /// - /// Note: This method should not be used outside testing, as it loses the snapshot - /// of the [`SessionState`] attached to this [`DataFrame`] and consequently subsequent - /// operations may take place against a different state - #[deprecated(since = "23.0.0", note = "Use DataFrame::into_optimized_plan")] - pub fn to_logical_plan(self) -> Result { - self.into_optimized_plan() - } - /// Return a DataFrame with the explanation of its plan so far. /// /// if `analyze` is specified, runs the plan and reports metrics @@ -1161,8 +1151,8 @@ impl DataFrame { "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), )); } - let table_options = self.session_state.default_table_options(); - let props = writer_options.unwrap_or_else(|| table_options.csv.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().csv); let plan = LogicalPlanBuilder::copy_to( self.plan, @@ -1210,9 +1200,8 @@ impl DataFrame { )); } - let table_options = self.session_state.default_table_options(); - - let props = writer_options.unwrap_or_else(|| table_options.json.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().json); let plan = LogicalPlanBuilder::copy_to( self.plan, diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index f4e8c9dfcd6f..7cc3201bf7e4 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -57,9 +57,8 @@ impl DataFrame { )); } - let table_options = self.session_state.default_table_options(); - - let props = writer_options.unwrap_or_else(|| table_options.parquet.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().parquet); let plan = LogicalPlanBuilder::copy_to( self.plan, @@ -75,6 +74,7 @@ impl DataFrame { #[cfg(test)] mod tests { + use std::collections::HashMap; use std::sync::Arc; use super::super::Result; @@ -82,9 +82,10 @@ mod tests { use crate::arrow::util::pretty; use crate::execution::context::SessionContext; use crate::execution::options::ParquetReadOptions; - use crate::test_util; + use crate::test_util::{self, register_aggregate_csv}; use datafusion_common::file_options::parquet_writer::parse_compression_string; + use datafusion_execution::config::SessionConfig; use datafusion_expr::{col, lit}; use object_store::local::LocalFileSystem; @@ -151,7 +152,7 @@ mod tests { .await?; // Check that file actually used the specified compression - let file = std::fs::File::open(tmp_dir.into_path().join("test.parquet"))?; + let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?; let reader = parquet::file::serialized_reader::SerializedFileReader::new(file) @@ -167,4 +168,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn write_parquet_with_small_rg_size() -> Result<()> { + // This test verifies writing a parquet file with small rg size + // relative to datafusion.execution.batch_size does not panic + let mut ctx = SessionContext::new_with_config( + SessionConfig::from_string_hash_map(HashMap::from_iter( + [("datafusion.execution.batch_size", "10")] + .iter() + .map(|(s1, s2)| (s1.to_string(), s2.to_string())), + ))?, + ); + register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + let test_df = ctx.table("aggregate_test_100").await?; + + let output_path = "file://local/test.parquet"; + + for rg_size in 1..10 { + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + let ctx = &test_df.session_state; + ctx.runtime_env().register_object_store(&local_url, local); + let mut options = TableParquetOptions::default(); + options.global.max_row_group_size = rg_size; + options.global.allow_single_file_parallelism = true; + df.write_parquet( + output_path, + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + + // Check that file actually used the correct rg size + let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?; + + let reader = + parquet::file::serialized_reader::SerializedFileReader::new(file) + .unwrap(); + + let parquet_metadata = reader.metadata(); + + let written_rows = parquet_metadata.row_group(0).num_rows(); + + assert_eq!(written_rows as usize, rg_size); + } + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index 761e6b62680f..039a6aacc07e 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -224,6 +224,12 @@ fn default_field_name(dt: &DataType) -> &str { DataType::RunEndEncoded(_, _) => { unimplemented!("RunEndEncoded support not implemented") } + DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { + unimplemented!("View support not implemented") + } DataType::Decimal128(_, _) => "decimal", DataType::Decimal256(_, _) => "decimal", } diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index c538819e2684..c1fbe352d37b 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -43,6 +43,7 @@ use futures::stream::BoxStream; use futures::StreamExt; #[cfg(feature = "compression")] use futures::TryStreamExt; +use object_store::buffered::BufWriter; use tokio::io::AsyncWrite; #[cfg(feature = "compression")] use tokio_util::io::{ReaderStream, StreamReader}; @@ -148,11 +149,11 @@ impl FileCompressionType { }) } - /// Wrap the given `AsyncWrite` so that it performs compressed writes + /// Wrap the given `BufWriter` so that it performs compressed writes /// according to this `FileCompressionType`. pub fn convert_async_writer( &self, - w: Box, + w: BufWriter, ) -> Result> { Ok(match self.variant { #[cfg(feature = "compression")] @@ -169,7 +170,7 @@ impl FileCompressionType { "Compression feature is not enabled".to_owned(), )) } - UNCOMPRESSED => w, + UNCOMPRESSED => Box::new(w), }) } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 72dc289d4b64..5ee0f7186703 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -49,7 +49,7 @@ use object_store::{ObjectMeta, ObjectStore}; /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across -/// providers that support the the same file formats. +/// providers that support the same file formats. /// /// [`TableProvider`]: crate::datasource::provider::TableProvider #[async_trait] diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index f66683c311c1..f5bd72495d66 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -461,7 +461,7 @@ pub trait ReadOptions<'a> { return Ok(Arc::new(s.to_owned())); } - self.to_listing_options(config, state.default_table_options().clone()) + self.to_listing_options(config, state.default_table_options()) .infer_schema(&state, &table_path) .await } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index c04c536e7ca6..bcf4e8a2c8e4 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -23,7 +23,7 @@ use std::fmt::Debug; use std::sync::Arc; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite, SharedBuffer}; +use super::write::{create_writer, SharedBuffer}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, @@ -56,6 +56,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use object_store::buffered::BufWriter; use parquet::arrow::arrow_writer::{ compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, @@ -78,9 +79,6 @@ use hashbrown::HashMap; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -/// Size of the buffer for [`AsyncArrowWriter`]. -const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; - /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. const INITIAL_BUFFER_BYTES: usize = 1048576; @@ -616,20 +614,13 @@ impl ParquetSink { location: &Path, object_store: Arc, parquet_props: WriterProperties, - ) -> Result< - AsyncArrowWriter>, - > { - let (_, multipart_writer) = object_store - .put_multipart(location) - .await - .map_err(DataFusionError::ObjectStore)?; + ) -> Result> { + let buf_writer = BufWriter::new(object_store, location.clone()); let writer = AsyncArrowWriter::try_new( - multipart_writer, + buf_writer, self.get_writer_schema(), - PARQUET_WRITER_BUFFER_SIZE, Some(parquet_props), )?; - Ok(writer) } @@ -885,42 +876,47 @@ fn spawn_parquet_parallel_serialization_task( )?; let mut current_rg_rows = 0; - while let Some(rb) = data.recv().await { - if current_rg_rows + rb.num_rows() < max_row_group_rows { - send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) - .await?; - current_rg_rows += rb.num_rows(); - } else { - let rows_left = max_row_group_rows - current_rg_rows; - let a = rb.slice(0, rows_left); - send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) - .await?; + while let Some(mut rb) = data.recv().await { + // This loop allows the "else" block to repeatedly split the RecordBatch to handle the case + // when max_row_group_rows < execution.batch_size as an alternative to a recursive async + // function. + loop { + if current_rg_rows + rb.num_rows() < max_row_group_rows { + send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) + .await?; + current_rg_rows += rb.num_rows(); + break; + } else { + let rows_left = max_row_group_rows - current_rg_rows; + let a = rb.slice(0, rows_left); + send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) + .await?; + + // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup + // on a separate task, so that we can immediately start on the next RG before waiting + // for the current one to finish. + drop(col_array_channels); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + max_row_group_rows, + ); + + serialize_tx.send(finalize_rg_task).await.map_err(|_| { + DataFusionError::Internal( + "Unable to send closed RG to concat task!".into(), + ) + })?; - // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup - // on a separate task, so that we can immediately start on the next RG before waiting - // for the current one to finish. - drop(col_array_channels); - let finalize_rg_task = spawn_rg_join_and_finalize_task( - column_writer_handles, - max_row_group_rows, - ); - - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + current_rg_rows = 0; + rb = rb.slice(rows_left, rb.num_rows() - rows_left); - let b = rb.slice(rows_left, rb.num_rows() - rows_left); - (column_writer_handles, col_array_channels) = - spawn_column_parallel_row_group_writer( - schema.clone(), - writer_props.clone(), - max_buffer_rb, - )?; - send_arrays_to_col_writers(&col_array_channels, &b, schema.clone()) - .await?; - current_rg_rows = b.num_rows(); + (column_writer_handles, col_array_channels) = + spawn_column_parallel_row_group_writer( + schema.clone(), + writer_props.clone(), + max_buffer_rb, + )?; + } } } @@ -947,7 +943,7 @@ async fn concatenate_parallel_row_groups( mut serialize_rx: Receiver>, schema: Arc, writer_props: Arc, - mut object_store_writer: AbortableWrite>, + mut object_store_writer: Box, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); @@ -989,7 +985,7 @@ async fn concatenate_parallel_row_groups( /// task then stitches these independent RowGroups together and streams this large /// single parquet file to an ObjectStore in multiple parts. async fn output_single_parquet_file_parallelized( - object_store_writer: AbortableWrite>, + object_store_writer: Box, data: Receiver, output_schema: Arc, parquet_props: &WriterProperties, diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 410a32a19cc1..42115fc7b93f 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -18,21 +18,18 @@ //! Module containing helper methods/traits related to enabling //! write support for the various file formats -use std::io::{Error, Write}; -use std::pin::Pin; +use std::io::Write; use std::sync::Arc; -use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::error::Result; use arrow_array::RecordBatch; -use datafusion_common::DataFusionError; use bytes::Bytes; -use futures::future::BoxFuture; +use object_store::buffered::BufWriter; use object_store::path::Path; -use object_store::{MultipartId, ObjectStore}; +use object_store::ObjectStore; use tokio::io::AsyncWrite; pub(crate) mod demux; @@ -69,79 +66,6 @@ impl Write for SharedBuffer { } } -/// Stores data needed during abortion of MultiPart writers -#[derive(Clone)] -pub(crate) struct MultiPart { - /// A shared reference to the object store - store: Arc, - multipart_id: MultipartId, - location: Path, -} - -impl MultiPart { - /// Create a new `MultiPart` - pub fn new( - store: Arc, - multipart_id: MultipartId, - location: Path, - ) -> Self { - Self { - store, - multipart_id, - location, - } - } -} - -/// A wrapper struct with abort method and writer -pub(crate) struct AbortableWrite { - writer: W, - multipart: MultiPart, -} - -impl AbortableWrite { - /// Create a new `AbortableWrite` instance with the given writer, and write mode. - pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { - Self { writer, multipart } - } - - /// handling of abort for different write modes - pub(crate) fn abort_writer(&self) -> Result>> { - let multi = self.multipart.clone(); - Ok(Box::pin(async move { - multi - .store - .abort_multipart(&multi.location, &multi.multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } -} - -impl AsyncWrite for AbortableWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) - } -} - /// A trait that defines the methods required for a RecordBatch serializer. pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. @@ -150,19 +74,15 @@ pub trait BatchSerializer: Sync + Send { fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } -/// Returns an [`AbortableWrite`] which writes to the given object store location -/// with the specified compression +/// Returns an [`AsyncWrite`] which writes to the given object store location +/// with the specified compression. +/// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. +/// Users can configure automatic cleanup with their cloud provider. pub(crate) async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, -) -> Result>> { - let (multipart_id, writer) = object_store - .put_multipart(location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - MultiPart::new(object_store, multipart_id, location.clone()), - )) +) -> Result> { + let buf_writer = BufWriter::new(object_store, location.clone()); + file_compression_type.convert_async_writer(buf_writer) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index b7f268959311..3ae2122de827 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer}; +use super::{create_writer, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; @@ -39,7 +39,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinSet; -type WriterType = AbortableWrite>; +type WriterType = Box; type SerializerType = Arc; /// Serializes a single data stream in parallel and writes to an ObjectStore @@ -49,7 +49,7 @@ type SerializerType = Arc; pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, - mut writer: AbortableWrite>, + mut writer: WriterType, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); @@ -173,19 +173,9 @@ pub(crate) async fn stateless_serialize_and_write_files( // Finalize or abort writers as appropriate for mut writer in finished_writers.into_iter() { - match any_errors { - true => { - let abort_result = writer.abort_writer(); - if abort_result.is_err() { - any_abort_errors = true; - } - } - false => { - writer.shutdown() + writer.shutdown() .await .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; - } - } } if any_errors { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 2a2551236e1b..c1e337b5c44a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -118,7 +118,7 @@ impl ListingTableConfig { } } - fn infer_format(path: &str) -> Result<(Arc, String)> { + fn infer_file_type(path: &str) -> Result<(FileType, String)> { let err_msg = format!("Unable to infer file type from path: {path}"); let mut exts = path.rsplit('.'); @@ -139,20 +139,7 @@ impl ListingTableConfig { .get_ext_with_compression(file_compression_type.to_owned()) .map_err(|_| DataFusionError::Internal(err_msg))?; - let file_format: Arc = match file_type { - FileType::ARROW => Arc::new(ArrowFormat), - FileType::AVRO => Arc::new(AvroFormat), - FileType::CSV => Arc::new( - CsvFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::JSON => Arc::new( - JsonFormat::default().with_file_compression_type(file_compression_type), - ), - #[cfg(feature = "parquet")] - FileType::PARQUET => Arc::new(ParquetFormat::default()), - }; - - Ok((file_format, ext)) + Ok((file_type, ext)) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -173,10 +160,27 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let (format, file_extension) = - ListingTableConfig::infer_format(file.location.as_ref())?; + let (file_type, file_extension) = + ListingTableConfig::infer_file_type(file.location.as_ref())?; + + let mut table_options = state.default_table_options(); + table_options.set_file_format(file_type.clone()); + let file_format: Arc = match file_type { + FileType::CSV => { + Arc::new(CsvFormat::default().with_options(table_options.csv)) + } + #[cfg(feature = "parquet")] + FileType::PARQUET => { + Arc::new(ParquetFormat::default().with_options(table_options.parquet)) + } + FileType::AVRO => Arc::new(AvroFormat), + FileType::JSON => { + Arc::new(JsonFormat::default().with_options(table_options.json)) + } + FileType::ARROW => Arc::new(ArrowFormat), + }; - let listing_options = ListingOptions::new(format) + let listing_options = ListingOptions::new(file_format) .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()); diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index d9149bcc20e0..eb95dc7b1d24 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::fs; - use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; @@ -117,37 +115,6 @@ impl ListingTableUrl { } } - /// Get object store for specified input_url - /// if input_url is actually not a url, we assume it is a local file path - /// if we have a local path, create it if not exists so ListingTableUrl::parse works - #[deprecated(note = "Use parse")] - pub fn parse_create_local_if_not_exists( - s: impl AsRef, - is_directory: bool, - ) -> Result { - let s = s.as_ref(); - let is_valid_url = Url::parse(s).is_ok(); - - match is_valid_url { - true => ListingTableUrl::parse(s), - false => { - let path = std::path::PathBuf::from(s); - if !path.exists() { - if is_directory { - fs::create_dir_all(path)?; - } else { - // ensure parent directory exists - if let Some(parent) = path.parent() { - fs::create_dir_all(parent)?; - } - fs::File::create(path)?; - } - } - ListingTableUrl::parse(s) - } - } - } - /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path #[cfg(not(target_arch = "wasm32"))] fn parse_path(s: &str) -> Result { diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 4e126bbba9f9..b616e0181cfc 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -34,7 +34,6 @@ use crate::datasource::TableProvider; use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::config::TableOptions; use datafusion_common::{arrow_datafusion_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; @@ -58,8 +57,7 @@ impl TableProviderFactory for ListingTableFactory { state: &SessionState, cmd: &CreateExternalTable, ) -> datafusion_common::Result> { - let mut table_options = - TableOptions::default_from_session_config(state.config_options()); + let mut table_options = state.default_table_options(); let file_type = FileType::from_str(cmd.file_type.as_str()).map_err(|_| { DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) })?; @@ -227,7 +225,7 @@ mod tests { let name = OwnedTableReference::bare("foo".to_string()); let mut options = HashMap::new(); - options.insert("csv.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); let cmd = CreateExternalTable { name, location: csv_file.path().to_str().unwrap().to_string(), diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 3c76ee635855..608a46144da3 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -33,7 +33,7 @@ use crate::physical_plan::{ common, DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, SendableRecordBatchStream, }; -use crate::physical_planner::create_physical_sort_expr; +use crate::physical_planner::create_physical_sort_exprs; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -231,16 +231,11 @@ impl TableProvider for MemTable { let file_sort_order = sort_order .iter() .map(|sort_exprs| { - sort_exprs - .iter() - .map(|expr| { - create_physical_sort_expr( - expr, - &df_schema, - state.execution_props(), - ) - }) - .collect::>>() + create_physical_sort_exprs( + sort_exprs, + &df_schema, + state.execution_props(), + ) }) .collect::>>()?; exec = exec.with_sort_information(file_sort_order); diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 96b3adf968b8..1e8775731015 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -122,6 +122,10 @@ impl DisplayAs for ArrowExec { } impl ExecutionPlan for ArrowExec { + fn name(&self) -> &'static str { + "ArrowExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 2ccd83de80cb..4e5140e82d3f 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -99,6 +99,10 @@ impl DisplayAs for AvroExec { } impl ExecutionPlan for AvroExec { + fn name(&self) -> &'static str { + "AvroExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 5fcb9f483952..831ef4520567 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -44,6 +44,7 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, StreamExt, TryStreamExt}; +use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -159,6 +160,10 @@ impl DisplayAs for CsvExec { } impl ExecutionPlan for CsvExec { + fn name(&self) -> &'static str { + "CsvExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -471,7 +476,7 @@ pub async fn plan_to_csv( let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); //only write headers on first iteration let mut write_headers = true; @@ -481,15 +486,12 @@ pub async fn plan_to_csv( .build(buffer); writer.write(&batch)?; buffer = writer.into_inner(); - multipart_writer.write_all(&buffer).await?; + buf_writer.write_all(&buffer).await?; buffer.clear(); //prevent writing headers more than once write_headers = false; } - multipart_writer - .shutdown() - .await - .map_err(DataFusionError::from) + buf_writer.shutdown().await.map_err(DataFusionError::from) }); } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 068426e0fdcb..a5afda47527f 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -43,6 +43,7 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, StreamExt, TryStreamExt}; +use object_store::buffered::BufWriter; use object_store::{self, GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -126,6 +127,10 @@ impl DisplayAs for NdJsonExec { } impl ExecutionPlan for NdJsonExec { + fn name(&self) -> &'static str { + "NdJsonExec" + } + fn as_any(&self) -> &dyn Any { self } @@ -149,6 +154,9 @@ impl ExecutionPlan for NdJsonExec { target_partitions: usize, config: &datafusion_common::config::ConfigOptions, ) -> Result>> { + if self.file_compression_type == FileCompressionType::GZIP { + return Ok(None); + } let repartition_file_min_size = config.optimizer.repartition_file_min_size; let preserve_order_within_groups = self.properties().output_ordering().is_some(); let file_groups = &self.base_config.file_groups; @@ -338,21 +346,18 @@ pub async fn plan_to_json( let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); while let Some(batch) = stream.next().await.transpose()? { let mut writer = json::LineDelimitedWriter::new(buffer); writer.write(&batch)?; buffer = writer.into_inner(); - multipart_writer.write_all(&buffer).await?; + buf_writer.write_all(&buffer).await?; buffer.clear(); } - multipart_writer - .shutdown() - .await - .map_err(DataFusionError::from) + buf_writer.shutdown().await.map_err(DataFusionError::from) }); } @@ -394,11 +399,14 @@ mod tests { use arrow::datatypes::{Field, SchemaBuilder}; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; use datafusion_common::FileType; - + use flate2::write::GzEncoder; + use flate2::Compression; use futures::StreamExt; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; + use std::fs::File; + use std::io; use tempfile::TempDir; use url::Url; @@ -886,4 +894,48 @@ mod tests { Ok(()) } + fn compress_file(path: &str, output_path: &str) -> io::Result<()> { + let input_file = File::open(path)?; + let mut reader = BufReader::new(input_file); + + let output_file = File::create(output_path)?; + let writer = std::io::BufWriter::new(output_file); + + let mut encoder = GzEncoder::new(writer, Compression::default()); + io::copy(&mut reader, &mut encoder)?; + + encoder.finish()?; + Ok(()) + } + #[tokio::test] + async fn test_disable_parallel_for_json_gz() -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(4); + let ctx = SessionContext::new_with_config(config); + let path = format!("{TEST_DATA_BASE}/1.json"); + let compressed_path = format!("{}.gz", &path); + compress_file(&path, &compressed_path)?; + let read_option = NdJsonReadOptions::default() + .file_compression_type(FileCompressionType::GZIP) + .file_extension("gz"); + let df = ctx.read_json(compressed_path.clone(), read_option).await?; + let res = df.collect().await; + fs::remove_file(&compressed_path)?; + assert_batches_eq!( + &[ + "+-----+------------------+---------------+------+", + "| a | b | c | d |", + "+-----+------------------+---------------+------+", + "| 1 | [2.0, 1.3, -6.1] | [false, true] | 4 |", + "| -10 | [2.0, 1.3, -6.1] | [true, true] | 4 |", + "| 2 | [2.0, , -6.1] | [false, ] | text |", + "| | | | |", + "+-----+------------------+---------------+------+", + ], + &res? + ); + Ok(()) + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs index a17a3c6d9752..c2a7e4345a5b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs @@ -29,8 +29,12 @@ use crate::physical_plan::metrics::{ pub struct ParquetFileMetrics { /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, + /// Number of row groups whose bloom filters were checked and matched + pub row_groups_matched_bloom_filter: Count, /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: Count, + /// Number of row groups whose statistics were checked and matched + pub row_groups_matched_statistics: Count, /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: Count, /// Total number of bytes scanned @@ -56,10 +60,18 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .counter("predicate_evaluation_errors", partition); + let row_groups_matched_bloom_filter = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_matched_bloom_filter", partition); + let row_groups_pruned_bloom_filter = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("row_groups_pruned_bloom_filter", partition); + let row_groups_matched_statistics = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_matched_statistics", partition); + let row_groups_pruned_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("row_groups_pruned_statistics", partition); @@ -85,7 +97,9 @@ impl ParquetFileMetrics { Self { predicate_evaluation_errors, + row_groups_matched_bloom_filter, row_groups_pruned_bloom_filter, + row_groups_matched_statistics, row_groups_pruned_statistics, bytes_scanned, pushdown_rows_filtered, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 2cfbb578da66..377dad5cee6c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -52,6 +52,7 @@ use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use log::debug; +use object_store::buffered::BufWriter; use object_store::path::Path; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; @@ -314,6 +315,10 @@ impl DisplayAs for ParquetExec { } impl ExecutionPlan for ParquetExec { + fn name(&self) -> &'static str { + "ParquetExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -698,15 +703,11 @@ pub async fn plan_to_parquet( let propclone = writer_properties.clone(); let storeref = store.clone(); - let (_, multipart_writer) = storeref.put_multipart(&file).await?; + let buf_writer = BufWriter::new(storeref, file.clone()); let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let mut writer = AsyncArrowWriter::try_new( - multipart_writer, - plan.schema(), - 10485760, - propclone, - )?; + let mut writer = + AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; while let Some(next_batch) = stream.next().await { let batch = next_batch?; writer.write(&batch).await?; @@ -1870,7 +1871,7 @@ mod tests { assert_contains!( &display, - "pruning_predicate=c1_min@0 != bar OR bar != c1_max@1" + "pruning_predicate=CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != bar OR bar != c1_max@1 END" ); assert_contains!(&display, r#"predicate=c1@0 != bar"#); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 064a8e1fff33..c7706f3458d0 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -547,6 +547,10 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { } } + fn row_counts(&self, _column: &datafusion_common::Column) -> Option { + None + } + fn contained( &self, _column: &datafusion_common::Column, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index ef2eb775e037..a82c5d97a2b7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -94,6 +94,7 @@ pub(crate) fn prune_row_groups_by_statistics( metrics.predicate_evaluation_errors.add(1); } } + metrics.row_groups_matched_statistics.add(1); } filtered.push(idx) @@ -166,6 +167,9 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< if prune_group { metrics.row_groups_pruned_bloom_filter.add(1); } else { + if !stats.column_sbbf.is_empty() { + metrics.row_groups_matched_bloom_filter.add(1); + } filtered.push(*idx); } } @@ -195,6 +199,10 @@ impl PruningStatistics for BloomFilterStatistics { None } + fn row_counts(&self, _column: &Column) -> Option { + None + } + /// Use bloom filters to determine if we are sure this column can not /// possibly contain `values` /// @@ -217,6 +225,8 @@ impl PruningStatistics for BloomFilterStatistics { .map(|value| { match value { ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) => sbbf.check(v), + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), ScalarValue::Boolean(Some(v)) => sbbf.check(v), ScalarValue::Float64(Some(v)) => sbbf.check(v), ScalarValue::Float32(Some(v)) => sbbf.check(v), @@ -328,6 +338,10 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { scalar.to_array().ok() } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( &self, _column: &Column, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 4e472606da51..aac5aff80f16 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -105,14 +105,20 @@ macro_rules! get_statistic { let s = std::str::from_utf8(s.$bytes_func()) .map(|s| s.to_string()) .ok(); + if s.is_none() { + log::debug!( + "Utf8 statistics is a non-UTF8 value, ignoring it." + ); + } Some(ScalarValue::Utf8(s)) } } } - // type not supported yet + // type not fully supported yet ParquetStatistics::FixedLenByteArray(s) => { match $target_arrow_type { - // just support the decimal data type + // just support specific logical data types, there are others each + // with their own ordering Some(DataType::Decimal128(precision, scale)) => { Some(ScalarValue::Decimal128( Some(from_bytes_to_i128(s.$bytes_func())), @@ -120,6 +126,23 @@ macro_rules! get_statistic { *scale, )) } + Some(DataType::FixedSizeBinary(size)) => { + let value = s.$bytes_func().to_vec(); + let value = if value.len().try_into() == Ok(*size) { + Some(value) + } else { + log::debug!( + "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", + size, + value.len(), + ); + None + }; + Some(ScalarValue::FixedSizeBinary( + *size, + value, + )) + } _ => None, } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 32c1c60ec564..31f390607f04 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -384,9 +384,9 @@ impl SessionContext { self.state.read().config.clone() } - /// Return a copied version of config for this Session + /// Return a copied version of table options for this Session pub fn copied_table_options(&self) -> TableOptions { - self.state.read().default_table_options().clone() + self.state.read().default_table_options() } /// Creates a [`DataFrame`] from SQL query text. @@ -794,7 +794,7 @@ impl SessionContext { let function_factory = &state.function_factory; match function_factory { - Some(f) => f.create(state.config(), stmt).await?, + Some(f) => f.create(&state, stmt).await?, _ => Err(DataFusionError::Configuration( "Function factory has not been configured".into(), ))?, @@ -1181,49 +1181,6 @@ impl SessionContext { } } - /// Returns the set of available tables in the default catalog and - /// schema. - /// - /// Use [`table`] to get a specific table. - /// - /// [`table`]: SessionContext::table - #[deprecated( - since = "23.0.0", - note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables" - )] - pub fn tables(&self) -> Result> { - Ok(self - .state - .read() - // a bare reference will always resolve to the default catalog and schema - .schema_for_ref(TableReference::Bare { table: "".into() })? - .table_names() - .iter() - .cloned() - .collect()) - } - - /// Optimizes the logical plan by applying optimizer rules. - #[deprecated( - since = "23.0.0", - note = "Use SessionState::optimize to ensure a consistent state for planning and execution" - )] - pub fn optimize(&self, plan: &LogicalPlan) -> Result { - self.state.read().optimize(plan) - } - - /// Creates a physical plan from a logical plan. - #[deprecated( - since = "23.0.0", - note = "Use SessionState::create_physical_plan or DataFrame::create_physical_plan to ensure a consistent state for planning and execution" - )] - pub async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - ) -> Result> { - self.state().create_physical_plan(logical_plan).await - } - /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) @@ -1331,7 +1288,7 @@ pub trait FunctionFactory: Sync + Send { /// Handles creation of user defined function specified in [CreateFunction] statement async fn create( &self, - state: &SessionConfig, + state: &SessionState, statement: CreateFunction, ) -> Result; } @@ -1793,11 +1750,7 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::CopyTo(CopyToStatement { - source, - target: _, - options: _, - }) => match source { + DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { visitor.insert(table_name); } @@ -2006,8 +1959,9 @@ impl SessionState { } /// return the TableOptions options with its extensions - pub fn default_table_options(&self) -> &TableOptions { - &self.table_option_namespace + pub fn default_table_options(&self) -> TableOptions { + self.table_option_namespace + .combine_with_session_config(self.config_options()) } /// Get a new TaskContext to run in this session diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 54fe6e8406fd..a58f8698d6ce 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -392,7 +392,7 @@ fn adjust_input_keys_ordering( let expr = proj.expr(); // For Projection, we need to transform the requirements to the columns before the Projection // And then to push down the requirements - // Construct a mapping from new name to the the orginal Column + // Construct a mapping from new name to the orginal Column let new_required = map_columns_before_projection(&requirements.data, expr); if new_required.len() == requirements.data.len() { requirements.children[0].data = new_required; @@ -1369,6 +1369,10 @@ pub(crate) mod tests { } impl ExecutionPlan for SortRequiredExec { + fn name(&self) -> &'static str { + "SortRequiredExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index bd71b3e8ed80..829d523c990c 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -133,6 +133,10 @@ impl DisplayAs for OutputRequirementExec { } impl ExecutionPlan for OutputRequirementExec { + fn name(&self) -> &'static str { + "OutputRequirementExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } @@ -216,7 +220,7 @@ impl PhysicalOptimizerRule for OutputRequirements { } } -/// This functions adds ancillary `OutputRequirementExec` to the the physical plan, so that +/// This functions adds ancillary `OutputRequirementExec` to the physical plan, so that /// global requirements are not lost during optimization. fn require_top_ordering(plan: Arc) -> Result> { let (new_plan, is_changed) = require_top_ordering_helper(plan)?; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e8f3bf01ecaa..ed445e6d48b8 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -322,7 +322,7 @@ fn try_swapping_with_output_req( projection: &ProjectionExec, output_req: &OutputRequirementExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -372,7 +372,7 @@ fn try_swapping_with_output_req( fn try_swapping_with_coalesce_partitions( projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -387,7 +387,7 @@ fn try_swapping_with_filter( projection: &ProjectionExec, filter: &FilterExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -412,7 +412,7 @@ fn try_swapping_with_repartition( projection: &ProjectionExec, repartition: &RepartitionExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down. + // If the projection does not narrow the schema, we should not try to push it down. if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -454,7 +454,7 @@ fn try_swapping_with_sort( projection: &ProjectionExec, sort: &SortExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down. + // If the projection does not narrow the schema, we should not try to push it down. if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -1082,7 +1082,7 @@ fn join_table_borders( (far_right_left_col_ind, far_left_right_col_ind) } -/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// Tries to update the equi-join `Column`'s of a join as if the input of /// the join was replaced by a projection. fn update_join_on( proj_left_exprs: &[(Column, String)], @@ -1152,7 +1152,7 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// Tries to update the column indices of a [`JoinFilter`] as if the input of /// the join was replaced by a projection. fn update_join_filter( projection_left_exprs: &[(Column, String)], @@ -1287,6 +1287,7 @@ fn new_join_children( #[cfg(test)] mod tests { use super::*; + use std::any::Any; use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -1313,7 +1314,10 @@ mod tests { use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_expr::{ + ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, + }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -1329,6 +1333,42 @@ mod tests { use itertools::Itertools; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } + #[test] fn test_update_matching_exprs() -> Result<()> { let exprs: Vec> = vec![ @@ -1345,7 +1385,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1412,7 +1454,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1482,7 +1526,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1549,7 +1595,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index d2126f90eca9..80bb5ad42e81 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -53,7 +53,7 @@ use log::trace; /// /// 1. Minimum and maximum values for columns /// -/// 2. Null counts for columns +/// 2. Null counts and row counts for columns /// /// 3. Whether the values in a column are contained in a set of literals /// @@ -100,7 +100,8 @@ pub trait PruningStatistics { /// these statistics. /// /// This value corresponds to the size of the [`ArrayRef`] returned by - /// [`Self::min_values`], [`Self::max_values`], and [`Self::null_counts`]. + /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], + /// and [`Self::row_counts`]. fn num_containers(&self) -> usize; /// Return the number of null values for the named column as an @@ -111,6 +112,14 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + /// Return the number of rows for the named column in each container + /// as an `Option`. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn row_counts(&self, column: &Column) -> Option; + /// Returns [`BooleanArray`] where each row represents information known /// about specific literal `values` in a column. /// @@ -268,7 +277,7 @@ pub trait PruningStatistics { /// 3. [`PruningStatistics`] that provides information about columns in that /// schema, for multiple “containers”. For each column in each container, it /// provides optional information on contained values, min_values, max_values, -/// and null_counts counts. +/// null_counts counts, and row_counts counts. /// /// **Outputs**: /// A (non null) boolean value for each container: @@ -306,17 +315,23 @@ pub trait PruningStatistics { /// * `false`: there are no rows that could possibly match the predicate, /// **PRUNES** the container /// -/// For example, given a column `x`, the `x_min` and `x_max` and `x_null_count` -/// represent the minimum and maximum values, and the null count of column `x`, -/// provided by the `PruningStatistics`. Here are some examples of the rewritten -/// predicates: +/// For example, given a column `x`, the `x_min`, `x_max`, `x_null_count`, and +/// `x_row_count` represent the minimum and maximum values, the null count of +/// column `x`, and the row count of column `x`, provided by the `PruningStatistics`. +/// `x_null_count` and `x_row_count` are used to handle the case where the column `x` +/// is known to be all `NULL`s. Note this is different from knowing nothing about +/// the column `x`, which confusingly is encoded by returning `NULL` for the min/max +/// values from [`PruningStatistics::max_values`] and [`PruningStatistics::min_values`]. +/// +/// Here are some examples of the rewritten predicates: /// /// Original Predicate | Rewritten Predicate /// ------------------ | -------------------- -/// `x = 5` | `x_min <= 5 AND 5 <= x_max` -/// `x < 5` | `x_max < 5` -/// `x = 5 AND y = 10` | `x_min <= 5 AND 5 <= x_max AND y_min <= 10 AND 10 <= y_max` -/// `x IS NULL` | `x_null_count > 0` +/// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END` +/// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END` +/// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END` +/// `x IS NULL` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_null_count > 0 END` +/// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END` /// /// ## Predicate Evaluation /// The PruningPredicate works in two passes @@ -326,28 +341,47 @@ pub trait PruningStatistics { /// LiteralGuarantees are not satisfied /// /// **Second Pass**: Evaluates the rewritten expression using the -/// min/max/null_counts values for each column for each container. For any +/// min/max/null_counts/row_counts values for each column for each container. For any /// container that this expression evaluates to `false`, it rules out those /// containers. /// -/// For example, given the predicate, `x = 5 AND y = 10`, if we know `x` is -/// between `1 and 100` and we know that `y` is between `4` and `7`, the input -/// statistics might look like +/// +/// ### Example 1 +/// +/// Given the predicate, `x = 5 AND y = 10`, the rewritten predicate would look like: +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 5 AND 5 <= x_max +/// END +/// AND +/// CASE +/// WHEN y_null_count = y_row_count THEN false +/// ELSE y_min <= 10 AND 10 <= y_max +/// END +/// ``` +/// +/// If we know that for a given container, `x` is between `1 and 100` and we know that +/// `y` is between `4` and `7`, we know nothing about the null count and row count of +/// `x` and `y`, the input statistics might look like: /// /// Column | Value /// -------- | ----- /// `x_min` | `1` /// `x_max` | `100` +/// `x_null_count` | `null` +/// `x_row_count` | `null` /// `y_min` | `4` /// `y_max` | `7` +/// `y_null_count` | `null` +/// `y_row_count` | `null` /// -/// The rewritten predicate would look like -/// -/// `x_min <= 5 AND 5 <= x_max AND y_min <= 10 AND 10 <= y_max` -/// -/// When these values are substituted in to the rewritten predicate and +/// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// +/// * `CASE WHEN null = null THEN false ELSE 1 <= 5 AND 5 <= 100 END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` +/// * `null = null` is `null` which is not true, so the `CASE` expression will use the `ELSE` clause /// * `1 <= 5 AND 5 <= 100 AND 4 <= 10 AND 10 <= 7` /// * `true AND true AND true AND false` /// * `false` @@ -364,6 +398,52 @@ pub trait PruningStatistics { /// more analysis, for example by actually reading the data and evaluating the /// predicate row by row. /// +/// ### Example 2 +/// +/// Given the same predicate, `x = 5 AND y = 10`, the rewritten predicate would +/// look like the same as example 1: +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 5 AND 5 <= x_max +/// END +/// AND +/// CASE +/// WHEN y_null_count = y_row_count THEN false +/// ELSE y_min <= 10 AND 10 <= y_max +/// END +/// ``` +/// +/// If we know that for another given container, `x_min` is NULL and `x_max` is +/// NULL (the min/max values are unknown), `x_null_count` is `100` and `x_row_count` +/// is `100`; we know that `y` is between `4` and `7`, but we know nothing about +/// the null count and row count of `y`. The input statistics might look like: +/// +/// Column | Value +/// -------- | ----- +/// `x_min` | `null` +/// `x_max` | `null` +/// `x_null_count` | `100` +/// `x_row_count` | `100` +/// `y_min` | `4` +/// `y_max` | `7` +/// `y_null_count` | `null` +/// `y_row_count` | `null` +/// +/// When these statistics values are substituted in to the rewritten predicate and +/// simplified, the result is `false`: +/// +/// * `CASE WHEN 100 = 100 THEN false ELSE null <= 5 AND 5 <= null END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` +/// * Since `100 = 100` is `true`, the `CASE` expression will use the `THEN` clause, i.e. `false` +/// * The other `CASE` expression will use the `ELSE` clause, i.e. `4 <= 10 AND 10 <= 7` +/// * `false AND true` +/// * `false` +/// +/// Returning `false` means the container can be pruned, which matches the +/// intuition that `x = 5 AND y = 10` can’t be true for all values in `x` +/// are known to be NULL. +/// /// # Related Work /// /// [`PruningPredicate`] implements the type of min/max pruning described in @@ -744,6 +824,22 @@ impl RequiredColumns { "null_count", ) } + + /// rewrite col --> col_row_count + fn row_count_column_expr( + &mut self, + column: &phys_expr::Column, + column_expr: &Arc, + field: &Field, + ) -> Result> { + self.stat_column_expr( + column, + column_expr, + field, + StatisticsType::RowCount, + "row_count", + ) + } } impl From> for RequiredColumns { @@ -794,6 +890,7 @@ fn build_statistics_record_batch( StatisticsType::Min => statistics.min_values(&column), StatisticsType::Max => statistics.max_values(&column), StatisticsType::NullCount => statistics.null_counts(&column), + StatisticsType::RowCount => statistics.row_counts(&column), }; let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); @@ -903,6 +1000,46 @@ impl<'a> PruningExpressionBuilder<'a> { self.required_columns .max_column_expr(&self.column, &self.column_expr, self.field) } + + /// This function is to simply retune the `null_count` physical expression no matter what the + /// predicate expression is + /// + /// i.e., x > 5 => x_null_count, + /// cast(x as int) < 10 => x_null_count, + /// try_cast(x as float) < 10.0 => x_null_count + fn null_count_column_expr(&mut self) -> Result> { + // Retune to [`phys_expr::Column`] + let column_expr = Arc::new(self.column.clone()) as _; + + // null_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) + let null_count_field = &Field::new(self.field.name(), DataType::UInt64, true); + + self.required_columns.null_count_column_expr( + &self.column, + &column_expr, + null_count_field, + ) + } + + /// This function is to simply retune the `row_count` physical expression no matter what the + /// predicate expression is + /// + /// i.e., x > 5 => x_row_count, + /// cast(x as int) < 10 => x_row_count, + /// try_cast(x as float) < 10.0 => x_row_count + fn row_count_column_expr(&mut self) -> Result> { + // Retune to [`phys_expr::Column`] + let column_expr = Arc::new(self.column.clone()) as _; + + // row_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) + let row_count_field = &Field::new(self.field.name(), DataType::UInt64, true); + + self.required_columns.row_count_column_expr( + &self.column, + &column_expr, + row_count_field, + ) + } } /// This function is designed to rewrite the column_expr to @@ -1320,14 +1457,56 @@ fn build_statistics_expr( ); } }; + let statistics_expr = wrap_case_expr(statistics_expr, expr_builder)?; Ok(statistics_expr) } +/// Wrap the statistics expression in a case expression. +/// This is necessary to handle the case where the column is known +/// to be all nulls. +/// +/// For example: +/// +/// `x_min <= 10 AND 10 <= x_max` +/// +/// will become +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 10 AND 10 <= x_max +/// END +/// ```` +/// +/// If the column is known to be all nulls, then the expression +/// `x_null_count = x_row_count` will be true, which will cause the +/// case expression to return false. Therefore, prune out the container. +fn wrap_case_expr( + statistics_expr: Arc, + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + // x_null_count = x_row_count + let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( + expr_builder.null_count_column_expr()?, + Operator::Eq, + expr_builder.row_count_column_expr()?, + )); + let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))); + + // CASE WHEN x_null_count = x_row_count THEN false ELSE END + Ok(Arc::new(phys_expr::CaseExpr::try_new( + None, + vec![(when_null_count_eq_row_count, then)], + Some(statistics_expr), + )?)) +} + #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) enum StatisticsType { Min, Max, NullCount, + RowCount, } #[cfg(test)] @@ -1361,6 +1540,7 @@ mod tests { max: Option, /// Optional values null_counts: Option, + row_counts: Option, /// Optional known values (e.g. mimic a bloom filter) /// (value, contained) /// If present, all BooleanArrays must be the same size as min/max @@ -1440,6 +1620,10 @@ mod tests { self.null_counts.clone() } + fn row_counts(&self) -> Option { + self.row_counts.clone() + } + /// return an iterator over all arrays in this statistics fn arrays(&self) -> Vec { let contained_arrays = self @@ -1451,6 +1635,7 @@ mod tests { self.min.as_ref().cloned(), self.max.as_ref().cloned(), self.null_counts.as_ref().cloned(), + self.row_counts.as_ref().cloned(), ] .into_iter() .flatten() @@ -1509,6 +1694,20 @@ mod tests { self } + /// Add row counts. There must be the same number of row counts as + /// there are containers + fn with_row_counts( + mut self, + counts: impl IntoIterator>, + ) -> Self { + let row_counts: ArrayRef = + Arc::new(counts.into_iter().collect::()); + + self.assert_invariants(); + self.row_counts = Some(row_counts); + self + } + /// Add contained information. pub fn with_contained( mut self, @@ -1576,6 +1775,28 @@ mod tests { self } + /// Add row counts for the specified columm. + /// There must be the same number of row counts as + /// there are containers + fn with_row_counts( + mut self, + name: impl Into, + counts: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_row_counts(counts); + + // put stats back in + self.stats.insert(col, container_stats); + self + } + /// Add contained information for the specified columm. fn with_contained( mut self, @@ -1628,6 +1849,13 @@ mod tests { .unwrap_or(None) } + fn row_counts(&self, column: &Column) -> Option { + self.stats + .get(column) + .map(|container_stats| container_stats.row_counts()) + .unwrap_or(None) + } + fn contained( &self, column: &Column, @@ -1663,6 +1891,10 @@ mod tests { None } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( &self, _column: &Column, @@ -1853,7 +2085,7 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1"; + let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 END"; // test column on the left let expr = col("c1").eq(lit(1)); @@ -1873,7 +2105,7 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 != 1 OR 1 != c1_max@1"; + let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != 1 OR 1 != c1_max@1 END"; // test column on the left let expr = col("c1").not_eq(lit(1)); @@ -1893,7 +2125,8 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max@0 > 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 > 1 END"; // test column on the left let expr = col("c1").gt(lit(1)); @@ -1913,7 +2146,7 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max@0 >= 1"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 >= 1 END"; // test column on the left let expr = col("c1").gt_eq(lit(1)); @@ -1932,7 +2165,8 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 < 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; // test column on the left let expr = col("c1").lt(lit(1)); @@ -1952,7 +2186,7 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 <= 1"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 <= 1 END"; // test column on the left let expr = col("c1").lt_eq(lit(1)); @@ -1977,7 +2211,8 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = "c1_min@0 < 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2043,7 +2278,7 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "c1_min@0 < true"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < true END"; // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated @@ -2066,7 +2301,21 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "c1_min@0 < 1 AND (c2_min@1 <= 2 AND 2 <= c2_max@2 OR c2_min@1 <= 3 AND 3 <= c2_max@2)"; + let expected_expr = "\ + CASE \ + WHEN c1_null_count@1 = c1_row_count@2 THEN false \ + ELSE c1_min@0 < 1 \ + END \ + AND (\ + CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@3 <= 2 AND 2 <= c2_max@4 \ + END \ + OR CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@3 <= 3 AND 3 <= c2_max@4 \ + END\ + )"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2080,10 +2329,30 @@ mod tests { c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); + // c1 < 1 should add c1_null_count + let c1_null_count_field = Field::new("c1_null_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[1], + ( + phys_expr::Column::new("c1", 0), + StatisticsType::NullCount, + c1_null_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); + // c1 < 1 should add c1_row_count + let c1_row_count_field = Field::new("c1_row_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[2], + ( + phys_expr::Column::new("c1", 0), + StatisticsType::RowCount, + c1_row_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( - required_columns.columns[1], + required_columns.columns[3], ( phys_expr::Column::new("c2", 1), StatisticsType::Min, @@ -2092,15 +2361,35 @@ mod tests { ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( - required_columns.columns[2], + required_columns.columns[4], ( phys_expr::Column::new("c2", 1), StatisticsType::Max, c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); + // c2 = 2 should add c2_null_count + let c2_null_count_field = Field::new("c2_null_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[5], + ( + phys_expr::Column::new("c2", 1), + StatisticsType::NullCount, + c2_null_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); + // c2 = 2 should add c2_row_count + let c2_row_count_field = Field::new("c2_row_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[6], + ( + phys_expr::Column::new("c2", 1), + StatisticsType::RowCount, + c2_row_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); // c2 = 3 shouldn't add any new statistics fields - assert_eq!(required_columns.columns.len(), 3); + assert_eq!(required_columns.columns.len(), 7); Ok(()) } @@ -2117,7 +2406,18 @@ mod tests { vec![lit(1), lit(2), lit(3)], false, )); - let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 3 AND 3 <= c1_max@1 \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2153,9 +2453,19 @@ mod tests { vec![lit(1), lit(2), lit(3)], true, )); - let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ - AND (c1_min@0 != 2 OR 2 != c1_max@1) \ - AND (c1_min@0 != 3 OR 3 != c1_max@1)"; + let expected_expr = "\ + CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 1 OR 1 != c1_max@1 \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 2 OR 2 != c1_max@1 \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 3 OR 3 != c1_max@1 \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2201,7 +2511,24 @@ mod tests { // test c1 in(1, 2) and c2 BETWEEN 4 AND 5 let expr3 = expr1.and(expr2); - let expected_expr = "(c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_max@2 >= 4 AND c2_min@3 <= 5"; + let expected_expr = "\ + (\ + CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ + END\ + ) AND CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_max@4 >= 4 \ + END \ + AND CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@7 <= 5 \ + END"; let predicate_expr = test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2228,9 +2555,12 @@ mod tests { #[test] fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ + END"; + // test cast(c1 as int64) = 1 // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = @@ -2243,7 +2573,10 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; + let expected_expr = "CASE \ + WHEN c1_null_count@1 = c1_row_count@2 THEN false \ + ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \ + END"; // test column on the left let expr = @@ -2275,7 +2608,18 @@ mod tests { ], false, )); - let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2289,10 +2633,18 @@ mod tests { ], true, )); - let expected_expr = - "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ - AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ - AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2819,7 +3171,7 @@ mod tests { let expected_ret = &[false, true, true, true, false]; prune_with_expr( - // i IS NULL, with actual null statistcs + // i IS NULL, with actual null statistics col("i").is_null(), &schema, &statistics, @@ -2827,6 +3179,78 @@ mod tests { ); } + #[test] + fn prune_int32_column_is_known_all_null() { + let (schema, statistics) = int32_setup(); + + // Expression "i < 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass (not keep) + // i [-11, -1] ==> all rows must pass (must keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (not keep) + let expected_ret = &[true, false, true, true, false]; + + prune_with_expr( + // i < 0 + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // provide row counts for each column + let statistics = statistics.with_row_counts( + "i", + vec![ + Some(10), // 10 rows of data + Some(9), // 9 rows of data + None, // unknown row counts + Some(4), + Some(10), + ], + ); + + // pruning result is still the same if we only know row counts + prune_with_expr( + // i < 0, with only row counts statistics + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // provide null counts for each column + let statistics = statistics.with_null_counts( + "i", + vec![ + Some(0), // no nulls + Some(1), // 1 null + None, // unknown nulls + Some(4), // 4 nulls, which is the same as the row counts, i.e. this column is all null (don't keep) + Some(0), // 0 nulls (max=null too which means no known max) + ], + ); + + // Expression "i < 0" with actual null and row counts statistics + // col | min, max | row counts | null counts | + // ----+--------------+------------+-------------+ + // i | [-5, 5] | 10 | 0 | ==> Some rows could pass (must keep) + // i | [1, 11] | 9 | 1 | ==> No rows can pass (not keep) + // i | [-11,-1] | Unknown | Unknown | ==> All rows must pass (must keep) + // i | [NULL, NULL] | 4 | 4 | ==> The column is all null (not keep) + // i | [1, NULL] | 10 | 0 | ==> No rows can pass (not keep) + let expected_ret = &[true, false, true, false, false]; + + prune_with_expr( + // i < 0, with actual null and row counts statistics + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + } + #[test] fn prune_cast_column_scalar() { // The data type of column i is INT32 diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 96f5e1c3ffd3..0a1730e944d3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -43,7 +43,7 @@ use crate::logical_expr::{ Repartition, Union, UserDefinedLogicalNode, }; use crate::logical_expr::{Limit, Values}; -use crate::physical_expr::create_physical_expr; +use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; @@ -96,6 +96,7 @@ use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; use datafusion_common::config::FormatOptions; +use datafusion_physical_expr::LexOrdering; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; @@ -595,7 +596,7 @@ impl DefaultPhysicalPlanner { table_partition_cols, overwrite: false, }; - let mut table_options = session_state.default_table_options().clone(); + let mut table_options = session_state.default_table_options(); let sink_format: Arc = match format_options { FormatOptions::CSV(options) => { table_options.csv = options.clone(); @@ -742,13 +743,13 @@ impl DefaultPhysicalPlanner { ); } - let logical_input_schema = input.schema(); + let logical_schema = logical_plan.schema(); let window_expr = window_expr .iter() .map(|e| { create_window_expr( e, - logical_input_schema, + logical_schema, session_state.execution_props(), ) }) @@ -958,14 +959,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { let physical_input = self.create_initial_plan(input, session_state).await?; let input_dfschema = input.as_ref().schema(); - let sort_expr = expr - .iter() - .map(|e| create_physical_sort_expr( - e, - input_dfschema, - session_state.execution_props(), - )) - .collect::>>()?; + let sort_expr = create_physical_sort_exprs(expr, input_dfschema, session_state.execution_props())?; let new_sort = SortExec::new(sort_expr, physical_input) .with_fetch(*fetch); Ok(Arc::new(new_sort)) @@ -1578,11 +1572,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { pub fn create_window_expr_with_name( e: &Expr, name: impl Into, - logical_input_schema: &DFSchema, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); - let physical_input_schema: &Schema = &logical_input_schema.into(); + let physical_schema: &Schema = &logical_schema.into(); match e { Expr::WindowFunction(WindowFunction { fun, @@ -1592,20 +1586,11 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = args - .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) - .collect::>>()?; - let partition_by = partition_by - .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) - .collect::>>()?; - let order_by = order_by - .iter() - .map(|e| { - create_physical_sort_expr(e, logical_input_schema, execution_props) - }) - .collect::>>()?; + let args = create_physical_exprs(args, logical_schema, execution_props)?; + let partition_by = + create_physical_exprs(partition_by, logical_schema, execution_props)?; + let order_by = + create_physical_sort_exprs(order_by, logical_schema, execution_props)?; if !is_window_frame_bound_valid(window_frame) { return plan_err!( @@ -1625,7 +1610,7 @@ pub fn create_window_expr_with_name( &partition_by, &order_by, window_frame, - physical_input_schema, + physical_schema, ignore_nulls, ) } @@ -1636,7 +1621,7 @@ pub fn create_window_expr_with_name( /// Create a window expression from a logical expression or an alias pub fn create_window_expr( e: &Expr, - logical_input_schema: &DFSchema, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" @@ -1644,7 +1629,7 @@ pub fn create_window_expr( Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), _ => (e.display_name()?, e), }; - create_window_expr_with_name(e, name, logical_input_schema, execution_props) + create_window_expr_with_name(e, name, logical_schema, execution_props) } type AggregateExprWithOptionalArgs = ( @@ -1672,10 +1657,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { - let args = args - .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) - .collect::>>()?; + let args = + create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( e, @@ -1685,17 +1668,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), + Some(e) => Some(create_physical_sort_exprs( + e, + logical_input_schema, + execution_props, + )?), None => None, }; let ignore_nulls = null_treatment @@ -1782,6 +1759,18 @@ pub fn create_physical_sort_expr( } } +/// Create vector of physical sort expression from a vector of logical expression +pub fn create_physical_sort_exprs( + exprs: &[Expr], + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result { + exprs + .iter() + .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) + .collect::>>() +} + impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// @@ -2627,6 +2616,10 @@ mod tests { } impl ExecutionPlan for NoOpExecutionPlan { + fn name(&self) -> &'static str { + "NoOpExecutionPlan" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 7a466a666d8d..8113d799a184 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -165,7 +165,7 @@ impl TestParquetFile { // run coercion on the filters to coerce types etc. let props = ExecutionProps::new(); let context = SimplifyContext::new(&props).with_schema(df_schema.clone()); - let parquet_options = ctx.state().default_table_options().parquet.clone(); + let parquet_options = ctx.copied_table_options().parquet; if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, df_schema.clone()).unwrap(); diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index cea701492910..4371cce856ce 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -37,6 +37,7 @@ use datafusion::assert_batches_eq; use datafusion_common::DFSchema; use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast, ExprSchemable}; +use datafusion_functions::unicode::expr_fn::character_length; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -367,7 +368,7 @@ async fn test_fn_lpad_with_string() -> Result<()> { #[tokio::test] async fn test_fn_ltrim() -> Result<()> { - let expr = ltrim(lit(" a b c ")); + let expr = ltrim(vec![lit(" a b c ")]); let expected = [ "+-----------------------------------------+", @@ -384,7 +385,7 @@ async fn test_fn_ltrim() -> Result<()> { #[tokio::test] async fn test_fn_ltrim_with_columns() -> Result<()> { - let expr = ltrim(col("a")); + let expr = ltrim(vec![col("a")]); let expected = [ "+---------------+", diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 59905d859dc8..8df16e7944d2 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -46,7 +46,7 @@ use tokio::task::JoinSet; /// same results #[tokio::test(flavor = "multi_thread")] async fn streaming_aggregate_test() { - let test_cases = vec![ + let test_cases = [ vec!["a"], vec!["b", "a"], vec!["c", "a"], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 00c65995a5ff..2514324a9541 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ @@ -39,6 +40,7 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -273,6 +275,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { window_frame.is_causal() }; + let extended_schema = + schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + let window_expr = create_window_expr( &window_fn, fn_name.to_string(), @@ -280,7 +285,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), - schema.as_ref(), + &extended_schema, false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( @@ -678,6 +683,8 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } + let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &window_fn, @@ -686,7 +693,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, false, )?], exec1, @@ -704,7 +711,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, false, )?], exec2, @@ -747,6 +754,32 @@ async fn run_window_test( Ok(()) } +// The planner has fully updated schema before calling the `create_window_expr` +// Replicate the same for this test +fn schema_add_window_fields( + args: &[Arc], + schema: &Arc, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + true, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index f9696955769e..60010bdddfb8 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +//! Tests for the DataFusion SQL query planner that require functions from the +//! datafusion-functions crate. + use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -42,12 +45,18 @@ fn init() { let _ = env_logger::try_init(); } +#[test] +fn select_arrow_cast() { + let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; + let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ + \n EmptyRelation"; + quick_test(sql, expected); +} #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { let sql = "SELECT col_int32 FROM test WHERE col_ts_nano_none < (now() - interval '1 hour')"; - let plan = test_sql(sql)?; // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned @@ -55,7 +64,7 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { "Projection: test.col_int32\ \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - assert_eq!(expected, format!("{plan:?}")); + quick_test(sql, expected); Ok(()) } @@ -74,6 +83,11 @@ fn timestamp_nano_ts_utc_predicates() { assert_eq!(expected, format!("{plan:?}")); } +fn quick_test(sql: &str, expected_plan: &str) { + let plan = test_sql(sql).unwrap(); + assert_eq!(expected_plan, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... @@ -81,12 +95,9 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan - let now_udf = datetime::functions() - .iter() - .find(|f| f.name() == "now") - .unwrap() - .to_owned(); - let context_provider = MyContextProvider::default().with_udf(now_udf); + let context_provider = MyContextProvider::default() + .with_udf(datetime::now()) + .with_udf(datafusion_functions::core::arrow_cast()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 7649b6acd45c..1da86a0363a5 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,9 +19,10 @@ use arrow::array::Decimal128Array; use arrow::{ array::{ - Array, ArrayRef, Date32Array, Date64Array, Float64Array, Int32Array, StringArray, + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -62,14 +63,16 @@ fn init() { enum Scenario { Timestamps, Dates, - Int32, + Int, Int32Range, + UInt, Float64, Decimal, DecimalBloomFilterInt32, DecimalBloomFilterInt64, DecimalLargePrecision, DecimalLargePrecisionBloomFilter, + ByteArray, PeriodsInColumnNames, } @@ -117,16 +120,33 @@ impl TestOutput { self.metric_value("predicate_evaluation_errors") } + /// The number of row_groups matched by bloom filter + fn row_groups_matched_bloom_filter(&self) -> Option { + self.metric_value("row_groups_matched_bloom_filter") + } + /// The number of row_groups pruned by bloom filter fn row_groups_pruned_bloom_filter(&self) -> Option { self.metric_value("row_groups_pruned_bloom_filter") } + /// The number of row_groups matched by statistics + fn row_groups_matched_statistics(&self) -> Option { + self.metric_value("row_groups_matched_statistics") + } + /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { self.metric_value("row_groups_pruned_statistics") } + /// The number of row_groups matched by bloom filter or statistics + fn row_groups_matched(&self) -> Option { + self.row_groups_matched_bloom_filter() + .zip(self.row_groups_matched_statistics()) + .map(|(a, b)| a + b) + } + /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { self.row_groups_pruned_bloom_filter() @@ -368,15 +388,64 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { .unwrap() } -/// Return record batch with i32 sequence +/// Return record batch with i8, i16, i32, and i64 sequences /// /// Columns are named -/// "i" -> Int32Array -fn make_int32_batch(start: i32, end: i32) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let v: Vec = (start..end).collect(); - let array = Arc::new(Int32Array::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +/// "i8" -> Int8Array +/// "i16" -> Int16Array +/// "i32" -> Int32Array +/// "i64" -> Int64Array +fn make_int_batches(start: i8, end: i8) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("i8", DataType::Int8, true), + Field::new("i16", DataType::Int16, true), + Field::new("i32", DataType::Int32, true), + Field::new("i64", DataType::Int64, true), + ])); + let v8: Vec = (start..end).collect(); + let v16: Vec = (start as _..end as _).collect(); + let v32: Vec = (start as _..end as _).collect(); + let v64: Vec = (start as _..end as _).collect(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int8Array::from(v8)) as ArrayRef, + Arc::new(Int16Array::from(v16)) as ArrayRef, + Arc::new(Int32Array::from(v32)) as ArrayRef, + Arc::new(Int64Array::from(v64)) as ArrayRef, + ], + ) + .unwrap() +} + +/// Return record batch with i8, i16, i32, and i64 sequences +/// +/// Columns are named +/// "u8" -> UInt8Array +/// "u16" -> UInt16Array +/// "u32" -> UInt32Array +/// "u64" -> UInt64Array +fn make_uint_batches(start: u8, end: u8) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("u8", DataType::UInt8, true), + Field::new("u16", DataType::UInt16, true), + Field::new("u32", DataType::UInt32, true), + Field::new("u64", DataType::UInt64, true), + ])); + let v8: Vec = (start..end).collect(); + let v16: Vec = (start as _..end as _).collect(); + let v32: Vec = (start as _..end as _).collect(); + let v64: Vec = (start as _..end as _).collect(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt8Array::from(v8)) as ArrayRef, + Arc::new(UInt16Array::from(v16)) as ArrayRef, + Arc::new(UInt32Array::from(v32)) as ArrayRef, + Arc::new(UInt64Array::from(v64)) as ArrayRef, + ], + ) + .unwrap() } fn make_int32_range(start: i32, end: i32) -> RecordBatch { @@ -489,6 +558,51 @@ fn make_date_batch(offset: Duration) -> RecordBatch { .unwrap() } +/// returns a batch with two columns (note "service.name" is the name +/// of the column. It is *not* a table named service.name +/// +/// name | service.name +fn make_bytearray_batch( + name: &str, + string_values: Vec<&str>, + binary_values: Vec<&[u8]>, + fixedsize_values: Vec<&[u8; 3]>, +) -> RecordBatch { + let num_rows = string_values.len(); + let name: StringArray = std::iter::repeat(Some(name)).take(num_rows).collect(); + let service_string: StringArray = string_values.iter().map(Some).collect(); + let service_binary: BinaryArray = binary_values.iter().map(Some).collect(); + let service_fixedsize: FixedSizeBinaryArray = fixedsize_values + .iter() + .map(|value| Some(value.as_slice())) + .collect::>() + .into(); + + let schema = Schema::new(vec![ + Field::new("name", name.data_type().clone(), true), + // note the column name has a period in it! + Field::new("service_string", service_string.data_type().clone(), true), + Field::new("service_binary", service_binary.data_type().clone(), true), + Field::new( + "service_fixedsize", + service_fixedsize.data_type().clone(), + true, + ), + ]); + let schema = Arc::new(schema); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(name), + Arc::new(service_string), + Arc::new(service_binary), + Arc::new(service_fixedsize), + ], + ) + .unwrap() +} + /// returns a batch with two columns (note "service.name" is the name /// of the column. It is *not* a table named service.name /// @@ -526,17 +640,25 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_date_batch(TimeDelta::try_days(3600).unwrap()), ] } - Scenario::Int32 => { + Scenario::Int => { vec![ - make_int32_batch(-5, 0), - make_int32_batch(-4, 1), - make_int32_batch(0, 5), - make_int32_batch(5, 10), + make_int_batches(-5, 0), + make_int_batches(-4, 1), + make_int_batches(0, 5), + make_int_batches(5, 10), ] } Scenario::Int32Range => { vec![make_int32_range(0, 10), make_int32_range(200000, 300000)] } + Scenario::UInt => { + vec![ + make_uint_batches(0, 5), + make_uint_batches(1, 6), + make_uint_batches(5, 10), + make_uint_batches(250, 255), + ] + } Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), @@ -587,6 +709,66 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![100000, 200000, 300000, 400000, 600000], 38, 5), ] } + Scenario::ByteArray => { + // frontends first, then backends. All in order, except frontends 4 and 7 + // are swapped to cause a statistics false positive on the 'fixed size' column. + vec![ + make_bytearray_batch( + "all frontends", + vec![ + "frontend one", + "frontend two", + "frontend three", + "frontend seven", + "frontend five", + ], + vec![ + b"frontend one", + b"frontend two", + b"frontend three", + b"frontend seven", + b"frontend five", + ], + vec![b"fe1", b"fe2", b"fe3", b"fe7", b"fe5"], + ), + make_bytearray_batch( + "mixed", + vec![ + "frontend six", + "frontend four", + "backend one", + "backend two", + "backend three", + ], + vec![ + b"frontend six", + b"frontend four", + b"backend one", + b"backend two", + b"backend three", + ], + vec![b"fe6", b"fe4", b"be1", b"be2", b"be3"], + ), + make_bytearray_batch( + "all backends", + vec![ + "backend four", + "backend five", + "backend six", + "backend seven", + "backend eight", + ], + vec![ + b"backend four", + b"backend five", + b"backend six", + b"backend seven", + b"backend eight", + ], + vec![b"be4", b"be5", b"be6", b"be7", b"be8"], + ), + ] + } Scenario::PeriodsInColumnNames => { vec![ // all frontend diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 3a43428f5bcf..da9617f13ee9 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -371,112 +371,263 @@ async fn prune_date64() { assert_eq!(output.result_rows, 1, "{}", output.description()); } -#[tokio::test] -// null count min max -// page-0 0 -5 -1 -// page-1 0 -4 0 -// page-2 0 0 4 -// page-3 0 5 9 -async fn prune_int32_lt() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i < 1", - Some(0), - Some(5), - 11, - ) - .await; - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - test_prune( - Scenario::Int32, - "SELECT * FROM t where -i > -1", - Some(0), - Some(5), - 11, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_gt() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i > 8", - Some(0), - Some(15), - 1, - ) - .await; - - test_prune( - Scenario::Int32, - "SELECT * FROM t where -i < -8", - Some(0), - Some(15), - 1, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_eq() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i = 1", - Some(0), - Some(15), - 1, - ) - .await; -} -#[tokio::test] -async fn prune_int32_scalar_fun_and_eq() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where abs(i) = 1 and i = 1", - Some(0), - Some(15), - 1, - ) - .await; +macro_rules! int_tests { + ($bits:expr) => { + paste::item! { + #[tokio::test] + // null count min max + // page-0 0 -5 -1 + // page-1 0 -4 0 + // page-2 0 0 4 + // page-3 0 5 9 + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} < 1", $bits), + Some(0), + Some(5), + 11, + ) + .await; + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} > -1", $bits), + Some(0), + Some(5), + 11, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} > 8", $bits), + Some(0), + Some(15), + 1, + ) + .await; + + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} < -8", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} = 1", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where abs(i{}) = 1 and i{} = 1", $bits, $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where abs(i{}) = 1", $bits), + Some(0), + Some(0), + 3, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{}+1 = 1", $bits), + Some(0), + Some(0), + 2, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where 1-i{} > 1", $bits), + Some(0), + Some(0), + 9, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} in (1)", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} not in (1)", $bits), + Some(0), + Some(0), + 19, + ) + .await; + } + } + } } -#[tokio::test] -async fn prune_int32_scalar_fun() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where abs(i) = 1", - Some(0), - Some(0), - 3, - ) - .await; +int_tests!(8); +int_tests!(16); +int_tests!(32); +int_tests!(64); + +macro_rules! uint_tests { + ($bits:expr) => { + paste::item! { + #[tokio::test] + // null count min max + // page-0 0 0 4 + // page-1 0 1 5 + // page-2 0 5 9 + // page-3 0 250 254 + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} < 6", $bits), + Some(0), + Some(5), + 11, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} > 253", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} = 6", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", $bits, $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where power(u{}, 2) = 25", $bits), + Some(0), + Some(0), + 2, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{}+1 = 6", $bits), + Some(0), + Some(0), + 2, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} in (6)", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where not in (6)" prune nothing + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} not in (6)", $bits), + Some(0), + Some(0), + 19, + ) + .await; + } + } + } } -#[tokio::test] -async fn prune_int32_complex_expr() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i+1 = 1", - Some(0), - Some(0), - 2, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_complex_expr_subtract() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where 1-i > 1", - Some(0), - Some(0), - 9, - ) - .await; -} +uint_tests!(8); +uint_tests!(16); +uint_tests!(32); +uint_tests!(64); #[tokio::test] // null count min max @@ -556,37 +707,6 @@ async fn prune_f64_complex_expr_subtract() { .await; } -#[tokio::test] -// null count min max -// page-0 0 -5 -1 -// page-1 0 -4 0 -// page-2 0 0 4 -// page-3 0 5 9 -async fn prune_int32_eq_in_list() { - // result of sql "SELECT * FROM t where in (1)" - test_prune( - Scenario::Int32, - "SELECT * FROM t where i in (1)", - Some(0), - Some(15), - 1, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_eq_in_list_negated() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - test_prune( - Scenario::Int32, - "SELECT * FROM t where i not in (1)", - Some(0), - Some(0), - 19, - ) - .await; -} - #[tokio::test] async fn prune_decimal_lt() { // The data type of decimal_col is decimal(9,2) diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index fa53b9c56cec..b70102f78a96 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -29,7 +29,9 @@ struct RowGroupPruningTest { scenario: Scenario, query: String, expected_errors: Option, + expected_row_group_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, + expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, expected_results: usize, } @@ -40,7 +42,9 @@ impl RowGroupPruningTest { scenario: Scenario::Timestamps, // or another default query: String::new(), expected_errors: None, + expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, expected_results: 0, } @@ -64,12 +68,24 @@ impl RowGroupPruningTest { self } + // Set the expected matched row groups by statistics + fn with_matched_by_stats(mut self, matched_by_stats: Option) -> Self { + self.expected_row_group_matched_by_statistics = matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; self } + // Set the expected matched row groups by bloom filter + fn with_matched_by_bloom_filter(mut self, matched_by_bf: Option) -> Self { + self.expected_row_group_matched_by_bloom_filter = matched_by_bf; + self + } + // Set the expected pruned row groups by bloom filter fn with_pruned_by_bloom_filter(mut self, pruned_by_bf: Option) -> Self { self.expected_row_group_pruned_by_bloom_filter = pruned_by_bf; @@ -90,20 +106,36 @@ impl RowGroupPruningTest { .await; println!("{}", output.description()); - assert_eq!(output.predicate_evaluation_errors(), self.expected_errors); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); assert_eq!( output.row_groups_pruned_statistics(), - self.expected_row_group_pruned_by_statistics + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.row_groups_matched_bloom_filter(), + self.expected_row_group_matched_by_bloom_filter, + "mismatched row_groups_matched_bloom_filter", ); assert_eq!( output.row_groups_pruned_bloom_filter(), - self.expected_row_group_pruned_by_bloom_filter + self.expected_row_group_pruned_by_bloom_filter, + "mismatched row_groups_pruned_bloom_filter", ); assert_eq!( output.result_rows, self.expected_results, - "{}", - output.description() + "mismatched expected rows: {}", + output.description(), ); } } @@ -114,7 +146,9 @@ async fn prune_timestamps_nanos() { .with_scenario(Scenario::Timestamps) .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -129,7 +163,9 @@ async fn prune_timestamps_micros() { "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -144,7 +180,9 @@ async fn prune_timestamps_millis() { "SELECT * FROM t where micros < to_timestamp_millis('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -159,7 +197,9 @@ async fn prune_timestamps_seconds() { "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -172,7 +212,9 @@ async fn prune_date32() { .with_scenario(Scenario::Dates) .with_query("SELECT * FROM t where date32 < cast('2020-01-02' as date)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -201,6 +243,7 @@ async fn prune_date64() { println!("{}", output.description()); // This should prune out groups without error assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(1)); assert_eq!(output.row_groups_pruned(), Some(3)); assert_eq!(output.result_rows, 1, "{}", output.description()); } @@ -211,7 +254,9 @@ async fn prune_disabled() { .with_scenario(Scenario::Timestamps) .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -230,6 +275,7 @@ async fn prune_disabled() { // This should not prune any assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(0)); assert_eq!(output.row_groups_pruned(), Some(0)); assert_eq!( output.result_rows, @@ -239,91 +285,191 @@ async fn prune_disabled() { ); } -#[tokio::test] -async fn prune_int32_lt() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i < 1") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; +// $bits: number of bits of the integer to test (8, 16, 32, 64) +// $correct_bloom_filters: if false, replicates the +// https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass +// if and only if Bloom filters on Int8 and Int16 columns are still buggy. +macro_rules! int_tests { + ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + paste::item! { + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} < 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where -i > -1") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; -} + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where -i{} > -1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + } -#[tokio::test] -async fn prune_int32_eq() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i = 1") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(3)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; -} -#[tokio::test] -async fn prune_int32_scalar_fun_and_eq() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i = 1") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(3)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; -} + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) + .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) + .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .test_row_group_prune() + .await; + } + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) + .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) + .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .test_row_group_prune() + .await; + } -#[tokio::test] -async fn prune_int32_scalar_fun() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where abs(i) = 1") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(3) - .test_row_group_prune() - .await; -} + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where abs(i{}) = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + } -#[tokio::test] -async fn prune_int32_complex_expr() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i+1 = 1") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{}+1 = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where 1-i{} > 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1)" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) + .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) + .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1000)", prune all + // test whether statistics works + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(4)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} not in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; + } + } + }; } +int_tests!(8, correct_bloom_filters: false); +int_tests!(16, correct_bloom_filters: false); +int_tests!(32, correct_bloom_filters: true); +int_tests!(64, correct_bloom_filters: true); + #[tokio::test] -async fn prune_int32_complex_expr_subtract() { +async fn prune_int32_eq_large_in_list() { + // result of sql "SELECT * FROM t where i in (2050...2582)", prune all RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where 1-i > 1") + .with_scenario(Scenario::Int32Range) + .with_query( + format!( + "SELECT * FROM t where i in ({})", + (200050..200082).join(",") + ) + .as_str(), + ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(9) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) .test_row_group_prune() .await; } @@ -334,7 +480,9 @@ async fn prune_f64_lt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f < 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -343,7 +491,9 @@ async fn prune_f64_lt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where -f > -1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -358,7 +508,9 @@ async fn prune_f64_scalar_fun_and_gt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -372,7 +524,9 @@ async fn prune_f64_scalar_fun() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f-1) <= 0.000001") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -386,7 +540,9 @@ async fn prune_f64_complex_expr() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f+1 > 1.1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -400,76 +556,15 @@ async fn prune_f64_complex_expr_subtract() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where 1-f > 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() .await; } -#[tokio::test] -async fn prune_int32_eq_in_list() { - // result of sql "SELECT * FROM t where in (1)" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i in (1)") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(3)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq_in_list_2() { - // result of sql "SELECT * FROM t where in (1000)", prune all - // test whether statistics works - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i in (1000)") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(0) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq_large_in_list() { - // result of sql "SELECT * FROM t where i in (2050...2582)", prune all - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32Range) - .with_query( - format!( - "SELECT * FROM t where i in ({})", - (200050..200082).join(",") - ) - .as_str(), - ) - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_pruned_by_bloom_filter(Some(1)) - .with_expected_rows(0) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq_in_list_negated() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i not in (1)") - .with_expected_errors(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(19) - .test_row_group_prune() - .await; -} - #[tokio::test] async fn prune_decimal_lt() { // The data type of decimal_col is decimal(9,2) @@ -479,7 +574,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col < 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -488,7 +585,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -497,7 +596,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col < 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -506,7 +607,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -522,7 +625,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col = 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -531,7 +636,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col = 4.00") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -541,7 +648,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col = 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -550,7 +659,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col = 4.00") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -567,7 +678,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -576,7 +689,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -585,7 +700,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -594,7 +711,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -605,7 +724,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalBloomFilterInt32) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -616,7 +737,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalBloomFilterInt64) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -627,11 +750,318 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalLargePrecisionBloomFilter) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'backend one'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'backend nine'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'frontend nine'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'frontend five' < 'frontend nine' < 'frontend two' + // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string != 'backend one'", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string < 'backend one'", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string < 'backend zero'", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('backend nine' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('frontend nine' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'frontend five' < 'frontend nine' < 'frontend two' + // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary != CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary < CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary < CAST('backend zero' AS bytea)", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('fe6' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) .test_row_group_prune() .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('fe6' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('be9' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'mixed' batch: 'be1' < 'be9' < 'fe4' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize != ARROW_CAST(CAST('be1' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize < ARROW_CAST(CAST('be3' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize < ARROW_CAST(CAST('be9' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; } #[tokio::test] @@ -644,7 +1074,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(7) .test_row_group_prune() @@ -653,7 +1085,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"name\" != 'HTTP GET / DISPATCH'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -662,7 +1096,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend' AND \"name\" != 'HTTP GET / DISPATCH'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 14bc7a3d4f68..84b791a3de05 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -321,83 +321,3 @@ async fn test_accumulator_row_accumulator() -> Result<()> { Ok(()) } - -#[tokio::test] -async fn test_first_value() -> Result<()> { - let session_ctx = SessionContext::new(); - session_ctx - .sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)") - .await? - .collect() - .await?; - - let results1 = session_ctx - .sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc") - .await? - .collect() - .await?; - let expected1 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| 4 |", - "+--------------------------+", - ]; - assert_batches_eq!(expected1, &results1); - - let results2 = session_ctx - .sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc") - .await? - .collect() - .await?; - let expected2 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| |", - "+--------------------------+", - ]; - assert_batches_eq!(expected2, &results2); - - Ok(()) -} - -#[tokio::test] -async fn test_first_value_with_sort() -> Result<()> { - let session_ctx = SessionContext::new(); - session_ctx - .sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5), (1, 4, 7), (2, 3, 8)") - .await? - .collect() - .await?; - - let results1 = session_ctx - .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM abc") - .await? - .collect() - .await?; - let expected1 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| 2 |", - "+--------------------------+", - ]; - assert_batches_eq!(expected1, &results1); - - let results2 = session_ctx - .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM abc") - .await? - .collect() - .await?; - let expected2 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| |", - "+--------------------------+", - ]; - assert_batches_eq!(expected2, &results2); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 695b3ba745e2..30b11fe2a0ee 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -737,7 +737,9 @@ async fn parquet_explain_analyze() { // should contain aggregated stats assert_contains!(&formatted, "output_rows=8"); + assert_contains!(&formatted, "row_groups_matched_bloom_filter=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter=0"); + assert_contains!(&formatted, "row_groups_matched_statistics=1"); assert_contains!(&formatted, "row_groups_pruned_statistics=0"); } @@ -754,7 +756,9 @@ async fn parquet_explain_analyze_verbose() { .to_string(); // should contain the raw per file stats (with the label) + assert_contains!(&formatted, "row_groups_matched_bloom_filter{partition=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter{partition=0"); + assert_contains!(&formatted, "row_groups_matched_statistics{partition=0"); assert_contains!(&formatted, "row_groups_pruned_statistics{partition=0"); } diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index d7adc9611b2f..b3a819fbc331 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::prelude::*; + use tempfile::TempDir; #[tokio::test] @@ -27,7 +28,7 @@ async fn unsupported_ddl_returns_error() { // disallow ddl let options = SQLOptions::new().with_allow_ddl(false); - let sql = "create view test_view as select * from test"; + let sql = "CREATE VIEW test_view AS SELECT * FROM test"; let df = ctx.sql_with_options(sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -46,7 +47,7 @@ async fn unsupported_dml_returns_error() { let options = SQLOptions::new().with_allow_dml(false); - let sql = "insert into test values (1)"; + let sql = "INSERT INTO test VALUES (1)"; let df = ctx.sql_with_options(sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -67,7 +68,10 @@ async fn unsupported_copy_returns_error() { let options = SQLOptions::new().with_allow_dml(false); - let sql = format!("copy (values(1)) to '{}'", tmpfile.to_string_lossy()); + let sql = format!( + "COPY (values(1)) TO '{}' STORED AS parquet", + tmpfile.to_string_lossy() + ); let df = ctx.sql_with_options(&sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -106,7 +110,7 @@ async fn ddl_can_not_be_planned_by_session_state() { let state = ctx.state(); // can not create a logical plan for catalog DDL - let sql = "drop table test"; + let sql = "DROP TABLE test"; let plan = state.create_logical_plan(sql).await.unwrap(); let physical_plan = state.create_physical_plan(&plan).await; assert_eq!( diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 4db97c75cb33..e8d2c3764e0c 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -73,9 +73,6 @@ async fn tpcds_logical_q9() -> Result<()> { create_logical_plan(9).await } -#[ignore] -// Schema error: No field named 'c'.'c_customer_sk'. -// issue: https://github.com/apache/arrow-datafusion/issues/4794 #[tokio::test] async fn tpcds_logical_q10() -> Result<()> { create_logical_plan(10).await @@ -201,9 +198,6 @@ async fn tpcds_logical_q34() -> Result<()> { create_logical_plan(34).await } -#[ignore] -// Schema error: No field named 'c'.'c_customer_sk'. -// issue: https://github.com/apache/arrow-datafusion/issues/4794 #[tokio::test] async fn tpcds_logical_q35() -> Result<()> { create_logical_plan(35).await @@ -577,7 +571,7 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // FieldNotFound +#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -703,7 +697,7 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // FieldNotFound +#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -734,7 +728,8 @@ async fn tpcds_physical_q40() -> Result<()> { create_physical_plan(40).await } -#[ignore] // Physical plan does not support logical expression () +#[ignore] +// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..) #[tokio::test] async fn tpcds_physical_q41() -> Result<()> { create_physical_plan(41).await diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 3f40c55a3ed7..a58a8cf51681 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -184,11 +184,11 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let expected = [ - "+-------------+", - "| SUM(t.time) |", - "+-------------+", - "| 19000 |", - "+-------------+", + "+---------------------------------------+", + "| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |", + "+---------------------------------------+", + "| 19000 |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b525e4fc6341..86be887198ae 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -747,7 +747,7 @@ struct CustomFunctionFactory {} impl FunctionFactory for CustomFunctionFactory { async fn create( &self, - _state: &SessionConfig, + _state: &SessionState, statement: CreateFunction, ) -> Result { let f: ScalarFunctionWrapper = statement.try_into()?; diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 312aef953e9c..0a7a87c7d81a 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -22,7 +22,10 @@ use std::{ sync::Arc, }; -use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; +use datafusion_common::{ + config::{ConfigExtension, ConfigOptions}, + Result, ScalarValue, +}; /// Configuration options for [`SessionContext`]. /// @@ -198,6 +201,12 @@ impl SessionConfig { self } + /// Insert new [ConfigExtension] + pub fn with_option_extension(mut self, extension: T) -> Self { + self.options_mut().extensions.insert(extension); + self + } + /// Get [`target_partitions`] /// /// [`target_partitions`]: datafusion_common::config::ExecutionOptions::target_partitions @@ -434,9 +443,9 @@ impl SessionConfig { /// converted to strings. /// /// Note that this method will eventually be deprecated and - /// replaced by [`config_options`]. + /// replaced by [`options`]. /// - /// [`config_options`]: Self::config_options + /// [`options`]: Self::options pub fn to_props(&self) -> HashMap { let mut map = HashMap::new(); // copy configs from config_options @@ -447,18 +456,6 @@ impl SessionConfig { map } - /// Return a handle to the configuration options. - #[deprecated(since = "21.0.0", note = "use options() instead")] - pub fn config_options(&self) -> &ConfigOptions { - &self.options - } - - /// Return a mutable handle to the configuration options. - #[deprecated(since = "21.0.0", note = "use options_mut() instead")] - pub fn config_options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - /// Add extensions. /// /// Extensions can be used to attach extra data to the session config -- e.g. tracing information or caches. diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index cae410655d10..4216ce95f35e 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -20,10 +20,7 @@ use std::{ sync::Arc, }; -use datafusion_common::{ - config::{ConfigOptions, Extensions}, - plan_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ @@ -102,39 +99,6 @@ impl TaskContext { } } - /// Create a new task context instance, by first copying all - /// name/value pairs from `task_props` into a `SessionConfig`. - #[deprecated( - since = "21.0.0", - note = "Construct SessionConfig and call TaskContext::new() instead" - )] - pub fn try_new( - task_id: String, - session_id: String, - task_props: HashMap, - scalar_functions: HashMap>, - aggregate_functions: HashMap>, - runtime: Arc, - extensions: Extensions, - ) -> Result { - let mut config = ConfigOptions::new().with_extensions(extensions); - for (k, v) in task_props { - config.set(&k, &v)?; - } - let session_config = SessionConfig::from(config); - let window_functions = HashMap::new(); - - Ok(Self::new( - Some(task_id), - session_id, - session_config, - scalar_functions, - aggregate_functions, - window_functions, - runtime, - )) - } - /// Return the SessionConfig associated with this [TaskContext] pub fn session_config(&self) -> &SessionConfig { &self.session_config @@ -160,7 +124,7 @@ impl TaskContext { self.runtime.clone() } - /// Update the [`ConfigOptions`] + /// Update the [`SessionConfig`] pub fn with_session_config(mut self, session_config: SessionConfig) -> Self { self.session_config = session_config; self @@ -229,7 +193,10 @@ impl FunctionRegistry for TaskContext { #[cfg(test)] mod tests { use super::*; - use datafusion_common::{config::ConfigExtension, extensions_options}; + use datafusion_common::{ + config::{ConfigExtension, ConfigOptions, Extensions}, + extensions_options, + }; extensions_options! { struct TestExtension { diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 621a320230f2..6f6147d36883 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -43,6 +43,7 @@ arrow-array = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } paste = "^1.0" +serde_json = { workspace = true } sqlparser = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } strum_macros = "0.26.0" diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 574de3e7082a..85f8c74f3737 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -218,19 +218,6 @@ impl FromStr for AggregateFunction { } } -/// Returns the datatype of the aggregate function. -/// This is used to get the returned data type for aggregate expr. -#[deprecated( - since = "27.0.0", - note = "please use `AggregateFunction::return_type` instead" -)] -pub fn return_type( - fun: &AggregateFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - impl AggregateFunction { /// Returns the datatype of the aggregate function given its argument types /// @@ -328,15 +315,6 @@ pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { avg_sum_type(&coerced_data_types[0]) } -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `AggregateFunction::signature` instead" -)] -pub fn signature(fun: &AggregateFunction) -> Signature { - fun.signature() -} - impl AggregateFunction { /// the signatures supported by the function `fun`. pub fn signature(&self) -> Signature { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fe3397b1af52..a1b3b717392e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -37,16 +37,6 @@ use strum_macros::EnumIter; #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions - /// atan - Atan, - /// atan2 - Atan2, - /// acosh - Acosh, - /// asinh - Asinh, - /// atanh - Atanh, /// cbrt Cbrt, /// ceil @@ -71,14 +61,8 @@ pub enum BuiltinScalarFunction { Lcm, /// iszero Iszero, - /// ln, Natural logarithm - Ln, /// log, same as log10 Log, - /// log10 - Log10, - /// log2 - Log2, /// nanvl Nanvl, /// pi @@ -102,25 +86,7 @@ pub enum BuiltinScalarFunction { /// cot Cot, - // array functions - /// array_replace - ArrayReplace, - /// array_replace_n - ArrayReplaceN, - /// array_replace_all - ArrayReplaceAll, - // string functions - /// ascii - Ascii, - /// bit_length - BitLength, - /// btrim - Btrim, - /// character_length - CharacterLength, - /// chr - Chr, /// concat Concat, /// concat_ws @@ -129,56 +95,8 @@ pub enum BuiltinScalarFunction { EndsWith, /// initcap InitCap, - /// left - Left, - /// lpad - Lpad, - /// lower - Lower, - /// ltrim - Ltrim, - /// octet_length - OctetLength, /// random Random, - /// repeat - Repeat, - /// replace - Replace, - /// reverse - Reverse, - /// right - Right, - /// rpad - Rpad, - /// rtrim - Rtrim, - /// split_part - SplitPart, - /// starts_with - StartsWith, - /// strpos - Strpos, - /// substr - Substr, - /// to_hex - ToHex, - /// translate - Translate, - /// trim - Trim, - /// upper - Upper, - /// uuid - Uuid, - /// overlay - OverLay, - /// levenshtein - Levenshtein, - /// substr_index - SubstrIndex, - /// find_in_set - FindInSet, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -231,11 +149,6 @@ impl BuiltinScalarFunction { pub fn volatility(&self) -> Volatility { match self { // Immutable scalar builtins - BuiltinScalarFunction::Atan => Volatility::Immutable, - BuiltinScalarFunction::Atan2 => Volatility::Immutable, - BuiltinScalarFunction::Acosh => Volatility::Immutable, - BuiltinScalarFunction::Asinh => Volatility::Immutable, - BuiltinScalarFunction::Atanh => Volatility::Immutable, BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, @@ -247,10 +160,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd => Volatility::Immutable, BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Lcm => Volatility::Immutable, - BuiltinScalarFunction::Ln => Volatility::Immutable, BuiltinScalarFunction::Log => Volatility::Immutable, - BuiltinScalarFunction::Log10 => Volatility::Immutable, - BuiltinScalarFunction::Log2 => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, @@ -262,46 +172,14 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, - BuiltinScalarFunction::Ascii => Volatility::Immutable, - BuiltinScalarFunction::BitLength => Volatility::Immutable, - BuiltinScalarFunction::Btrim => Volatility::Immutable, - BuiltinScalarFunction::CharacterLength => Volatility::Immutable, - BuiltinScalarFunction::Chr => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::Left => Volatility::Immutable, - BuiltinScalarFunction::Lpad => Volatility::Immutable, - BuiltinScalarFunction::Lower => Volatility::Immutable, - BuiltinScalarFunction::Ltrim => Volatility::Immutable, - BuiltinScalarFunction::OctetLength => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Repeat => Volatility::Immutable, - BuiltinScalarFunction::Replace => Volatility::Immutable, - BuiltinScalarFunction::Reverse => Volatility::Immutable, - BuiltinScalarFunction::Right => Volatility::Immutable, - BuiltinScalarFunction::Rpad => Volatility::Immutable, - BuiltinScalarFunction::Rtrim => Volatility::Immutable, - BuiltinScalarFunction::SplitPart => Volatility::Immutable, - BuiltinScalarFunction::StartsWith => Volatility::Immutable, - BuiltinScalarFunction::Strpos => Volatility::Immutable, - BuiltinScalarFunction::Substr => Volatility::Immutable, - BuiltinScalarFunction::ToHex => Volatility::Immutable, - BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::Trim => Volatility::Immutable, - BuiltinScalarFunction::Upper => Volatility::Immutable, - BuiltinScalarFunction::OverLay => Volatility::Immutable, - BuiltinScalarFunction::Levenshtein => Volatility::Immutable, - BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, - BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Volatile builtin functions BuiltinScalarFunction::Random => Volatility::Volatile, - BuiltinScalarFunction::Uuid => Volatility::Volatile, } } @@ -322,20 +200,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::Ascii => Ok(Int32), - BuiltinScalarFunction::BitLength => { - utf8_to_int_type(&input_expr_types[0], "bit_length") - } - BuiltinScalarFunction::Btrim => { - utf8_to_str_type(&input_expr_types[0], "btrim") - } - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } - BuiltinScalarFunction::Chr => Ok(Utf8), BuiltinScalarFunction::Coalesce => { // COALESCE has multiple args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &self.signature()); @@ -346,66 +210,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lower => { - utf8_to_str_type(&input_expr_types[0], "lower") - } - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::Ltrim => { - utf8_to_str_type(&input_expr_types[0], "ltrim") - } - BuiltinScalarFunction::OctetLength => { - utf8_to_int_type(&input_expr_types[0], "octet_length") - } BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Uuid => Ok(Utf8), - BuiltinScalarFunction::Repeat => { - utf8_to_str_type(&input_expr_types[0], "repeat") - } - BuiltinScalarFunction::Replace => { - utf8_to_str_type(&input_expr_types[0], "replace") - } - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => { - utf8_to_str_type(&input_expr_types[0], "right") - } - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::Rtrim => { - utf8_to_str_type(&input_expr_types[0], "rtrim") - } - BuiltinScalarFunction::SplitPart => { - utf8_to_str_type(&input_expr_types[0], "split_part") - } - BuiltinScalarFunction::StartsWith => Ok(Boolean), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::Strpos => { - utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") - } - BuiltinScalarFunction::Substr => { - utf8_to_str_type(&input_expr_types[0], "substr") - } - BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, - _ => { - return plan_err!("The to_hex function can only accept integers."); - } - }), - BuiltinScalarFunction::SubstrIndex => { - utf8_to_str_type(&input_expr_types[0], "substr_index") - } - BuiltinScalarFunction::FindInSet => { - utf8_to_int_type(&input_expr_types[0], "find_in_set") - } - BuiltinScalarFunction::Translate => { - utf8_to_str_type(&input_expr_types[0], "translate") - } - BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), - BuiltinScalarFunction::Upper => { - utf8_to_str_type(&input_expr_types[0], "upper") - } BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd @@ -416,11 +223,6 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, - BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - BuiltinScalarFunction::Log => match &input_expr_types[0] { Float32 => Ok(Float32), _ => Ok(Float64), @@ -433,27 +235,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Iszero => Ok(Boolean), - BuiltinScalarFunction::OverLay => { - utf8_to_str_type(&input_expr_types[0], "overlay") - } - - BuiltinScalarFunction::Levenshtein => { - utf8_to_int_type(&input_expr_types[0], "levenshtein") - } - - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Cos | BuiltinScalarFunction::Cosh | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 | BuiltinScalarFunction::Radians | BuiltinScalarFunction::Round | BuiltinScalarFunction::Signum @@ -477,11 +264,6 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), - BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), - BuiltinScalarFunction::ArrayReplaceAll => { - Signature::any(3, self.volatility()) - } BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) @@ -489,58 +271,11 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::Ascii - | BuiltinScalarFunction::BitLength - | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => { + BuiltinScalarFunction::InitCap => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Btrim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim - | BuiltinScalarFunction::Trim => Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], - self.volatility(), - ), - BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { - Signature::uniform(1, vec![Int64], self.volatility()) - } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), - Exact(vec![Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, Int64, LargeUtf8]), - ], - self.volatility(), - ) - } - BuiltinScalarFunction::Left - | BuiltinScalarFunction::Repeat - | BuiltinScalarFunction::Right => Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - self.volatility(), - ), - BuiltinScalarFunction::SplitPart => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), - BuiltinScalarFunction::EndsWith - | BuiltinScalarFunction::Strpos - | BuiltinScalarFunction::StartsWith => Signature::one_of( + BuiltinScalarFunction::EndsWith => Signature::one_of( vec![ Exact(vec![Utf8, Utf8]), Exact(vec![Utf8, LargeUtf8]), @@ -549,35 +284,8 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - - BuiltinScalarFunction::Substr => Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), - ], - self.volatility(), - ), - - BuiltinScalarFunction::SubstrIndex => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), - BuiltinScalarFunction::FindInSet => Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], - self.volatility(), - ), - - BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { - Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) - } BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), - BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Power => Signature::one_of( vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], self.volatility(), @@ -600,10 +308,7 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - self.volatility(), - ), + BuiltinScalarFunction::Log => Signature::one_of( vec![ Exact(vec![Float32]), @@ -623,33 +328,13 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { Signature::uniform(2, vec![Int64], self.volatility()) } - BuiltinScalarFunction::OverLay => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), - BuiltinScalarFunction::Levenshtein => Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], - self.volatility(), - ), - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Cbrt + BuiltinScalarFunction::Cbrt | BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Cos | BuiltinScalarFunction::Cosh | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 | BuiltinScalarFunction::Radians | BuiltinScalarFunction::Signum | BuiltinScalarFunction::Sin @@ -676,18 +361,11 @@ impl BuiltinScalarFunction { pub fn monotonicity(&self) -> Option { if matches!( &self, - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 | BuiltinScalarFunction::Radians | BuiltinScalarFunction::Round | BuiltinScalarFunction::Signum @@ -708,11 +386,6 @@ impl BuiltinScalarFunction { /// Returns all names that can be used to call this function pub fn aliases(&self) -> &'static [&'static str] { match self { - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], - BuiltinScalarFunction::Atan2 => &["atan2"], BuiltinScalarFunction::Cbrt => &["cbrt"], BuiltinScalarFunction::Ceil => &["ceil"], BuiltinScalarFunction::Cos => &["cos"], @@ -725,10 +398,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd => &["gcd"], BuiltinScalarFunction::Iszero => &["iszero"], BuiltinScalarFunction::Lcm => &["lcm"], - BuiltinScalarFunction::Ln => &["ln"], BuiltinScalarFunction::Log => &["log"], - BuiltinScalarFunction::Log10 => &["log10"], - BuiltinScalarFunction::Log2 => &["log2"], BuiltinScalarFunction::Nanvl => &["nanvl"], BuiltinScalarFunction::Pi => &["pi"], BuiltinScalarFunction::Power => &["power", "pow"], @@ -744,51 +414,10 @@ impl BuiltinScalarFunction { // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], - // string functions - BuiltinScalarFunction::Ascii => &["ascii"], - BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::Chr => &["chr"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lower => &["lower"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], - BuiltinScalarFunction::OctetLength => &["octet_length"], - BuiltinScalarFunction::Repeat => &["repeat"], - BuiltinScalarFunction::Replace => &["replace"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], - BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StartsWith => &["starts_with"], - BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], - BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], - BuiltinScalarFunction::Uuid => &["uuid"], - BuiltinScalarFunction::Levenshtein => &["levenshtein"], - BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], - BuiltinScalarFunction::FindInSet => &["find_in_set"], - - // hashing functions - BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], - BuiltinScalarFunction::ArrayReplaceN => { - &["array_replace_n", "list_replace_n"] - } - BuiltinScalarFunction::ArrayReplaceAll => { - &["array_replace_all", "list_replace_all"] - } - BuiltinScalarFunction::OverLay => &["overlay"], } } } @@ -853,9 +482,6 @@ macro_rules! get_optimal_return_type { // `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); -// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. -get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index 831edc078d6a..87c3c063b91a 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -26,11 +26,14 @@ use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; /// Represents the result of evaluating an expression: either a single -/// `ScalarValue` or an [`ArrayRef`]. +/// [`ScalarValue`] or an [`ArrayRef`]. /// /// While a [`ColumnarValue`] can always be converted into an array /// for convenience, it is often much more performant to provide an /// optimized path for scalar values. +/// +/// See [`ColumnarValue::values_to_arrays`] for a function that converts +/// multiple columnar values into arrays of the same length. #[derive(Clone, Debug)] pub enum ColumnarValue { /// Array of values @@ -59,8 +62,13 @@ impl ColumnarValue { } } - /// Convert a columnar value into an ArrayRef. [`Self::Scalar`] is - /// converted by repeating the same scalar multiple times. + /// Convert a columnar value into an Arrow [`ArrayRef`] with the specified + /// number of rows. [`Self::Scalar`] is converted by repeating the same + /// scalar multiple times which is not as efficient as handling the scalar + /// directly. + /// + /// See [`Self::values_to_arrays`] to convert multiple columnar values into + /// arrays of the same length. /// /// # Errors /// diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0da05d96f67e..7ede4cd8ffc9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -703,27 +703,6 @@ pub fn find_df_window_func(name: &str) -> Option { } } -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunctionDefinition, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunctionDefinition) -> Signature { - fun.signature() -} - // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -887,13 +866,6 @@ impl Expr { create_name(self) } - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. - #[deprecated(since = "14.0.0", note = "please use `display_name` instead")] - pub fn name(&self) -> Result { - self.display_name() - } - /// Returns a full and complete string representation of this expression. pub fn canonical_name(&self) -> String { format!("{self}") diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c5ad2a9b3ce4..a2015787040f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -541,10 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine"); scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); -scalar_expr!(Atan, atan, num, "inverse tangent"); -scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine"); -scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine"); -scalar_expr!(Atanh, atanh, num, "inverse hyperbolic tangent"); scalar_expr!(Factorial, factorial, num, "factorial"); scalar_expr!( Floor, @@ -570,114 +566,11 @@ scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) "); scalar_expr!(Exp, exp, num, "exponential"); scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor"); scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); -scalar_expr!(Log2, log2, num, "base 2 logarithm of number"); -scalar_expr!(Log10, log10, num, "base 10 logarithm of number"); -scalar_expr!(Ln, ln, num, "natural logarithm (base e) of number"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); -scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); -scalar_expr!( - ToHex, - to_hex, - num, - "returns the hexdecimal representation of an integer" -); -scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); -scalar_expr!( - ArrayReplace, - array_replace, - array from to, - "replaces the first occurrence of the specified element with another specified element." -); -scalar_expr!( - ArrayReplaceN, - array_replace_n, - array from to max, - "replaces the first `max` occurrences of the specified element with another specified element." -); -scalar_expr!( - ArrayReplaceAll, - array_replace_all, - array from to, - "replaces all occurrences of the specified element with another specified element." -); - -// string functions -scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); -scalar_expr!( - BitLength, - bit_length, - string, - "the number of bits in the `string`" -); -scalar_expr!( - CharacterLength, - character_length, - string, - "the number of characters in the `string`" -); -scalar_expr!( - Chr, - chr, - code_point, - "converts the Unicode code point to a UTF8 character" -); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); -scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); -scalar_expr!(Lower, lower, string, "convert the string to lower case"); -scalar_expr!( - Ltrim, - ltrim, - string, - "removes all characters, spaces by default, from the beginning of a string" -); -scalar_expr!( - OctetLength, - octet_length, - string, - "returns the number of bytes of a string" -); -scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`"); -scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times"); -scalar_expr!(Reverse, reverse, string, "reverses the `string`"); -scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); -scalar_expr!( - Rtrim, - rtrim, - string, - "removes all characters, spaces by default, from the end of a string" -); -scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); -scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); -scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); -scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); -scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); -scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); -scalar_expr!( - Trim, - trim, - string, - "removes all characters, space by default from the string" -); -scalar_expr!(Upper, upper, string, "converts the string to upper case"); -//use vec as parameter -nary_scalar_expr!( - Lpad, - lpad, - "fill up a string to the length by prepending the characters" -); -nary_scalar_expr!( - Rpad, - rpad, - "fill up a string to the length by appending the characters" -); -nary_scalar_expr!( - Btrim, - btrim, - "removes all characters, spaces by default, from both sides of a string" -); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -686,12 +579,6 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); -nary_scalar_expr!( - OverLay, - overlay, - "replace the substring of string that starts at the start'th character and extends for count characters with new substring" -); - scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); scalar_expr!( Iszero, @@ -700,10 +587,6 @@ scalar_expr!( "returns true if a given number is +0.0 or -0.0 otherwise returns false" ); -scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); -scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); -scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); - /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None) @@ -1091,10 +974,6 @@ mod test { test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); - test_unary_scalar_expr!(Atan, atan); - test_unary_scalar_expr!(Asinh, asinh); - test_unary_scalar_expr!(Acosh, acosh); - test_unary_scalar_expr!(Atanh, atanh); test_unary_scalar_expr!(Factorial, factorial); test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); @@ -1106,69 +985,12 @@ mod test { test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Signum, signum); test_unary_scalar_expr!(Exp, exp); - test_unary_scalar_expr!(Log2, log2); - test_unary_scalar_expr!(Log10, log10); - test_unary_scalar_expr!(Ln, ln); - test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(Ascii, ascii, input); - test_scalar_expr!(BitLength, bit_length, string); - test_nary_scalar_expr!(Btrim, btrim, string); - test_nary_scalar_expr!(Btrim, btrim, string, characters); - test_scalar_expr!(CharacterLength, character_length, string); - test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); - test_scalar_expr!(Left, left, string, count); - test_scalar_expr!(Lower, lower, string); - test_nary_scalar_expr!(Lpad, lpad, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Ltrim, ltrim, string); - test_scalar_expr!(OctetLength, octet_length, string); - test_scalar_expr!(Replace, replace, string, from, to); - test_scalar_expr!(Repeat, repeat, string, count); - test_scalar_expr!(Reverse, reverse, string); - test_scalar_expr!(Right, right, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count, characters); - test_scalar_expr!(Rtrim, rtrim, string); - test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); - test_scalar_expr!(StartsWith, starts_with, string, characters); test_scalar_expr!(EndsWith, ends_with, string, characters); - test_scalar_expr!(Strpos, strpos, string, substring); - test_scalar_expr!(Substr, substr, string, position); - test_scalar_expr!(Substr, substring, string, position, count); - test_scalar_expr!(ToHex, to_hex, string); - test_scalar_expr!(Translate, translate, string, from, to); - test_scalar_expr!(Trim, trim, string); - test_scalar_expr!(Upper, upper, string); - - test_scalar_expr!(ArrayReplace, array_replace, array, from, to); - test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max); - test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to); - - test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); - test_nary_scalar_expr!(OverLay, overlay, string, characters, position); - test_scalar_expr!(Levenshtein, levenshtein, string1, string2); - test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); - test_scalar_expr!(FindInSet, find_in_set, string, stringlist); - } - - #[test] - fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = uuid() - { - let name = BuiltinScalarFunction::Uuid; - assert_eq!(name, fun); - assert_eq!(0, args.len()); - } else { - unreachable!(); - } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 357b1aed7dde..7a227a91c455 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -74,31 +74,6 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { .data() } -/// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions -/// in the `expr` expression tree. -#[deprecated( - since = "20.0.0", - note = "use normalize_col_with_schemas_and_ambiguity_check instead" -)] -#[allow(deprecated)] -pub fn normalize_col_with_schemas( - expr: Expr, - schemas: &[&Arc], - using_columns: &[HashSet], -) -> Result { - expr.transform(&|expr| { - Ok({ - if let Expr::Column(c) = expr { - let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(expr) - } - }) - }) - .data() -} - /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage pub fn normalize_col_with_schemas_and_ambiguity_check( expr: Expr, @@ -398,31 +373,13 @@ mod test { ); } - #[test] - #[allow(deprecated)] - fn normalize_cols_priority() { - let expr = col("a") + col("b"); - // Schemas with multiple matches for column a, first takes priority - let schema_a = make_schema_with_empty_metadata(vec![make_field("tableA", "a")]); - let schema_b = make_schema_with_empty_metadata(vec![make_field("tableB", "b")]); - let schema_a2 = make_schema_with_empty_metadata(vec![make_field("tableA2", "a")]); - let schemas = vec![schema_a2, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); - } - #[test] fn normalize_cols_non_exist() { // test normalizing columns when the name doesn't exist let expr = col("a") + col("b"); let schema_a = make_schema_with_empty_metadata(vec![make_field("\"tableA\"", "a")]); - let schemas = vec![schema_a]; + let schemas = [schema_a]; let schemas = schemas.iter().collect::>(); let error = diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1d83fbe8c0e0..f1ac22d584ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -50,6 +50,10 @@ pub trait ExprSchemable { /// cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; + + /// given a schema, return the type and nullability of the expr + fn data_type_and_nullable(&self, schema: &dyn ExprSchema) + -> Result<(DataType, bool)>; } impl ExprSchemable for Expr { @@ -370,32 +374,90 @@ impl ExprSchemable for Expr { } } + /// Returns the datatype and nullability of the expression based on [ExprSchema]. + /// + /// Note: [`DFSchema`] implements [ExprSchema]. + /// + /// [`DFSchema`]: datafusion_common::DFSchema + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// datatype or nullability. + fn data_type_and_nullable( + &self, + schema: &dyn ExprSchema, + ) -> Result<(DataType, bool)> { + match self { + Expr::Alias(Alias { expr, name, .. }) => match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| (d.clone(), n)), + Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), + }, + _ => expr.data_type_and_nullable(schema), + }, + Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => { + expr.data_type_and_nullable(schema) + } + Expr::Column(c) => schema + .data_type_and_nullable(c) + .map(|(d, n)| (d.clone(), n)), + Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), + Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), + Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + Expr::IsNull(_) + | Expr::IsNotNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Exists { .. } => Ok((DataType::Boolean, false)), + Expr::ScalarSubquery(subquery) => Ok(( + subquery.subquery.schema().field(0).data_type().clone(), + subquery.subquery.schema().field(0).is_nullable(), + )), + Expr::BinaryExpr(BinaryExpr { + ref left, + ref right, + ref op, + }) => { + let left = left.data_type_and_nullable(schema)?; + let right = right.data_type_and_nullable(schema)?; + Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) + } + _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + } + } + /// Returns a [arrow::datatypes::Field] compatible with this expression. /// /// So for example, a projected expression `col(c1) + col(c2)` is /// placed in an output field **named** col("c1 + c2") fn to_field(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Column(c) => Ok(DFField::new( - c.relation.clone(), - &c.name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - ) - .with_metadata(self.metadata(input_schema)?)), - Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( - relation.clone(), - name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - ) - .with_metadata(self.metadata(input_schema)?)), - _ => Ok(DFField::new_unqualified( - &self.display_name()?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - ) - .with_metadata(self.metadata(input_schema)?)), + Expr::Column(c) => { + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + Ok( + DFField::new(c.relation.clone(), &c.name, data_type, nullable) + .with_metadata(self.metadata(input_schema)?), + ) + } + Expr::Alias(Alias { relation, name, .. }) => { + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(DFField::new(relation.clone(), name, data_type, nullable) + .with_metadata(self.metadata(input_schema)?)) + } + _ => { + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + Ok( + DFField::new_unqualified(&self.display_name()?, data_type, nullable) + .with_metadata(self.metadata(input_schema)?), + ) + } } } @@ -704,5 +766,9 @@ mod tests { fn metadata(&self, _col: &Column) -> Result<&HashMap> { Ok(&self.metadata) } + + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + Ok((self.data_type(col)?, self.nullable(col)?)) + } } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index a3760eeb357d..adf4dd3fef20 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,9 +17,7 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{ - Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature, -}; +use crate::{Accumulator, ColumnarValue, PartitionEvaluator}; use arrow::datatypes::DataType; use datafusion_common::Result; use std::sync::Arc; @@ -53,24 +51,3 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; - -/// Returns the datatype of the scalar function -#[deprecated( - since = "27.0.0", - note = "please use `BuiltinScalarFunction::return_type` instead" -)] -pub fn return_type( - fun: &BuiltinScalarFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -/// Return the [`Signature`] supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltinScalarFunction::signature` instead" -)] -pub fn signature(fun: &BuiltinScalarFunction) -> Signature { - fun.signature() -} diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 01e6af948762..f47249d76d5b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -51,9 +51,9 @@ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, - DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, + get_target_functional_dependencies, not_impl_err, plan_datafusion_err, plan_err, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; /// Default table name for unnamed table @@ -132,14 +132,26 @@ impl LogicalPlanBuilder { ) -> Result { // TODO: we need to do a bunch of validation here. Maybe more. if is_distinct { - return Err(DataFusionError::NotImplemented( - "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported".to_string(), - )); + return not_impl_err!( + "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported" + ); + } + // Ensure that the static term and the recursive term have the same number of fields + let static_fields_len = self.plan.schema().fields().len(); + let recurive_fields_len = recursive_term.schema().fields().len(); + if static_fields_len != recurive_fields_len { + return plan_err!( + "Non-recursive term and recursive term must have the same number of columns ({} != {})", + static_fields_len, recurive_fields_len + ); } + // Ensure that the recursive term has the same field types as the static term + let coerced_recursive_term = + coerce_plan_expr_for_schema(&recursive_term, self.plan.schema())?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term: Arc::new(self.plan.clone()), - recursive_term: Arc::new(recursive_term), + recursive_term: Arc::new(coerced_recursive_term), is_distinct, }))) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index e0cb44626e24..edc3afd55d63 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -16,14 +16,22 @@ // under the License. //! This module provides logic for displaying LogicalPlans in various styles +use std::collections::HashMap; use std::fmt; -use crate::LogicalPlan; +use crate::{ + expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, + Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, + Unnest, Values, Window, +}; +use crate::dml::CopyTo; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; +use serde_json::json; /// Formats plans with a single line per node. For example: /// @@ -221,6 +229,490 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } } +/// Formats plans to display as postgresql plan json format. +/// +/// There are already many existing visualizer for this format, for example [dalibo](https://explain.dalibo.com/). +/// Unfortunately, there is no formal spec for this format, but it is widely used in the PostgreSQL community. +/// +/// Here is an example of the format: +/// +/// ```json +/// [ +/// { +/// "Plan": { +/// "Node Type": "Sort", +/// "Output": [ +/// "question_1.id", +/// "question_1.title", +/// "question_1.text", +/// "question_1.file", +/// "question_1.type", +/// "question_1.source", +/// "question_1.exam_id" +/// ], +/// "Sort Key": [ +/// "question_1.id" +/// ], +/// "Plans": [ +/// { +/// "Node Type": "Seq Scan", +/// "Parent Relationship": "Left", +/// "Relation Name": "question", +/// "Schema": "public", +/// "Alias": "question_1", +/// "Output": [ +/// "question_1.id", +/// "question_1.title", +/// "question_1.text", +/// "question_1.file", +/// "question_1.type", +/// "question_1.source", +/// "question_1.exam_id" +/// ], +/// "Filter": "(question_1.exam_id = 1)" +/// } +/// ] +/// } +/// } +/// ] +/// ``` +pub struct PgJsonVisitor<'a, 'b> { + f: &'a mut fmt::Formatter<'b>, + + /// A mapping from plan node id to the plan node json representation. + objects: HashMap, + + next_id: u32, + + /// If true, includes summarized schema information + with_schema: bool, + + /// Holds the ids (as generated from `graphviz_builder` of all + /// parent nodes + parent_ids: Vec, +} + +impl<'a, 'b> PgJsonVisitor<'a, 'b> { + pub fn new(f: &'a mut fmt::Formatter<'b>) -> Self { + Self { + f, + objects: HashMap::new(), + next_id: 0, + with_schema: false, + parent_ids: Vec::new(), + } + } + + /// Sets a flag which controls if the output schema is displayed + pub fn with_schema(&mut self, with_schema: bool) { + self.with_schema = with_schema; + } + + /// Converts a logical plan node to a json object. + fn to_json_value(node: &LogicalPlan) -> serde_json::Value { + match node { + LogicalPlan::EmptyRelation(_) => { + json!({ + "Node Type": "EmptyRelation", + }) + } + LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { + json!({ + "Node Type": "RecursiveQuery", + "Is Distinct": is_distinct, + }) + } + LogicalPlan::Values(Values { ref values, .. }) => { + let str_values = values + .iter() + // limit to only 5 values to avoid horrible display + .take(5) + .map(|row| { + let item = row + .iter() + .map(|expr| expr.to_string()) + .collect::>() + .join(", "); + format!("({item})") + }) + .collect::>() + .join(", "); + + let elipse = if values.len() > 5 { "..." } else { "" }; + + let values_str = format!("{}{}", str_values, elipse); + json!({ + "Node Type": "Values", + "Values": values_str + }) + } + LogicalPlan::TableScan(TableScan { + ref source, + ref table_name, + ref filters, + ref fetch, + .. + }) => { + let mut object = json!({ + "Node Type": "TableScan", + "Relation Name": table_name.table(), + }); + + if let Some(s) = table_name.schema() { + object["Schema"] = serde_json::Value::String(s.to_string()); + } + + if let Some(c) = table_name.catalog() { + object["Catalog"] = serde_json::Value::String(c.to_string()); + } + + if !filters.is_empty() { + let mut full_filter = vec![]; + let mut partial_filter = vec![]; + let mut unsupported_filters = vec![]; + let filters: Vec<&Expr> = filters.iter().collect(); + + if let Ok(results) = source.supports_filters_pushdown(&filters) { + filters.iter().zip(results.iter()).for_each( + |(x, res)| match res { + TableProviderFilterPushDown::Exact => full_filter.push(x), + TableProviderFilterPushDown::Inexact => { + partial_filter.push(x) + } + TableProviderFilterPushDown::Unsupported => { + unsupported_filters.push(x) + } + }, + ); + } + + if !full_filter.is_empty() { + object["Full Filters"] = serde_json::Value::String( + expr_vec_fmt!(full_filter).to_string(), + ); + }; + if !partial_filter.is_empty() { + object["Partial Filters"] = serde_json::Value::String( + expr_vec_fmt!(partial_filter).to_string(), + ); + } + if !unsupported_filters.is_empty() { + object["Unsupported Filters"] = serde_json::Value::String( + expr_vec_fmt!(unsupported_filters).to_string(), + ); + } + } + + if let Some(f) = fetch { + object["Fetch"] = serde_json::Value::Number((*f).into()); + } + + object + } + LogicalPlan::Projection(Projection { ref expr, .. }) => { + json!({ + "Node Type": "Projection", + "Expressions": expr.iter().map(|e| e.to_string()).collect::>() + }) + } + LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { + json!({ + "Node Type": "Projection", + "Operation": op.name(), + "Table Name": table_name.table() + }) + } + LogicalPlan::Copy(CopyTo { + input: _, + output_url, + format_options, + partition_by: _, + options, + }) => { + let op_str = options + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join(", "); + json!({ + "Node Type": "CopyTo", + "Output URL": output_url, + "Format Options": format!("{}", format_options), + "Options": op_str + }) + } + LogicalPlan::Ddl(ddl) => { + json!({ + "Node Type": "Ddl", + "Operation": format!("{}", ddl.display()) + }) + } + LogicalPlan::Filter(Filter { + predicate: ref expr, + .. + }) => { + json!({ + "Node Type": "Filter", + "Condition": format!("{}", expr) + }) + } + LogicalPlan::Window(Window { + ref window_expr, .. + }) => { + json!({ + "Node Type": "WindowAggr", + "Expressions": expr_vec_fmt!(window_expr) + }) + } + LogicalPlan::Aggregate(Aggregate { + ref group_expr, + ref aggr_expr, + .. + }) => { + json!({ + "Node Type": "Aggregate", + "Group By": expr_vec_fmt!(group_expr), + "Aggregates": expr_vec_fmt!(aggr_expr) + }) + } + LogicalPlan::Sort(Sort { expr, fetch, .. }) => { + let mut object = json!({ + "Node Type": "Sort", + "Sort Key": expr_vec_fmt!(expr), + }); + + if let Some(fetch) = fetch { + object["Fetch"] = serde_json::Value::Number((*fetch).into()); + } + + object + } + LogicalPlan::Join(Join { + on: ref keys, + filter, + join_constraint, + join_type, + .. + }) => { + let join_expr: Vec = + keys.iter().map(|(l, r)| format!("{l} = {r}")).collect(); + let filter_expr = filter + .as_ref() + .map(|expr| format!(" Filter: {expr}")) + .unwrap_or_else(|| "".to_string()); + json!({ + "Node Type": format!("{} Join", join_type), + "Join Constraint": format!("{:?}", join_constraint), + "Join Keys": join_expr.join(", "), + "Filter": format!("{}", filter_expr) + }) + } + LogicalPlan::CrossJoin(_) => { + json!({ + "Node Type": "Cross Join" + }) + } + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::RoundRobinBatch(n) => { + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "RoundRobinBatch", + "Partition Count": n + }) + } + Partitioning::Hash(expr, n) => { + let hash_expr: Vec = + expr.iter().map(|e| format!("{e}")).collect(); + + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "Hash", + "Partition Count": n, + "Partitioning Key": hash_expr + }) + } + Partitioning::DistributeBy(expr) => { + let dist_by_expr: Vec = + expr.iter().map(|e| format!("{e}")).collect(); + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "DistributeBy", + "Partitioning Key": dist_by_expr + }) + } + }, + LogicalPlan::Limit(Limit { + ref skip, + ref fetch, + .. + }) => { + let mut object = serde_json::json!( + { + "Node Type": "Limit", + "Skip": skip, + } + ); + if let Some(f) = fetch { + object["Fetch"] = serde_json::Value::Number((*f).into()); + }; + object + } + LogicalPlan::Subquery(Subquery { .. }) => { + json!({ + "Node Type": "Subquery" + }) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + json!({ + "Node Type": "Subquery", + "Alias": alias.table(), + }) + } + LogicalPlan::Statement(statement) => { + json!({ + "Node Type": "Statement", + "Statement": format!("{}", statement.display()) + }) + } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => { + json!({ + "Node Type": "DistinctAll" + }) + } + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => { + let mut object = json!({ + "Node Type": "DistinctOn", + "On": expr_vec_fmt!(on_expr), + "Select": expr_vec_fmt!(select_expr), + }); + if let Some(sort_expr) = sort_expr { + object["Sort"] = serde_json::Value::String( + expr_vec_fmt!(sort_expr).to_string(), + ); + } + + object + } + }, + LogicalPlan::Explain { .. } => { + json!({ + "Node Type": "Explain" + }) + } + LogicalPlan::Analyze { .. } => { + json!({ + "Node Type": "Analyze" + }) + } + LogicalPlan::Union(_) => { + json!({ + "Node Type": "Union" + }) + } + LogicalPlan::Extension(e) => { + json!({ + "Node Type": e.node.name(), + "Detail": format!("{:?}", e.node) + }) + } + LogicalPlan::Prepare(Prepare { + name, data_types, .. + }) => { + json!({ + "Node Type": "Prepare", + "Name": name, + "Data Types": format!("{:?}", data_types) + }) + } + LogicalPlan::DescribeTable(DescribeTable { .. }) => { + json!({ + "Node Type": "DescribeTable" + }) + } + LogicalPlan::Unnest(Unnest { column, .. }) => { + json!({ + "Node Type": "Unnest", + "Column": format!("{}", column) + }) + } + } + } +} + +impl<'a, 'b> TreeNodeVisitor for PgJsonVisitor<'a, 'b> { + type Node = LogicalPlan; + + fn f_down( + &mut self, + node: &LogicalPlan, + ) -> datafusion_common::Result { + let id = self.next_id; + self.next_id += 1; + let mut object = Self::to_json_value(node); + + object["Plans"] = serde_json::Value::Array(vec![]); + + if self.with_schema { + object["Output"] = serde_json::Value::Array( + node.schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .map(serde_json::Value::String) + .collect(), + ); + }; + + self.objects.insert(id, object); + self.parent_ids.push(id); + Ok(TreeNodeRecursion::Continue) + } + + fn f_up( + &mut self, + _node: &Self::Node, + ) -> datafusion_common::Result { + let id = self.parent_ids.pop().unwrap(); + + let current_node = self.objects.remove(&id).ok_or_else(|| { + DataFusionError::Internal("Missing current node!".to_string()) + })?; + + if let Some(parent_id) = self.parent_ids.last() { + let parent_node = self + .objects + .get_mut(parent_id) + .expect("Missing parent node!"); + let plans = parent_node + .get_mut("Plans") + .and_then(|p| p.as_array_mut()) + .expect("Plans should be an array"); + + plans.push(current_node); + } else { + // This is the root node + let plan = serde_json::json!([{"Plan": current_node}]); + write!( + self.f, + "{}", + serde_json::to_string_pretty(&plan) + .map_err(|e| DataFusionError::External(Box::new(e)))? + )?; + } + + Ok(TreeNodeRecursion::Continue) + } +} + #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index f87ca45f14be..b55256ca17de 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -53,10 +53,11 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; - /// Returns all expressions in the current logical plan node. This - /// should not include expressions of any inputs (aka - /// non-recursively). These expressions are used for optimizer - /// passes and rewrites. + /// Returns all expressions in the current logical plan node. This should + /// not include expressions of any inputs (aka non-recursively). + /// + /// These expressions are used for optimizer + /// passes and rewrites. See [`LogicalPlan::expressions`] for more details. fn expressions(&self) -> Vec; /// A list of output columns (e.g. the names of columns in @@ -97,6 +98,24 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { inputs: &[LogicalPlan], ) -> Arc; + /// Returns the necessary input columns for this node required to compute + /// the columns in the output schema + /// + /// This is used for projection push-down when DataFusion has determined that + /// only a subset of the output columns of this node are needed by its parents. + /// This API is used to tell DataFusion which, if any, of the input columns are no longer + /// needed. + /// + /// Return `None`, the default, if this information can not be determined. + /// Returns `Some(_)` with the column indices for each child of this node that are + /// needed to compute `output_columns` + fn necessary_children_exprs( + &self, + _output_columns: &[usize], + ) -> Option>> { + None + } + /// Update the hash `state` with this node requirements from /// [`Hash`]. /// @@ -242,6 +261,24 @@ pub trait UserDefinedLogicalNodeCore: // but the doc comments have not been updated. #[allow(clippy::wrong_self_convention)] fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self; + + /// Returns the necessary input columns for this node required to compute + /// the columns in the output schema + /// + /// This is used for projection push-down when DataFusion has determined that + /// only a subset of the output columns of this node are needed by its parents. + /// This API is used to tell DataFusion which, if any, of the input columns are no longer + /// needed. + /// + /// Return `None`, the default, if this information can not be determined. + /// Returns `Some(_)` with the column indices for each child of this node that are + /// needed to compute `output_columns` + fn necessary_children_exprs( + &self, + _output_columns: &[usize], + ) -> Option>> { + None + } } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` @@ -283,6 +320,13 @@ impl UserDefinedLogicalNode for T { Arc::new(self.from_template(exprs, inputs)) } + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + self.necessary_children_exprs(output_columns) + } + fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c6f280acb255..0bf5b8dffaa2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -54,6 +54,7 @@ use datafusion_common::{ }; // backwards compatibility +use crate::display::PgJsonVisitor; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -217,56 +218,6 @@ impl LogicalPlan { } } - /// Get all meaningful schemas of a plan and its children plan. - #[deprecated(since = "20.0.0")] - pub fn all_schemas(&self) -> Vec<&DFSchemaRef> { - match self { - // return self and children schemas - LogicalPlan::Window(_) - | LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => { - let mut schemas = vec![self.schema()]; - self.inputs().iter().for_each(|input| { - schemas.push(input.schema()); - }); - schemas - } - // just return self.schema() - LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Values(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Union(_) - | LogicalPlan::Extension(_) - | LogicalPlan::TableScan(_) => { - vec![self.schema()] - } - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // return only the schema of the static term - static_term.all_schemas() - } - // return children schemas - LogicalPlan::Limit(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Prepare(_) => { - self.inputs().iter().map(|p| p.schema()).collect() - } - // return empty - LogicalPlan::Statement(_) | LogicalPlan::DescribeTable(_) => vec![], - } - } - /// Returns the (fixed) output schema for explain plans pub fn explain_schema() -> SchemaRef { SchemaRef::new(Schema::new(vec![ @@ -284,9 +235,17 @@ impl LogicalPlan { ]) } - /// returns all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children + /// Returns all expressions (non-recursively) evaluated by the current + /// logical plan node. This does not include expressions in any children + /// + /// The returned expressions do not necessarily represent or even + /// contributed to the output schema of this node. For example, + /// `LogicalPlan::Filter` returns the filter expression even though the + /// output of a Filter has the same columns as the input. + /// + /// The expressions do contain all the columns that are used by this plan, + /// so if there are columns not referenced by these expressions then + /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; self.inspect_expressions(|e| { @@ -1344,6 +1303,26 @@ impl LogicalPlan { Wrapper(self) } + /// Return a displayable structure that produces plan in postgresql JSON format. + /// + /// Users can use this format to visualize the plan in existing plan visualization tools, for example [dalibo](https://explain.dalibo.com/) + pub fn display_pg_json(&self) -> impl Display + '_ { + // Boilerplate structure to wrap LogicalPlan with something + // that that can be formatted + struct Wrapper<'a>(&'a LogicalPlan); + impl<'a> Display for Wrapper<'a> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let mut visitor = PgJsonVisitor::new(f); + visitor.with_schema(true); + match self.0.visit(&mut visitor) { + Ok(_) => Ok(()), + Err(_) => Err(fmt::Error), + } + } + } + Wrapper(self) + } + /// Return a `format`able structure that produces lines meant for /// graphical display using the `DOT` language. This format can be /// visualized using software from @@ -2410,7 +2389,7 @@ impl DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Aggregate { @@ -2823,6 +2802,67 @@ digraph { Ok(()) } + #[test] + fn test_display_pg_json() -> Result<()> { + let plan = display_plan()?; + + let expected_pg_json = r#"[ + { + "Plan": { + "Expressions": [ + "employee_csv.id" + ], + "Node Type": "Projection", + "Output": [ + "id" + ], + "Plans": [ + { + "Condition": "employee_csv.state IN ()", + "Node Type": "Filter", + "Output": [ + "id", + "state" + ], + "Plans": [ + { + "Node Type": "Subquery", + "Output": [ + "state" + ], + "Plans": [ + { + "Node Type": "TableScan", + "Output": [ + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + }, + { + "Node Type": "TableScan", + "Output": [ + "id", + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + } + ] + } + } +]"#; + + let pg_json = format!("{}", plan.display_pg_json()); + + assert_eq!(expected_pg_json, pg_json); + Ok(()) + } + /// Tests for the Visitor trait and walking logical plan nodes #[derive(Debug, Default)] struct OkVisitor { @@ -3079,14 +3119,6 @@ digraph { empty_schema: DFSchemaRef, } - impl NoChildExtension { - fn empty() -> Self { - Self { - empty_schema: Arc::new(DFSchema::empty()), - } - } - } - impl UserDefinedLogicalNode for NoChildExtension { fn as_any(&self) -> &dyn std::any::Any { unimplemented!() @@ -3129,18 +3161,6 @@ digraph { } } - #[test] - #[allow(deprecated)] - fn test_extension_all_schemas() { - let plan = LogicalPlan::Extension(Extension { - node: Arc::new(NoChildExtension::empty()), - }); - - let schemas = plan.all_schemas(); - assert_eq!(1, schemas.len()); - assert_eq!(0, schemas[0].fields().len()); - } - #[test] fn test_replace_invalid_placeholder() { // test empty placeholder diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3002a745055f..56266a05170b 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -326,8 +326,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// For the best performance, the implementations of `invoke` should handle /// the common case when one or more of their arguments are constant values - /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] - /// and treating all arguments as arrays will work, but will be slower. + /// (aka [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. fn invoke(&self, args: &[ColumnarValue]) -> Result; /// Returns any aliases (alternate names) for this function. diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 99239ffb3bdc..6ef9c6b055af 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -40,6 +40,7 @@ path = "src/lib.rs" arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } @@ -48,3 +49,10 @@ datafusion-functions = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" + +[dev-dependencies] +criterion = { version = "0.5", features = ["async_tokio"] } + +[[bench]] +harness = false +name = "array_expression" diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/functions-array/benches/array_expression.rs similarity index 52% rename from datafusion/core/benches/array_expression.rs rename to datafusion/functions-array/benches/array_expression.rs index 95bc93e0e353..48b829793cef 100644 --- a/datafusion/core/benches/array_expression.rs +++ b/datafusion/functions-array/benches/array_expression.rs @@ -18,52 +18,34 @@ #[macro_use] extern crate criterion; extern crate arrow; -extern crate datafusion; -mod data_utils; use crate::criterion::Criterion; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::{ArrayRef, Int64Array, ListArray}; -use datafusion_physical_expr::array_expressions; -use std::sync::Arc; +use datafusion_expr::lit; +use datafusion_functions_array::expr_fn::{array_replace_all, make_array}; fn criterion_benchmark(c: &mut Criterion) { // Construct large arrays for benchmarking let array_len = 100000000; - let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); - let list_array = ListArray::from_iter_primitive::(vec![ - Some(array.clone()), - Some(array.clone()), - Some(array), - ]); - let from_array = Int64Array::from_value(2, 3); - let to_array = Int64Array::from_value(-2, 3); + let array = (0..array_len).map(|_| lit(2_i64)).collect::>(); + let list_array = make_array(vec![make_array(array); 3]); + let from_array = make_array(vec![lit(2_i64); 3]); + let to_array = make_array(vec![lit(-2_i64); 3]); - let args = vec![ - Arc::new(list_array) as ArrayRef, - Arc::new(from_array) as ArrayRef, - Arc::new(to_array) as ArrayRef, - ]; - - let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); - let expected_array = ListArray::from_iter_primitive::(vec![ - Some(array.clone()), - Some(array.clone()), - Some(array), - ]); + let expected_array = list_array.clone(); // Benchmark array functions c.bench_function("array_replace", |b| { b.iter(|| { assert_eq!( - array_expressions::array_replace_all(args.as_slice()) - .unwrap() - .as_list::(), - criterion::black_box(&expected_array) + array_replace_all( + list_array.clone(), + from_array.clone(), + to_array.clone() + ), + *criterion::black_box(&expected_array) ) }) }); diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index 17c0ad1619d6..4e4ebaf035fc 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array functions. +//! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -85,11 +85,11 @@ impl ScalarUDFImpl for ArrayHas { &self.signature } - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 2 { @@ -147,11 +147,11 @@ impl ScalarUDFImpl for ArrayHasAll { &self.signature } - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 2 { return exec_err!("array_has_all needs two arguments"); @@ -204,11 +204,11 @@ impl ScalarUDFImpl for ArrayHasAny { &self.signature } - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 2 { diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs new file mode 100644 index 000000000000..ed9f8d01f973 --- /dev/null +++ b/datafusion/functions-array/src/cardinality.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for cardinality function. + +use crate::utils::make_scalar_function; +use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array}; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + Cardinality, + cardinality, + array, + "returns the total number of elements in the array.", + cardinality_udf +); + +impl Cardinality { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("cardinality")], + } + } +} + +#[derive(Debug)] +pub(super) struct Cardinality { + signature: Signature, + aliases: Vec, +} +impl ScalarUDFImpl for Cardinality { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "cardinality" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(cardinality_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Cardinality SQL function +pub fn cardinality_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + + match &args[0].data_type() { + List(_) => { + let list_array = as_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + LargeList(_) => { + let list_array = as_large_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + other => { + exec_err!("cardinality does not support type '{:?}'", other) + } + } +} + +fn generic_list_cardinality( + array: &GenericListArray, +) -> Result { + let result = array + .iter() + .map(|arr| match crate::utils::compute_array_dims(arr)? { + Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), + None => Ok(None), + }) + .collect::>()?; + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index a8e7d1008f46..cb76192e29c2 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Includes `array append`, `array prepend`, and `array concat` functions +//! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. use std::{any::Any, cmp::Ordering, sync::Arc}; @@ -39,7 +39,7 @@ use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function make_udf_function!( ArrayAppend, array_append, - array element, // arg name + array element, // arg name "appends an element to the end of an array.", // doc array_append_udf // internal function name ); @@ -283,9 +283,9 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .collect::>(); // Concatenated array on i-th row - let concated_array = arrow::compute::concat(elements.as_slice())?; - array_lengths.push(concated_array.len()); - arrays.push(concated_array); + let concatenated_array = arrow::compute::concat(elements.as_slice())?; + array_lengths.push(concatenated_array.len()); + arrays.push(concatenated_array); valid.append(true); } } diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-array/src/dimension.rs new file mode 100644 index 000000000000..569eff66f7f4 --- /dev/null +++ b/datafusion/functions-array/src/dimension.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_dims and array_ndims functions. + +use arrow::array::{ + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, +}; +use arrow::datatypes::{DataType, UInt64Type}; +use std::any::Any; + +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, plan_err, Result}; + +use crate::utils::{compute_array_dims, make_scalar_function}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use arrow_schema::Field; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::sync::Arc; + +make_udf_function!( + ArrayDims, + array_dims, + array, + "returns an array of the array's dimensions.", + array_dims_udf +); + +#[derive(Debug)] +pub(super) struct ArrayDims { + signature: Signature, + aliases: Vec, +} + +impl ArrayDims { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec!["array_dims".to_string(), "list_dims".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayDims { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_dims" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => { + List(Arc::new(Field::new("item", UInt64, true))) + } + _ => { + return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_dims_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + ArrayNdims, + array_ndims, + array, + "returns the number of dimensions of the array.", + array_ndims_udf +); + +#[derive(Debug)] +pub(super) struct ArrayNdims { + signature: Signature, + aliases: Vec, +} +impl ArrayNdims { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("array_ndims"), String::from("list_ndims")], + } + } +} + +impl ScalarUDFImpl for ArrayNdims { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_ndims" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_ndims_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_dims SQL function +pub fn array_dims_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } + + let data = match args[0].data_type() { + List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + array_type => { + return exec_err!("array_dims does not support type '{array_type:?}'"); + } + }; + + let result = ListArray::from_iter_primitive::(data); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Array_ndims SQL function +pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } + + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); + + for arr in array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) + } + } + + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } + match args[0].data_type() { + List(_) => { + let array = as_list_array(&args[0])?; + general_list_ndims::(array) + } + LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) + } + array_type => exec_err!("array_ndims does not support type {array_type:?}"), + } +} diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs new file mode 100644 index 000000000000..d5fa174eee5f --- /dev/null +++ b/datafusion/functions-array/src/empty.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_empty function. + +use crate::utils::make_scalar_function; +use arrow_array::{ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow_schema::DataType; +use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; +use datafusion_common::cast::{as_generic_list_array, as_null_array}; +use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayEmpty, + array_empty, + array, + "returns true for an empty array or false for a non-empty array.", + array_empty_udf +); + +#[derive(Debug)] +pub(super) struct ArrayEmpty { + signature: Signature, + aliases: Vec, +} +impl ArrayEmpty { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![ + "empty".to_string(), + "array_empty".to_string(), + "list_empty".to_string(), + ], + } + } +} + +impl ScalarUDFImpl for ArrayEmpty { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "empty" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, + _ => { + return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_empty_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_empty SQL function +pub fn array_empty_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); + } + + if as_null_array(&args[0]).is_ok() { + // Make sure to return Boolean type. + return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); + } + let array_type = args[0].data_type(); + + match array_type { + List(_) => general_array_empty::(&args[0]), + LargeList(_) => general_array_empty::(&args[0]), + _ => exec_err!("array_empty does not support type '{array_type:?}'."), + } +} + +fn general_array_empty(array: &ArrayRef) -> Result { + let array = as_generic_list_array::(array)?; + let builder = array + .iter() + .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) + .collect::(); + Ok(Arc::new(builder)) +} diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index 1faaf80e69f6..444c7c758771 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! implementation kernel for array_except function +//! [`ScalarUDFImpl`] definitions for array_except function. -use crate::utils::check_datatypes; +use crate::utils::{check_datatypes, make_scalar_function}; use arrow::row::{RowConverter, SortField}; use arrow_array::cast::AsArray; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, FieldRef}; -use datafusion_common::{exec_err, internal_err}; +use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -66,16 +66,15 @@ impl ScalarUDFImpl for ArrayExcept { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { match (&arg_types[0].clone(), &arg_types[1].clone()) { (DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()), (dt, _) => Ok(dt.clone()), } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_except_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_except_inner)(args) } fn aliases(&self) -> &[String] { @@ -84,7 +83,7 @@ impl ScalarUDFImpl for ArrayExcept { } /// Array_except SQL function -pub fn array_except_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_except_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_except needs two arguments"); } @@ -118,7 +117,7 @@ fn general_except( l: &GenericListArray, r: &GenericListArray, field: &FieldRef, -) -> datafusion_common::Result> { +) -> Result> { let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; let l_values = l.values().to_owned(); diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-array/src/extract.rs index 86eeaea3c9b4..0dbd106b6f18 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-array/src/extract.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Array Element and Array Slice +//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front and array_pop_back functions. use arrow::array::Array; use arrow::array::ArrayRef; @@ -27,15 +27,14 @@ use arrow::array::MutableArrayData; use arrow::array::OffsetSizeTrait; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; -use datafusion_common::exec_err; -use datafusion_common::internal_datafusion_err; -use datafusion_common::plan_err; -use datafusion_common::DataFusionError; -use datafusion_common::Result; +use datafusion_common::{ + exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, +}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -110,7 +109,6 @@ impl ScalarUDFImpl for ArrayElement { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; match &arg_types[0] { List(field) | LargeList(field) @@ -137,18 +135,18 @@ impl ScalarUDFImpl for ArrayElement { /// /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 -fn array_element_inner(args: &[ArrayRef]) -> datafusion_common::Result { +fn array_element_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_element needs two arguments"); } match &args[0].data_type() { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) @@ -163,7 +161,7 @@ fn array_element_inner(args: &[ArrayRef]) -> datafusion_common::Result fn general_array_element( array: &GenericListArray, indexes: &Int64Array, -) -> datafusion_common::Result +) -> Result where i64: TryInto, { @@ -175,10 +173,7 @@ where let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity); - fn adjusted_array_index( - index: i64, - len: O, - ) -> datafusion_common::Result> + fn adjusted_array_index(index: i64, len: O) -> Result> where i64: TryInto, { @@ -302,11 +297,11 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result { let array_data_type = args[0].data_type(); match array_data_type { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; general_array_slice::(array, from_array, to_array, stride) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; let from_array = as_int64_array(&args[1])?; let to_array = as_int64_array(&args[2])?; @@ -545,11 +540,11 @@ impl ScalarUDFImpl for ArrayPopFront { fn array_pop_front_inner(args: &[ArrayRef]) -> Result { let array_data_type = args[0].data_type(); match array_data_type { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; general_pop_front_list::(array) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; general_pop_front_list::(array) } @@ -627,11 +622,11 @@ fn array_pop_back_inner(args: &[ArrayRef]) -> Result { let array_data_type = args[0].data_type(); match array_data_type { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; general_pop_back_list::(array) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; general_pop_back_list::(array) } diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs new file mode 100644 index 000000000000..e2b50c6c02cc --- /dev/null +++ b/datafusion/functions-array/src/flatten.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for flatten function. + +use crate::utils::make_scalar_function; +use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; +use datafusion_common::cast::{ + as_generic_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + Flatten, + flatten, + array, + "flattens an array of arrays into a single array.", + flatten_udf +); + +#[derive(Debug)] +pub(super) struct Flatten { + signature: Signature, + aliases: Vec, +} +impl Flatten { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("flatten")], + } + } +} + +impl ScalarUDFImpl for Flatten { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "flatten" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + fn get_base_type(data_type: &DataType) -> Result { + match data_type { + List(field) | FixedSizeList(field, _) + if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => + { + get_base_type(field.data_type()) + } + LargeList(field) if matches!(field.data_type(), LargeList(_)) => { + get_base_type(field.data_type()) + } + Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), + FixedSizeList(field, _) => Ok(List(field.clone())), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } + } + + let data_type = get_base_type(&arg_types[0])?; + Ok(data_type) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(flatten_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Flatten SQL function +pub fn flatten_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + + let array_type = args[0].data_type(); + match array_type { + List(_) => { + let list_arr = as_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + let list_arr = as_large_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) + } + Null => Ok(args[0].clone()), + _ => { + exec_err!("flatten does not support type '{array_type:?}'") + } + } +} + +fn flatten_internal( + list_arr: GenericListArray, + indexes: Option>, +) -> Result> { + let (field, offsets, values, _) = list_arr.clone().into_parts(); + let data_type = field.data_type(); + + match data_type { + // Recursively get the base offsets for flattened array + List(_) | LargeList(_) => { + let sub_list = as_generic_list_array::(&values)?; + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + flatten_internal::(sub_list.clone(), Some(offsets)) + } else { + flatten_internal::(sub_list.clone(), Some(offsets)) + } + } + // Reach the base level, create a new list array + _ => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + let list_arr = GenericListArray::::new(field, offsets, values, None); + Ok(list_arr) + } else { + Ok(list_arr.clone()) + } + } + } +} + +// Create new offsets that are equivalent to `flatten` the array. +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()]) + .collect(); + OffsetBuffer::new(offsets.into()) +} diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs deleted file mode 100644 index 15cdf8f279ae..000000000000 --- a/datafusion/functions-array/src/kernels.rs +++ /dev/null @@ -1,1209 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! implementation kernels for array functions - -use arrow::array::{ - Array, ArrayRef, BooleanArray, Capacities, Date32Array, Float32Array, Float64Array, - GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeListArray, - LargeStringArray, ListArray, ListBuilder, MutableArrayData, OffsetSizeTrait, - StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::compute; -use arrow::datatypes::{ - DataType, Date32Type, Field, IntervalMonthDayNanoType, UInt64Type, -}; -use arrow_array::new_null_array; -use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; -use arrow_schema::FieldRef; -use arrow_schema::SortOptions; - -use datafusion_common::cast::{ - as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array, - as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array, - as_string_array, -}; -use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_datafusion_err, DataFusionError, Result, - ScalarValue, -}; - -use std::any::type_name; -use std::sync::Arc; - -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - -macro_rules! call_array_function { - ($DATATYPE:expr, false) => { - match $DATATYPE { - DataType::Utf8 => array_function!(StringArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), - } - }; - ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ - match $DATATYPE { - DataType::List(_) => array_function!(ListArray), - DataType::Utf8 => array_function!(StringArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), - } - }}; -} - -/// Array_to_string SQL function -pub(super) fn array_to_string(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("array_to_string expects two or three arguments"); - } - - let arr = &args[0]; - - let delimiters = as_string_array(&args[1])?; - let delimiters: Vec> = delimiters.iter().collect(); - - let mut null_string = String::from(""); - let mut with_null_string = false; - if args.len() == 3 { - null_string = as_string_array(&args[2])?.value(0).to_string(); - with_null_string = true; - } - - fn compute_array_to_string( - arg: &mut String, - arr: ArrayRef, - delimiter: String, - null_string: String, - with_null_string: bool, - ) -> datafusion_common::Result<&mut String> { - match arr.data_type() { - DataType::List(..) => { - let list_array = as_list_array(&arr)?; - for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } - - Ok(arg) - } - DataType::LargeList(..) => { - let list_array = as_large_list_array(&arr)?; - for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } - - Ok(arg) - } - DataType::Null => Ok(arg), - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - to_string!( - arg, - arr, - &delimiter, - &null_string, - with_null_string, - $ARRAY_TYPE - ) - }; - } - call_array_function!(data_type, false) - } - } - } - - fn generate_string_array( - list_arr: &GenericListArray, - delimiters: Vec>, - null_string: String, - with_null_string: bool, - ) -> datafusion_common::Result { - let mut res: Vec> = Vec::new(); - for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - let mut arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } - } else { - res.push(None); - } - } - - Ok(StringArray::from(res)) - } - - let arr_type = arr.data_type(); - let string_arr = match arr_type { - DataType::List(_) | DataType::FixedSizeList(_, _) => { - let list_array = as_list_array(&arr)?; - generate_string_array::( - list_array, - delimiters, - null_string, - with_null_string, - )? - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&arr)?; - generate_string_array::( - list_array, - delimiters, - null_string, - with_null_string, - )? - } - _ => { - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - // delimiter length is 1 - assert_eq!(delimiters.len(), 1); - let delimiter = delimiters[0].unwrap(); - let s = compute_array_to_string( - &mut arg, - arr.clone(), - delimiter.to_string(), - null_string, - with_null_string, - )? - .clone(); - - if !s.is_empty() { - let s = s.strip_suffix(delimiter).unwrap().to_string(); - res.push(Some(s)); - } else { - res.push(Some(s)); - } - StringArray::from(res) - } - }; - - Ok(Arc::new(string_arr)) -} - -/// Splits string at occurrences of delimiter and returns an array of parts -/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -pub fn string_to_array(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("string_to_array expects two or three arguments"); - } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - - let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - )); - - match args.len() { - 2 => { - string_array.iter().zip(delimiter_array.iter()).for_each( - |(string, delimiter)| { - match (string, delimiter) { - (Some(string), Some("")) => { - list_builder.values().append_value(string); - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - list_builder.values().append_value(s); - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c); - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value - } - }, - ); - } - - 3 => { - let null_value_array = as_generic_string_array::(&args[2])?; - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(string); - } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(c); - } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value - } - }); - } - _ => { - return exec_err!( - "Expect string_to_array function to take two or three parameters" - ) - } - } - - let list_array = list_builder.finish(); - Ok(Arc::new(list_array) as ArrayRef) -} - -/// Generates an array of integers from start to stop with a given step. -/// -/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. -/// It returns a `Result` representing the resulting ListArray after the operation. -/// -/// # Arguments -/// -/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. -/// -/// # Examples -/// -/// gen_range(3) => [0, 1, 2] -/// gen_range(1, 4) => [1, 2, 3] -/// gen_range(1, 7, 2) => [1, 3, 5] -pub(super) fn gen_range(args: &[ArrayRef], include_upper: bool) -> Result { - let (start_array, stop_array, step_array) = match args.len() { - 1 => (None, as_int64_array(&args[0])?, None), - 2 => ( - Some(as_int64_array(&args[0])?), - as_int64_array(&args[1])?, - None, - ), - 3 => ( - Some(as_int64_array(&args[0])?), - as_int64_array(&args[1])?, - Some(as_int64_array(&args[2])?), - ), - _ => return exec_err!("gen_range expects 1 to 3 arguments"), - }; - - let mut values = vec![]; - let mut offsets = vec![0]; - let mut valid = BooleanBufferBuilder::new(stop_array.len()); - for (idx, stop) in stop_array.iter().enumerate() { - match retrieve_range_args(start_array, stop, step_array, idx) { - Some((_, _, 0)) => { - return exec_err!( - "step can't be 0 for function {}(start [, stop, step])", - if include_upper { - "generate_series" - } else { - "range" - } - ); - } - Some((start, stop, step)) => { - // Below, we utilize `usize` to represent steps. - // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. - let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| { - not_impl_datafusion_err!("step {} can't fit into usize", step) - })?; - values.extend( - gen_range_iter(start, stop, step < 0, include_upper) - .step_by(step_abs), - ); - offsets.push(values.len() as i32); - valid.append(true); - } - // If any of the arguments is NULL, append a NULL value to the result. - None => { - offsets.push(values.len() as i32); - valid.append(false); - } - }; - } - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Int64, true)), - OffsetBuffer::new(offsets.into()), - Arc::new(Int64Array::from(values)), - Some(NullBuffer::new(valid.finish())), - )?); - Ok(arr) -} - -/// Get the (start, stop, step) args for the range and generate_series function. -/// If any of the arguments is NULL, returns None. -fn retrieve_range_args( - start_array: Option<&Int64Array>, - stop: Option, - step_array: Option<&Int64Array>, - idx: usize, -) -> Option<(i64, i64, i64)> { - // Default start value is 0 if not provided - let start = - start_array.map_or(Some(0), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; - let stop = stop?; - // Default step value is 1 if not provided - let step = - step_array.map_or(Some(1), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; - Some((start, stop, step)) -} - -/// Returns an iterator of i64 values from start to stop -fn gen_range_iter( - start: i64, - stop: i64, - decreasing: bool, - include_upper: bool, -) -> Box> { - match (decreasing, include_upper) { - // Decreasing range, stop is inclusive - (true, true) => Box::new((stop..=start).rev()), - // Decreasing range, stop is exclusive - (true, false) => { - if stop == i64::MAX { - // start is never greater than stop, and stop is exclusive, - // so the decreasing range must be empty. - Box::new(std::iter::empty()) - } else { - // Increase the stop value by one to exclude it. - // Since stop is not i64::MAX, `stop + 1` will not overflow. - Box::new((stop + 1..=start).rev()) - } - } - // Increasing range, stop is inclusive - (false, true) => Box::new(start..=stop), - // Increasing range, stop is exclusive - (false, false) => Box::new(start..stop), - } -} - -/// Returns the length of each array dimension -fn compute_array_dims(arr: Option) -> Result>>> { - let mut value = match arr { - Some(arr) => arr, - None => return Ok(None), - }; - if value.is_empty() { - return Ok(None); - } - let mut res = vec![Some(value.len() as u64)]; - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res.push(Some(value.len() as u64)); - } - _ => return Ok(Some(res)), - } - } -} - -fn generic_list_cardinality( - array: &GenericListArray, -) -> Result { - let result = array - .iter() - .map(|arr| match compute_array_dims(arr)? { - Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), - None => Ok(None), - }) - .collect::>()?; - Ok(Arc::new(result) as ArrayRef) -} - -/// Cardinality SQL function -pub fn cardinality(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("cardinality expects one argument"); - } - - match &args[0].data_type() { - DataType::List(_) => { - let list_array = as_list_array(&args[0])?; - generic_list_cardinality::(list_array) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&args[0])?; - generic_list_cardinality::(list_array) - } - other => { - exec_err!("cardinality does not support type '{:?}'", other) - } - } -} - -/// Array_dims SQL function -pub fn array_dims(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_dims needs one argument"); - } - - let data = match args[0].data_type() { - DataType::List(_) => { - let array = as_list_array(&args[0])?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - DataType::LargeList(_) => { - let array = as_large_list_array(&args[0])?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - array_type => { - return exec_err!("array_dims does not support type '{array_type:?}'"); - } - }; - - let result = ListArray::from_iter_primitive::(data); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Array_ndims SQL function -pub fn array_ndims(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_ndims needs one argument"); - } - - fn general_list_ndims( - array: &GenericListArray, - ) -> Result { - let mut data = Vec::new(); - let ndims = datafusion_common::utils::list_ndims(array.data_type()); - - for arr in array.iter() { - if arr.is_some() { - data.push(Some(ndims)) - } else { - data.push(None) - } - } - - Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) - } - match args[0].data_type() { - DataType::List(_) => { - let array = as_list_array(&args[0])?; - general_list_ndims::(array) - } - DataType::LargeList(_) => { - let array = as_large_list_array(&args[0])?; - general_list_ndims::(array) - } - array_type => exec_err!("array_ndims does not support type {array_type:?}"), - } -} -pub fn gen_range_date( - args: &[ArrayRef], - include_upper: bool, -) -> datafusion_common::Result { - if args.len() != 3 { - return exec_err!("arguments length does not match"); - } - let (start_array, stop_array, step_array) = ( - Some(as_date32_array(&args[0])?), - as_date32_array(&args[1])?, - Some(as_interval_mdn_array(&args[2])?), - ); - - let mut values = vec![]; - let mut offsets = vec![0]; - for (idx, stop) in stop_array.iter().enumerate() { - let mut stop = stop.unwrap_or(0); - let start = start_array.as_ref().map(|x| x.value(idx)).unwrap_or(0); - let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); - let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); - let neg = months < 0 || days < 0; - if !include_upper { - stop = Date32Type::subtract_month_day_nano(stop, step); - } - let mut new_date = start; - loop { - if neg && new_date < stop || !neg && new_date > stop { - break; - } - values.push(new_date); - new_date = Date32Type::add_month_day_nano(new_date, step); - } - offsets.push(values.len() as i32); - } - - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Date32, true)), - OffsetBuffer::new(offsets.into()), - Arc::new(Date32Array::from(values)), - None, - )?); - Ok(arr) -} - -/// Array_empty SQL function -pub fn array_empty(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_empty expects one argument"); - } - - if as_null_array(&args[0]).is_ok() { - // Make sure to return Boolean type. - return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); - } - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => general_array_empty::(&args[0]), - DataType::LargeList(_) => general_array_empty::(&args[0]), - _ => exec_err!("array_empty does not support type '{array_type:?}'."), - } -} - -fn general_array_empty(array: &ArrayRef) -> Result { - let array = as_generic_list_array::(array)?; - let builder = array - .iter() - .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) - .collect::(); - Ok(Arc::new(builder)) -} - -/// Returns the length of a concrete array dimension -fn compute_array_length( - arr: Option, - dimension: Option, -) -> Result> { - let mut current_dimension: i64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok(None), - }; - let dimension = match dimension { - Some(value) => { - if value < 1 { - return Ok(None); - } - - value - } - None => return Ok(None), - }; - - loop { - if current_dimension == dimension { - return Ok(Some(value.len() as u64)); - } - - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - current_dimension += 1; - } - DataType::LargeList(..) => { - value = downcast_arg!(value, LargeListArray).value(0); - current_dimension += 1; - } - _ => return Ok(None), - } - } -} - -/// Dispatch array length computation based on the offset type. -fn general_array_length(array: &[ArrayRef]) -> Result { - let list_array = as_generic_list_array::(&array[0])?; - let dimension = if array.len() == 2 { - as_int64_array(&array[1])?.clone() - } else { - Int64Array::from_value(1, list_array.len()) - }; - - let result = list_array - .iter() - .zip(dimension.iter()) - .map(|(arr, dim)| compute_array_length(arr, dim)) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} - -/// Array_repeat SQL function -pub fn array_repeat(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_repeat expects two arguments"); - } - - let element = &args[0]; - let count_array = as_int64_array(&args[1])?; - - match element.data_type() { - DataType::List(_) => { - let list_array = as_list_array(element)?; - general_list_repeat::(list_array, count_array) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(element)?; - general_list_repeat::(list_array, count_array) - } - _ => general_repeat::(element, count_array), - } -} - -/// For each element of `array[i]` repeat `count_array[i]` times. -/// -/// Assumption for the input: -/// 1. `count[i] >= 0` -/// 2. `array.len() == count_array.len()` -/// -/// For example, -/// ```text -/// array_repeat( -/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] -/// ) -/// ``` -fn general_repeat( - array: &ArrayRef, - count_array: &Int64Array, -) -> Result { - let data_type = array.data_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (row_index, &count) in count_vec.iter().enumerate() { - let repeated_array = if array.is_null(row_index) { - new_null_array(data_type, count) - } else { - let original_data = array.to_data(); - let capacity = Capacities::Array(count); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for _ in 0..count { - mutable.extend(0, row_index, row_index + 1); - } - - let data = mutable.freeze(); - arrow_array::make_array(data) - }; - new_values.push(repeated_array); - } - - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::from_lengths(count_vec), - values, - None, - )?)) -} - -/// Handle List version of `general_repeat` -/// -/// For each element of `list_array[i]` repeat `count_array[i]` times. -/// -/// For example, -/// ```text -/// array_repeat( -/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] -/// ) -/// ``` -fn general_list_repeat( - list_array: &GenericListArray, - count_array: &Int64Array, -) -> Result { - let data_type = list_array.data_type(); - let value_type = list_array.value_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { - let list_arr = match list_array_row { - Some(list_array_row) => { - let original_data = list_array_row.to_data(); - let capacity = Capacities::Array(original_data.len() * count); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data], - false, - capacity, - ); - - for _ in 0..count { - mutable.extend(0, 0, original_data.len()); - } - - let data = mutable.freeze(); - let repeated_array = arrow_array::make_array(data); - - let list_arr = GenericListArray::::try_new( - Arc::new(Field::new("item", value_type.clone(), true)), - OffsetBuffer::::from_lengths(vec![original_data.len(); count]), - repeated_array, - None, - )?; - Arc::new(list_arr) as ArrayRef - } - None => new_null_array(data_type, count), - }; - new_values.push(list_arr); - } - - let lengths = new_values.iter().map(|a| a.len()).collect::>(); - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; - - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::::from_lengths(lengths), - values, - None, - )?)) -} - -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!("array_length expects one or two arguments"); - } - - match &args[0].data_type() { - DataType::List(_) => general_array_length::(args), - DataType::LargeList(_) => general_array_length::(args), - array_type => exec_err!("array_length does not support type '{array_type:?}'"), - } -} - -/// array_resize SQL function -pub fn array_resize(arg: &[ArrayRef]) -> Result { - if arg.len() < 2 || arg.len() > 3 { - return exec_err!("array_resize needs two or three arguments"); - } - - let new_len = as_int64_array(&arg[1])?; - let new_element = if arg.len() == 3 { - Some(arg[2].clone()) - } else { - None - }; - - match &arg[0].data_type() { - DataType::List(field) => { - let array = as_list_array(&arg[0])?; - general_list_resize::(array, new_len, field, new_element) - } - DataType::LargeList(field) => { - let array = as_large_list_array(&arg[0])?; - general_list_resize::(array, new_len, field, new_element) - } - array_type => exec_err!("array_resize does not support type '{array_type:?}'."), - } -} - -/// array_resize keep the original array and append the default element to the end -fn general_list_resize( - array: &GenericListArray, - count_array: &Int64Array, - field: &FieldRef, - default_element: Option, -) -> Result -where - O: TryInto, -{ - let data_type = array.value_type(); - - let values = array.values(); - let original_data = values.to_data(); - - // create default element array - let default_element = if let Some(default_element) = default_element { - default_element - } else { - let null_scalar = ScalarValue::try_from(&data_type)?; - null_scalar.to_array_of_size(original_data.len())? - }; - let default_value_data = default_element.to_data(); - - // create a mutable array to store the original data - let capacity = Capacities::Array(original_data.len() + default_value_data.len()); - let mut offsets = vec![O::usize_as(0)]; - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &default_value_data], - false, - capacity, - ); - - for (row_index, offset_window) in array.offsets().windows(2).enumerate() { - let count = count_array.value(row_index).to_usize().ok_or_else(|| { - internal_datafusion_err!("array_resize: failed to convert size to usize") - })?; - let count = O::usize_as(count); - let start = offset_window[0]; - if start + count > offset_window[1] { - let extra_count = - (start + count - offset_window[1]).try_into().map_err(|_| { - internal_datafusion_err!( - "array_resize: failed to convert size to i64" - ) - })?; - let end = offset_window[1]; - mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); - // append default element - for _ in 0..extra_count { - mutable.extend(1, row_index, row_index + 1); - } - } else { - let end = start + count; - mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); - }; - offsets.push(offsets[row_index] + count); - } - - let data = mutable.freeze(); - Ok(Arc::new(GenericListArray::::try_new( - field.clone(), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - None, - )?)) -} - -/// Array_sort SQL function -pub fn array_sort(args: &[ArrayRef]) -> Result { - if args.is_empty() || args.len() > 3 { - return exec_err!("array_sort expects one to three arguments"); - } - - let sort_option = match args.len() { - 1 => None, - 2 => { - let sort = as_string_array(&args[1])?.value(0); - Some(SortOptions { - descending: order_desc(sort)?, - nulls_first: true, - }) - } - 3 => { - let sort = as_string_array(&args[1])?.value(0); - let nulls_first = as_string_array(&args[2])?.value(0); - Some(SortOptions { - descending: order_desc(sort)?, - nulls_first: order_nulls_first(nulls_first)?, - }) - } - _ => return exec_err!("array_sort expects 1 to 3 arguments"), - }; - - let list_array = as_list_array(&args[0])?; - let row_count = list_array.len(); - - let mut array_lengths = vec![]; - let mut arrays = vec![]; - let mut valid = BooleanBufferBuilder::new(row_count); - for i in 0..row_count { - if list_array.is_null(i) { - array_lengths.push(0); - valid.append(false); - } else { - let arr_ref = list_array.value(i); - let arr_ref = arr_ref.as_ref(); - - let sorted_array = compute::sort(arr_ref, sort_option)?; - array_lengths.push(sorted_array.len()); - arrays.push(sorted_array); - valid.append(true); - } - } - - // Assume all arrays have the same data type - let data_type = list_array.value_type(); - let buffer = valid.finish(); - - let elements = arrays - .iter() - .map(|a| a.as_ref()) - .collect::>(); - - let list_arr = ListArray::new( - Arc::new(Field::new("item", data_type, true)), - OffsetBuffer::from_lengths(array_lengths), - Arc::new(compute::concat(elements.as_slice())?), - Some(NullBuffer::new(buffer)), - ); - Ok(Arc::new(list_arr)) -} - -fn order_desc(modifier: &str) -> Result { - match modifier.to_uppercase().as_str() { - "DESC" => Ok(true), - "ASC" => Ok(false), - _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), - } -} - -fn order_nulls_first(modifier: &str) -> Result { - match modifier.to_uppercase().as_str() { - "NULLS FIRST" => Ok(true), - "NULLS LAST" => Ok(false), - _ => exec_err!( - "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" - ), - } -} - -// Create new offsets that are euqiavlent to `flatten` the array. -fn get_offsets_for_flatten( - offsets: OffsetBuffer, - indexes: OffsetBuffer, -) -> OffsetBuffer { - let buffer = offsets.into_inner(); - let offsets: Vec = indexes - .iter() - .map(|i| buffer[i.to_usize().unwrap()]) - .collect(); - OffsetBuffer::new(offsets.into()) -} - -fn flatten_internal( - list_arr: GenericListArray, - indexes: Option>, -) -> Result> { - let (field, offsets, values, _) = list_arr.clone().into_parts(); - let data_type = field.data_type(); - - match data_type { - // Recursively get the base offsets for flattened array - DataType::List(_) | DataType::LargeList(_) => { - let sub_list = as_generic_list_array::(&values)?; - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - flatten_internal::(sub_list.clone(), Some(offsets)) - } else { - flatten_internal::(sub_list.clone(), Some(offsets)) - } - } - // Reach the base level, create a new list array - _ => { - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = GenericListArray::::new(field, offsets, values, None); - Ok(list_arr) - } else { - Ok(list_arr.clone()) - } - } - } -} - -/// Flatten SQL function -pub fn flatten(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("flatten expects one argument"); - } - - let array_type = args[0].data_type(); - match array_type { - DataType::List(_) => { - let list_arr = as_list_array(&args[0])?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) - } - DataType::LargeList(_) => { - let list_arr = as_large_list_array(&args[0])?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) - } - DataType::Null => Ok(args[0].clone()), - _ => { - exec_err!("flatten does not support type '{array_type:?}'") - } - } -} - -/// array_reverse SQL function -pub fn array_reverse(arg: &[ArrayRef]) -> Result { - if arg.len() != 1 { - return exec_err!("array_reverse needs one argument"); - } - - match &arg[0].data_type() { - DataType::List(field) => { - let array = as_list_array(&arg[0])?; - general_array_reverse::(array, field) - } - DataType::LargeList(field) => { - let array = as_large_list_array(&arg[0])?; - general_array_reverse::(array, field) - } - DataType::Null => Ok(arg[0].clone()), - array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), - } -} - -fn general_array_reverse( - array: &GenericListArray, - field: &FieldRef, -) -> Result -where - O: TryFrom, -{ - let values = array.values(); - let original_data = values.to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut offsets = vec![O::usize_as(0)]; - let mut nulls = vec![]; - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for (row_index, offset_window) in array.offsets().windows(2).enumerate() { - // skip the null value - if array.is_null(row_index) { - nulls.push(false); - offsets.push(offsets[row_index] + O::one()); - mutable.extend(0, 0, 1); - continue; - } else { - nulls.push(true); - } - - let start = offset_window[0]; - let end = offset_window[1]; - - let mut index = end - O::one(); - let mut cnt = 0; - - while index >= start { - mutable.extend(0, index.to_usize().unwrap(), index.to_usize().unwrap() + 1); - index = index - O::one(); - cnt += 1; - } - offsets.push(offsets[row_index] + O::usize_as(cnt)); - } - - let data = mutable.freeze(); - Ok(Arc::new(GenericListArray::::try_new( - field.clone(), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - Some(nulls.into()), - )?)) -} diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs new file mode 100644 index 000000000000..9bbd11950d21 --- /dev/null +++ b/datafusion/functions-array/src/length.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_length function. + +use crate::utils::{downcast_arg, make_scalar_function}; +use arrow_array::{ + Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, +}; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use core::any::type_name; +use datafusion_common::cast::{as_generic_list_array, as_int64_array}; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayLength, + array_length, + array, + "returns the length of the array dimension.", + array_length_udf +); + +#[derive(Debug)] +pub(super) struct ArrayLength { + signature: Signature, + aliases: Vec, +} +impl ArrayLength { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("array_length"), String::from("list_length")], + } + } +} + +impl ScalarUDFImpl for ArrayLength { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_length_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_length SQL function +pub fn array_length_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + + match &args[0].data_type() { + List(_) => general_array_length::(args), + LargeList(_) => general_array_length::(args), + array_type => exec_err!("array_length does not support type '{array_type:?}'"), + } +} + +/// Dispatch array length computation based on the offset type. +fn general_array_length(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() + } else { + Int64Array::from_value(1, list_array.len()) + }; + + let result = list_array + .iter() + .zip(dimension.iter()) + .map(|(arr, dim)| compute_array_length(arr, dim)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the length of a concrete array dimension +fn compute_array_length( + arr: Option, + dimension: Option, +) -> Result> { + let mut current_dimension: i64 = 1; + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + let dimension = match dimension { + Some(value) => { + if value < 1 { + return Ok(None); + } + + value + } + None => return Ok(None), + }; + + loop { + if current_dimension == dimension { + return Ok(Some(value.len() as u64)); + } + + match value.data_type() { + List(..) => { + value = downcast_arg!(value, ListArray).value(0); + current_dimension += 1; + } + LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; + } + _ => return Ok(None), + } + } +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 2c19dfad6222..7c261f958bf0 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -29,16 +29,26 @@ pub mod macros; mod array_has; +mod cardinality; mod concat; -mod core; +mod dimension; +mod empty; mod except; mod extract; -mod kernels; +mod flatten; +mod length; +mod make_array; mod position; +mod range; mod remove; +mod repeat; +mod replace; +mod resize; +mod reverse; mod rewrite; mod set_ops; -mod udf; +mod sort; +mod string; mod utils; use datafusion_common::Result; @@ -52,49 +62,52 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; - pub use super::core::make_array; + pub use super::dimension::array_dims; + pub use super::dimension::array_ndims; + pub use super::empty::array_empty; pub use super::except::array_except; pub use super::extract::array_element; pub use super::extract::array_pop_back; pub use super::extract::array_pop_front; pub use super::extract::array_slice; + pub use super::flatten::flatten; + pub use super::length::array_length; + pub use super::make_array::make_array; pub use super::position::array_position; pub use super::position::array_positions; + pub use super::range::gen_series; + pub use super::range::range; pub use super::remove::array_remove; pub use super::remove::array_remove_all; pub use super::remove::array_remove_n; + pub use super::repeat::array_repeat; + pub use super::replace::array_replace; + pub use super::replace::array_replace_all; + pub use super::replace::array_replace_n; + pub use super::resize::array_resize; + pub use super::reverse::array_reverse; pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; - pub use super::udf::array_dims; - pub use super::udf::array_empty; - pub use super::udf::array_length; - pub use super::udf::array_ndims; - pub use super::udf::array_repeat; - pub use super::udf::array_resize; - pub use super::udf::array_reverse; - pub use super::udf::array_sort; - pub use super::udf::array_to_string; - pub use super::udf::cardinality; - pub use super::udf::flatten; - pub use super::udf::gen_series; - pub use super::udf::range; - pub use super::udf::string_to_array; + pub use super::sort::array_sort; + pub use super::string::array_to_string; + pub use super::string::string_to_array; } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ - udf::array_to_string_udf(), - udf::string_to_array_udf(), - udf::range_udf(), - udf::gen_series_udf(), - udf::array_dims_udf(), - udf::cardinality_udf(), - udf::array_ndims_udf(), + string::array_to_string_udf(), + string::string_to_array_udf(), + range::range_udf(), + range::gen_series_udf(), + dimension::array_dims_udf(), + cardinality::cardinality_udf(), + dimension::array_ndims_udf(), concat::array_append_udf(), concat::array_prepend_udf(), concat::array_concat_udf(), @@ -103,25 +116,28 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { extract::array_pop_back_udf(), extract::array_pop_front_udf(), extract::array_slice_udf(), - core::make_array_udf(), + make_array::make_array_udf(), array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), - udf::array_empty_udf(), - udf::array_length_udf(), - udf::flatten_udf(), - udf::array_sort_udf(), - udf::array_repeat_udf(), - udf::array_resize_udf(), - udf::array_reverse_udf(), + empty::array_empty_udf(), + length::array_length_udf(), + flatten::flatten_udf(), + sort::array_sort_udf(), + repeat::array_repeat_udf(), + resize::array_resize_udf(), + reverse::array_reverse_udf(), set_ops::array_distinct_udf(), set_ops::array_intersect_udf(), set_ops::array_union_udf(), position::array_position_udf(), position::array_positions_udf(), remove::array_remove_udf(), - remove::array_remove_n_udf(), remove::array_remove_all_udf(), + remove::array_remove_n_udf(), + replace::array_replace_n_udf(), + replace::array_replace_all_udf(), + replace::array_replace_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions-array/src/core.rs b/datafusion/functions-array/src/make_array.rs similarity index 91% rename from datafusion/functions-array/src/core.rs rename to datafusion/functions-array/src/make_array.rs index 4c84b7018c99..8eaae09f28f5 100644 --- a/datafusion/functions-array/src/core.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// core array function like `make_array` +//! [`ScalarUDFImpl`] definitions for `make_array` function. use std::{any::Any, sync::Arc}; @@ -24,9 +24,9 @@ use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, NullArray, OffsetSizeTrait, }; use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::Result; -use datafusion_common::{plan_err, utils::array_into_list_array}; +use datafusion_common::{plan_err, utils::array_into_list_array, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ @@ -73,7 +73,7 @@ impl ScalarUDFImpl for MakeArray { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { match arg_types.len() { 0 => Ok(DataType::List(Arc::new(Field::new( "item", @@ -89,14 +89,12 @@ impl ScalarUDFImpl for MakeArray { } } - Ok(DataType::List(Arc::new(Field::new( - "item", expr_type, true, - )))) + Ok(List(Arc::new(Field::new("item", expr_type, true)))) } } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(make_array_inner)(args) } @@ -109,10 +107,10 @@ impl ScalarUDFImpl for MakeArray { /// Constructs an array using the input `data` as `ArrayRef`. /// Returns a reference-counted `Array` instance result. pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { - let mut data_type = DataType::Null; + let mut data_type = Null; for arg in arrays { let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&DataType::Null) { + if !arg_data_type.equals_datatype(&Null) { data_type = arg_data_type.clone(); break; } @@ -120,12 +118,11 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: - DataType::Null => { - let array = - new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); + Null => { + let array = new_null_array(&Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } - DataType::LargeList(..) => array_array::(arrays, data_type), + LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index 4988e0ded106..a5a7a7405aa9 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_position function. +//! [`ScalarUDFImpl`] definitions for array_position and array_positions functions. use arrow_schema::DataType::{LargeList, List, UInt64}; use arrow_schema::{DataType, Field}; @@ -27,15 +27,16 @@ use std::sync::Arc; use arrow_array::types::UInt64Type; use arrow_array::{ - Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, UInt64Array, + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, }; use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, internal_err}; +use datafusion_common::{exec_err, internal_err, Result}; use itertools::Itertools; +use crate::utils::{compare_element_to_list, make_scalar_function}; + make_udf_function!( ArrayPosition, array_position, @@ -77,16 +78,12 @@ impl ScalarUDFImpl for ArrayPosition { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(UInt64) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_position_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_position_inner)(args) } fn aliases(&self) -> &[String] { @@ -95,7 +92,7 @@ impl ScalarUDFImpl for ArrayPosition { } /// Array_position SQL function -pub fn array_position_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_position_inner(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("array_position expects two or three arguments"); } @@ -105,9 +102,7 @@ pub fn array_position_inner(args: &[ArrayRef]) -> datafusion_common::Result exec_err!("array_position does not support type '{array_type:?}'."), } } -fn general_position_dispatch( - args: &[ArrayRef], -) -> datafusion_common::Result { +fn general_position_dispatch(args: &[ArrayRef]) -> Result { let list_array = as_generic_list_array::(&args[0])?; let element_array = &args[1]; @@ -145,7 +140,7 @@ fn generic_position( list_array: &GenericListArray, element_array: &ArrayRef, arr_from: Vec, // 0-indexed -) -> datafusion_common::Result { +) -> Result { let mut data = Vec::with_capacity(list_array.len()); for (row_index, (list_array_row, &from)) in @@ -173,107 +168,6 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. -/// -/// # Arguments -/// -/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. -/// -/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. -/// -/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. -/// -/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. -/// -/// # Returns -/// -/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. -/// -/// # Example -/// -/// ```text -/// compare_element_to_list( -/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] -/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] -/// -/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] -/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] -/// ) -/// ``` -fn compare_element_to_list( - list_array_row: &dyn Array, - element_array: &dyn Array, - row_index: usize, - eq: bool, -) -> datafusion_common::Result { - if list_array_row.data_type() != element_array.data_type() { - return exec_err!( - "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", - list_array_row.data_type(), - element_array.data_type() - ); - } - - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let res = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - DataType::LargeList(_) => { - // compare each element of the from array - let element_array_row_inner = - as_large_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_large_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so we can compare NULL - if eq { - arrow::compute::kernels::cmp::not_distinct(&list_array_row, &element_arr)? - } else { - arrow::compute::kernels::cmp::distinct(&list_array_row, &element_arr)? - } - } - }; - - Ok(res) -} - make_udf_function!( ArrayPositions, array_positions, @@ -311,16 +205,12 @@ impl ScalarUDFImpl for ArrayPositions { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_positions_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_positions_inner)(args) } fn aliases(&self) -> &[String] { @@ -329,7 +219,7 @@ impl ScalarUDFImpl for ArrayPositions { } /// Array_positions SQL function -pub fn array_positions_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_positions_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_positions expects two arguments"); } @@ -337,12 +227,12 @@ pub fn array_positions_inner(args: &[ArrayRef]) -> datafusion_common::Result { + List(_) => { let arr = as_list_array(&args[0])?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } - DataType::LargeList(_) => { + LargeList(_) => { let arr = as_large_list_array(&args[0])?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) @@ -356,7 +246,7 @@ pub fn array_positions_inner(args: &[ArrayRef]) -> datafusion_common::Result( list_array: &GenericListArray, element_array: &ArrayRef, -) -> datafusion_common::Result { +) -> Result { let mut data = Vec::with_capacity(list_array.len()); for (row_index, list_array_row) in list_array.iter().enumerate() { diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs new file mode 100644 index 000000000000..1c9e0c878e6e --- /dev/null +++ b/datafusion/functions-array/src/range.rs @@ -0,0 +1,328 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for range and gen_series functions. + +use arrow::array::{Array, ArrayRef, Int64Array, ListArray}; +use arrow::datatypes::{DataType, Field}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use std::any::Any; + +use crate::utils::make_scalar_function; +use arrow_array::types::{Date32Type, IntervalMonthDayNanoType}; +use arrow_array::Date32Array; +use arrow_schema::DataType::{Date32, Int64, Interval, List}; +use arrow_schema::IntervalUnit::MonthDayNano; +use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array}; +use datafusion_common::{exec_err, not_impl_datafusion_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::Expr; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::sync::Arc; + +make_udf_function!( + Range, + range, + start stop step, + "create a list of values in the range between start and stop", + range_udf +); +#[derive(Debug)] +pub(super) struct Range { + signature: Signature, + aliases: Vec, +} +impl Range { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![Int64]), + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Int64, Int64, Int64]), + TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("range")], + } + } +} +impl ScalarUDFImpl for Range { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "range" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + Int64 => make_scalar_function(|args| gen_range_inner(args, false))(args), + Date32 => make_scalar_function(|args| gen_range_date(args, false))(args), + _ => { + exec_err!("unsupported type for range") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + GenSeries, + gen_series, + start stop step, + "create a list of values in the range between start and stop, include upper bound", + gen_series_udf +); +#[derive(Debug)] +pub(super) struct GenSeries { + signature: Signature, + aliases: Vec, +} +impl GenSeries { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![Int64]), + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Int64, Int64, Int64]), + TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("generate_series")], + } + } +} +impl ScalarUDFImpl for GenSeries { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "generate_series" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + Int64 => make_scalar_function(|args| gen_range_inner(args, true))(args), + Date32 => make_scalar_function(|args| gen_range_date(args, true))(args), + _ => { + exec_err!("unsupported type for range") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub(super) fn gen_range_inner( + args: &[ArrayRef], + include_upper: bool, +) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), + }; + + let mut values = vec![]; + let mut offsets = vec![0]; + let mut valid = BooleanBufferBuilder::new(stop_array.len()); + for (idx, stop) in stop_array.iter().enumerate() { + match retrieve_range_args(start_array, stop, step_array, idx) { + Some((_, _, 0)) => { + return exec_err!( + "step can't be 0 for function {}(start [, stop, step])", + if include_upper { + "generate_series" + } else { + "range" + } + ); + } + Some((start, stop, step)) => { + // Below, we utilize `usize` to represent steps. + // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. + let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| { + not_impl_datafusion_err!("step {} can't fit into usize", step) + })?; + values.extend( + gen_range_iter(start, stop, step < 0, include_upper) + .step_by(step_abs), + ); + offsets.push(values.len() as i32); + valid.append(true); + } + // If any of the arguments is NULL, append a NULL value to the result. + None => { + offsets.push(values.len() as i32); + valid.append(false); + } + }; + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + Some(NullBuffer::new(valid.finish())), + )?); + Ok(arr) +} + +/// Get the (start, stop, step) args for the range and generate_series function. +/// If any of the arguments is NULL, returns None. +fn retrieve_range_args( + start_array: Option<&Int64Array>, + stop: Option, + step_array: Option<&Int64Array>, + idx: usize, +) -> Option<(i64, i64, i64)> { + // Default start value is 0 if not provided + let start = + start_array.map_or(Some(0), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; + let stop = stop?; + // Default step value is 1 if not provided + let step = + step_array.map_or(Some(1), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; + Some((start, stop, step)) +} + +/// Returns an iterator of i64 values from start to stop +fn gen_range_iter( + start: i64, + stop: i64, + decreasing: bool, + include_upper: bool, +) -> Box> { + match (decreasing, include_upper) { + // Decreasing range, stop is inclusive + (true, true) => Box::new((stop..=start).rev()), + // Decreasing range, stop is exclusive + (true, false) => { + if stop == i64::MAX { + // start is never greater than stop, and stop is exclusive, + // so the decreasing range must be empty. + Box::new(std::iter::empty()) + } else { + // Increase the stop value by one to exclude it. + // Since stop is not i64::MAX, `stop + 1` will not overflow. + Box::new((stop + 1..=start).rev()) + } + } + // Increasing range, stop is inclusive + (false, true) => Box::new(start..=stop), + // Increasing range, stop is exclusive + (false, false) => Box::new(start..stop), + } +} + +fn gen_range_date(args: &[ArrayRef], include_upper: bool) -> Result { + if args.len() != 3 { + return exec_err!("arguments length does not match"); + } + let (start_array, stop_array, step_array) = ( + Some(as_date32_array(&args[0])?), + as_date32_array(&args[1])?, + Some(as_interval_mdn_array(&args[2])?), + ); + + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let mut stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|x| x.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); + let neg = months < 0 || days < 0; + if !include_upper { + stop = Date32Type::subtract_month_day_nano(stop, step); + } + let mut new_date = start; + loop { + if neg && new_date < stop || !neg && new_date > stop { + break; + } + values.push(new_date); + new_date = Date32Type::add_month_day_nano(new_date, step); + } + offsets.push(values.len() as i32); + } + + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", Date32, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Date32Array::from(values)), + None, + )?); + Ok(arr) +} diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 91c76a6708dc..21e373081054 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -18,6 +18,7 @@ //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. use crate::utils; +use crate::utils::make_scalar_function; use arrow_array::cast::AsArray; use arrow_array::{ new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, @@ -25,7 +26,7 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -58,6 +59,7 @@ impl ScalarUDFImpl for ArrayRemove { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_remove" } @@ -66,13 +68,12 @@ impl ScalarUDFImpl for ArrayRemove { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_remove_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_remove_inner)(args) } fn aliases(&self) -> &[String] { @@ -107,6 +108,7 @@ impl ScalarUDFImpl for ArrayRemoveN { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_remove_n" } @@ -115,13 +117,12 @@ impl ScalarUDFImpl for ArrayRemoveN { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_remove_n_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_remove_n_inner)(args) } fn aliases(&self) -> &[String] { @@ -159,6 +160,7 @@ impl ScalarUDFImpl for ArrayRemoveAll { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_remove_all" } @@ -167,13 +169,12 @@ impl ScalarUDFImpl for ArrayRemoveAll { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_remove_all_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_remove_all_inner)(args) } fn aliases(&self) -> &[String] { @@ -182,7 +183,7 @@ impl ScalarUDFImpl for ArrayRemoveAll { } /// Array_remove SQL function -pub fn array_remove_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_remove_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_remove expects two arguments"); } @@ -192,7 +193,7 @@ pub fn array_remove_inner(args: &[ArrayRef]) -> datafusion_common::Result datafusion_common::Result { +pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result { if args.len() != 3 { return exec_err!("array_remove_n expects three arguments"); } @@ -202,7 +203,7 @@ pub fn array_remove_n_inner(args: &[ArrayRef]) -> datafusion_common::Result datafusion_common::Result { +pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_remove_all expects two arguments"); } @@ -215,7 +216,7 @@ fn array_remove_internal( array: &ArrayRef, element_array: &ArrayRef, arr_n: Vec, -) -> datafusion_common::Result { +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); @@ -252,7 +253,7 @@ fn general_remove( list_array: &GenericListArray, element_array: &ArrayRef, arr_n: Vec, -) -> datafusion_common::Result { +) -> Result { let data_type = list_array.value_type(); let mut new_values = vec![]; // Build up the offsets for the final output array diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs new file mode 100644 index 000000000000..89b766bdcdfc --- /dev/null +++ b/datafusion/functions-array/src/repeat.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_repeat function. + +use crate::utils::make_scalar_function; +use arrow::array::{Capacities, MutableArrayData}; +use arrow::compute; +use arrow_array::{ + new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray, + OffsetSizeTrait, +}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType::{LargeList, List}; +use arrow_schema::{DataType, Field}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayRepeat, + array_repeat, + element count, // arg name + "returns an array containing element `count` times.", // doc + array_repeat_udf // internal function name +); +#[derive(Debug)] +pub(super) struct ArrayRepeat { + signature: Signature, + aliases: Vec, +} + +impl ArrayRepeat { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("array_repeat"), String::from("list_repeat")], + } + } +} + +impl ScalarUDFImpl for ArrayRepeat { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_repeat_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_repeat SQL function +pub fn array_repeat_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + + let element = &args[0]; + let count_array = as_int64_array(&args[1])?; + + match element.data_type() { + List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat::(list_array, count_array) + } + LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) + } + _ => general_repeat::(element, count_array), + } +} + +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat( + array: &ArrayRef, + count_array: &Int64Array, +) -> Result { + let data_type = array.data_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (row_index, &count) in count_vec.iter().enumerate() { + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); + } + + let data = mutable.freeze(); + arrow_array::make_array(data) + }; + new_values.push(repeated_array); + } + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + None, + )?)) +} + +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &GenericListArray, + count_array: &Int64Array, +) -> Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + capacity, + ); + + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); + } + + let data = mutable.freeze(); + let repeated_array = arrow_array::make_array(data); + + let list_arr = GenericListArray::::try_new( + Arc::new(Field::new("item", value_type.clone(), true)), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), + }; + new_values.push(list_arr); + } + + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::::from_lengths(lengths), + values, + None, + )?)) +} diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs new file mode 100644 index 000000000000..c32305bb454b --- /dev/null +++ b/datafusion/functions-array/src/replace.rs @@ -0,0 +1,365 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions. + +use arrow::array::{ + Array, ArrayRef, AsArray, Capacities, MutableArrayData, OffsetSizeTrait, +}; +use arrow::datatypes::DataType; + +use arrow_array::GenericListArray; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_schema::Field; +use datafusion_common::cast::as_int64_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::Expr; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::compare_element_to_list; +use crate::utils::make_scalar_function; + +use std::any::Any; +use std::sync::Arc; + +// Create static instances of ScalarUDFs for each function +make_udf_function!(ArrayReplace, + array_replace, + array from to, + "replaces the first occurrence of the specified element with another specified element.", + array_replace_udf +); +make_udf_function!(ArrayReplaceN, + array_replace_n, + array from to max, + "replaces the first `max` occurrences of the specified element with another specified element.", + array_replace_n_udf +); +make_udf_function!(ArrayReplaceAll, + array_replace_all, + array from to, + "replaces all occurrences of the specified element with another specified element.", + array_replace_all_udf +); + +#[derive(Debug)] +pub(super) struct ArrayReplace { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplace { + pub fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + aliases: vec![String::from("array_replace"), String::from("list_replace")], + } + } +} + +impl ScalarUDFImpl for ArrayReplace { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_replace" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_replace_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(super) struct ArrayReplaceN { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplaceN { + pub fn new() -> Self { + Self { + signature: Signature::any(4, Volatility::Immutable), + aliases: vec![ + String::from("array_replace_n"), + String::from("list_replace_n"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayReplaceN { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_replace_n" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_replace_n_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(super) struct ArrayReplaceAll { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplaceAll { + pub fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + aliases: vec![ + String::from("array_replace_all"), + String::from("list_replace_all"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayReplaceAll { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_replace_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_replace_all_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// (\[`ListArray`\] of \[`ListArray`\]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &GenericListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array + let mut offsets: Vec = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut valid = BooleanBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; + } + + let start = offset_window[0]; + let end = offset_window[1]; + + let list_array_row = list_array.value(row_index); + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; + + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let n = arr_n[row_index]; + let mut counter = 0; + + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); + if let Some(true) = to_replace { + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); + break; + } + } else { + // copy original data for false / null matches + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); + } + } + + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), + )?)) +} + +pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + + // replace at most one occurrence for each element + let arr_n = vec![1; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } +} + +pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + + // replace the specified number of occurrences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } +} + +pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + + // replace all occurrences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } +} diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs new file mode 100644 index 000000000000..c5855d054494 --- /dev/null +++ b/datafusion/functions-array/src/resize.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_resize function. + +use crate::utils::make_scalar_function; +use arrow::array::{Capacities, MutableArrayData}; +use arrow_array::{ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; +use arrow_buffer::{ArrowNativeType, OffsetBuffer}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List}; +use arrow_schema::{DataType, FieldRef}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayResize, + array_resize, + array size value, + "returns an array with the specified size filled with the given value.", + array_resize_udf +); + +#[derive(Debug)] +pub(super) struct ArrayResize { + signature: Signature, + aliases: Vec, +} + +impl ArrayResize { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["array_resize".to_string(), "list_resize".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayResize { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_resize" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), + LargeList(field) => Ok(LargeList(field.clone())), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_resize_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// array_resize SQL function +pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { + if arg.len() < 2 || arg.len() > 3 { + return exec_err!("array_resize needs two or three arguments"); + } + + let new_len = as_int64_array(&arg[1])?; + let new_element = if arg.len() == 3 { + Some(arg[2].clone()) + } else { + None + }; + + match &arg[0].data_type() { + List(field) => { + let array = as_list_array(&arg[0])?; + general_list_resize::(array, new_len, field, new_element) + } + LargeList(field) => { + let array = as_large_list_array(&arg[0])?; + general_list_resize::(array, new_len, field, new_element) + } + array_type => exec_err!("array_resize does not support type '{array_type:?}'."), + } +} + +/// array_resize keep the original array and append the default element to the end +fn general_list_resize( + array: &GenericListArray, + count_array: &Int64Array, + field: &FieldRef, + default_element: Option, +) -> Result +where + O: TryInto, +{ + let data_type = array.value_type(); + + let values = array.values(); + let original_data = values.to_data(); + + // create default element array + let default_element = if let Some(default_element) = default_element { + default_element + } else { + let null_scalar = ScalarValue::try_from(&data_type)?; + null_scalar.to_array_of_size(original_data.len())? + }; + let default_value_data = default_element.to_data(); + + // create a mutable array to store the original data + let capacity = Capacities::Array(original_data.len() + default_value_data.len()); + let mut offsets = vec![O::usize_as(0)]; + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &default_value_data], + false, + capacity, + ); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let count = count_array.value(row_index).to_usize().ok_or_else(|| { + internal_datafusion_err!("array_resize: failed to convert size to usize") + })?; + let count = O::usize_as(count); + let start = offset_window[0]; + if start + count > offset_window[1] { + let extra_count = + (start + count - offset_window[1]).try_into().map_err(|_| { + internal_datafusion_err!( + "array_resize: failed to convert size to i64" + ) + })?; + let end = offset_window[1]; + mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + // append default element + for _ in 0..extra_count { + mutable.extend(1, row_index, row_index + 1); + } + } else { + let end = start + count; + mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + }; + offsets.push(offsets[row_index] + count); + } + + let data = mutable.freeze(); + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) +} diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs new file mode 100644 index 000000000000..8324c407bd86 --- /dev/null +++ b/datafusion/functions-array/src/reverse.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_reverse function. + +use crate::utils::make_scalar_function; +use arrow::array::{Capacities, MutableArrayData}; +use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType::{LargeList, List, Null}; +use arrow_schema::{DataType, FieldRef}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayReverse, + array_reverse, + array, + "reverses the order of elements in the array.", + array_reverse_udf +); + +#[derive(Debug)] +pub(super) struct ArrayReverse { + signature: Signature, + aliases: Vec, +} + +impl ArrayReverse { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + aliases: vec!["array_reverse".to_string(), "list_reverse".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayReverse { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_reverse_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// array_reverse SQL function +pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { + if arg.len() != 1 { + return exec_err!("array_reverse needs one argument"); + } + + match &arg[0].data_type() { + List(field) => { + let array = as_list_array(&arg[0])?; + general_array_reverse::(array, field) + } + LargeList(field) => { + let array = as_large_list_array(&arg[0])?; + general_array_reverse::(array, field) + } + Null => Ok(arg[0].clone()), + array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), + } +} + +fn general_array_reverse( + array: &GenericListArray, + field: &FieldRef, +) -> Result +where + O: TryFrom, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut offsets = vec![O::usize_as(0)]; + let mut nulls = vec![]; + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + // skip the null value + if array.is_null(row_index) { + nulls.push(false); + offsets.push(offsets[row_index] + O::one()); + mutable.extend(0, 0, 1); + continue; + } else { + nulls.push(true); + } + + let start = offset_window[0]; + let end = offset_window[1]; + + let mut index = end - O::one(); + let mut cnt = 0; + + while index >= start { + mutable.extend(0, index.to_usize().unwrap(), index.to_usize().unwrap() + 1); + index = index - O::one(); + cnt += 1; + } + offsets.push(offsets[row_index] + O::usize_as(cnt)); + } + + let data = mutable.freeze(); + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(nulls.into()), + )?)) +} diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 6a91e9078232..d231dce4cb68 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -23,6 +23,7 @@ use crate::extract::{array_element, array_slice}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::list_ndims; +use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; @@ -42,7 +43,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { expr: Expr, schema: &DFSchema, _config: &ConfigOptions, - ) -> datafusion_common::Result> { + ) -> Result> { let transformed = match expr { // array1 @> array2 -> array_has_all(array1, array2) Expr::BinaryExpr(BinaryExpr { left, op, right }) diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index df5bc91a2689..5f3087fafd6f 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! Array Intersection, Union, and Distinct functions +//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions. -use crate::core::make_array_inner; +use crate::make_array::make_array_inner; use crate::utils::make_scalar_function; use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::expr::ScalarFunction; @@ -48,7 +49,7 @@ make_udf_function!( ArrayIntersect, array_intersect, first_array second_array, - "Returns an array of the elements in the intersection of array1 and array2.", + "returns an array of the elements in the intersection of array1 and array2.", array_intersect_udf ); @@ -56,7 +57,7 @@ make_udf_function!( ArrayDistinct, array_distinct, array, - "return distinct values from the array after removing duplicates.", + "returns distinct values from the array after removing duplicates.", array_distinct_udf ); @@ -79,6 +80,7 @@ impl ScalarUDFImpl for ArrayUnion { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_union" } @@ -89,8 +91,8 @@ impl ScalarUDFImpl for ArrayUnion { fn return_type(&self, arg_types: &[DataType]) -> Result { match (&arg_types[0], &arg_types[1]) { - (&DataType::Null, dt) => Ok(dt.clone()), - (dt, DataType::Null) => Ok(dt.clone()), + (&Null, dt) => Ok(dt.clone()), + (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -126,6 +128,7 @@ impl ScalarUDFImpl for ArrayIntersect { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_intersect" } @@ -136,12 +139,8 @@ impl ScalarUDFImpl for ArrayIntersect { fn return_type(&self, arg_types: &[DataType]) -> Result { match (arg_types[0].clone(), arg_types[1].clone()) { - (DataType::Null, DataType::Null) | (DataType::Null, _) => Ok(DataType::Null), - (_, DataType::Null) => Ok(DataType::List(Arc::new(Field::new( - "item", - DataType::Null, - true, - )))), + (Null, Null) | (Null, _) => Ok(Null), + (_, Null) => Ok(List(Arc::new(Field::new("item", Null, true)))), (dt, _) => Ok(dt), } } @@ -174,6 +173,7 @@ impl ScalarUDFImpl for ArrayDistinct { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_distinct" } @@ -183,7 +183,6 @@ impl ScalarUDFImpl for ArrayDistinct { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; match &arg_types[0] { List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( "item", @@ -218,17 +217,17 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result { } // handle null - if args[0].data_type() == &DataType::Null { + if args[0].data_type() == &Null { return Ok(args[0].clone()); } // handle for list & largelist match args[0].data_type() { - DataType::List(field) => { + List(field) => { let array = as_list_array(&args[0])?; general_array_distinct(array, field) } - DataType::LargeList(field) => { + LargeList(field) => { let array = as_large_list_array(&args[0])?; general_array_distinct(array, field) } @@ -257,10 +256,10 @@ fn generic_set_lists( field: Arc, set_op: SetOp, ) -> Result { - if matches!(l.value_type(), DataType::Null) { + if matches!(l.value_type(), Null) { let field = Arc::new(Field::new("item", r.value_type(), true)); return general_array_distinct::(r, &field); - } else if matches!(r.value_type(), DataType::Null) { + } else if matches!(r.value_type(), Null) { let field = Arc::new(Field::new("item", l.value_type(), true)); return general_array_distinct::(l, &field); } @@ -331,43 +330,43 @@ fn general_set_op( set_op: SetOp, ) -> Result { match (array1.data_type(), array2.data_type()) { - (DataType::Null, DataType::List(field)) => { + (Null, List(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&DataType::Null)); + return Ok(new_empty_array(&Null)); } let array = as_list_array(&array2)?; general_array_distinct::(array, field) } - (DataType::List(field), DataType::Null) => { + (List(field), Null) => { if set_op == SetOp::Intersect { return make_array_inner(&[]); } let array = as_list_array(&array1)?; general_array_distinct::(array, field) } - (DataType::Null, DataType::LargeList(field)) => { + (Null, LargeList(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&DataType::Null)); + return Ok(new_empty_array(&Null)); } let array = as_large_list_array(&array2)?; general_array_distinct::(array, field) } - (DataType::LargeList(field), DataType::Null) => { + (LargeList(field), Null) => { if set_op == SetOp::Intersect { return make_array_inner(&[]); } let array = as_large_list_array(&array1)?; general_array_distinct::(array, field) } - (DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)), + (Null, Null) => Ok(new_empty_array(&Null)), - (DataType::List(field), DataType::List(_)) => { + (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; generic_set_lists::(array1, array2, field.clone(), set_op) } - (DataType::LargeList(field), DataType::LargeList(_)) => { + (LargeList(field), LargeList(_)) => { let array1 = as_large_list_array(&array1)?; let array2 = as_large_list_array(&array2)?; generic_set_lists::(array1, array2, field.clone(), set_op) diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs new file mode 100644 index 000000000000..af78712065fc --- /dev/null +++ b/datafusion/functions-array/src/sort.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_sort function. + +use crate::utils::make_scalar_function; +use arrow::compute; +use arrow_array::{Array, ArrayRef, ListArray}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List}; +use arrow_schema::{DataType, Field, SortOptions}; +use datafusion_common::cast::{as_list_array, as_string_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArraySort, + array_sort, + array desc null_first, + "returns sorted array.", + array_sort_udf +); + +#[derive(Debug)] +pub(super) struct ArraySort { + signature: Signature, + aliases: Vec, +} + +impl ArraySort { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["array_sort".to_string(), "list_sort".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySort { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_sort" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + LargeList(field) => Ok(LargeList(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_sort_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_sort SQL function +pub fn array_sort_inner(args: &[ArrayRef]) -> Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) + } + _ => return exec_err!("array_sort expects 1 to 3 arguments"), + }; + + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); + + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } + + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) +} + +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => exec_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs new file mode 100644 index 000000000000..38059035005b --- /dev/null +++ b/datafusion/functions-array/src/string.rs @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_to_string and string_to_array functions. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, ListBuilder, + OffsetSizeTrait, StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use arrow::datatypes::{DataType, Field}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{Expr, TypeSignature}; + +use datafusion_common::{plan_err, DataFusionError, Result}; + +use std::any::{type_name, Any}; + +use crate::utils::{downcast_arg, make_scalar_function}; +use arrow_schema::DataType::{FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8}; +use datafusion_common::cast::{ + as_generic_string_array, as_large_list_array, as_list_array, as_string_array, +}; +use datafusion_common::exec_err; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::sync::Arc; + +macro_rules! to_string { + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + for x in arr { + match x { + Some(x) => { + $ARG.push_str(&x.to_string()); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } + } + } + } + Ok($ARG) + }}; +} + +macro_rules! call_array_function { + ($DATATYPE:expr, false) => { + match $DATATYPE { + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }; + ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ + match $DATATYPE { + DataType::List(_) => array_function!(ListArray), + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }}; +} + +// Create static instances of ScalarUDFs for each function +make_udf_function!( + ArrayToString, + array_to_string, + array delimiter, // arg name + "converts each element to its text representation.", // doc + array_to_string_udf // internal function name +); +#[derive(Debug)] +pub(super) struct ArrayToString { + signature: Signature, + aliases: Vec, +} + +impl ArrayToString { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![ + String::from("array_to_string"), + String::from("list_to_string"), + String::from("array_join"), + String::from("list_join"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayToString { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_to_string" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, + _ => { + return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_to_string_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + StringToArray, + string_to_array, + string delimiter null_string, // arg name + "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc + string_to_array_udf // internal function name +); +#[derive(Debug)] +pub(super) struct StringToArray { + signature: Signature, + aliases: Vec, +} + +impl StringToArray { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), + TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![ + String::from("string_to_array"), + String::from("string_to_list"), + ], + } + } +} + +impl ScalarUDFImpl for StringToArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "string_to_array" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match arg_types[0] { + Utf8 | LargeUtf8 => { + List(Arc::new(Field::new("item", arg_types[0].clone(), true))) + } + _ => { + return plan_err!( + "The string_to_array function can only accept Utf8 or LargeUtf8." + ); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + Utf8 => make_scalar_function(string_to_array_inner::)(args), + LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), + other => { + exec_err!("unsupported type for string_to_array function as {other}") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_to_string SQL function +pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + + let arr = &args[0]; + + let delimiters = as_string_array(&args[1])?; + let delimiters: Vec> = delimiters.iter().collect(); + + let mut null_string = String::from(""); + let mut with_null_string = false; + if args.len() == 3 { + null_string = as_string_array(&args[2])?.value(0).to_string(); + with_null_string = true; + } + + fn compute_array_to_string( + arg: &mut String, + arr: ArrayRef, + delimiter: String, + null_string: String, + with_null_string: bool, + ) -> Result<&mut String> { + match arr.data_type() { + List(..) => { + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } + LargeList(..) => { + let list_array = as_large_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } + Null => Ok(arg), + data_type => { + macro_rules! array_function { + ($ARRAY_TYPE:ident) => { + to_string!( + arg, + arr, + &delimiter, + &null_string, + with_null_string, + $ARRAY_TYPE + ) + }; + } + call_array_function!(data_type, false) + } + } + } + + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); + } else { + res.push(Some(s)); + } + } else { + res.push(None); + } + } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + List(_) | FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); + // delimiter length is 1 + assert_eq!(delimiters.len(), 1); + let delimiter = delimiters[0].unwrap(); + let s = compute_array_to_string( + &mut arg, + arr.clone(), + delimiter.to_string(), + null_string, + with_null_string, + )? + .clone(); + + if !s.is_empty() { + let s = s.strip_suffix(delimiter).unwrap().to_string(); + res.push(Some(s)); + } else { + res.push(Some(s)); + } + StringArray::from(res) + } + }; + + Ok(Arc::new(string_arr)) +} + +/// String_to_array SQL function +/// Splits string at occurrences of delimiter and returns an array of parts +/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' +pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("string_to_array expects two or three arguments"); + } + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + + let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + )); + + match args.len() { + 2 => { + string_array.iter().zip(delimiter_array.iter()).for_each( + |(string, delimiter)| { + match (string, delimiter) { + (Some(string), Some("")) => { + list_builder.values().append_value(string); + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + list_builder.values().append_value(s); + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + list_builder.values().append_value(c); + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }, + ); + } + + 3 => { + let null_value_array = as_generic_string_array::(&args[2])?; + string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(s); + } + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }); + } + _ => { + return exec_err!( + "Expect string_to_array function to take two or three parameters" + ) + } + } + + let list_array = list_builder.finish(); + Ok(Arc::new(list_array) as ArrayRef) +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs deleted file mode 100644 index e0793900c6b3..000000000000 --- a/datafusion/functions-array/src/udf.rs +++ /dev/null @@ -1,869 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for array functions. - -use arrow::array::{NullArray, StringArray}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use arrow::datatypes::IntervalUnit::MonthDayNano; -use arrow_schema::DataType::{LargeUtf8, List, Utf8}; -use datafusion_common::exec_err; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::TypeSignature; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -// Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayToString, - array_to_string, - array delimiter, // arg name - "converts each element to its text representation.", // doc - array_to_string_udf // internal function name -); -#[derive(Debug)] -pub(super) struct ArrayToString { - signature: Signature, - aliases: Vec, -} - -impl ArrayToString { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("array_to_string"), - String::from("list_to_string"), - String::from("array_join"), - String::from("list_join"), - ], - } - } -} - -impl ScalarUDFImpl for ArrayToString { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_to_string" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, - _ => { - return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_to_string(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!(StringToArray, - string_to_array, - string delimiter null_string, // arg name - "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc - string_to_array_udf // internal function name -); -#[derive(Debug)] -pub(super) struct StringToArray { - signature: Signature, - aliases: Vec, -} - -impl StringToArray { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("string_to_array"), - String::from("string_to_list"), - ], - } - } -} - -impl ScalarUDFImpl for StringToArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "string_to_array" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { - List(Arc::new(Field::new("item", arg_types[0].clone(), true))) - } - _ => { - return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." - ); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let mut args = ColumnarValue::values_to_arrays(args)?; - // Case: delimiter is NULL, needs to be handled as well. - if args[1].as_any().is::() { - args[1] = Arc::new(StringArray::new_null(args[1].len())); - }; - - match args[0].data_type() { - Utf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - LargeUtf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - other => { - exec_err!("unsupported type for string_to_array function as {other}") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Range, - range, - start stop step, - "create a list of values in the range between start and stop", - range_udf -); -#[derive(Debug)] -pub(super) struct Range { - signature: Signature, - aliases: Vec, -} -impl Range { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("range")], - } - } -} -impl ScalarUDFImpl for Range { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "range" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - match args[0].data_type() { - arrow::datatypes::DataType::Int64 => { - crate::kernels::gen_range(&args, false).map(ColumnarValue::Array) - } - arrow::datatypes::DataType::Date32 => { - crate::kernels::gen_range_date(&args, false).map(ColumnarValue::Array) - } - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - GenSeries, - gen_series, - start stop step, - "create a list of values in the range between start and stop, include upper bound", - gen_series_udf -); -#[derive(Debug)] -pub(super) struct GenSeries { - signature: Signature, - aliases: Vec, -} -impl GenSeries { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("generate_series")], - } - } -} -impl ScalarUDFImpl for GenSeries { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "generate_series" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - match args[0].data_type() { - arrow::datatypes::DataType::Int64 => { - crate::kernels::gen_range(&args, true).map(ColumnarValue::Array) - } - arrow::datatypes::DataType::Date32 => { - crate::kernels::gen_range_date(&args, true).map(ColumnarValue::Array) - } - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayDims, - array_dims, - array, - "returns an array of the array's dimensions.", - array_dims_udf -); - -#[derive(Debug)] -pub(super) struct ArrayDims { - signature: Signature, - aliases: Vec, -} - -impl ArrayDims { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_dims".to_string(), "list_dims".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayDims { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_dims" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new("item", UInt64, true))) - } - _ => { - return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_dims(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArraySort, - array_sort, - array desc null_first, - "returns sorted array.", - array_sort_udf -); - -#[derive(Debug)] -pub(super) struct ArraySort { - signature: Signature, - aliases: Vec, -} - -impl ArraySort { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_sort".to_string(), "list_sort".to_string()], - } - } -} - -impl ScalarUDFImpl for ArraySort { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_sort" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_sort(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayResize, - array_resize, - array size value, - "returns an array with the specified size filled with the given value.", - array_resize_udf -); - -#[derive(Debug)] -pub(super) struct ArrayResize { - signature: Signature, - aliases: Vec, -} - -impl ArrayResize { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_resize".to_string(), "list_resize".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayResize { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_resize" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), - LargeList(field) => Ok(LargeList(field.clone())), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_resize(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Cardinality, - cardinality, - array, - "returns the total number of elements in the array.", - cardinality_udf -); - -impl Cardinality { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("cardinality")], - } - } -} - -#[derive(Debug)] -pub(super) struct Cardinality { - signature: Signature, - aliases: Vec, -} -impl ScalarUDFImpl for Cardinality { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "cardinality" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::cardinality(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayNdims, - array_ndims, - array, - "returns the number of dimensions of the array.", - array_ndims_udf -); - -#[derive(Debug)] -pub(super) struct ArrayNdims { - signature: Signature, - aliases: Vec, -} -impl ArrayNdims { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("array_ndims"), String::from("list_ndims")], - } - } -} - -impl ScalarUDFImpl for ArrayNdims { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_ndims" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_ndims(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayEmpty, - array_empty, - array, - "returns true for an empty array or false for a non-empty array.", - array_empty_udf -); - -#[derive(Debug)] -pub(super) struct ArrayEmpty { - signature: Signature, - aliases: Vec, -} -impl ArrayEmpty { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("empty")], - } - } -} - -impl ScalarUDFImpl for ArrayEmpty { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "empty" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_empty(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayRepeat, - array_repeat, - element count, // arg name - "returns an array containing element `count` times.", // doc - array_repeat_udf // internal function name -); -#[derive(Debug)] -pub(super) struct ArrayRepeat { - signature: Signature, - aliases: Vec, -} - -impl ArrayRepeat { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_repeat"), String::from("list_repeat")], - } - } -} - -impl ScalarUDFImpl for ArrayRepeat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_repeat" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_repeat(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayLength, - array_length, - array, - "returns the length of the array dimension.", - array_length_udf -); - -#[derive(Debug)] -pub(super) struct ArrayLength { - signature: Signature, - aliases: Vec, -} -impl ArrayLength { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_length"), String::from("list_length")], - } - } -} - -impl ScalarUDFImpl for ArrayLength { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_length" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_length(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Flatten, - flatten, - array, - "flattens an array of arrays into a single array.", - flatten_udf -); - -#[derive(Debug)] -pub(super) struct Flatten { - signature: Signature, - aliases: Vec, -} -impl Flatten { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("flatten")], - } - } -} - -impl ScalarUDFImpl for Flatten { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "flatten" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) - } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(field.clone())), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - let data_type = get_base_type(&arg_types[0])?; - Ok(data_type) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::flatten(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayReverse, - array_reverse, - array, - "reverses the order of elements in the array.", - array_reverse_udf -); - -#[derive(Debug)] -pub(super) struct ArrayReverse { - signature: Signature, - aliases: Vec, -} - -impl crate::udf::ArrayReverse { - pub fn new() -> Self { - Self { - signature: Signature::any(1, Volatility::Immutable), - aliases: vec!["array_reverse".to_string(), "list_reverse".to_string()], - } - } -} - -impl ScalarUDFImpl for crate::udf::ArrayReverse { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_reserse" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_reverse(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index ad613163c6af..d86e4fe2ab7b 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -20,15 +20,32 @@ use std::sync::Arc; use arrow::{array::ArrayRef, datatypes::DataType}; + use arrow_array::{ - Array, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, + Array, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, + UInt32Array, }; use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; + +use core::any::type_name; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast to {}", + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} +pub(crate) use downcast_arg; + pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); if !args.iter().all(|arg| { @@ -202,9 +219,9 @@ pub(crate) fn compare_element_to_list( let element_arr = Scalar::new(element_array_row); // use not_distinct so we can compare NULL if eq { - arrow::compute::kernels::cmp::not_distinct(&list_array_row, &element_arr)? + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? } else { - arrow::compute::kernels::cmp::distinct(&list_array_row, &element_arr)? + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? } } }; @@ -212,6 +229,30 @@ pub(crate) fn compare_element_to_list( Ok(res) } +/// Returns the length of each array dimension +pub(crate) fn compute_array_dims( + arr: Option, +) -> Result>>> { + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + if value.is_empty() { + return Ok(None); + } + let mut res = vec![Some(value.len() as u64)]; + + loop { + match value.data_type() { + DataType::List(..) => { + value = downcast_arg!(value, ListArray).value(0); + res.push(Some(value.len() as u64)); + } + _ => return Ok(Some(res)), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 5a6da5345d7c..425ac207c33e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -31,6 +31,7 @@ rust-version = { workspace = true } [features] # enable core functions core_expressions = [] +crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] # enable datetime functions datetime_expressions = [] # Enable encoding by default so the doctests work. In general don't automatically enable all packages. @@ -41,6 +42,8 @@ default = [ "math_expressions", "regex_expressions", "crypto_expressions", + "string_expressions", + "unicode_expressions", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] @@ -48,7 +51,11 @@ encoding_expressions = ["base64", "hex"] math_expressions = [] # enable regular expressions regex_expressions = ["regex"] -crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] +# enable string functions +string_expressions = ["uuid"] +# enable unicode functions +unicode_expressions = ["hashbrown", "unicode-segmentation"] + [lib] name = "datafusion_functions" path = "src/lib.rs" @@ -65,12 +72,16 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } +hashbrown = { version = "0.14", features = ["raw"], optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } +unicode-segmentation = { version = "^1.7.1", optional = true } +uuid = { version = "1.7", features = ["v4"], optional = true } + [dev-dependencies] criterion = "0.5" rand = { workspace = true } @@ -80,15 +91,19 @@ tokio = { workspace = true, features = ["macros", "rt", "sync"] } [[bench]] harness = false name = "to_timestamp" +required-features = ["datetime_expressions"] [[bench]] harness = false name = "regx" +required-features = ["regex_expressions"] [[bench]] harness = false name = "make_date" +required-features = ["datetime_expressions"] [[bench]] harness = false name = "to_char" +required-features = ["datetime_expressions"] diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 5831e263b4eb..da4882381e76 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -44,7 +44,7 @@ fn data(rng: &mut ThreadRng) -> StringArray { } fn regex(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ ".*([A-Z]{1}).*".to_string(), "^(A).*".to_string(), r#"[\p{Letter}-]+"#.to_string(), @@ -60,7 +60,7 @@ fn regex(rng: &mut ThreadRng) -> StringArray { } fn flags(rng: &mut ThreadRng) -> StringArray { - let samples = vec![Some("i".to_string()), Some("im".to_string()), None]; + let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); for _ in 0..1000 { let sample = samples.choose(rng).unwrap(); @@ -103,20 +103,6 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("regexp_match_1000", |b| { - let mut rng = rand::thread_rng(); - let data = Arc::new(data(&mut rng)) as ArrayRef; - let regex = Arc::new(regex(&mut rng)) as ArrayRef; - let flags = Arc::new(flags(&mut rng)) as ArrayRef; - - b.iter(|| { - black_box( - regexp_match::(&[data.clone(), regex.clone(), flags.clone()]) - .expect("regexp_match should work on valid values"), - ) - }) - }); - c.bench_function("regexp_replace_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 45a40f175da4..d9a153e64abc 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -64,7 +64,7 @@ fn data(rng: &mut ThreadRng) -> Date32Array { } fn patterns(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ "%Y:%m:%d".to_string(), "%d-%m-%Y".to_string(), "%d%m%Y".to_string(), diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs similarity index 90% rename from datafusion/sql/src/expr/arrow_cast.rs rename to datafusion/functions/src/core/arrow_cast.rs index 9a0d61f41c01..b6c1b5eb9a38 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -15,87 +15,135 @@ // specific language governing permissions and limitations // under the License. -//! Implementation of the `arrow_cast` function that allows -//! casting to arbitrary arrow types (rather than SQL types) +//! [`ArrowCastFunc`]: Implementation of the `arrow_cast` +use std::any::Any; use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{ - plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, + ScalarValue, }; -use datafusion_common::plan_err; -use datafusion_expr::{Expr, ExprSchemable}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -pub const ARROW_CAST_NAME: &str = "arrow_cast"; - -/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// Implements casting to arbitrary arrow types (rather than SQL types) +/// +/// Note that the `arrow_cast` function is somewhat special in that its +/// return depends only on the *value* of its second argument (not its type) /// -/// This function is not a [`BuiltinScalarFunction`] because the -/// return type of [`BuiltinScalarFunction`] depends only on the -/// *types* of the arguments. However, the type of `arrow_type` depends on -/// the *value* of its second argument. +/// It is implemented by calling the same underlying arrow `cast` kernel as +/// normal SQL casts. /// -/// Use the `cast` function to cast to SQL type (which is then mapped -/// to the corresponding arrow type). For example to cast to `int` -/// (which is then mapped to the arrow type `Int32`) +/// For example to cast to `int` using SQL (which is then mapped to the arrow +/// type `Int32`) /// /// ```sql /// select cast(column_x as int) ... /// ``` /// -/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// You can use the `arrow_cast` functiont to cast to a specific arrow type /// /// For example /// ```sql /// select arrow_cast(column_x, 'Float64') /// ``` -/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction -pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { +#[derive(Debug)] +pub(super) struct ArrowCastFunc { + signature: Signature, +} + +impl ArrowCastFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ArrowCastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrow_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default + // implementation + internal_err!("arrow_cast should return type from exprs") + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + _arg_types: &[DataType], + ) -> Result { + data_type_from_args(args) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + internal_err!("arrow_cast should have been simplified to cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + // convert this into a real cast + let target_type = data_type_from_args(&args)?; + // remove second (type) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + arg + } else { + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + data_type: target_type, + }) + }; + // return the newly written argument to DataFusion + Ok(ExprSimplifyResult::Simplified(new_expr)) + } +} + +/// Returns the requested type from the arguments +fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let arg1 = args.pop().unwrap(); - let arg0 = args.pop().unwrap(); - - // arg1 must be a string - let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { - v - } else { + let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { return plan_err!( - "arrow_cast requires its second argument to be a constant string, got {arg1}" + "arrow_cast requires its second argument to be a constant string, got {:?}", + &args[1] ); }; - - // do the actual lookup to the appropriate data type - let data_type = parse_data_type(&data_type_string)?; - - arg0.cast_to(&data_type, schema) + parse_data_type(val) } /// Parses `str` into a `DataType`. /// -/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// `parse_data_type` is the reverse of [`DataType`]'s `Display` /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// -/// Example: -/// ``` -/// # use datafusion_sql::parse_data_type; -/// # use arrow_schema::DataType; -/// let display_value = "Int32"; -/// -/// // "Int32" is the Display value of `DataType` -/// assert_eq!(display_value, &format!("{}", DataType::Int32)); -/// -/// // parse_data_type coverts "Int32" back to `DataType`: -/// let data_type = parse_data_type(display_value).unwrap(); -/// assert_eq!(data_type, DataType::Int32); -/// ``` -/// /// Remove if added to arrow: -pub fn parse_data_type(val: &str) -> Result { +fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } @@ -647,8 +695,6 @@ impl Display for Token { #[cfg(test)] mod test { - use arrow_schema::{IntervalUnit, TimeUnit}; - use super::*; #[test] @@ -844,7 +890,6 @@ mod test { assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); } } - println!(" Ok"); } } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 73cc4d18bf9f..85d2410251c5 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -17,27 +17,33 @@ //! "core" DataFusion functions +mod arrow_cast; mod arrowtypeof; mod getfield; +mod named_struct; mod nullif; mod nvl; mod nvl2; mod r#struct; // create UDFs +make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); make_udf_function!(nvl::NVLFunc, NVL, nvl); make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); +make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), + (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."), (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), (r#struct, args, "Returns a struct with the given arguments"), + (named_struct, args, "Returns a struct with the given names and arguments pairs"), (get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct") ); diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs new file mode 100644 index 000000000000..327a41baa741 --- /dev/null +++ b/datafusion/functions/src/core/named_struct.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::StructArray; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// put values in a struct array. +fn named_struct_expr(args: &[ColumnarValue]) -> Result { + // do not accept 0 arguments. + if args.is_empty() { + return exec_err!( + "named_struct requires at least one pair of arguments, got 0 instead" + ); + } + + if args.len() % 2 != 0 { + return exec_err!( + "named_struct requires an even number of arguments, got {} instead", + args.len() + ); + } + + let (names, values): (Vec<_>, Vec<_>) = args + .chunks_exact(2) + .enumerate() + .map(|(i, chunk)| { + + let name_column = &chunk[0]; + + let name = match name_column { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, + _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) + }; + + Ok((name, chunk[1].clone())) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let arrays = ColumnarValue::values_to_arrays(&values)?; + + let fields = names + .into_iter() + .zip(arrays) + .map(|(name, value)| { + ( + Arc::new(Field::new(name, value.data_type().clone(), true)), + value, + ) + }) + .collect::>(); + + Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields)))) +} + +#[derive(Debug)] +pub(super) struct NamedStructFunc { + signature: Signature, +} + +impl NamedStructFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for NamedStructFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "named_struct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "named_struct: return_type called instead of return_type_from_exprs" + ) + } + + fn return_type_from_exprs( + &self, + args: &[datafusion_expr::Expr], + schema: &dyn datafusion_common::ExprSchema, + _arg_types: &[DataType], + ) -> Result { + // do not accept 0 arguments. + if args.is_empty() { + return exec_err!( + "named_struct requires at least one pair of arguments, got 0 instead" + ); + } + + if args.len() % 2 != 0 { + return exec_err!( + "named_struct requires an even number of arguments, got {} instead", + args.len() + ); + } + + let return_fields = args + .chunks_exact(2) + .enumerate() + .map(|(i, chunk)| { + let name = &chunk[0]; + let value = &chunk[1]; + + if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { + Ok(Field::new(name, value.get_type(schema)?, true)) + } else { + exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) + } + }) + .collect::>>()?; + Ok(DataType::Struct(Fields::from(return_fields))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + named_struct_expr(args) + } +} diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 2a8622f0a1ec..ac300e0abde3 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -47,17 +47,10 @@ fn array_struct(args: &[ArrayRef]) -> Result { Ok(Arc::new(StructArray::from(vec))) } + /// put values in a struct array. fn struct_expr(args: &[ColumnarValue]) -> Result { - let arrays = args - .iter() - .map(|x| { - Ok(match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array()?.clone(), - }) - }) - .collect::>>()?; + let arrays = ColumnarValue::values_to_arrays(args)?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } #[derive(Debug)] diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index 007ffd35ca3a..f0689ffd64e9 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -22,8 +22,9 @@ use arrow::array::{ }; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; +use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::LocalResult::Single; -use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; +use chrono::{DateTime, TimeZone, Utc}; use itertools::Either; use datafusion_common::cast::as_generic_string_array; @@ -84,12 +85,15 @@ pub(crate) fn string_to_datetime_formatted( )) }; + let mut parsed = Parsed::new(); + parse(&mut parsed, s, StrftimeItems::new(format)).map_err(|e| err(&e.to_string()))?; + // attempt to parse the string assuming it has a timezone - let dt = DateTime::parse_from_str(s, format); + let dt = parsed.to_datetime(); if let Err(e) = &dt { // no timezone or other failure, try without a timezone - let ndt = NaiveDateTime::parse_from_str(s, format); + let ndt = parsed.to_naive_datetime_with_offset(0); if let Err(e) = &ndt { return Err(err(&e.to_string())); } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 1f00f5bc3137..b41f7e13cff2 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -18,18 +18,19 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::types::ArrowTemporalType; -use arrow::array::{Array, ArrayRef, ArrowNumericType, Float64Array, PrimitiveArray}; -use arrow::compute::cast; -use arrow::compute::kernels::temporal; -use arrow::datatypes::DataType::{Date32, Date64, Float64, Timestamp, Utf8}; +use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::compute::{binary, cast, date_part, DatePart}; +use arrow::datatypes::DataType::{ + Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, +}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_timestamp_microsecond_array, - as_timestamp_millisecond_array, as_timestamp_nanosecond_array, - as_timestamp_second_array, + as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, + as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, }; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; @@ -70,6 +71,10 @@ impl DatePartFunc { ]), Exact(vec![Utf8, Date64]), Exact(vec![Utf8, Date32]), + Exact(vec![Utf8, Time32(Second)]), + Exact(vec![Utf8, Time32(Millisecond)]), + Exact(vec![Utf8, Time64(Microsecond)]), + Exact(vec![Utf8, Time64(Nanosecond)]), ], Volatility::Immutable, ), @@ -78,46 +83,6 @@ impl DatePartFunc { } } -macro_rules! extract_date_part { - ($ARRAY: expr, $FN:expr) => { - match $ARRAY.data_type() { - DataType::Date32 => { - let array = as_date32_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - DataType::Date64 => { - let array = as_date64_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => { - let array = as_timestamp_second_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Millisecond => { - let array = as_timestamp_millisecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Microsecond => { - let array = as_timestamp_microsecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Nanosecond => { - let array = as_timestamp_nanosecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - }, - datatype => exec_err!("Extract does not support datatype {:?}", datatype), - } - }; -} - impl ScalarUDFImpl for DatePartFunc { fn as_any(&self) -> &dyn Any { self @@ -139,16 +104,15 @@ impl ScalarUDFImpl for DatePartFunc { if args.len() != 2 { return exec_err!("Expected two arguments in DATE_PART"); } - let (date_part, array) = (&args[0], &args[1]); + let (part, array) = (&args[0], &args[1]); - let date_part = - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = date_part { - v - } else { - return exec_err!( - "First argument of `DATE_PART` must be non-null scalar Utf8" - ); - }; + let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { + v + } else { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ); + }; let is_scalar = matches!(array, ColumnarValue::Scalar(_)); @@ -157,28 +121,28 @@ impl ScalarUDFImpl for DatePartFunc { ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; - let arr = match date_part.to_lowercase().as_str() { - "year" => extract_date_part!(&array, temporal::year), - "quarter" => extract_date_part!(&array, temporal::quarter), - "month" => extract_date_part!(&array, temporal::month), - "week" => extract_date_part!(&array, temporal::week), - "day" => extract_date_part!(&array, temporal::day), - "doy" => extract_date_part!(&array, temporal::doy), - "dow" => extract_date_part!(&array, temporal::num_days_from_sunday), - "hour" => extract_date_part!(&array, temporal::hour), - "minute" => extract_date_part!(&array, temporal::minute), - "second" => extract_date_part!(&array, seconds), - "millisecond" => extract_date_part!(&array, millis), - "microsecond" => extract_date_part!(&array, micros), - "nanosecond" => extract_date_part!(&array, nanos), - "epoch" => extract_date_part!(&array, epoch), - _ => exec_err!("Date part '{date_part}' not supported"), - }?; + let arr = match part.to_lowercase().as_str() { + "year" => date_part_f64(array.as_ref(), DatePart::Year)?, + "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, + "month" => date_part_f64(array.as_ref(), DatePart::Month)?, + "week" => date_part_f64(array.as_ref(), DatePart::Week)?, + "day" => date_part_f64(array.as_ref(), DatePart::Day)?, + "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, + "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "hour" => date_part_f64(array.as_ref(), DatePart::Hour)?, + "minute" => date_part_f64(array.as_ref(), DatePart::Minute)?, + "second" => seconds(array.as_ref(), Second)?, + "millisecond" => seconds(array.as_ref(), Millisecond)?, + "microsecond" => seconds(array.as_ref(), Microsecond)?, + "nanosecond" => seconds(array.as_ref(), Nanosecond)?, + "epoch" => epoch(array.as_ref())?, + _ => return exec_err!("Date part '{part}' not supported"), + }; Ok(if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(&arr?, 0)?) + ColumnarValue::Scalar(ScalarValue::try_from_array(arr.as_ref(), 0)?) } else { - ColumnarValue::Array(arr?) + ColumnarValue::Array(arr) }) } @@ -187,83 +151,60 @@ impl ScalarUDFImpl for DatePartFunc { } } -fn to_ticks(array: &PrimitiveArray, frac: i32) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let zipped = temporal::second(array)? - .values() - .iter() - .zip(temporal::nanosecond(array)?.values().iter()) - .map(|o| (*o.0 as f64 + (*o.1 as f64) / 1_000_000_000.0) * (frac as f64)) - .collect::>(); - - Ok(Float64Array::from(zipped)) +/// Invoke [`date_part`] and cast the result to Float64 +fn date_part_f64(array: &dyn Array, part: DatePart) -> Result { + Ok(cast(date_part(array, part)?.as_ref(), &Float64)?) } -fn seconds(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1) -} - -fn millis(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000) -} - -fn micros(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000_000) +/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// result to a total number of seconds, milliseconds, microseconds or +/// nanoseconds +fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { + let sf = match unit { + Second => 1_f64, + Millisecond => 1_000_f64, + Microsecond => 1_000_000_f64, + Nanosecond => 1_000_000_000_f64, + }; + let secs = date_part(array, DatePart::Second)?; + // This assumes array is primitive and not a dictionary + let secs = as_int32_array(secs.as_ref())?; + let subsecs = date_part(array, DatePart::Nanosecond)?; + let subsecs = as_int32_array(subsecs.as_ref())?; + + let r: Float64Array = binary(secs, subsecs, |secs, subsecs| { + (secs as f64 + (subsecs as f64 / 1_000_000_000_f64)) * sf + })?; + Ok(Arc::new(r)) } -fn nanos(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000_000_000) -} +fn epoch(array: &dyn Array) -> Result { + const SECONDS_IN_A_DAY: f64 = 86400_f64; -fn epoch(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let b = match array.data_type() { - Timestamp(tu, _) => { - let scale = match tu { - Second => 1, - Millisecond => 1_000, - Microsecond => 1_000_000, - Nanosecond => 1_000_000_000, - } as f64; - array.unary(|n| { - let n: i64 = n.into(); - n as f64 / scale - }) + let f: Float64Array = match array.data_type() { + Timestamp(Second, _) => as_timestamp_second_array(array)?.unary(|x| x as f64), + Timestamp(Millisecond, _) => { + as_timestamp_millisecond_array(array)?.unary(|x| x as f64 / 1_000_f64) + } + Timestamp(Microsecond, _) => { + as_timestamp_microsecond_array(array)?.unary(|x| x as f64 / 1_000_000_f64) + } + Timestamp(Nanosecond, _) => { + as_timestamp_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) + } + Date32 => as_date32_array(array)?.unary(|x| x as f64 * SECONDS_IN_A_DAY), + Date64 => as_date64_array(array)?.unary(|x| x as f64 / 1_000_f64), + Time32(Second) => as_time32_second_array(array)?.unary(|x| x as f64), + Time32(Millisecond) => { + as_time32_millisecond_array(array)?.unary(|x| x as f64 / 1_000_f64) + } + Time64(Microsecond) => { + as_time64_microsecond_array(array)?.unary(|x| x as f64 / 1_000_000_f64) } - Date32 => { - let seconds_in_a_day = 86400_f64; - array.unary(|n| { - let n: i64 = n.into(); - n as f64 * seconds_in_a_day - }) + Time64(Nanosecond) => { + as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - Date64 => array.unary(|n| { - let n: i64 = n.into(); - n as f64 / 1_000_f64 - }), - _ => return exec_err!("Can not convert {:?} to epoch", array.data_type()), + d => return exec_err!("Can not convert {d:?} to epoch"), }; - Ok(b) + Ok(Arc::new(f)) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 3ca098b1f99b..ef5c45a5ad9c 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::sync::Arc; use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Time32, Time64, Timestamp, Utf8, @@ -109,7 +109,6 @@ impl ScalarUDFImpl for ToCharFunc { } match &args[1] { - // null format, use default formats ColumnarValue::Scalar(ScalarValue::Utf8(None)) | ColumnarValue::Scalar(ScalarValue::Null) => { _to_char_scalar(args[0].clone(), None) @@ -175,6 +174,18 @@ fn _to_char_scalar( let data_type = &expression.data_type(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); let array = expression.into_array(1)?; + + if format.is_none() { + if is_scalar_expression { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } else { + return Ok(ColumnarValue::Array(new_null_array( + &DataType::Utf8, + array.len(), + ))); + } + } + let format_options = match _build_format_options(data_type, format) { Ok(value) => value, Err(value) => return value, @@ -202,7 +213,7 @@ fn _to_char_scalar( fn _to_char_array(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; - let mut results: Vec = vec![]; + let mut results: Vec> = vec![]; let format_array = arrays[1].as_string::(); let data_type = arrays[0].data_type(); @@ -212,6 +223,10 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { } else { Some(format_array.value(idx)) }; + if format.is_none() { + results.push(None); + continue; + } let format_options = match _build_format_options(data_type, format) { Ok(value) => value, Err(value) => return value, @@ -221,7 +236,7 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { let formatter = ArrayFormatter::try_new(arrays[0].as_ref(), &format_options)?; let result = formatter.value(idx).try_to_string(); match result { - Ok(value) => results.push(value), + Ok(value) => results.push(Some(value)), Err(e) => return exec_err!("{}", e), } } @@ -230,9 +245,12 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { ColumnarValue::Array(_) => Ok(ColumnarValue::Array(Arc::new(StringArray::from( results, )) as ArrayRef)), - ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - results.first().unwrap().to_string(), - )))), + ColumnarValue::Scalar(_) => match results.first().unwrap() { + Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + value.to_string(), + )))), + None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }, } } diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 3a2eab8e5f05..2a00839dc532 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -84,6 +84,10 @@ use log::debug; #[macro_use] pub mod macros; +#[cfg(feature = "string_expressions")] +pub mod string; +make_stub_package!(string, "string_expressions"); + /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] @@ -120,6 +124,12 @@ make_stub_package!(regex, "regex_expressions"); pub mod crypto; make_stub_package!(crypto, "crypto_expressions"); +#[cfg(feature = "unicode_expressions")] +pub mod unicode; +make_stub_package!(unicode, "unicode_expressions"); + +mod utils; + /// Fluent-style API for creating `Expr`s pub mod expr_fn { #[cfg(feature = "core_expressions")] @@ -134,6 +144,10 @@ pub mod expr_fn { pub use super::math::expr_fn::*; #[cfg(feature = "regex_expressions")] pub use super::regex::expr_fn::*; + #[cfg(feature = "string_expressions")] + pub use super::string::expr_fn::*; + #[cfg(feature = "unicode_expressions")] + pub use super::unicode::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -144,7 +158,9 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(encoding::functions()) .chain(math::functions()) .chain(regex::functions()) - .chain(crypto::functions()); + .chain(crypto::functions()) + .chain(unicode::functions()) + .chain(string::functions()); all_functions.try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index e735523df621..4907d74fe941 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -156,15 +156,18 @@ macro_rules! downcast_arg { /// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument +/// $MONOTONIC_FUNC: the monotonicity of the function macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, + }; use std::any::Any; use std::sync::Arc; @@ -208,6 +211,10 @@ macro_rules! make_math_unary_udf { } } + fn monotonicity(&self) -> Result> { + Ok($MONOTONICITY) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; @@ -243,3 +250,31 @@ macro_rules! make_math_unary_udf { } }; } + +#[macro_export] +macro_rules! make_function_inputs2 { + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE1>() + }}; +} diff --git a/datafusion/functions/src/math/atan2.rs b/datafusion/functions/src/math/atan2.rs new file mode 100644 index 000000000000..b090c6c454fd --- /dev/null +++ b/datafusion/functions/src/math/atan2.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `atan2()`. + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use crate::make_function_inputs2; +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub(super) struct Atan2 { + signature: Signature, +} + +impl Atan2 { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for Atan2 { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "atan2" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use self::DataType::*; + match &arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(atan2, vec![])(args) + } +} + +/// Atan2 SQL function +pub fn atan2(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::atan2 } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::atan2 } + )) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function atan2"), + } +} + +#[cfg(test)] +mod test { + use super::*; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + + #[test] + fn test_atan2_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("failed to initialize function atan2"); + let floats = + as_float64_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); + } + + #[test] + fn test_atan2_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("failed to initialize function atan2"); + let floats = + as_float32_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 27deb7d68427..2ee1fffa1625 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,16 +18,27 @@ //! "math" DataFusion functions mod abs; +mod atan2; mod nans; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(atan2::Atan2, ATAN2, atan2); -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh); -make_math_unary_udf!(AcosFunc, ACOS, acos, acos); -make_math_unary_udf!(AsinFunc, ASIN, asin, asin); -make_math_unary_udf!(TanFunc, TAN, tan, tan); +make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); +make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); + +make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); +make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); +make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); +make_math_unary_udf!(TanFunc, TAN, tan, tan, None); + +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); +make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); +make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( @@ -37,6 +48,9 @@ export_functions!( "returns true if a given number is +NaN or -NaN otherwise returns false" ), (abs, num, "returns the absolute value of a given number"), + (log2, num, "base 2 logarithm of a number"), + (log10, num, "base 10 logarithm of a number"), + (ln, num, "natural logarithm (base e) of a number"), ( acos, num, @@ -48,5 +62,10 @@ export_functions!( "returns the arc sine or inverse sine of a number" ), (tan, num, "returns the tangent of a number"), - (tanh, num, "returns the hyperbolic tangent of a number") + (tanh, num, "returns the hyperbolic tangent of a number"), + (atanh, num, "returns inverse hyperbolic tangent"), + (asinh, num, "returns inverse hyperbolic sine"), + (acosh, num, "returns inverse hyperbolic cosine"), + (atan, num, "returns inverse tangent"), + (atan2, y x, "returns inverse tangent of a division given in the argument") ); diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs new file mode 100644 index 000000000000..9a07f4c19cf1 --- /dev/null +++ b/datafusion/functions/src/string/ascii.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::make_scalar_function; +use arrow::array::Int32Array; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// Returns the numeric code of the first character of the argument. +/// ascii('x') = 120 +pub fn ascii(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + let mut chars = string.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct AsciiFunc { + signature: Signature, +} +impl AsciiFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for AsciiFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ascii" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(Int32) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(ascii::, vec![])(args), + DataType::LargeUtf8 => { + return make_scalar_function(ascii::, vec![])(args); + } + _ => internal_err!("Unsupported data type"), + } + } +} diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs new file mode 100644 index 000000000000..6a200471d42d --- /dev/null +++ b/datafusion/functions/src/string/bit_length.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::compute::kernels::length::bit_length; +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::utf8_to_int_type; + +#[derive(Debug)] +pub(super) struct BitLengthFunc { + signature: Signature, +} + +impl BitLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bit_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "bit_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "bit_length function requires 1 argument, got {}", + args.len() + ); + } + + match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), + )), + _ => unreachable!(), + }, + } + } +} diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs new file mode 100644 index 000000000000..573a23d07021 --- /dev/null +++ b/datafusion/functions/src/string/btrim.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +#[derive(Debug)] +pub(super) struct BTrimFunc { + signature: Signature, + aliases: Vec, +} + +impl BTrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + aliases: vec![String::from("trim")], + } + } +} + +impl ScalarUDFImpl for BTrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "btrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "btrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(btrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(btrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function btrim"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs new file mode 100644 index 000000000000..d1f8dc398a2b --- /dev/null +++ b/datafusion/functions/src/string/chr.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::array::StringArray; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; +use arrow::datatypes::DataType::Utf8; + +use datafusion_common::cast::as_int64_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::make_scalar_function; + +/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. +/// chr(65) = 'A' +pub fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer == 0 { + exec_err!("null character not permitted.") + } else { + match core::char::from_u32(integer as u32) { + Some(integer) => Ok(integer.to_string()), + None => { + exec_err!("requested character too large for encoding.") + } + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct ChrFunc { + signature: Signature, +} + +impl ChrFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ChrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "chr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(chr, vec![])(args) + } +} diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs new file mode 100644 index 000000000000..276aad121df2 --- /dev/null +++ b/datafusion/functions/src/string/common.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_expr::ColumnarValue; + +pub(crate) enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +pub(crate) fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + + match args.len() { + 1 => { + let result = string_array + .iter() + .map(|string| string.map(|string: &str| func(string, " "))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let characters_array = as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (Some(string), Some(characters)) => Some(func(string, characters)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } + } +} + +/// applies a unary expression to `args[0]` that is expected to be downcastable to +/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` +pub(crate) fn unary_string_function<'a, T, O, F, R>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef, + O: OffsetSizeTrait, + T: OffsetSizeTrait, + F: Fn(&'a str) -> R, +{ + if args.len() != 1 { + return exec_err!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name + ); + } + + let string_array = as_generic_string_array::(args[0])?; + + // first map is the iterator, second is for the `Option<_>` + Ok(string_array.iter().map(|string| string.map(&op)).collect()) +} + +pub(crate) fn handle<'a, F, R>( + args: &'a [ColumnarValue], + op: F, + name: &str, +) -> Result +where + R: AsRef, + F: Fn(&'a str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i32, + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i64, + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs new file mode 100644 index 000000000000..8f497e73e393 --- /dev/null +++ b/datafusion/functions/src/string/levenshtein.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct LevenshteinFunc { + signature: Signature, +} + +impl LevenshteinFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LevenshteinFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "levenshtein") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(levenshtein::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(levenshtein::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function levenshtein") + } + } + } +} + +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!( + "levenshtein function requires two arguments, got {}", + args.len() + ); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int32Array, StringArray}; + + use datafusion_common::cast::as_int32_array; + + use super::*; + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs new file mode 100644 index 000000000000..327772bd808d --- /dev/null +++ b/datafusion/functions/src/string/lower.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; + +#[derive(Debug)] +pub(super) struct LowerFunc { + signature: Signature, +} + +impl LowerFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LowerFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lower" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "lower") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + handle(args, |string| string.to_lowercase(), "lower") + } +} diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs new file mode 100644 index 000000000000..e6926e5bd56e --- /dev/null +++ b/datafusion/functions/src/string/ltrim.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +#[derive(Debug)] +pub(super) struct LtrimFunc { + signature: Signature, +} + +impl LtrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LtrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ltrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "ltrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(ltrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(ltrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function ltrim"), + } + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs new file mode 100644 index 000000000000..81639c45f7ff --- /dev/null +++ b/datafusion/functions/src/string/mod.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! "string" DataFusion functions + +use std::sync::Arc; + +use datafusion_expr::ScalarUDF; + +mod ascii; +mod bit_length; +mod btrim; +mod chr; +mod common; +mod levenshtein; +mod lower; +mod ltrim; +mod octet_length; +mod overlay; +mod repeat; +mod replace; +mod rtrim; +mod split_part; +mod starts_with; +mod to_hex; +mod upper; +mod uuid; + +// create UDFs +make_udf_function!(ascii::AsciiFunc, ASCII, ascii); +make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); +make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); +make_udf_function!(chr::ChrFunc, CHR, chr); +make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); +make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); +make_udf_function!(lower::LowerFunc, LOWER, lower); +make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); +make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); +make_udf_function!(repeat::RepeatFunc, REPEAT, repeat); +make_udf_function!(replace::ReplaceFunc, REPLACE, replace); +make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); +make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); +make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); +make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); +make_udf_function!(upper::UpperFunc, UPPER, upper); +make_udf_function!(uuid::UuidFunc, UUID, uuid); + +pub mod expr_fn { + use datafusion_expr::Expr; + + #[doc = "Returns the numeric code of the first character of the argument."] + pub fn ascii(arg1: Expr) -> Expr { + super::ascii().call(vec![arg1]) + } + + #[doc = "Returns the number of bits in the `string`"] + pub fn bit_length(arg: Expr) -> Expr { + super::bit_length().call(vec![arg]) + } + + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn btrim(args: Vec) -> Expr { + super::btrim().call(args) + } + + #[doc = "Converts the Unicode code point to a UTF8 character"] + pub fn chr(arg: Expr) -> Expr { + super::chr().call(vec![arg]) + } + + #[doc = "Returns the Levenshtein distance between the two given strings"] + pub fn levenshtein(arg1: Expr, arg2: Expr) -> Expr { + super::levenshtein().call(vec![arg1, arg2]) + } + + #[doc = "Converts a string to lowercase."] + pub fn lower(arg1: Expr) -> Expr { + super::lower().call(vec![arg1]) + } + + #[doc = "Removes all characters, spaces by default, from the beginning of a string"] + pub fn ltrim(args: Vec) -> Expr { + super::ltrim().call(args) + } + + #[doc = "returns the number of bytes of a string"] + pub fn octet_length(args: Vec) -> Expr { + super::octet_length().call(args) + } + + #[doc = "replace the substring of string that starts at the start'th character and extends for count characters with new substring"] + pub fn overlay(args: Vec) -> Expr { + super::overlay().call(args) + } + + #[doc = "Repeats the `string` to `n` times"] + pub fn repeat(string: Expr, n: Expr) -> Expr { + super::repeat().call(vec![string, n]) + } + + #[doc = "Replaces all occurrences of `from` with `to` in the `string`"] + pub fn replace(string: Expr, from: Expr, to: Expr) -> Expr { + super::replace().call(vec![string, from, to]) + } + + #[doc = "Removes all characters, spaces by default, from the end of a string"] + pub fn rtrim(args: Vec) -> Expr { + super::rtrim().call(args) + } + + #[doc = "Splits a string based on a delimiter and picks out the desired field based on the index."] + pub fn split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr { + super::split_part().call(vec![string, delimiter, index]) + } + + #[doc = "Returns true if string starts with prefix."] + pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr { + super::starts_with().call(vec![arg1, arg2]) + } + + #[doc = "Converts an integer to a hexadecimal string."] + pub fn to_hex(arg1: Expr) -> Expr { + super::to_hex().call(vec![arg1]) + } + + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn trim(args: Vec) -> Expr { + super::btrim().call(args) + } + + #[doc = "Converts a string to uppercase."] + pub fn upper(arg1: Expr) -> Expr { + super::upper().call(vec![arg1]) + } + + #[doc = "returns uuid v4 as a string value"] + pub fn uuid() -> Expr { + super::uuid().call(vec![]) + } +} + +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![ + ascii(), + bit_length(), + btrim(), + chr(), + levenshtein(), + lower(), + ltrim(), + octet_length(), + overlay(), + repeat(), + replace(), + rtrim(), + split_part(), + starts_with(), + to_hex(), + upper(), + uuid(), + ] +} diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs new file mode 100644 index 000000000000..639bf6cb48a9 --- /dev/null +++ b/datafusion/functions/src/string/octet_length.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::compute::kernels::length::length; +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::utf8_to_int_type; + +#[derive(Debug)] +pub(super) struct OctetLengthFunc { + signature: Signature, +} + +impl OctetLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for OctetLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "octet_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "octet_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "octet_length function requires 1 argument, got {}", + args.len() + ); + } + + match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| x.len() as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), + )), + _ => unreachable!(), + }, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{Array, Int32Array, StringArray}; + use arrow::datatypes::DataType::Int32; + + use datafusion_common::ScalarValue; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::octet_length::OctetLengthFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + exec_err!( + "The OCTET_LENGTH function can only accept strings, but got Int32." + ), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Array(Arc::new(StringArray::from(vec![ + String::from("chars"), + String::from("chars2"), + ])))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) + ], + exec_err!("octet_length function requires 1 argument, got 2"), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("chars") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs new file mode 100644 index 000000000000..8b9cc03afc4d --- /dev/null +++ b/datafusion/functions/src/string/overlay.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct OverlayFunc { + signature: Signature, +} + +impl OverlayFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for OverlayFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "overlay" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "overlay") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(overlay::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function overlay"), + } + } +} + +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int64Array, StringArray}; + + use super::*; + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs new file mode 100644 index 000000000000..f4319af0a5c4 --- /dev/null +++ b/datafusion/functions/src/string/repeat.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct RepeatFunc { + signature: Signature, +} + +impl RepeatFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RepeatFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "repeat") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(repeat::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(repeat::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function repeat"), + } + } +} + +/// Repeats string the specified number of times. +/// repeat('Pg', 4) = 'PgPgPgPg' +fn repeat(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let number_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(number_array.iter()) + .map(|(string, number)| match (string, number) { + (Some(string), Some(number)) => Some(string.repeat(number as usize)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::repeat::RepeatFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs new file mode 100644 index 000000000000..e869ac205440 --- /dev/null +++ b/datafusion/functions/src/string/replace.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct ReplaceFunc { + signature: Signature, +} + +impl ReplaceFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ReplaceFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "replace" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "replace") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(replace::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function replace") + } + } + } +} + +/// Replaces all occurrences in string of substring from with substring to. +/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' +fn replace(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let from_array = as_generic_string_array::(&args[1])?; + let to_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +mod test {} diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs new file mode 100644 index 000000000000..d04d15ce8847 --- /dev/null +++ b/datafusion/functions/src/string/rtrim.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + +#[derive(Debug)] +pub(super) struct RtrimFunc { + signature: Signature, +} + +impl RtrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RtrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rtrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rtrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rtrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rtrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rtrim"), + } + } +} diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs new file mode 100644 index 000000000000..0aa968a1ef5b --- /dev/null +++ b/datafusion/functions/src/string/split_part.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SplitPartFunc { + signature: Signature, +} + +impl SplitPartFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8, Int64]), + Exact(vec![Utf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SplitPartFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "split_part" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "split_part") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(split_part::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(split_part::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function split_part") + } + } + } +} + +/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). +/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' +fn split_part(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let n_array = as_int64_array(&args[2])?; + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(n_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + if n <= 0 { + exec_err!("field position must be greater than zero") + } else { + let split_string: Vec<&str> = string.split(delimiter).collect(); + match split_string.get(n as usize - 1) { + Some(s) => Ok(Some(*s)), + None => Ok(Some("")), + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::ScalarValue; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::split_part::SplitPartFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("def")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + exec_err!("field position must be greater than zero"), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs new file mode 100644 index 000000000000..f1b03907f8d8 --- /dev/null +++ b/datafusion/functions/src/string/starts_with.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +/// Returns true if string starts with prefix. +/// starts_with('alphabet', 'alph') = 't' +pub fn starts_with(args: &[ArrayRef]) -> Result { + let left = as_generic_string_array::(&args[0])?; + let right = as_generic_string_array::(&args[1])?; + + let result = arrow::compute::kernels::comparison::starts_with(left, right)?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct StartsWithFunc { + signature: Signature, +} +impl StartsWithFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StartsWithFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "starts_with" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(starts_with::, vec![])(args), + DataType::LargeUtf8 => { + return make_scalar_function(starts_with::, vec![])(args); + } + _ => internal_err!("Unsupported data type"), + } + } +} diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs new file mode 100644 index 000000000000..ab320c68d493 --- /dev/null +++ b/datafusion/functions/src/string/to_hex.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, +}; + +use datafusion_common::cast::as_primitive_array; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +/// Converts the number to its equivalent hexadecimal representation. +/// to_hex(2147483647) = '7fffffff' +pub fn to_hex(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let integer_array = as_primitive_array::(&args[0])?; + + let result = integer_array + .iter() + .map(|integer| { + if let Some(value) = integer { + if let Some(value_usize) = value.to_usize() { + Ok(Some(format!("{value_usize:x}"))) + } else if let Some(value_isize) = value.to_isize() { + Ok(Some(format!("{value_isize:x}"))) + } else { + exec_err!("Unsupported data type {integer:?} for function to_hex") + } + } else { + Ok(None) + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct ToHexFunc { + signature: Signature, +} +impl ToHexFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ToHexFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_hex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Int8 | Int16 | Int32 | Int64 => Utf8, + _ => { + return plan_err!("The to_hex function can only accept integers."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Int32 => make_scalar_function(to_hex::, vec![])(args), + DataType::Int64 => make_scalar_function(to_hex::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function to_hex"), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::Int32Type, + }; + + use datafusion_common::cast::as_string_array; + + use super::*; + + #[test] + // Test to_hex function for zero + fn to_hex_zero() -> Result<()> { + let array = vec![0].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("0")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } + + #[test] + // Test to_hex function for positive number + fn to_hex_positive_number() -> Result<()> { + let array = vec![100].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("64")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } + + #[test] + // Test to_hex function for negative number + fn to_hex_negative_number() -> Result<()> { + let array = vec![-1].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("ffffffffffffffff")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs new file mode 100644 index 000000000000..066174abf277 --- /dev/null +++ b/datafusion/functions/src/string/upper.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +#[derive(Debug)] +pub(super) struct UpperFunc { + signature: Signature, +} + +impl UpperFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for UpperFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "upper") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + handle(args, |string| string.to_uppercase(), "upper") + } +} diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs new file mode 100644 index 000000000000..791ad6d3c4f3 --- /dev/null +++ b/datafusion/functions/src/string/uuid.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::iter; +use std::sync::Arc; + +use arrow::array::GenericStringArray; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Utf8; +use uuid::Uuid; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub(super) struct UuidFunc { + signature: Signature, +} + +impl UuidFunc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for UuidFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "uuid" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + /// Prints random (v4) uuid values per row + /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len: usize = match &args[0] { + ColumnarValue::Array(array) => array.len(), + _ => return exec_err!("Expect uuid function to take no param"), + }; + + let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len); + let array = GenericStringArray::::from_iter_values(values); + Ok(ColumnarValue::Array(Arc::new(array))) + } +} diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs new file mode 100644 index 000000000000..51331bf9a586 --- /dev/null +++ b/datafusion/functions/src/unicode/character_length.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{make_scalar_function, utf8_to_int_type}; +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub(super) struct CharacterLengthFunc { + signature: Signature, + aliases: Vec, +} + +impl CharacterLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + aliases: vec![String::from("length"), String::from("char_length")], + } + } +} + +impl ScalarUDFImpl for CharacterLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "character_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "character_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(character_length::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(character_length::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function character_length") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns number of characters in the string. +/// character_length('josé') = 4 +/// The implementation counts UTF-8 code points to count the number of characters +fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + T::Native::from_usize(string.chars().count()) + .expect("should not fail as string.chars will always return integer") + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::character_length::CharacterLengthFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("chars") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé"))))], + internal_err!( + "function character_length requires compilation with feature flag: unicode_expressions." + ), + i32, + Int32, + Int32Array + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs new file mode 100644 index 000000000000..7e0306d49454 --- /dev/null +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct FindInSetFunc { + signature: Signature, +} + +impl FindInSetFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for FindInSetFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "find_in_set" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "find_in_set") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(find_in_set::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(find_in_set::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + } + } +} + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return exec_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs new file mode 100644 index 000000000000..473589fdc8aa --- /dev/null +++ b/datafusion/functions/src/unicode/left.rs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct LeftFunc { + signature: Signature, +} + +impl LeftFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LeftFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "left" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "left") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(left::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function left"), + } + } +} + +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +/// The implementation uses UTF-8 code points as characters +pub fn left(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => { + let len = string.chars().count() as i64; + Some(if n.abs() < len { + string.chars().take((len + n) as usize).collect::() + } else { + "".to_string() + }) + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => { + Some(string.chars().take(n as usize).collect::()) + } + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::left::LeftFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ab")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + internal_err!( + "function left requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs new file mode 100644 index 000000000000..76a8e68cca25 --- /dev/null +++ b/datafusion/functions/src/unicode/lpad.rs @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct LPadFunc { + signature: Signature, +} + +impl LPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "lpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function lpad"), + } + } +} + +/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars.get(l % fill_chars.len()).unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "lpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some(" josé")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("xyxhi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(21i64)), + ColumnarValue::Scalar(ScalarValue::from("abcdef")), + ], + Ok(Some("abcdefabcdefabcdefahi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(" ")), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("")), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("xyxyxyjosé")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("éñ")), + ], + Ok(Some("éñéñéñjosé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + internal_err!( + "function lpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs new file mode 100644 index 000000000000..eba4cd5048eb --- /dev/null +++ b/datafusion/functions/src/unicode/mod.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! "unicode" DataFusion functions + +use std::sync::Arc; + +use datafusion_expr::ScalarUDF; + +mod character_length; +mod find_in_set; +mod left; +mod lpad; +mod reverse; +mod right; +mod rpad; +mod strpos; +mod substr; +mod substrindex; +mod translate; + +// create UDFs +make_udf_function!( + character_length::CharacterLengthFunc, + CHARACTER_LENGTH, + character_length +); +make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); +make_udf_function!(left::LeftFunc, LEFT, left); +make_udf_function!(lpad::LPadFunc, LPAD, lpad); +make_udf_function!(right::RightFunc, RIGHT, right); +make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); +make_udf_function!(rpad::RPadFunc, RPAD, rpad); +make_udf_function!(strpos::StrposFunc, STRPOS, strpos); +make_udf_function!(substr::SubstrFunc, SUBSTR, substr); +make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); +make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); + +pub mod expr_fn { + use datafusion_expr::Expr; + + #[doc = "the number of characters in the `string`"] + pub fn char_length(string: Expr) -> Expr { + character_length(string) + } + + #[doc = "the number of characters in the `string`"] + pub fn character_length(string: Expr) -> Expr { + super::character_length().call(vec![string]) + } + + #[doc = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"] + pub fn find_in_set(string: Expr, strlist: Expr) -> Expr { + super::find_in_set().call(vec![string, strlist]) + } + + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn instr(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + + #[doc = "the number of characters in the `string`"] + pub fn length(string: Expr) -> Expr { + character_length(string) + } + + #[doc = "returns the first `n` characters in the `string`"] + pub fn left(string: Expr, n: Expr) -> Expr { + super::left().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by prepending the characters"] + pub fn lpad(args: Vec) -> Expr { + super::lpad().call(args) + } + + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn position(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + + #[doc = "reverses the `string`"] + pub fn reverse(string: Expr) -> Expr { + super::reverse().call(vec![string]) + } + + #[doc = "returns the last `n` characters in the `string`"] + pub fn right(string: Expr, n: Expr) -> Expr { + super::right().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by appending the characters"] + pub fn rpad(args: Vec) -> Expr { + super::rpad().call(args) + } + + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn strpos(string: Expr, substring: Expr) -> Expr { + super::strpos().call(vec![string, substring]) + } + + #[doc = "substring from the `position` to the end"] + pub fn substr(string: Expr, position: Expr) -> Expr { + super::substr().call(vec![string, position]) + } + + #[doc = "substring from the `position` with `length` characters"] + pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { + super::substr().call(vec![string, position, length]) + } + + #[doc = "Returns the substring from str before count occurrences of the delimiter"] + pub fn substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr { + super::substr_index().call(vec![string, delimiter, count]) + } + + #[doc = "replaces the characters in `from` with the counterpart in `to`"] + pub fn translate(string: Expr, from: Expr, to: Expr) -> Expr { + super::translate().call(vec![string, from, to]) + } +} + +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![ + character_length(), + find_in_set(), + left(), + lpad(), + reverse(), + right(), + rpad(), + strpos(), + substr(), + substr_index(), + translate(), + ] +} diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs new file mode 100644 index 000000000000..42ca6e0d17c3 --- /dev/null +++ b/datafusion/functions/src/unicode/reverse.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct ReverseFunc { + signature: Signature, +} + +impl ReverseFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ReverseFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "reverse") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(reverse::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function reverse") + } + } + } +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +/// The implementation uses UTF-8 code points as characters +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| string.map(|string: &str| string.chars().rev().collect::())) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::reverse::ReverseFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("abcde"))], + Ok(Some("edcba")), + &str, + Utf8, + StringArray + ); + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("loẅks"))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("loẅks"))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("abcde"))], + internal_err!( + "function reverse requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs new file mode 100644 index 000000000000..d1bd976342b2 --- /dev/null +++ b/datafusion/functions/src/unicode/right.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::{max, Ordering}; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct RightFunc { + signature: Signature, +} + +impl RightFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RightFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "right" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "right") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(right::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function right"), + } + } +} + +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +/// The implementation uses UTF-8 code points as characters +pub fn right(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => Some( + string + .chars() + .skip(n.unsigned_abs() as usize) + .collect::(), + ), + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some( + string + .chars() + .skip(max(string.chars().count() as i64 - n, 0) as usize) + .collect::(), + ), + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::right::RightFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("de")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("cde")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + internal_err!( + "function right requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs new file mode 100644 index 000000000000..070278c90b2f --- /dev/null +++ b/datafusion/functions/src/unicode/rpad.rs @@ -0,0 +1,361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct RPadFunc { + signature: Signature, +} + +impl RPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rpad"), + } + } +} + +/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + } + s.push_str(char_vector.iter().collect::().as_str()); + Ok(Some(s)) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "rpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::rpad::RPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("josé ")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("hixyx")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(21i64)), + ColumnarValue::Scalar(ScalarValue::from("abcdef")), + ], + Ok(Some("hiabcdefabcdefabcdefa")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(" ")), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("")), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("joséxyxyxy")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("éñ")), + ], + Ok(Some("josééñéñéñ")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + internal_err!( + "function rpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs new file mode 100644 index 000000000000..1e8bfa37d40e --- /dev/null +++ b/datafusion/functions/src/unicode/strpos.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct StrposFunc { + signature: Signature, + aliases: Vec, +} + +impl StrposFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("instr"), String::from("position")], + } + } +} + +impl ScalarUDFImpl for StrposFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "strpos" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "strpos/instr/position") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(strpos::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(strpos::, vec![])(args) + } + other => exec_err!("Unsupported data type {other:?} for function strpos"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) +/// strpos('high', 'ig') = 2 +/// The implementation uses UTF-8 code points as characters +fn strpos(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let substring_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(substring_array.iter()) + .map(|(string, substring)| match (string, substring) { + (Some(string), Some(substring)) => { + // the find method returns the byte index of the substring + // Next, we count the number of the chars until that byte + T::Native::from_usize( + string + .find(substring) + .map(|x| string[..x].chars().count() + 1) + .unwrap_or(0), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs new file mode 100644 index 000000000000..403157e2a85a --- /dev/null +++ b/datafusion/functions/src/unicode/substr.rs @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::max; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrFunc { + signature: Signature, +} + +impl SubstrFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SubstrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(substr::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function substr"), + } + } +} + +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +/// The implementation uses UTF-8 code points as characters +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .map(|(string, start)| match (string, start) { + (Some(string), Some(start)) => { + if start <= 0 { + Some(string.to_string()) + } else { + Some(string.chars().skip(start as usize - 1).collect()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .zip(count_array.iter()) + .map(|((string, start), count)| match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + exec_err!( + "negative substring length not allowed: substr(, {start}, {count})" + ) + } else { + let skip = max(0, start - 1); + let count = max(0, count + (if start < 1 {start - 1} else {0})); + Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("substr was called with {other} arguments. It requires 2 or 3.") + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substr::SubstrFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(30i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from 5 (10 + -5) + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from -1 (4 + -5) + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(4i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // starting from 0 (5 + -5) + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + exec_err!("negative substring length not allowed: substr(, 1, -1)"), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + internal_err!( + "function substr requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs new file mode 100644 index 000000000000..77e8116fff4c --- /dev/null +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrIndexFunc { + signature: Signature, + aliases: Vec, +} + +impl SubstrIndexFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("substring_index")], + } + } +} + +impl ScalarUDFImpl for SubstrIndexFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr_index" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr_index") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr_index::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(substr_index::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function substr_index") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + // In MySQL, these cases will return an empty string. + if n == 0 || string.is_empty() || delimiter.is_empty() { + return Some(String::new()); + } + + let splitted: Box> = if n > 0 { + Box::new(string.split(delimiter)) + } else { + Box::new(string.rsplit(delimiter)) + }; + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); + // The length of the substring covered by substr_index. + let length = splitted + .take(occurrences) // at least 1 element, since n != 0 + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len(); + if n > 0 { + Some(string[..length].to_owned()) + } else { + Some(string[string.len() - length..].to_owned()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs new file mode 100644 index 000000000000..bc1836700304 --- /dev/null +++ b/datafusion/functions/src/unicode/translate.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use hashbrown::HashMap; +use unicode_segmentation::UnicodeSegmentation; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct TranslateFunc { + signature: Signature, +} + +impl TranslateFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TranslateFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "translate" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "translate") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(translate::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(translate::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function translate") + } + } + } +} + +/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. +/// translate('12345', '143', 'ax') = 'a2x5' +fn translate(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let from_array = as_generic_string_array::(&args[1])?; + let to_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => { + // create a hashmap of [char, index] to change from O(n) to O(1) for from list + let from_map: HashMap<&str, usize> = from + .graphemes(true) + .collect::>() + .iter() + .enumerate() + .map(|(index, c)| (c.to_owned(), index)) + .collect(); + + let to = to.graphemes(true).collect::>(); + + Some( + string + .graphemes(true) + .collect::>() + .iter() + .flat_map(|c| match from_map.get(*c) { + Some(n) => to.get(*n).copied(), + None => Some(*c), + }) + .collect::>() + .concat(), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::translate::TranslateFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(Some("a2x5")), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), + ColumnarValue::Scalar(ScalarValue::from("éñí")), + ColumnarValue::Scalar(ScalarValue::from("óü")), + ], + Ok(Some("ó2ü5")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")), + ], + internal_err!( + "function translate requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs new file mode 100644 index 000000000000..9b7144b483bd --- /dev/null +++ b/datafusion/functions/src/utils.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; +use std::sync::Arc; + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. +get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +/// Creates a scalar function implementation for the given function. +/// * `inner` - the function to be executed +/// * `hints` - hints to be used when expanding scalars to arrays +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: Result> = $EXPECTED; + let func = $FUNC; + + let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let return_type = func.return_type(&type_array); + + match expected { + Ok(expected) => { + assert_eq!(return_type.is_ok(), true); + assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + + let result = func.invoke($ARGS); + assert_eq!(result.is_ok(), true); + + let len = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let inferred_length = len.unwrap_or(1); + let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_type.is_err() { + match return_type { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke($ARGS) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_function; +} diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 861715b351a6..1d64a22f1463 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -34,9 +34,8 @@ path = "src/lib.rs" [features] crypto_expressions = ["datafusion-physical-expr/crypto_expressions"] -default = ["unicode_expressions", "crypto_expressions", "regex_expressions"] +default = ["crypto_expressions", "regex_expressions"] regex_expressions = ["datafusion-physical-expr/regex_expressions"] -unicode_expressions = ["datafusion-physical-expr/unicode_expressions"] [dependencies] arrow = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index f8dcf460a469..c76c1c8a7bd0 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,8 +47,8 @@ use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan, Operator, - Projection, ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, - WindowFrameBound, WindowFrameUnits, + ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; #[derive(Default)] @@ -76,7 +76,7 @@ fn analyze_internal( plan: &LogicalPlan, ) -> Result { // optimize child plans first - let mut new_inputs = plan + let new_inputs = plan .inputs() .iter() .map(|p| analyze_internal(external_schema, p)) @@ -110,14 +110,7 @@ fn analyze_internal( }) .collect::>>()?; - // TODO: with_new_exprs can't change the schema, so we need to do this here - match &plan { - LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( - new_expr, - Arc::new(new_inputs.swap_remove(0)), - )?)), - _ => plan.with_new_exprs(new_expr, new_inputs), - } + plan.with_new_exprs(new_expr, new_inputs) } pub(crate) struct TypeCoercionRewriter { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7b8eccad5133..25c25c63f0b7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,6 +17,7 @@ //! Eliminate common sub-expression. +use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -29,21 +30,86 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, + internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; use datafusion_expr::{col, Expr, ExprSchemable}; -/// A map from expression's identifier to tuple including -/// - the expression itself (cloned) -/// - counter -/// - DataType of this expression. -type ExprSet = HashMap; +/// Set of expressions generated by the [`ExprIdentifierVisitor`] +/// and consumed by the [`CommonSubexprRewriter`]. +#[derive(Default)] +struct ExprSet { + /// A map from expression's identifier (stringified expr) to tuple including: + /// - the expression itself (cloned) + /// - counter + /// - DataType of this expression. + /// - symbol used as the identifier in the alias. + map: HashMap, +} + +impl ExprSet { + fn expr_identifier(expr: &Expr) -> Identifier { + format!("{expr}") + } + + fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> { + self.map.get(key) + } + + fn entry( + &mut self, + key: Identifier, + ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> { + self.map.entry(key) + } + + fn populate_expr_set( + &mut self, + expr: &[Expr], + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.iter().try_for_each(|e| { + self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?; + + Ok(()) + }) + } + + /// Go through an expression tree and generate identifier for every node in this tree. + fn expr_to_identifier( + &mut self, + expr: &Expr, + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.visit(&mut ExprIdentifierVisitor { + expr_set: self, + input_schema, + visit_stack: vec![], + node_count: 0, + expr_mask, + })?; -/// Identifier type. Current implementation use describe of an expression (type String) as -/// Identifier. + Ok(()) + } +} + +impl From> for ExprSet { + fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self { + let mut expr_set = Self::default(); + entries.into_iter().for_each(|(k, v)| { + expr_set.map.insert(k, v); + }); + expr_set + } +} + +/// Identifier for each subexpression. +/// +/// Note that the current implementation uses the `Display` of an expression +/// (a `String`) as `Identifier`. /// /// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no /// collision (as low as possible)" @@ -53,31 +119,48 @@ type ExprSet = HashMap; /// here is not such a good choose. type Identifier = String; -/// Perform Common Sub-expression Elimination optimization. +/// Performs Common Sub-expression Elimination optimization. +/// +/// This optimization improves query performance by computing expressions that +/// appear more than once and reusing those results rather than re-computing the +/// same value /// -/// Currently only common sub-expressions within one logical plan will +/// Currently only common sub-expressions within a single `LogicalPlan` are /// be eliminated. +/// +/// # Example +/// +/// Given a projection that computes the same expensive expression +/// multiple times such as parsing as string as a date with `to_date` twice: +/// +/// ```text +/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))]) +/// ``` +/// +/// This optimization will rewrite the plan to compute the common expression once +/// using a new `ProjectionExec` and then rewrite the original expressions to +/// refer to that new column. +/// +/// ```text +/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here +/// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once +/// ``` pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result>> { exprs_list .iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { + .map(|exprs| { exprs .iter() .cloned() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr(expr, id_array, expr_set, affected_id) - }) + .map(|expr| replace_common_expr(expr, expr_set, affected_id)) .collect::>>() }) .collect::>>() @@ -86,7 +169,6 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], input: &LogicalPlan, expr_set: &ExprSet, config: &dyn OptimizerConfig, @@ -94,7 +176,7 @@ impl CommonSubexprEliminate { let mut affected_id = BTreeSet::::new(); let rewrite_exprs = - self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?; + self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; let mut new_input = self .try_optimize(input, config)? @@ -112,8 +194,7 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result { let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Get all window expressions inside the consecutive window operators. // Consecutive window expressions may refer to same complex expression. @@ -132,30 +213,18 @@ impl CommonSubexprEliminate { plan = input.as_ref().clone(); let input_schema = Arc::clone(input.schema()); - let arrays = - to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?; window_exprs.push(window_expr); - arrays_per_window.push(arrays); } let mut window_exprs = window_exprs .iter() .map(|expr| expr.as_slice()) .collect::>(); - let arrays_per_window = arrays_per_window - .iter() - .map(|arrays| arrays.as_slice()) - .collect::>(); - assert_eq!(window_exprs.len(), arrays_per_window.len()); - let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, - &arrays_per_window, - &plan, - &expr_set, - config, - )?; + let (mut new_expr, new_input) = + self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?; assert_eq!(window_exprs.len(), new_expr.len()); // Construct consecutive window operator, with their corresponding new window expressions. @@ -192,46 +261,36 @@ impl CommonSubexprEliminate { input, .. } = aggregate; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); - // rewrite inputs + // build expr_set, with groupby and aggr let input_schema = Arc::clone(input.schema()); - let group_arrays = to_arrays( + expr_set.populate_expr_set( group_expr, Arc::clone(&input_schema), - &mut expr_set, ExprMask::Normal, )?; - let aggr_arrays = - to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?; - let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], - &[&group_arrays, &aggr_arrays], - input, - &expr_set, - config, - )?; + // rewrite inputs + let (mut new_expr, new_input) = + self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?; // note the reversed pop order. let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; // create potential projection on top - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); let new_input_schema = Arc::clone(new_input.schema()); - let aggr_arrays = to_arrays( + expr_set.populate_expr_set( &new_aggr_expr, new_input_schema.clone(), - &mut expr_set, ExprMask::NormalAndAggregates, )?; + let mut affected_id = BTreeSet::::new(); - let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], - &[&aggr_arrays], - &expr_set, - &mut affected_id, - )?; + let mut rewritten = + self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?; let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { @@ -251,9 +310,9 @@ impl CommonSubexprEliminate { for id in affected_id { match expr_set.get(&id) { - Some((expr, _, _)) => { + Some((expr, _, _, symbol)) => { // todo: check `nullable` - agg_exprs.push(expr.clone().alias(&id)); + agg_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -271,8 +330,7 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = - ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten); + let id = ExprSet::expr_identifier(&expr_rewritten); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -306,13 +364,13 @@ impl CommonSubexprEliminate { let inputs = plan.inputs(); let input = inputs[0]; let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Visit expr list and build expr identifier to occuring count map (`expr_set`). - let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?; + self.rewrite_expr(&[&expr], input, &expr_set, config)?; plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) } @@ -398,28 +456,6 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) } -fn to_arrays( - expr: &[Expr], - input_schema: DFSchemaRef, - expr_set: &mut ExprSet, - expr_mask: ExprMask, -) -> Result>> { - expr.iter() - .map(|e| { - let mut id_array = vec![]; - expr_to_identifier( - e, - expr_set, - &mut id_array, - Arc::clone(&input_schema), - expr_mask, - )?; - - Ok(id_array) - }) - .collect::>>() -} - /// Build the "intermediate" projection plan that evaluates the extracted common expressions. fn build_common_expr_project_plan( input: LogicalPlan, @@ -431,11 +467,11 @@ fn build_common_expr_project_plan( for id in affected_id { match expr_set.get(&id) { - Some((expr, _, data_type)) => { + Some((expr, _, data_type, symbol)) => { // todo: check `nullable` let field = DFField::new_unqualified(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(&id)); + project_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -535,15 +571,15 @@ impl ExprMask { /// This visitor implementation use a stack `visit_stack` to track traversal, which /// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called /// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem` +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` /// before the first `EnterMark` is considered to be sub-tree of the leaving node. /// /// This visitor also records identifier in `id_array`. Makes the following traverse /// pass can get the identifier of a node without recalculate it. We assign each node /// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`post_visit`) a node. Has the property +/// Series number represents the order we left (`f_up()`) a node. Has the property /// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to /// get the index of `id_array` for each node. /// /// `Expr` without sub-expr (column, literal etc.) will not have identifier @@ -551,17 +587,13 @@ impl ExprMask { struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, - /// series number (usize) and identifier. - id_array: &'a mut Vec<(usize, Identifier)>, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, // inner states visit_stack: Vec, - /// increased in pre_visit, start from 0. + /// increased in fn_down, start from 0. node_count: usize, - /// increased in post_visit, start from 1. - series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, } @@ -571,31 +603,29 @@ enum VisitRecord { /// `usize` is the monotone increasing series number assigned in pre_visit(). /// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). EnterMark(usize), + /// the node's children were skipped => jump to f_up on same node + JumpMark(usize), /// Accumulated identifier of sub expression. ExprItem(Identifier), } impl ExprIdentifierVisitor<'_> { - fn desc_expr(expr: &Expr) -> String { - format!("{expr}") - } - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { + fn pop_enter_mark(&mut self) -> (usize, Identifier) { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, desc)); + VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { + return (idx, desc); } - VisitRecord::ExprItem(s) => { - desc.push_str(&s); + VisitRecord::ExprItem(id) => { + desc.push_str(&id); } } } - None + unreachable!("Enter mark should paired with node number"); } } @@ -606,81 +636,51 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(TreeNodeRecursion::Jump); + self.visit_stack + .push(VisitRecord::JumpMark(self.node_count)); + return Ok(TreeNodeRecursion::Jump); // go to f_up } + self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; - // put placeholder - self.id_array.push((0, "".to_string())); + Ok(TreeNodeRecursion::Continue) } fn f_up(&mut self, expr: &Expr) -> Result { - self.series_number += 1; + let (_idx, sub_expr_identifier) = self.pop_enter_mark(); - let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - self.id_array[idx].0 = self.series_number; - let desc = Self::desc_expr(expr); - self.visit_stack.push(VisitRecord::ExprItem(desc)); + let curr_expr_identifier = ExprSet::expr_identifier(expr); + self.visit_stack + .push(VisitRecord::ExprItem(curr_expr_identifier)); return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::desc_expr(expr); - desc.push_str(&sub_expr_desc); + let curr_expr_identifier = ExprSet::expr_identifier(expr); + let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}"); - self.id_array[idx] = (self.series_number, desc.clone()); - self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); + self.visit_stack + .push(VisitRecord::ExprItem(alias_symbol.clone())); let data_type = expr.get_type(&self.input_schema)?; self.expr_set - .entry(desc) - .or_insert_with(|| (expr.clone(), 0, data_type)) + .entry(curr_expr_identifier) + .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol)) .1 += 1; Ok(TreeNodeRecursion::Continue) } } -/// Go through an expression tree and generate identifier for every node in this tree. -fn expr_to_identifier( - expr: &Expr, - expr_set: &mut ExprSet, - id_array: &mut Vec<(usize, Identifier)>, - input_schema: DFSchemaRef, - expr_mask: ExprMask, -) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_set, - id_array, - input_schema, - visit_stack: vec![], - node_count: 0, - series_number: 0, - expr_mask, - })?; - - Ok(()) -} - /// Rewrite expression by replacing detected common sub-expression with /// the corresponding temporary column name. That column contains the /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a [(usize, Identifier)], /// Which identifier is replaced. affected_id: &'a mut BTreeSet, - - /// the max series number we have rewritten. Other expression nodes - /// with smaller series number is already replaced and shouldn't - /// do anything with them. - max_series_number: usize, - /// current node's information's index in `id_array`. - curr_index: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { @@ -693,88 +693,42 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { if expr.short_circuits() || is_volatile_expression(&expr)? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - if self.curr_index >= self.id_array.len() - || self.max_series_number > self.id_array[self.curr_index].0 - { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } - let curr_id = &self.id_array[self.curr_index].1; - // skip `Expr`s without identifier (empty identifier). - if curr_id.is_empty() { - self.curr_index += 1; - return Ok(Transformed::no(expr)); - } + let curr_id = &ExprSet::expr_identifier(&expr); + + // lookup previously visited expression match self.expr_set.get(curr_id) { - Some((_, counter, _)) => { + Some((_, counter, _, symbol)) => { + // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - internal_datafusion_err!("expr_set invalid state") - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - let expr_name = expr.display_name()?; // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(id).alias(expr_name), + col(symbol).alias(expr_name), true, TreeNodeRecursion::Jump, )) } else { - self.curr_index += 1; Ok(Transformed::no(expr)) } } - _ => internal_err!("expr_set invalid state"), + None => Ok(Transformed::no(expr)), } } } fn replace_common_expr( expr: Expr, - id_array: &[(usize, Identifier)], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { expr.rewrite(&mut CommonSubexprRewriter { expr_set, - id_array, affected_id, - max_series_number: 0, - curr_index: 0, }) .data() } @@ -810,73 +764,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2); - - let schema = Arc::new(DFSchema::new_with_metadata( - vec![ - DFField::new_unqualified("a", DataType::Int64, false), - DFField::new_unqualified("c", DataType::Int64, false), - ], - Default::default(), - )?); - - // skip aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::Normal, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (4, ""), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, ""), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::NormalAndAggregates, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, "AVG(c)c"), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: @@ -1121,24 +1008,28 @@ mod test { let table_scan = test_table_scan().unwrap(); let affected_id: BTreeSet = ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32), + (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), ), ( "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32), + (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ - ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), + .into(); + let expr_set_2 = vec![ + ( + "c+a".to_string(), + (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), + ), + ( + "b+a".to_string(), + (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), + ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) .unwrap(); @@ -1164,30 +1055,48 @@ mod test { ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] .into_iter() .collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.c") + col("test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.b") + col("test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ + .into(); + let expr_set_2 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c+test1.a"), 1, DataType::UInt32), + ( + col("test1.c+test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b+test1.a"), 1, DataType::UInt32), + ( + col("test1.b+test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) .unwrap(); diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 08ee38f64abd..b942f187c331 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -31,7 +31,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::SchemaRef; use datafusion_common::{ - get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, + get_required_group_by_exprs_indices, internal_err, Column, DFSchema, DFSchemaRef, + JoinType, Result, }; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::{ @@ -162,14 +163,40 @@ fn optimize_projections( .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) .collect::>() } + LogicalPlan::Extension(extension) => { + let necessary_children_indices = if let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices) + { + necessary_children_indices + } else { + // Requirements from parent cannot be routed down to user defined logical plan safely + return Ok(None); + }; + let children = extension.node.inputs(); + if children.len() != necessary_children_indices.len() { + return internal_err!("Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ + consistent with actual children length for the node."); + } + // Expressions used by node. + let exprs = plan.expressions(); + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + let child_schema = child.schema(); + let child_req_indices = + indices_referred_by_exprs(child_schema, exprs.iter())?; + Ok((merge_slices(&necessary_indices, &child_req_indices), false)) + }) + .collect::>>()? + } LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Extension(_) | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. - // TODO: Add support for `LogicalPlan::Extension`. return Ok(None); } LogicalPlan::Projection(proj) => { @@ -899,21 +926,161 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result #[cfg(test)] mod tests { + use std::fmt::Formatter; use std::sync::Arc; use crate::optimize_projections::OptimizeProjections; - use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use crate::test::{ + assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name, + }; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Result, TableReference}; + use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference}; use datafusion_expr::{ - binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, - table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator, + binary_expr, build_join_schema, col, count, lit, + logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when, + BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + UserDefinedLogicalNodeCore, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } + #[derive(Debug, Hash, PartialEq, Eq)] + struct NoOpUserDefined { + exprs: Vec, + schema: DFSchemaRef, + input: Arc, + } + + impl NoOpUserDefined { + fn new(schema: DFSchemaRef, input: Arc) -> Self { + Self { + exprs: vec![], + schema, + input, + } + } + + fn with_exprs(mut self, exprs: Vec) -> Self { + self.exprs = exprs; + self + } + } + + impl UserDefinedLogicalNodeCore for NoOpUserDefined { + fn name(&self) -> &str { + "NoOpUserDefined" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoOpUserDefined") + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + Self { + exprs: exprs.to_vec(), + input: Arc::new(inputs[0].clone()), + schema: self.schema.clone(), + } + } + + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + // Since schema is same. Output columns requires their corresponding version in the input columns. + Some(vec![output_columns.to_vec()]) + } + } + + #[derive(Debug, Hash, PartialEq, Eq)] + struct UserDefinedCrossJoin { + exprs: Vec, + schema: DFSchemaRef, + left_child: Arc, + right_child: Arc, + } + + impl UserDefinedCrossJoin { + fn new(left_child: Arc, right_child: Arc) -> Self { + let left_schema = left_child.schema(); + let right_schema = right_child.schema(); + let schema = Arc::new( + build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(), + ); + Self { + exprs: vec![], + schema, + left_child, + right_child, + } + } + } + + impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin { + fn name(&self) -> &str { + "UserDefinedCrossJoin" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.left_child, &self.right_child] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "UserDefinedCrossJoin") + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 2); + Self { + exprs: exprs.to_vec(), + left_child: Arc::new(inputs[0].clone()), + right_child: Arc::new(inputs[1].clone()), + schema: self.schema.clone(), + } + } + + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + let left_child_len = self.left_child.schema().fields().len(); + let mut left_reqs = vec![]; + let mut right_reqs = vec![]; + for &out_idx in output_columns { + if out_idx < left_child_len { + left_reqs.push(out_idx); + } else { + // Output indices further than the left_child_len + // comes from right children + right_reqs.push(out_idx - left_child_len) + } + } + Some(vec![left_reqs, right_reqs]) + } + } + #[test] fn merge_two_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -1192,4 +1359,112 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + + // Since only column `a` is referred at the output. Scan should only contain projection=[a]. + // User defined node should be able to propagate necessary expressions by its parent to its child. + #[test] + fn test_user_defined_logical_plan_node() -> Result<()> { + let table_scan = test_table_scan()?; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + )), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Only column `a` is referred at the output. However, User defined node itself uses column `b` + // during its operation. Hence, scan should contain projection=[a, b]. + // User defined node should be able to propagate necessary expressions by its parent, as well as its own + // required expressions. + #[test] + fn test_user_defined_logical_plan_node2() -> Result<()> { + let table_scan = test_table_scan()?; + let exprs = vec![Expr::Column(Column::from_qualified_name("b"))]; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new( + NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + ) + .with_exprs(exprs), + ), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a, b]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` + // during its operation. Hence, scan should contain projection=[a, b, c]. + // User defined node should be able to propagate necessary expressions by its parent, as well as its own + // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions + // should be propagated also. + #[test] + fn test_user_defined_logical_plan_node3() -> Result<()> { + let table_scan = test_table_scan()?; + let left_expr = Expr::Column(Column::from_qualified_name("b")); + let right_expr = Expr::Column(Column::from_qualified_name("c")); + let binary_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(left_expr), + Operator::Plus, + Box::new(right_expr), + )); + let exprs = vec![binary_expr]; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new( + NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + ) + .with_exprs(exprs), + ), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Columns `l.a`, `l.c`, `r.a` is referred at the output. + // User defined node should be able to propagate necessary expressions by its parent, to its children. + // Even if it has multiple children. + // left child should have `projection=[a, c]`, and right side should have `projection=[a]`. + #[test] + fn test_user_defined_logical_plan_node4() -> Result<()> { + let left_table = test_table_scan_with_name("l")?; + let right_table = test_table_scan_with_name("r")?; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(UserDefinedCrossJoin::new( + Arc::new(left_table.clone()), + Arc::new(right_table.clone()), + )), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\ + \n UserDefinedCrossJoin\ + \n TableScan: l projection=[a, c]\ + \n TableScan: r projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5b5bca75ddb0..1cbe7decf15b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,7 +21,7 @@ use std::borrow::Cow; use std::collections::HashSet; use std::ops::Not; -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -175,7 +175,6 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); - let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { @@ -190,8 +189,6 @@ impl ExprSimplifier { .data()? .rewrite(&mut simplifier) .data()? - .rewrite(&mut inlist_simplifier) - .data()? .rewrite(&mut guarantee_rewriter) .data()? // run both passes twice to try an minimize simplifications that we missed @@ -408,11 +405,12 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } +#[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), - // Evalaution encountered an error, contains the original expression + // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -1452,13 +1450,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Operator::Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { - let left = as_inlist(left.as_ref()); - let right = as_inlist(right.as_ref()); - - let lhs = left.unwrap(); - let rhs = right.unwrap(); - let lhs = lhs.into_owned(); - let rhs = rhs.into_owned(); + let lhs = to_inlist(*left).unwrap(); + let rhs = to_inlist(*right).unwrap(); let mut seen: HashSet = HashSet::new(); let list = lhs .list @@ -1473,7 +1466,123 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { negated: false, }; - return Ok(Transformed::yes(Expr::InList(merged_inlist))); + Transformed::yes(Expr::InList(merged_inlist)) + } + + // Simplify expressions that is guaranteed to be true or false to a literal boolean expression + // + // Rules: + // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists + // Intersection: + // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` + // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` + // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` + // Union: + // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` + // # This rule is handled by `or_in_list_simplifier.rs` + // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` + // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression + // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` + // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` + // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, false).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_union(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l1, l2).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l2, l1).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } } // no additional rewrites possible @@ -1482,6 +1591,22 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 +fn are_inlist_and_eq_and_match_neg( + left: &Expr, + right: &Expr, + is_left_neg: bool, + is_right_neg: bool, +) -> bool { + match (left, right) { + (Expr::InList(l), Expr::InList(r)) => { + l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg + } + _ => false, + } +} + +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); @@ -1519,6 +1644,78 @@ fn as_inlist(expr: &Expr) -> Option> { } } +fn to_inlist(expr: Expr) -> Option { + match expr { + Expr::InList(inlist) => Some(inlist), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(_)) => Some(InList { + expr: left, + list: vec![*right], + negated: false, + }), + (Expr::Literal(_), Expr::Column(_)) => Some(InList { + expr: right, + list: vec![*left], + negated: false, + }), + _ => None, + }, + _ => None, + } +} + +/// Return the union of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { + // extend the list in l1 with the elements in l2 that are not already in l1 + let l1_items: HashSet<_> = l1.list.iter().collect(); + + // keep all l2 items that do not also appear in l1 + let keep_l2: Vec<_> = l2 + .list + .into_iter() + .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) + .collect(); + + l1.list.extend(keep_l2); + l1.negated = negated; + Ok(Expr::InList(l1)) +} + +/// Return the intersection of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // remove all items from l1 that are not in l2 + l1.list.retain(|e| l2_items.contains(e)); + + // e in () is always false + // e not in () is always true + if l1.list.is_empty() { + return Ok(lit(negated)); + } + Ok(Expr::InList(l1)) +} + +/// Return the all items in l1 that are not in l2 +/// maintaining the order of the elements in the two lists +fn inlist_except(mut l1: InList, l2: InList) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // keep only items from l1 that are not in l2 + l1.list.retain(|e| !l2_items.contains(e)); + + if l1.list.is_empty() { + return Ok(lit(false)); + } + Ok(Expr::InList(l1)) +} + #[cfg(test)] mod tests { use std::{ diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 5d1cf27827a9..9dcb8ed15563 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,12 +19,10 @@ use super::THRESHOLD_INLINE_INLIST; -use std::collections::HashSet; - use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; -use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; +use datafusion_expr::Expr; pub(super) struct ShortenInListSimplifier {} @@ -97,121 +95,3 @@ impl TreeNodeRewriter for ShortenInListSimplifier { Ok(Transformed::no(expr)) } } - -pub(super) struct InListSimplifier {} - -impl InListSimplifier { - pub(super) fn new() -> Self { - Self {} - } -} - -impl TreeNodeRewriter for InListSimplifier { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - // Simplify expressions that is guaranteed to be true or false to a literal boolean expression - // - // Rules: - // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists - // Intersection: - // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` - // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` - // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` - // Union: - // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` - // # This rule is handled by `or_in_list_simplifier.rs` - // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` - // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression - // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` - // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` - // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr.clone() { - match (*left, op, *right) { - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && !l2.negated => - { - return inlist_intersection(l1, l2, false).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_union(l1, l2, true).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && l2.negated => - { - return inlist_except(l1, l2).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && !l2.negated => - { - return inlist_except(l2, l1).map(Transformed::yes); - } - (Expr::InList(l1), Operator::Or, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_intersection(l1, l2, true).map(Transformed::yes); - } - (left, op, right) => { - // put the expression back together - return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), - }))); - } - } - } - - Ok(Transformed::no(expr)) - } -} - -/// Return the union of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { - // extend the list in l1 with the elements in l2 that are not already in l1 - let l1_items: HashSet<_> = l1.list.iter().collect(); - - // keep all l2 items that do not also appear in l1 - let keep_l2: Vec<_> = l2 - .list - .into_iter() - .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) - .collect(); - - l1.list.extend(keep_l2); - l1.negated = negated; - Ok(Expr::InList(l1)) -} - -/// Return the intersection of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // remove all items from l1 that are not in l2 - l1.list.retain(|e| l2_items.contains(e)); - - // e in () is always false - // e not in () is always true - if l1.list.is_empty() { - return Ok(lit(negated)); - } - Ok(Expr::InList(l1)) -} - -/// Return the all items in l1 that are not in l2 -/// maintaining the order of the elements in the two lists -fn inlist_except(mut l1: InList, l2: InList) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // keep only items from l1 that are not in l2 - l1.list.retain(|e| !l2_items.contains(e)); - - if l1.list.is_empty() { - return Ok(lit(false)); - } - Ok(Expr::InList(l1)) -} diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d63ad9bb4a3a..baca00bea724 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,12 +37,10 @@ crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] default = [ "crypto_expressions", "regex_expressions", - "unicode_expressions", "encoding_expressions", ] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] -unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = { version = "0.8", default-features = false, features = [ @@ -73,8 +71,6 @@ petgraph = "0.6.2" rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } -unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.5" diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 846431034c96..cee679863870 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -370,13 +370,16 @@ pub fn create_aggregate_expr( ) .with_ignore_nulls(ignore_nulls), ), - (AggregateFunction::LastValue, _) => Arc::new(expressions::LastValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - )), + (AggregateFunction::LastValue, _) => Arc::new( + expressions::LastValue::new( + input_phy_exprs[0].clone(), + name, + input_phy_types[0].clone(), + ordering_req.to_vec(), + ordering_types, + ) + .with_ignore_nulls(ignore_nulls), + ), (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 71782fcc5f9b..9c5605f495ea 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -35,7 +35,7 @@ use arrow_array::types::{ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::Accumulator; use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; @@ -47,7 +47,7 @@ use crate::binary_map::OutputType; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -/// Expression for a COUNT(DISTINCT) aggregation. +/// Expression for a `COUNT(DISTINCT)` aggregation. #[derive(Debug)] pub struct DistinctCount { /// Column name @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount { use TimeUnit::*; Ok(match &self.state_data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount { OutputType::Binary, )), + // Use the generic accumulator based on `ScalarValue` for all other types _ => Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -183,7 +185,11 @@ impl PartialEq for DistinctCount { } /// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. Some types have specialized accumulators that are (much) +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and /// [`BytesDistinctCountAccumulator`] #[derive(Debug)] @@ -193,8 +199,9 @@ struct DistinctCountAccumulator { } impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -207,7 +214,8 @@ impl DistinctCountAccumulator { + std::mem::size_of::() } - // calculates the size as accurate as possible, call to this method is expensive + // calculates the size as accurately as possible. Note that calling this + // method is expensive fn full_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -221,6 +229,7 @@ impl DistinctCountAccumulator { } impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); @@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator { }) } + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); @@ -253,8 +267,15 @@ impl Accumulator for DistinctCountAccumulator { assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let array = &states[0]; let list_array = array.as_list::(); - let inner_array = list_array.value(0); - self.update_batch(&[inner_array]) + for inner_array in list_array.iter() { + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; + self.update_batch(&[inner_array])?; + } + Ok(()) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 17dd3ef1206d..6d6e32a14987 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -393,6 +393,7 @@ pub struct LastValue { expr: Arc, ordering_req: LexOrdering, requirement_satisfied: bool, + ignore_nulls: bool, } impl LastValue { @@ -412,9 +413,15 @@ impl LastValue { expr, ordering_req, requirement_satisfied, + ignore_nulls: false, } } + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } + /// Returns the name of the aggregate expression. pub fn name(&self) -> &str { &self.name @@ -483,6 +490,7 @@ impl AggregateExpr for LastValue { &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), + self.ignore_nulls, ) .map(|acc| { Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ @@ -528,6 +536,7 @@ impl AggregateExpr for LastValue { &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), + self.ignore_nulls, ) .map(|acc| { Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ @@ -561,6 +570,8 @@ struct LastValueAccumulator { ordering_req: LexOrdering, // Stores whether incoming data already satisfies the ordering requirement. requirement_satisfied: bool, + // Ignore null values. + ignore_nulls: bool, } impl LastValueAccumulator { @@ -569,6 +580,7 @@ impl LastValueAccumulator { data_type: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + ignore_nulls: bool, ) -> Result { let orderings = ordering_dtypes .iter() @@ -581,6 +593,7 @@ impl LastValueAccumulator { orderings, ordering_req, requirement_satisfied, + ignore_nulls, }) } @@ -597,7 +610,17 @@ impl LastValueAccumulator { }; if self.requirement_satisfied { // Get last entry according to the order of data: - return Ok((!value.is_empty()).then_some(value.len() - 1)); + if self.ignore_nulls { + // If ignoring nulls, find the last non-null value. + for i in (0..value.len()).rev() { + if !value.is_null(i) { + return Ok(Some(i)); + } + } + return Ok(None); + } else { + return Ok((!value.is_empty()).then_some(value.len() - 1)); + } } let sort_columns = ordering_values .iter() @@ -611,8 +634,20 @@ impl LastValueAccumulator { } }) .collect::>(); - let indices = lexsort_to_indices(&sort_columns, Some(1))?; - Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + + if self.ignore_nulls { + let indices = lexsort_to_indices(&sort_columns, None)?; + // If ignoring nulls, find the last non-null value. + for index in indices.iter().flatten() { + if !value.is_null(index as usize) { + return Ok(Some(index as usize)); + } + } + Ok(None) + } else { + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } } fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { @@ -746,7 +781,7 @@ mod tests { let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -814,13 +849,13 @@ mod tests { // LastValueAccumulator let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; last_accumulator.update_batch(&[arrs[0].clone()])?; let state1 = last_accumulator.state()?; let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; last_accumulator.update_batch(&[arrs[1].clone()])?; let state2 = last_accumulator.state()?; @@ -836,7 +871,7 @@ mod tests { } let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs deleted file mode 100644 index c3c0f4c82282..000000000000 --- a/datafusion/physical-expr/src/array_expressions.rs +++ /dev/null @@ -1,423 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Array expressions - -use std::sync::Arc; - -use arrow::array::*; -use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field}; -use arrow_buffer::NullBuffer; - -use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::{exec_err, plan_err, Result}; - -/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. -/// -/// # Arguments -/// -/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. -/// -/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. -/// -/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. -/// -/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. -/// -/// # Returns -/// -/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. -/// -/// # Example -/// -/// ```text -/// compare_element_to_list( -/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] -/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] -/// -/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] -/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] -/// ) -/// ``` -fn compare_element_to_list( - list_array_row: &dyn Array, - element_array: &dyn Array, - row_index: usize, - eq: bool, -) -> Result { - if list_array_row.data_type() != element_array.data_type() { - return exec_err!( - "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", - list_array_row.data_type(), - element_array.data_type() - ); - } - - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let res = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - DataType::LargeList(_) => { - // compare each element of the from array - let element_array_row_inner = - as_large_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_large_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so we can compare NULL - if eq { - arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? - } else { - arrow_ord::cmp::distinct(&list_array_row, &element_arr)? - } - } - }; - - Ok(res) -} - -/// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` or 'LargeListArray' depending on the offset size. -/// -/// # Example (non nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are non nested -/// would return a single new `ListArray`, where each row was a list -/// of 2 elements: -/// -/// ```text -/// ┌─────────┐ ┌─────────┐ ┌──────────────┐ -/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │ -/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │ -/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │ -/// └─────────┘ └─────────┘ └──────────────┘ -/// col1 col2 output -/// ``` -/// -/// # Example (nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are lists -/// would return a single new `ListArray`, where each row was a list -/// of the corresponding elements of col1 and col2. -/// -/// ``` text -/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ -/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │ -/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │ -/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │ -/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │ -/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │ -/// └──────────────┘ └──────────────┘ └─────────────────────────────┘ -/// col1 col2 output -/// ``` -fn array_array( - args: &[ArrayRef], - data_type: DataType, -) -> Result { - // do not accept 0 arguments. - if args.is_empty() { - return plan_err!("Array requires at least one argument"); - } - - let mut data = vec![]; - let mut total_len = 0; - for arg in args { - let arg_data = if arg.as_any().is::() { - ArrayData::new_empty(&data_type) - } else { - arg.to_data() - }; - total_len += arg_data.len(); - data.push(arg_data); - } - - let mut offsets: Vec = Vec::with_capacity(total_len); - offsets.push(O::usize_as(0)); - - let capacity = Capacities::Array(total_len); - let data_ref = data.iter().collect::>(); - let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); - - let num_rows = args[0].len(); - for row_idx in 0..num_rows { - for (arr_idx, arg) in args.iter().enumerate() { - if !arg.as_any().is::() - && !arg.is_null(row_idx) - && arg.is_valid(row_idx) - { - mutable.extend(arr_idx, row_idx, row_idx + 1); - } else { - mutable.extend_nulls(1); - } - } - offsets.push(O::usize_as(mutable.len())); - } - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type, true)), - OffsetBuffer::new(offsets.into()), - arrow_array::make_array(data), - None, - )?)) -} - -/// `make_array` SQL function -pub fn make_array(arrays: &[ArrayRef]) -> Result { - let mut data_type = DataType::Null; - for arg in arrays { - let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&DataType::Null) { - data_type = arg_data_type.clone(); - break; - } - } - - match data_type { - // Either an empty array or all nulls: - DataType::Null => { - let array = - new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); - Ok(Arc::new(array_into_list_array(array))) - } - DataType::LargeList(..) => array_array::(arrays, data_type), - _ => array_array::(arrays, data_type), - } -} - -/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences -/// of `from_array[i]`, `to_array[i]`. -/// -/// The type of each **element** in `list_array` must be the same as the type of -/// `from_array` and `to_array`. This function also handles nested arrays -/// ([`ListArray`] of [`ListArray`]s) -/// -/// For example, when called to replace a list array (where each element is a -/// list of int32s, the second and third argument are int32 arrays, and the -/// fourth argument is the number of occurrences to replace -/// -/// ```text -/// general_replace( -/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) -/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) -/// ) -/// ``` -fn general_replace( - list_array: &GenericListArray, - from_array: &ArrayRef, - to_array: &ArrayRef, - arr_n: Vec, -) -> Result { - // Build up the offsets for the final output array - let mut offsets: Vec = vec![O::usize_as(0)]; - let values = list_array.values(); - let original_data = values.to_data(); - let to_data = to_array.to_data(); - let capacity = Capacities::Array(original_data.len()); - - // First array is the original array, second array is the element to replace with. - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &to_data], - false, - capacity, - ); - - let mut valid = BooleanBufferBuilder::new(list_array.len()); - - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - if list_array.is_null(row_index) { - offsets.push(offsets[row_index]); - valid.append(false); - continue; - } - - let start = offset_window[0]; - let end = offset_window[1]; - - let list_array_row = list_array.value(row_index); - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let eq_array = - compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - - let original_idx = O::usize_as(0); - let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; - let mut counter = 0; - - // All elements are false, no need to replace, just copy original data - if eq_array.false_count() == eq_array.len() { - mutable.extend( - original_idx.to_usize().unwrap(), - start.to_usize().unwrap(), - end.to_usize().unwrap(), - ); - offsets.push(offsets[row_index] + (end - start)); - valid.append(true); - continue; - } - - for (i, to_replace) in eq_array.iter().enumerate() { - let i = O::usize_as(i); - if let Some(true) = to_replace { - mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); - counter += 1; - if counter == n { - // copy original data for any matches past n - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - end.to_usize().unwrap(), - ); - break; - } - } else { - // copy original data for false / null matches - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - ); - } - } - - offsets.push(offsets[row_index] + (end - start)); - valid.append(true); - } - - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - Some(NullBuffer::new(valid.finish())), - )?)) -} - -pub fn array_replace(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace expects three arguments"); - } - - // replace at most one occurence for each element - let arr_n = vec![1; args[0].len()]; - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => exec_err!("array_replace does not support type '{array_type:?}'."), - } -} - -pub fn array_replace_n(args: &[ArrayRef]) -> Result { - if args.len() != 4 { - return exec_err!("array_replace_n expects four arguments"); - } - - // replace the specified number of occurences - let arr_n = as_int64_array(&args[3])?.values().to_vec(); - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => { - exec_err!("array_replace_n does not support type '{array_type:?}'.") - } - } -} - -pub fn array_replace_all(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace_all expects three arguments"); - } - - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; args[0].len()]; - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => { - exec_err!("array_replace_all does not support type '{array_type:?}'.") - } - } -} diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr/src/binary_map.rs index b661f0a74148..6c3a452a8611 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr/src/binary_map.rs @@ -280,7 +280,7 @@ where /// # Returns /// /// The payload value for the entry, either the existing value or - /// the the newly inserted value + /// the newly inserted value /// /// # Safety: /// diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 280535f5e6be..58519c61cf1f 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -535,7 +535,7 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ + let entries = [ EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), // This group is meaningless should be removed EquivalenceClass::new(vec![lit(3), lit(3)]), @@ -543,11 +543,11 @@ mod tests { ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. - let expected = vec![ + let expected = [ EquivalenceClass::new(vec![lit(1), lit(2)]), EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries); + let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); eq_groups.remove_redundant_entries(); let eq_groups = eq_groups.classes; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index c7cb9e5f530e..1364d3a8c028 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -746,7 +746,7 @@ mod tests { // Generate a data that satisfies properties given let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ + let col_exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, @@ -815,7 +815,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a08e85b24162..5eb9d6eb1b86 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1793,7 +1793,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 994c17309ec0..a1e471bdd422 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -30,26 +30,29 @@ //! an argument i32 is passed to a function that supports f64, the //! argument is automatically is coerced to f64. -use crate::sort_properties::SortProperties; -use crate::{ - array_expressions, conditional_expressions, math_expressions, string_expressions, - PhysicalExpr, ScalarFunctionExpr, -}; +use std::ops::Neg; +use std::sync::Arc; + use arrow::{ array::ArrayRef, - compute::kernels::length::{bit_length, length}, - datatypes::{DataType, Int32Type, Int64Type, Schema}, + datatypes::{DataType, Schema}, }; use arrow_array::Array; + use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; +use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, }; -use std::ops::Neg; -use std::sync::Arc; + +use crate::sort_properties::SortProperties; +use crate::{ + conditional_expressions, math_expressions, string_expressions, PhysicalExpr, + ScalarFunctionExpr, +}; /// Create a physical (function) expression. /// This function errors when `args`' can't be coerced to a valid argument type of the function. @@ -57,7 +60,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - execution_props: &ExecutionProps, + _execution_props: &ExecutionProps, ) -> Result> { let input_expr_types = input_phy_exprs .iter() @@ -69,14 +72,12 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = - create_physical_fun(fun, execution_props)?; - let monotonicity = fun.monotonicity(); + let fun_def = ScalarFunctionDefinition::BuiltIn(*fun); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), - fun_expr, + fun_def, input_phy_exprs.to_vec(), data_type, monotonicity, @@ -84,26 +85,6 @@ pub fn create_physical_expr( ))) } -#[cfg(feature = "unicode_expressions")] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => {{ - use crate::unicode_expressions; - unicode_expressions::$FUNC::<$T> - }}; -} - -#[cfg(not(feature = "unicode_expressions"))] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => { - |_: &[ArrayRef]| -> Result { - internal_err!( - "function {} requires compilation with feature flag: unicode_expressions.", - $NAME - ) - } - }; -} - #[derive(Debug, Clone, Copy)] pub enum Hint { /// Indicates the argument needs to be padded if it is scalar @@ -195,14 +176,9 @@ where /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - _execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions - BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), - BuiltinScalarFunction::Acosh => Arc::new(math_expressions::acosh), - BuiltinScalarFunction::Asinh => Arc::new(math_expressions::asinh), - BuiltinScalarFunction::Atanh => Arc::new(math_expressions::atanh), BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), BuiltinScalarFunction::Cosh => Arc::new(math_expressions::cosh), @@ -221,9 +197,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Lcm => { Arc::new(|args| make_scalar_function_inner(math_expressions::lcm)(args)) } - BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), - BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), - BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), BuiltinScalarFunction::Nanvl => { Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } @@ -244,84 +217,13 @@ pub fn create_physical_fun( BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } - BuiltinScalarFunction::Atan2 => { - Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args)) - } BuiltinScalarFunction::Log => { Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args)) } BuiltinScalarFunction::Cot => { Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } - - // array functions - BuiltinScalarFunction::ArrayReplace => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace)(args) - }), - BuiltinScalarFunction::ArrayReplaceN => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace_n)(args) - }), - BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace_all)(args) - }), - // string functions - BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::ascii::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::ascii::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function ascii"), - }), - BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| (x.len() * 8) as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), - )), - _ => unreachable!(), - }, - }), - BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function btrim"), - }), - BuiltinScalarFunction::CharacterLength => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!( - "Unsupported data type {other:?} for function character_length" - ), - }) - } - BuiltinScalarFunction::Chr => { - Arc::new(|args| make_scalar_function_inner(string_expressions::chr)(args)) - } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { @@ -338,140 +240,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function initcap") } }), - BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function left"), - }), - BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), - BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function lpad"), - }), - BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::ltrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::ltrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function ltrim"), - }), - BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| x.len() as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - _ => unreachable!(), - }, - }), - BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::repeat::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::repeat::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function repeat"), - }), - BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::replace::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::replace::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function replace") - } - }), - BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function reverse") - } - }), - BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function right"), - }), - BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rpad"), - }), - BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::rtrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::rtrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rtrim"), - }), - BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::split_part::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::split_part::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function split_part") - } - }), - BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::starts_with::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::starts_with::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function starts_with") - } - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -483,141 +251,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function ends_with") } }), - BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int32Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int64Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - }), - BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function substr"), - }), - BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { - DataType::Int32 => { - make_scalar_function_inner(string_expressions::to_hex::)(args) - } - DataType::Int64 => { - make_scalar_function_inner(string_expressions::to_hex::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function to_hex"), - }), - BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i32, - "translate" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i64, - "translate" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function translate") - } - }), - BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function trim"), - }), - BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), - BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), - BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::overlay::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::overlay::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function overlay"), - }), - BuiltinScalarFunction::Levenshtein => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function_inner( - string_expressions::levenshtein::, - )(args), - DataType::LargeUtf8 => make_scalar_function_inner( - string_expressions::levenshtein::, - )(args), - other => { - exec_err!("Unsupported data type {other:?} for function levenshtein") - } - }) - } - BuiltinScalarFunction::SubstrIndex => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - substr_index, - i32, - "substr_index" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - substr_index, - i64, - "substr_index" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function substr_index") - } - }) - } - BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - find_in_set, - Int32Type, - "find_in_set" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - find_in_set, - Int64Type, - "find_in_set" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") - } - }), }) } @@ -684,9 +317,6 @@ fn func_order_in_one_dimension( #[cfg(test)] mod tests { - use super::*; - use crate::expressions::lit; - use crate::expressions::try_cast; use arrow::{ array::{ Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int32Array, @@ -695,12 +325,18 @@ mod tests { datatypes::Field, record_batch::RecordBatch, }; + use datafusion_common::cast::as_uint64_array; - use datafusion_common::{exec_err, internal_err, plan_err}; + use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; + use crate::expressions::lit; + use crate::expressions::try_cast; + + use super::*; + /// $FUNC function to test /// $ARGS arguments (vec) to pass to function /// $EXPECTED a Result> where Result allows testing errors and Option allows testing Null @@ -752,1104 +388,132 @@ mod tests { #[test] fn test_functions() -> Result<()> { - test_function!(Ascii, &[lit("x")], Ok(Some(120)), i32, Int32, Int32Array); - test_function!(Ascii, &[lit("ésoj")], Ok(Some(233)), i32, Int32, Int32Array); - test_function!( - Ascii, - &[lit("💯")], - Ok(Some(128175)), - i32, - Int32, - Int32Array - ); - test_function!( - Ascii, - &[lit("💯a")], - Ok(Some(128175)), - i32, - Int32, - Int32Array - ); - test_function!(Ascii, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - Ascii, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); - test_function!( - BitLength, - &[lit("chars")], - Ok(Some(40)), - i32, - Int32, - Int32Array - ); - test_function!( - BitLength, - &[lit("josé")], - Ok(Some(40)), - i32, - Int32, - Int32Array - ); - test_function!(BitLength, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - Btrim, - &[lit(" trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); test_function!( - Btrim, - &[lit(" trim")], - Ok(Some("trim")), + Concat, + &[lit("aa"), lit("bb"), lit("cc"),], + Ok(Some("aabbcc")), &str, Utf8, StringArray ); test_function!( - Btrim, - &[lit("trim ")], - Ok(Some("trim")), + Concat, + &[lit("aa"), lit(ScalarValue::Utf8(None)), lit("cc"),], + Ok(Some("aacc")), &str, Utf8, StringArray ); test_function!( - Btrim, - &[lit("\n trim \n")], - Ok(Some("\n trim \n")), + Concat, + &[lit(ScalarValue::Utf8(None))], + Ok(Some("")), &str, Utf8, StringArray ); test_function!( - Btrim, - &[lit("xyxtrimyyx"), lit("xyz"),], - Ok(Some("trim")), + ConcatWithSeparator, + &[lit("|"), lit("aa"), lit("bb"), lit("cc"),], + Ok(Some("aa|bb|cc")), &str, Utf8, StringArray ); test_function!( - Btrim, - &[lit("\nxyxtrimyyx\n"), lit("xyz\n"),], - Ok(Some("trim")), + ConcatWithSeparator, + &[lit("|"), lit(ScalarValue::Utf8(None)),], + Ok(Some("")), &str, Utf8, StringArray ); test_function!( - Btrim, - &[lit(ScalarValue::Utf8(None)), lit("xyz"),], + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(None)), + lit("aa"), + lit("bb"), + lit("cc"), + ], Ok(None), &str, Utf8, StringArray ); test_function!( - Btrim, - &[lit("xyxtrimyyx"), lit(ScalarValue::Utf8(None)),], - Ok(None), + ConcatWithSeparator, + &[lit("|"), lit("aa"), lit(ScalarValue::Utf8(None)), lit("cc"),], + Ok(Some("aa|cc")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( - CharacterLength, - &[lit("chars")], - Ok(Some(5)), - i32, - Int32, - Int32Array + Exp, + &[lit(ScalarValue::Int32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array ); - #[cfg(feature = "unicode_expressions")] test_function!( - CharacterLength, - &[lit("josé")], - Ok(Some(4)), - i32, - Int32, - Int32Array + Exp, + &[lit(ScalarValue::UInt32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array ); - #[cfg(feature = "unicode_expressions")] test_function!( - CharacterLength, - &[lit("")], - Ok(Some(0)), - i32, - Int32, - Int32Array + Exp, + &[lit(ScalarValue::UInt64(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array ); - #[cfg(feature = "unicode_expressions")] test_function!( - CharacterLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array + Exp, + &[lit(ScalarValue::Float64(Some(1.0)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array ); - #[cfg(not(feature = "unicode_expressions"))] test_function!( - CharacterLength, - &[lit("josé")], - internal_err!( - "function character_length requires compilation with feature flag: unicode_expressions." - ), - i32, - Int32, - Int32Array + Exp, + &[lit(ScalarValue::Float32(Some(1.0)))], + Ok(Some((1.0_f32).exp())), + f32, + Float32, + Float32Array ); test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(128175)))], - Ok(Some("💯")), + InitCap, + &[lit("hi THOMAS")], + Ok(Some("Hi Thomas")), &str, Utf8, StringArray ); + test_function!(InitCap, &[lit("")], Ok(Some("")), &str, Utf8, StringArray); + test_function!(InitCap, &[lit("")], Ok(Some("")), &str, Utf8, StringArray); test_function!( - Chr, - &[lit(ScalarValue::Int64(None))], + InitCap, + &[lit(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, StringArray ); test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(120)))], - Ok(Some("x")), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(128175)))], - Ok(Some("💯")), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(0)))], - exec_err!("null character not permitted."), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(i64::MAX)))], - exec_err!("requested character too large for encoding."), - &str, - Utf8, - StringArray - ); - test_function!( - Concat, - &[lit("aa"), lit("bb"), lit("cc"),], - Ok(Some("aabbcc")), - &str, - Utf8, - StringArray - ); - test_function!( - Concat, - &[lit("aa"), lit(ScalarValue::Utf8(None)), lit("cc"),], - Ok(Some("aacc")), - &str, - Utf8, - StringArray - ); - test_function!( - Concat, - &[lit(ScalarValue::Utf8(None))], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[lit("|"), lit("aa"), lit("bb"), lit("cc"),], - Ok(Some("aa|bb|cc")), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[lit("|"), lit(ScalarValue::Utf8(None)),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[ - lit(ScalarValue::Utf8(None)), - lit("aa"), - lit("bb"), - lit("cc"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[lit("|"), lit("aa"), lit(ScalarValue::Utf8(None)), lit("cc"),], - Ok(Some("aa|cc")), - &str, - Utf8, - StringArray - ); - test_function!( - Exp, - &[lit(ScalarValue::Int32(Some(1)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::UInt32(Some(1)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::UInt64(Some(1)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::Float64(Some(1.0)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::Float32(Some(1.0)))], - Ok(Some((1.0_f32).exp())), - f32, - Float32, - Float32Array - ); - test_function!( - InitCap, - &[lit("hi THOMAS")], - Ok(Some("Hi Thomas")), - &str, - Utf8, - StringArray - ); - test_function!(InitCap, &[lit("")], Ok(Some("")), &str, Utf8, StringArray); - test_function!(InitCap, &[lit("")], Ok(Some("")), &str, Utf8, StringArray); - test_function!( - InitCap, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("ab")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("abc")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Left, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function left requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Lpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(" trim ")], - Ok(Some("trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("trim ")], - Ok(Some("trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("\n trim ")], - Ok(Some("\n trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - OctetLength, - &[lit("chars")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLength, - &[lit("josé")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!(OctetLength, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - OctetLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); - test_function!( - Repeat, - &[lit("Pg"), lit(ScalarValue::Int64(Some(4))),], - Ok(Some("PgPgPgPg")), - &str, - Utf8, - StringArray - ); - test_function!( - Repeat, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(4))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Repeat, - &[lit("Pg"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("abcde")], - Ok(Some("edcba")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Reverse, - &[lit("abcde")], - internal_err!( - "function reverse requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("de")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("cde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Right, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function right requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("josé ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("hixyx")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("hiabcdefabcdefabcdefa")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("joséxyxyxy")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("josééñéñéñ")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Rpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function rpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim ")], - Ok(Some(" trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim \n")], - Ok(Some(" trim \n")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim")], - Ok(Some(" trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit("trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - SplitPart, - &[ - lit("abc~@~def~@~ghi"), - lit("~@~"), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("def")), - &str, - Utf8, - StringArray - ); - test_function!( - SplitPart, - &[ - lit("abc~@~def~@~ghi"), - lit("~@~"), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - SplitPart, - &[ - lit("abc~@~def~@~ghi"), - lit("~@~"), - lit(ScalarValue::Int64(Some(-1))), - ], - exec_err!("field position must be greater than zero"), - &str, - Utf8, - StringArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit("alph"),], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit("blph"),], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit(ScalarValue::Utf8(None)), lit("alph"),], - Ok(None), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit(ScalarValue::Utf8(None)),], - Ok(None), - bool, - Boolean, - BooleanArray - ); - test_function!( - EndsWith, - &[lit("alphabet"), lit("alph"),], - Ok(Some(false)), - bool, - Boolean, - BooleanArray + EndsWith, + &[lit("alphabet"), lit("alph"),], + Ok(Some(false)), + bool, + Boolean, + BooleanArray ); test_function!( EndsWith, @@ -1875,350 +539,7 @@ mod tests { Boolean, BooleanArray ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("ésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-5))),], - Ok(Some("joséésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(1))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(2))),], - Ok(Some("lphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(3))),], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(30))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("ph")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from 5 (10 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(10))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from -1 (4 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(4))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - // starting from 0 (5 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(None)), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Int64(Some(-1))), - ], - exec_err!("negative substring length not allowed: substr(, 1, -1)"), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("joséésoj"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("és")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - ], - internal_err!( - "function substr requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit("143"), lit("ax"),], - Ok(Some("a2x5")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit(ScalarValue::Utf8(None)), lit("143"), lit("ax"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit(ScalarValue::Utf8(None)), lit("ax"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit("143"), lit(ScalarValue::Utf8(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("é2íñ5"), lit("éñí"), lit("óü"),], - Ok(Some("ó2ü5")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Translate, - &[ - lit("12345"), - lit("143"), - lit("ax"), - ], - internal_err!( - "function translate requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(" trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit("upper")], - Ok(Some("UPPER")), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit("UPPER")], - Ok(Some("UPPER")), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); + Ok(()) } @@ -2228,7 +549,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); // pick some arbitrary functions to test - let funs = [BuiltinScalarFunction::Concat, BuiltinScalarFunction::Repeat]; + let funs = [BuiltinScalarFunction::Concat]; for fun in funs.iter() { let expr = create_physical_expr_with_type_coercion( @@ -2261,11 +582,7 @@ mod tests { let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let funs = [ - BuiltinScalarFunction::Pi, - BuiltinScalarFunction::Random, - BuiltinScalarFunction::Uuid, - ]; + let funs = [BuiltinScalarFunction::Pi, BuiltinScalarFunction::Random]; for fun in funs.iter() { create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index e8b80ee4e1e6..7819d5116160 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,6 @@ pub mod aggregate; pub mod analysis; -pub mod array_expressions; pub mod binary_map; pub mod conditional_expressions; pub mod equivalence; @@ -34,8 +33,6 @@ pub mod sort_properties; pub mod string_expressions; pub mod tree_node; pub mod udf; -#[cfg(feature = "unicode_expressions")] -pub mod unicode_expressions; pub mod utils; pub mod window; @@ -54,7 +51,7 @@ pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, PhysicalExpr, PhysicalExprRef, }; -pub use planner::create_physical_expr; +pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; pub use sort_expr::{ LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index db8855cb5400..5339c12f6e93 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -492,31 +492,6 @@ pub fn power(args: &[ArrayRef]) -> Result { } } -/// Atan2 SQL function -pub fn atan2(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::atan2 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::atan2 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function atan2"), - } -} - /// Log SQL function pub fn log(args: &[ArrayRef]) -> Result { // Support overloaded log(base, x) and log(x) which defaults to log(10, x) @@ -725,42 +700,6 @@ mod tests { assert_eq!(floats.value(3), 625); } - #[test] - fn test_atan2_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float64_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); - } - - #[test] - fn test_atan2_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float32_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); - } - #[test] fn test_log_f64() { let args: Vec = vec![ diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 241f01a4170a..0dbea09ffb51 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -168,20 +168,15 @@ pub fn create_physical_expr( } else { None }; - let when_expr = case + let (when_expr, then_expr): (Vec<&Expr>, Vec<&Expr>) = case .when_then_expr .iter() - .map(|(w, _)| { - create_physical_expr(w.as_ref(), input_dfschema, execution_props) - }) - .collect::>>()?; - let then_expr = case - .when_then_expr - .iter() - .map(|(_, t)| { - create_physical_expr(t.as_ref(), input_dfschema, execution_props) - }) - .collect::>>()?; + .map(|(w, t)| (w.as_ref(), t.as_ref())) + .unzip(); + let when_expr = + create_physical_exprs(when_expr, input_dfschema, execution_props)?; + let then_expr = + create_physical_exprs(then_expr, input_dfschema, execution_props)?; let when_then_expr: Vec<(Arc, Arc)> = when_expr .iter() @@ -248,10 +243,8 @@ pub fn create_physical_expr( } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let physical_args = args - .iter() - .map(|e| create_physical_expr(e, input_dfschema, execution_props)) - .collect::>>()?; + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -310,12 +303,8 @@ pub fn create_physical_expr( let value_expr = create_physical_expr(expr, input_dfschema, execution_props)?; - let list_exprs = list - .iter() - .map(|expr| { - create_physical_expr(expr, input_dfschema, execution_props) - }) - .collect::>>()?; + let list_exprs = + create_physical_exprs(list, input_dfschema, execution_props)?; expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, @@ -325,17 +314,32 @@ pub fn create_physical_expr( } } +/// Create vector of Physical Expression from a vector of logical expression +pub fn create_physical_exprs<'a, I>( + exprs: I, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result>> +where + I: IntoIterator, +{ + exprs + .into_iter() + .map(|expr| create_physical_expr(expr, input_dfschema, execution_props)) + .collect::>>() +} + #[cfg(test)] mod tests { use super::*; use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, left, Literal}; + use datafusion_expr::{col, lit}; #[test] fn test_create_physical_expr_scalar_input_output() -> Result<()> { - let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit())); + let expr = col("letter").eq(lit("A")); let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1c9f0e609c3c..d34084236690 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,22 +34,22 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::functions::out_ordering; +use crate::functions::{create_physical_fun, out_ordering}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, - ScalarFunctionImplementation, + ScalarFunctionDefinition, }; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, name: String, args: Vec>, return_type: DataType, @@ -79,7 +79,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, args: Vec>, return_type: DataType, monotonicity: Option, @@ -96,7 +96,7 @@ impl ScalarFunctionExpr { } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionImplementation { + pub fn fun(&self) -> &ScalarFunctionDefinition { &self.fun } @@ -172,8 +172,18 @@ impl PhysicalExpr for ScalarFunctionExpr { }; // evaluate the function - let fun = self.fun.as_ref(); - (fun)(&inputs) + match self.fun { + ScalarFunctionDefinition::BuiltIn(ref fun) => { + let fun = create_physical_fun(fun)?; + (fun)(&inputs) + } + ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), + ScalarFunctionDefinition::Name(_) => { + internal_err!( + "Name function must be resolved to one of the other variants prior to physical planning" + ) + } + } } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index ace7ef2888a3..2185b7c5b4a1 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -22,151 +22,22 @@ //! String expressions use std::sync::Arc; -use std::{ - fmt::{Display, Formatter}, - iter, -}; use arrow::{ array::{ Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, StringArray, }, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, + datatypes::DataType, }; -use uuid::Uuid; -use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use datafusion_common::{ - cast::{ - as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, - }, + cast::{as_generic_string_array, as_string_array}, exec_err, ScalarValue, }; use datafusion_expr::ColumnarValue; -/// applies a unary expression to `args[0]` that is expected to be downcastable to -/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) -/// # Errors -/// This function errors when: -/// * the number of arguments is not 1 -/// * the first argument is not castable to a `GenericStringArray` -pub(crate) fn unary_string_function<'a, T, O, F, R>( - args: &[&'a dyn Array], - op: F, - name: &str, -) -> Result> -where - R: AsRef, - O: OffsetSizeTrait, - T: OffsetSizeTrait, - F: Fn(&'a str) -> R, -{ - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } - - let string_array = as_generic_string_array::(args[0])?; - - // first map is the iterator, second is for the `Option<_>` - Ok(string_array.iter().map(|string| string.map(&op)).collect()) -} - -fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result -where - R: AsRef, - F: Fn(&'a str) -> R, -{ - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i32, - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i64, - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - } -} - -/// Returns the numeric code of the first character of the argument. -/// ascii('x') = 120 -pub fn ascii(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - let mut chars = string.chars(); - chars.next().map_or(0, |v| v as i32) - }) - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. -/// chr(65) = 'A' -pub fn chr(args: &[ArrayRef]) -> Result { - let integer_array = as_int64_array(&args[0])?; - - // first map is the iterator, second is for the `Option<_>` - let result = integer_array - .iter() - .map(|integer: Option| { - integer - .map(|integer| { - if integer == 0 { - exec_err!("null character not permitted.") - } else { - match core::char::from_u32(integer as u32) { - Some(integer) => Ok(integer.to_string()), - None => { - exec_err!("requested character too large for encoding.") - } - } - } - }) - .transpose() - }) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' pub fn concat(args: &[ColumnarValue]) -> Result { @@ -342,168 +213,6 @@ pub fn instr(args: &[ArrayRef]) -> Result { } } -/// Converts the string to all lower case. -/// lower('TOM') = 'tom' -pub fn lower(args: &[ColumnarValue]) -> Result { - handle(args, |string| string.to_lowercase(), "lower") -} - -enum TrimType { - Left, - Right, - Both, -} - -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Left => write!(f, "ltrim"), - TrimType::Right => write!(f, "rtrim"), - TrimType::Both => write!(f, "btrim"), - } - } -} - -fn general_trim( - args: &[ArrayRef], - trim_type: TrimType, -) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_start_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>( - str::trim_start_matches::<&[char]>(input, pattern.as_ref()), - pattern.as_ref(), - ) - }, - }; - - let string_array = as_generic_string_array::(&args[0])?; - - match args.len() { - 1 => { - let result = string_array - .iter() - .map(|string| string.map(|string: &str| func(string, " "))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) - } - } -} - -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Both) -} - -/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Left) -} - -/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Right) -} - -/// Repeats string the specified number of times. -/// repeat('Pg', 4) = 'PgPgPgPg' -pub fn repeat(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let number_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(number_array.iter()) - .map(|(string, number)| match (string, number) { - (Some(string), Some(number)) => Some(string.repeat(number as usize)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Replaces all occurrences in string of substring from with substring to. -/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' -pub fn replace(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). -/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' -pub fn split_part(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let n_array = as_int64_array(&args[2])?; - let result = string_array - .iter() - .zip(delimiter_array.iter()) - .zip(n_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - if n <= 0 { - exec_err!("field position must be greater than zero") - } else { - let split_string: Vec<&str> = string.split(delimiter).collect(); - match split_string.get(n as usize - 1) { - Some(s) => Ok(Some(*s)), - None => Ok(Some("")), - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { @@ -525,267 +234,3 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } - -/// Converts the number to its equivalent hexadecimal representation. -/// to_hex(2147483647) = '7fffffff' -pub fn to_hex(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let integer_array = as_primitive_array::(&args[0])?; - - let result = integer_array - .iter() - .map(|integer| { - if let Some(value) = integer { - if let Some(value_usize) = value.to_usize() { - Ok(Some(format!("{value_usize:x}"))) - } else if let Some(value_isize) = value.to_isize() { - Ok(Some(format!("{value_isize:x}"))) - } else { - exec_err!("Unsupported data type {integer:?} for function to_hex") - } - } else { - Ok(None) - } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) -} - -/// Converts the string to all upper case. -/// upper('tom') = 'TOM' -pub fn upper(args: &[ColumnarValue]) -> Result { - handle(args, |string| string.to_uppercase(), "upper") -} - -/// Prints random (v4) uuid values per row -/// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' -pub fn uuid(args: &[ColumnarValue]) -> Result { - let len: usize = match &args[0] { - ColumnarValue::Array(array) => array.len(), - _ => return exec_err!("Expect uuid function to take no param"), - }; - - let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len); - let array = GenericStringArray::::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) -} - -/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) -/// Replaces a substring of string1 with string2 starting at the integer bit -/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas -/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead -pub fn overlay(args: &[ArrayRef]) -> Result { - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .zip(len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") - } - } -} - -///Returns the Levenshtein distance between the two given strings. -/// LEVENSHTEIN('kitten', 'sitting') = 3 -pub fn levenshtein(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!( - "levenshtein function requires two arguments, got {}", - args.len() - ); - } - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; - match args[0].data_type() { - DataType::Utf8 => { - let result = str1_array - .iter() - .zip(str2_array.iter()) - .map(|(string1, string2)| match (string1, string2) { - (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i32) - } - _ => None, - }) - .collect::(); - Ok(Arc::new(result) as ArrayRef) - } - DataType::LargeUtf8 => { - let result = str1_array - .iter() - .zip(str2_array.iter()) - .map(|(string1, string2)| match (string1, string2) { - (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i64) - } - _ => None, - }) - .collect::(); - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." - ) - } - } -} - -#[cfg(test)] -mod tests { - use arrow::{array::Int32Array, datatypes::Int32Type}; - use arrow_array::Int64Array; - - use datafusion_common::cast::as_int32_array; - - use crate::string_expressions; - - use super::*; - - #[test] - // Test to_hex function for zero - fn to_hex_zero() -> Result<()> { - let array = vec![0].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("0")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - // Test to_hex function for positive number - fn to_hex_positive_number() -> Result<()> { - let array = vec![100].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("64")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - // Test to_hex function for negative number - fn to_hex_negative_number() -> Result<()> { - let array = vec![-1].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("ffffffffffffffff")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - fn to_overlay() -> Result<()> { - let string = - Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); - let replace_string = - Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); - let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start - let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len - - let res = overlay::(&[string, replace_string, start, end]).unwrap(); - let result = as_generic_string_array::(&res).unwrap(); - let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); - assert_eq!(&expected, result); - - Ok(()) - } - - #[test] - fn to_levenshtein() -> Result<()> { - let string1_array = - Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); - let string2_array = - Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); - let res = levenshtein::(&[string1_array, string2_array]).unwrap(); - let result = - as_int32_array(&res).expect("failed to initialized function levenshtein"); - let expected = Int32Array::from(vec![2, 3, 2, 3]); - assert_eq!(&expected, result); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index ede3e5badbb1..4fc94bfa15ec 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -20,7 +20,9 @@ use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow_schema::Schema; use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; -use datafusion_expr::{type_coercion::functions::data_types, Expr}; +use datafusion_expr::{ + type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, +}; use std::sync::Arc; /// Create a physical expression of the UDF. @@ -45,9 +47,10 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun(), + fun_def, input_phy_exprs.to_vec(), return_type, fun.monotonicity()?, diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs deleted file mode 100644 index 8ec9e062d9b7..000000000000 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ /dev/null @@ -1,551 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Some of these functions reference the Postgres documentation -// or implementation to ensure compatibility and are subject to -// the Postgres license. - -//! Unicode expressions - -use std::cmp::{max, Ordering}; -use std::sync::Arc; - -use arrow::{ - array::{ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray}, - datatypes::{ArrowNativeType, ArrowPrimitiveType}, -}; -use hashbrown::HashMap; -use unicode_segmentation::UnicodeSegmentation; - -use datafusion_common::{ - cast::{as_generic_string_array, as_int64_array}, - exec_err, Result, -}; - -/// Returns number of characters in the string. -/// character_length('josé') = 4 -/// The implementation counts UTF-8 code points to count the number of characters -pub fn character_length(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - T::Native::from_usize(string.chars().count()) - .expect("should not fail as string.chars will always return integer") - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -pub fn left(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - let len = string.chars().count() as i64; - Some(if n.abs() < len { - string.chars().take((len + n) as usize).collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). -/// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' -/// The implementation uses UTF-8 code points as characters -pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.chars().rev().collect::())) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -pub fn right(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.push_str(char_vector.iter().collect::().as_str()); - Ok(Some(s)) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "rpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - -/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) -/// strpos('high', 'ig') = 2 -/// The implementation uses UTF-8 code points as characters -pub fn strpos(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(substring_array.iter()) - .map(|(string, substring)| match (string, substring) { - (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -/// The implementation uses UTF-8 code points as characters -pub fn substr(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { - (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( - "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - let skip = max(0, start - 1); - let count = max(0, count + (if start < 1 {start - 1} else {0})); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("substr was called with {other} arguments. It requires 2 or 3.") - } - } -} - -/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. -/// translate('12345', '143', 'ax') = 'a2x5' -pub fn translate(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => { - // create a hashmap of [char, index] to change from O(n) to O(1) for from list - let from_map: HashMap<&str, usize> = from - .graphemes(true) - .collect::>() - .iter() - .enumerate() - .map(|(index, c)| (c.to_owned(), index)) - .collect(); - - let to = to.graphemes(true).collect::>(); - - Some( - string - .graphemes(true) - .collect::>() - .iter() - .flat_map(|c| match from_map.get(*c) { - Some(n) => to.get(*n).copied(), - None => Some(*c), - }) - .collect::>() - .concat(), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www -/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache -/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org -/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org -pub fn substr_index(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!( - "substr_index was called with {} arguments. It requires 3.", - args.len() - ); - } - - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(delimiter_array.iter()) - .zip(count_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - // In MySQL, these cases will return an empty string. - if n == 0 || string.is_empty() || delimiter.is_empty() { - return Some(String::new()); - } - - let splitted: Box> = if n > 0 { - Box::new(string.split(delimiter)) - } else { - Box::new(string.rsplit(delimiter)) - }; - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - // The length of the substring covered by substr_index. - let length = splitted - .take(occurrences) // at least 1 element, since n != 0 - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len(); - if n > 0 { - Some(string[..length].to_owned()) - } else { - Some(string[string.len() - length..].to_owned()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings -///A string list is a string composed of substrings separated by , characters. -pub fn find_in_set(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } - - let str_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - let str_list_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = str_array - .iter() - .zip(str_list_array.iter()) - .map(|(string, str_list)| match (string, str_list) { - (Some(string), Some(str_list)) => { - let mut res = 0; - let str_set: Vec<&str> = str_list.split(',').collect(); - for (idx, str) in str_set.iter().enumerate() { - if str == &string { - res = idx + 1; - break; - } - } - T::Native::from_usize(res) - } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) -} diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index e913f39333f9..9de71c2d604c 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -225,31 +225,38 @@ impl PartitionEvaluator for NthValueEvaluator { } // Extract valid indices if ignoring nulls. - let (slice, valid_indices) = if self.ignore_nulls { + let valid_indices = if self.ignore_nulls { + // Calculate valid indices, inside the window frame boundaries let slice = arr.slice(range.start, n_range); - let valid_indices = - slice.nulls().unwrap().valid_indices().collect::>(); + let valid_indices = slice + .nulls() + .map(|nulls| { + nulls + .valid_indices() + // Add offset `range.start` to valid indices, to point correct index in the original arr. + .map(|idx| idx + range.start) + .collect::>() + }) + .unwrap_or_default(); if valid_indices.is_empty() { return ScalarValue::try_from(arr.data_type()); } - (Some(slice), Some(valid_indices)) + Some(valid_indices) } else { - (None, None) + None }; match self.state.kind { NthValueKind::First => { - if let Some(slice) = &slice { - let valid_indices = valid_indices.unwrap(); - ScalarValue::try_from_array(slice, valid_indices[0]) + if let Some(valid_indices) = &valid_indices { + ScalarValue::try_from_array(arr, valid_indices[0]) } else { ScalarValue::try_from_array(arr, range.start) } } NthValueKind::Last => { - if let Some(slice) = &slice { - let valid_indices = valid_indices.unwrap(); + if let Some(valid_indices) = &valid_indices { ScalarValue::try_from_array( - slice, + arr, valid_indices[valid_indices.len() - 1], ) } else { @@ -264,15 +271,11 @@ impl PartitionEvaluator for NthValueEvaluator { if index >= n_range { // Outside the range, return NULL: ScalarValue::try_from(arr.data_type()) - } else if self.ignore_nulls { - let valid_indices = valid_indices.unwrap(); + } else if let Some(valid_indices) = valid_indices { if index >= valid_indices.len() { return ScalarValue::try_from(arr.data_type()); } - ScalarValue::try_from_array( - &slice.unwrap(), - valid_indices[index], - ) + ScalarValue::try_from_array(&arr, valid_indices[index]) } else { ScalarValue::try_from_array(arr, range.start + index) } @@ -282,14 +285,13 @@ impl PartitionEvaluator for NthValueEvaluator { if n_range < reverse_index { // Outside the range, return NULL: ScalarValue::try_from(arr.data_type()) - } else if self.ignore_nulls { - let valid_indices = valid_indices.unwrap(); + } else if let Some(valid_indices) = valid_indices { if reverse_index > valid_indices.len() { return ScalarValue::try_from(arr.data_type()); } let new_index = valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&slice.unwrap(), new_index) + ScalarValue::try_from_array(&arr, new_index) } else { ScalarValue::try_from_array( arr, diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 72ee4fb3ef7e..1ba32bff746e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -58,7 +58,6 @@ parking_lot = { workspace = true } pin-project-lite = "^0.2.7" rand = { workspace = true } tokio = { workspace = true } -uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] rstest = { workspace = true } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 65987e01553d..e263876b07d5 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -636,6 +636,10 @@ impl DisplayAs for AggregateExec { } impl ExecutionPlan for AggregateExec { + fn name(&self) -> &'static str { + "AggregateExec" + } + /// Return a reference to Any that can be used for down-casting fn as_any(&self) -> &dyn Any { self @@ -1658,6 +1662,10 @@ mod tests { } impl ExecutionPlan for TestYieldingExec { + fn name(&self) -> &'static str { + "TestYieldingExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 4f1914b12c96..556103e1e222 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -40,7 +40,7 @@ pub(crate) enum GroupOrdering { } impl GroupOrdering { - /// Create a `GroupOrdering` for the the specified ordering + /// Create a `GroupOrdering` for the specified ordering pub fn try_new( input_schema: &Schema, mode: &InputOrderMode, diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 83a73ee992fb..c420581c4323 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -111,6 +111,10 @@ impl DisplayAs for AnalyzeExec { } impl ExecutionPlan for AnalyzeExec { + fn name(&self) -> &'static str { + "AnalyzeExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 0b9ecebbb1e8..bc7c4a3d0673 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -104,6 +104,10 @@ impl DisplayAs for CoalesceBatchesExec { } impl ExecutionPlan for CoalesceBatchesExec { + fn name(&self) -> &'static str { + "CoalesceBatchesExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 1e58260a5344..1c725ce31f14 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -89,6 +89,10 @@ impl DisplayAs for CoalescePartitionsExec { } impl ExecutionPlan for CoalescePartitionsExec { + fn name(&self) -> &'static str { + "CoalescePartitionsExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index f4a2cba68e16..59c54199333e 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -349,12 +349,6 @@ pub fn can_project( } } -/// Returns the total number of bytes of memory occupied physically by this batch. -#[deprecated(since = "28.0.0", note = "RecordBatch::get_array_memory_size")] -pub fn batch_byte_size(batch: &RecordBatch) -> usize { - batch.get_array_memory_size() -} - #[cfg(test)] mod tests { use std::ops::Not; diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 4b7b35e53e1b..ca93ce5e7b83 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -489,6 +489,10 @@ mod tests { } impl ExecutionPlan for TestStatsExecPlan { + fn name(&self) -> &'static str { + "TestStatsExecPlan" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 4ff79cdaae70..8e8eb4d25e32 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -101,6 +101,10 @@ impl DisplayAs for EmptyExec { } impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + "EmptyExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 320ee37bed95..649946993229 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -98,6 +98,10 @@ impl DisplayAs for ExplainExec { } impl ExecutionPlan for ExplainExec { + fn name(&self) -> &'static str { + "ExplainExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 4155b00820f4..a9201f435ad8 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -29,7 +29,7 @@ use super::{ }; use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - Column, DisplayFormatType, ExecutionPlan, + DisplayFormatType, ExecutionPlan, }; use arrow::compute::filter_record_batch; @@ -159,6 +159,27 @@ impl FilterExec { }) } + fn extend_constants( + input: &Arc, + predicate: &Arc, + ) -> Vec> { + let mut res_constants = Vec::new(); + let input_eqs = input.equivalence_properties(); + + let conjunctions = split_conjunction(predicate); + for conjunction in conjunctions { + if let Some(binary) = conjunction.as_any().downcast_ref::() { + if binary.op() == &Operator::Eq { + if input_eqs.is_expr_constant(binary.left()) { + res_constants.push(binary.right().clone()) + } else if input_eqs.is_expr_constant(binary.right()) { + res_constants.push(binary.left().clone()) + } + } + } + } + res_constants + } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( input: &Arc, @@ -171,9 +192,7 @@ impl FilterExec { let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - let lhs_expr = Arc::new(lhs.clone()) as _; - let rhs_expr = Arc::new(rhs.clone()) as _; - eq_properties.add_equal_conditions(&lhs_expr, &rhs_expr) + eq_properties.add_equal_conditions(lhs, rhs) } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -181,8 +200,12 @@ impl FilterExec { .into_iter() .filter(|column| stats.column_statistics[column.index()].is_singleton()) .map(|column| Arc::new(column) as _); + // this is for statistics eq_properties = eq_properties.add_constants(constants); - + // this is for logical constant (for example: a = '1', then a could be marked as a constant) + // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) + eq_properties = + eq_properties.add_constants(Self::extend_constants(input, predicate)); Ok(PlanProperties::new( eq_properties, input.output_partitioning().clone(), // Output Partitioning @@ -206,6 +229,10 @@ impl DisplayAs for FilterExec { } impl ExecutionPlan for FilterExec { + fn name(&self) -> &'static str { + "FilterExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -376,34 +403,33 @@ impl RecordBatchStream for FilterExecStream { /// Return the equals Column-Pairs and Non-equals Column-Pairs fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { - let mut eq_predicate_columns = Vec::<(&Column, &Column)>::new(); - let mut ne_predicate_columns = Vec::<(&Column, &Column)>::new(); + let mut eq_predicate_columns = Vec::::new(); + let mut ne_predicate_columns = Vec::::new(); let predicates = split_conjunction(predicate); predicates.into_iter().for_each(|p| { if let Some(binary) = p.as_any().downcast_ref::() { - if let (Some(left_column), Some(right_column)) = ( - binary.left().as_any().downcast_ref::(), - binary.right().as_any().downcast_ref::(), - ) { - match binary.op() { - Operator::Eq => { - eq_predicate_columns.push((left_column, right_column)) - } - Operator::NotEq => { - ne_predicate_columns.push((left_column, right_column)) - } - _ => {} + match binary.op() { + Operator::Eq => { + eq_predicate_columns.push((binary.left(), binary.right())) } + Operator::NotEq => { + ne_predicate_columns.push((binary.left(), binary.right())) + } + _ => {} } } }); (eq_predicate_columns, ne_predicate_columns) } + +/// Pair of `Arc`s +pub type PhysicalExprPairRef<'a> = (&'a Arc, &'a Arc); + /// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates pub type EqualAndNonEqual<'a> = - (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>); + (Vec>, Vec>); #[cfg(test)] mod tests { @@ -416,7 +442,9 @@ mod tests { use crate::test::exec::StatisticsExec; use crate::ExecutionPlan; + use crate::empty::EmptyExec; use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_expr::Operator; @@ -451,14 +479,16 @@ mod tests { )?; let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&predicate); + assert_eq!(2, equal_pairs.len()); + assert!(equal_pairs[0].0.eq(&col("c2", &schema)?)); + assert!(equal_pairs[0].1.eq(&lit(4u32))); - assert_eq!(1, equal_pairs.len()); - assert_eq!(equal_pairs[0].0.name(), "c2"); - assert_eq!(equal_pairs[0].1.name(), "c9"); + assert!(equal_pairs[1].0.eq(&col("c2", &schema)?)); + assert!(equal_pairs[1].1.eq(&col("c9", &schema)?)); assert_eq!(1, ne_pairs.len()); - assert_eq!(ne_pairs[0].0.name(), "c1"); - assert_eq!(ne_pairs[0].1.name(), "c13"); + assert!(ne_pairs[0].0.eq(&col("c1", &schema)?)); + assert!(ne_pairs[0].1.eq(&col("c13", &schema)?)); Ok(()) } @@ -1065,4 +1095,37 @@ mod tests { assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) } + + #[test] + fn test_equivalence_properties_union_type() -> Result<()> { + let union_type = DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", union_type, true), + ])); + + let exec = FilterExec::try_new( + binary( + binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?, + Operator::And, + binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, + &schema, + )?, + Arc::new(EmptyExec::new(schema.clone())), + )?; + + exec.statistics().unwrap(); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 16c929b78144..f0233264f280 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -206,6 +206,10 @@ impl DisplayAs for FileSinkExec { } impl ExecutionPlan for FileSinkExec { + fn name(&self) -> &'static str { + "FileSinkExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 9f8dc0ce56b0..19d34f8048e3 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -194,6 +194,10 @@ impl DisplayAs for CrossJoinExec { } impl ExecutionPlan for CrossJoinExec { + fn name(&self) -> &'static str { + "CrossJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index a1c50a2113ba..1c0181c2e116 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -611,6 +611,10 @@ fn project_index_to_exprs( } impl ExecutionPlan for HashJoinExec { + fn name(&self) -> &'static str { + "HashJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 2c16fff52750..796b8602b22f 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -273,6 +273,10 @@ impl DisplayAs for NestedLoopJoinExec { } impl ExecutionPlan for NestedLoopJoinExec { + fn name(&self) -> &'static str { + "NestedLoopJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 7b70a2952b4c..21630087f2ca 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -262,6 +262,10 @@ impl DisplayAs for SortMergeJoinExec { } impl ExecutionPlan for SortMergeJoinExec { + fn name(&self) -> &'static str { + "SortMergeJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 79b8c813d860..453b217f7fc7 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -385,6 +385,10 @@ impl DisplayAs for SymmetricHashJoinExec { } impl ExecutionPlan for SymmetricHashJoinExec { + fn name(&self) -> &'static str { + "SymmetricHashJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 1cb2b100e2d6..a3d20b97d1ab 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -825,27 +825,27 @@ fn estimate_join_cardinality( right_stats: Statistics, on: &JoinOn, ) -> Option { + let (left_col_stats, right_col_stats) = on + .iter() + .map(|(left, right)| { + match ( + left.as_any().downcast_ref::(), + right.as_any().downcast_ref::(), + ) { + (Some(left), Some(right)) => ( + left_stats.column_statistics[left.index()].clone(), + right_stats.column_statistics[right.index()].clone(), + ), + _ => ( + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ), + } + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let (left_col_stats, right_col_stats) = on - .iter() - .map(|(left, right)| { - match ( - left.as_any().downcast_ref::(), - right.as_any().downcast_ref::(), - ) { - (Some(left), Some(right)) => ( - left_stats.column_statistics[left.index()].clone(), - right_stats.column_statistics[right.index()].clone(), - ), - _ => ( - ColumnStatistics::new_unknown(), - ColumnStatistics::new_unknown(), - ), - } - }) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let ij_cardinality = estimate_inner_join_cardinality( Statistics { num_rows: left_stats.num_rows.clone(), @@ -888,10 +888,38 @@ fn estimate_join_cardinality( }) } - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti => None, + // For SemiJoins estimation result is either zero, in cases when inputs + // are non-overlapping according to statistics, or equal to number of rows + // for outer input + JoinType::LeftSemi | JoinType::RightSemi => { + let (outer_stats, inner_stats) = match join_type { + JoinType::LeftSemi => (left_stats, right_stats), + _ => (right_stats, left_stats), + }; + let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) { + Some(estimation) => *estimation.get_value()?, + None => *outer_stats.num_rows.get_value()?, + }; + + Some(PartialJoinStatistics { + num_rows: cardinality, + column_statistics: outer_stats.column_statistics, + }) + } + + // For AntiJoins estimation always equals to outer statistics, as + // non-overlapping inputs won't affect estimation + JoinType::LeftAnti | JoinType::RightAnti => { + let outer_stats = match join_type { + JoinType::LeftAnti => left_stats, + _ => right_stats, + }; + + Some(PartialJoinStatistics { + num_rows: *outer_stats.num_rows.get_value()?, + column_statistics: outer_stats.column_statistics, + }) + } } } @@ -903,6 +931,11 @@ fn estimate_inner_join_cardinality( left_stats: Statistics, right_stats: Statistics, ) -> Option> { + // Immediatedly return if inputs considered as non-overlapping + if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) { + return Some(estimation); + }; + // The algorithm here is partly based on the non-histogram selectivity estimation // from Spark's Catalyst optimizer. let mut join_selectivity = Precision::Absent; @@ -911,30 +944,13 @@ fn estimate_inner_join_cardinality( .iter() .zip(right_stats.column_statistics.iter()) { - // If there is no overlap in any of the join columns, this means the join - // itself is disjoint and the cardinality is 0. Though we can only assume - // this when the statistics are exact (since it is a very strong assumption). - if left_stat.min_value.get_value()? > right_stat.max_value.get_value()? { - return Some( - if left_stat.min_value.is_exact().unwrap_or(false) - && right_stat.max_value.is_exact().unwrap_or(false) - { - Precision::Exact(0) - } else { - Precision::Inexact(0) - }, - ); - } - if left_stat.max_value.get_value()? < right_stat.min_value.get_value()? { - return Some( - if left_stat.max_value.is_exact().unwrap_or(false) - && right_stat.min_value.is_exact().unwrap_or(false) - { - Precision::Exact(0) - } else { - Precision::Inexact(0) - }, - ); + // Break if any of statistics bounds are undefined + if left_stat.min_value.get_value().is_none() + || left_stat.max_value.get_value().is_none() + || right_stat.min_value.get_value().is_none() + || right_stat.max_value.get_value().is_none() + { + return None; } let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); @@ -968,6 +984,58 @@ fn estimate_inner_join_cardinality( } } +/// Estimates if inputs are non-overlapping, using input statistics. +/// If inputs are disjoint, returns zero estimation, otherwise returns None +fn estimate_disjoint_inputs( + left_stats: &Statistics, + right_stats: &Statistics, +) -> Option> { + for (left_stat, right_stat) in left_stats + .column_statistics + .iter() + .zip(right_stats.column_statistics.iter()) + { + // If there is no overlap in any of the join columns, this means the join + // itself is disjoint and the cardinality is 0. Though we can only assume + // this when the statistics are exact (since it is a very strong assumption). + let left_min_val = left_stat.min_value.get_value(); + let right_max_val = right_stat.max_value.get_value(); + if left_min_val.is_some() + && right_max_val.is_some() + && left_min_val > right_max_val + { + return Some( + if left_stat.min_value.is_exact().unwrap_or(false) + && right_stat.max_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); + } + + let left_max_val = left_stat.max_value.get_value(); + let right_min_val = right_stat.min_value.get_value(); + if left_max_val.is_some() + && right_min_val.is_some() + && left_max_val < right_min_val + { + return Some( + if left_stat.max_value.is_exact().unwrap_or(false) + && right_stat.min_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); + } + } + + None +} + /// Estimate the number of maximum distinct values that can be present in the /// given column from its statistics. If distinct_count is available, uses it /// directly. Otherwise, if the column is numeric and has min/max values, it @@ -1716,9 +1784,11 @@ mod tests { #[test] fn test_inner_join_cardinality_single_column() -> Result<()> { let cases: Vec<(PartialStats, PartialStats, Option>)> = vec![ - // ----------------------------------------------------------------------------- - // | left(rows, min, max, distinct), right(rows, min, max, distinct), expected | - // ----------------------------------------------------------------------------- + // ------------------------------------------------ + // | left(rows, min, max, distinct, null_count), | + // | right(rows, min, max, distinct, null_count), | + // | expected, | + // ------------------------------------------------ // Cardinality computation // ======================= @@ -1824,6 +1894,11 @@ mod tests { None, ), // Non overlapping min/max (when exact=False). + ( + (10, Absent, Inexact(4), Absent, Absent), + (10, Inexact(5), Absent, Absent, Absent), + Some(Inexact(0)), + ), ( (10, Inexact(0), Inexact(10), Absent, Absent), (10, Inexact(11), Inexact(20), Absent, Absent), @@ -2106,6 +2181,204 @@ mod tests { Ok(()) } + #[test] + fn test_anti_semi_join_cardinality() -> Result<()> { + let cases: Vec<(JoinType, PartialStats, PartialStats, Option)> = vec![ + // ------------------------------------------------ + // | join_type , | + // | left(rows, min, max, distinct, null_count), | + // | right(rows, min, max, distinct, null_count), | + // | expected, | + // ------------------------------------------------ + + // Cardinality computation + // ======================= + ( + JoinType::LeftSemi, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(50), + ), + ( + JoinType::RightSemi, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(10), + ), + ( + JoinType::LeftSemi, + (10, Absent, Absent, Absent, Absent), + (50, Absent, Absent, Absent, Absent), + Some(10), + ), + ( + JoinType::LeftSemi, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(30), Inexact(40), Absent, Absent), + Some(0), + ), + ( + JoinType::LeftSemi, + (50, Inexact(10), Absent, Absent, Absent), + (10, Absent, Inexact(5), Absent, Absent), + Some(0), + ), + ( + JoinType::LeftSemi, + (50, Absent, Inexact(20), Absent, Absent), + (10, Inexact(30), Absent, Absent, Absent), + Some(0), + ), + ( + JoinType::LeftAnti, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(50), + ), + ( + JoinType::RightAnti, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(10), + ), + ( + JoinType::LeftAnti, + (10, Absent, Absent, Absent, Absent), + (50, Absent, Absent, Absent, Absent), + Some(10), + ), + ( + JoinType::LeftAnti, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(30), Inexact(40), Absent, Absent), + Some(50), + ), + ( + JoinType::LeftAnti, + (50, Inexact(10), Absent, Absent, Absent), + (10, Absent, Inexact(5), Absent, Absent), + Some(50), + ), + ( + JoinType::LeftAnti, + (50, Absent, Inexact(20), Absent, Absent), + (10, Inexact(30), Absent, Absent, Absent), + Some(50), + ), + ]; + + let join_on = vec![( + Arc::new(Column::new("l_col", 0)) as _, + Arc::new(Column::new("r_col", 0)) as _, + )]; + + for (join_type, outer_info, inner_info, expected) in cases { + let outer_num_rows = outer_info.0; + let outer_col_stats = vec![create_column_stats( + outer_info.1, + outer_info.2, + outer_info.3, + outer_info.4, + )]; + + let inner_num_rows = inner_info.0; + let inner_col_stats = vec![create_column_stats( + inner_info.1, + inner_info.2, + inner_info.3, + inner_info.4, + )]; + + let output_cardinality = estimate_join_cardinality( + &join_type, + Statistics { + num_rows: Inexact(outer_num_rows), + total_byte_size: Absent, + column_statistics: outer_col_stats, + }, + Statistics { + num_rows: Inexact(inner_num_rows), + total_byte_size: Absent, + column_statistics: inner_col_stats, + }, + &join_on, + ) + .map(|cardinality| cardinality.num_rows); + + assert_eq!( + output_cardinality, expected, + "failure for join_type: {}", + join_type + ); + } + + Ok(()) + } + + #[test] + fn test_semi_join_cardinality_absent_rows() -> Result<()> { + let dummy_column_stats = + vec![create_column_stats(Absent, Absent, Absent, Absent)]; + let join_on = vec![( + Arc::new(Column::new("l_col", 0)) as _, + Arc::new(Column::new("r_col", 0)) as _, + )]; + + let absent_outer_estimation = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + Statistics { + num_rows: Exact(10), + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + &join_on, + ); + assert!( + absent_outer_estimation.is_none(), + "Expected \"None\" esimated SemiJoin cardinality for absent outer num_rows" + ); + + let absent_inner_estimation = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(500), + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + &join_on, + ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows"); + + assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows esimated SemiJoin cardinality for absent inner num_rows"); + + let absent_inner_estimation = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + &join_on, + ); + assert!(absent_inner_estimation.is_none(), "Expected \"None\" esimated SemiJoin cardinality for absent outer and inner num_rows"); + + Ok(()) + } + #[test] fn test_calculate_join_output_ordering() -> Result<()> { let options = SortOptions::default(); diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 6334a4a211d4..3e8e439c9a38 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -33,7 +33,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::utils::DataPtr; use datafusion_common::Result; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, }; @@ -113,6 +112,15 @@ pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; /// [`required_input_distribution`]: ExecutionPlan::required_input_distribution /// [`required_input_ordering`]: ExecutionPlan::required_input_ordering pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { + /// Short name for the ExecutionPlan, such as 'ParquetExec'. + fn name(&self) -> &'static str { + let full_name = std::any::type_name::(); + let maybe_start_idx = full_name.rfind(':'); + match maybe_start_idx { + Some(start_idx) => &full_name[start_idx + 1..], + None => "UNKNOWN", + } + } /// Returns the execution plan as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -778,4 +786,135 @@ pub fn get_plan_string(plan: &Arc) -> Vec { #[allow(clippy::single_component_path_imports)] use rstest_reuse; +#[cfg(test)] +mod tests { + use std::any::Any; + use std::sync::Arc; + + use arrow_schema::{Schema, SchemaRef}; + use datafusion_common::{Result, Statistics}; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; + + #[derive(Debug)] + pub struct EmptyExec; + + impl EmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for EmptyExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + pub struct RenamedEmptyExec; + + impl RenamedEmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for RenamedEmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for RenamedEmptyExec { + fn name(&self) -> &'static str { + "MyRenamedEmptyExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[test] + fn test_execution_plan_name() { + let schema1 = Arc::new(Schema::empty()); + let default_name_exec = EmptyExec::new(schema1); + assert_eq!(default_name_exec.name(), "EmptyExec"); + + let schema2 = Arc::new(Schema::empty()); + let renamed_exec = RenamedEmptyExec::new(schema2); + assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); + } +} + pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 9fa15cbf64e2..fab483b0da7d 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -111,6 +111,10 @@ impl DisplayAs for GlobalLimitExec { } impl ExecutionPlan for GlobalLimitExec { + fn name(&self) -> &'static str { + "GlobalLimitExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -317,6 +321,10 @@ impl DisplayAs for LocalLimitExec { } impl ExecutionPlan for LocalLimitExec { + fn name(&self) -> &'static str { + "LocalLimitExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 795ec3c7315e..883cdb540a9e 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -103,6 +103,10 @@ impl DisplayAs for MemoryExec { } impl ExecutionPlan for MemoryExec { + fn name(&self) -> &'static str { + "MemoryExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 3880cf3d77af..c047ff5122fe 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -119,6 +119,10 @@ impl DisplayAs for PlaceholderRowExec { } impl ExecutionPlan for PlaceholderRowExec { + fn name(&self) -> &'static str { + "PlaceholderRowExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 8fe82e7de3eb..f72815c01a9e 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -180,6 +180,10 @@ impl DisplayAs for ProjectionExec { } impl ExecutionPlan for ProjectionExec { + fn name(&self) -> &'static str { + "ProjectionExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 68abc9653a8b..ba7d1a54548a 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -108,6 +108,10 @@ impl RecursiveQueryExec { } impl ExecutionPlan for RecursiveQueryExec { + fn name(&self) -> &'static str { + "RecursiveQueryExec" + } + fn as_any(&self) -> &dyn Any { self } @@ -309,10 +313,9 @@ impl RecursiveQueryStream { // Downstream plans should not expect any partitioning. let partition = 0; - self.recursive_stream = Some( - self.recursive_term - .execute(partition, self.task_context.clone())?, - ); + let recursive_plan = reset_plan_states(self.recursive_term.clone())?; + self.recursive_stream = + Some(recursive_plan.execute(partition, self.task_context.clone())?); self.poll_next(cx) } } @@ -343,6 +346,25 @@ fn assign_work_table( .data() } +/// Some plans will change their internal states after execution, making them unable to be executed again. +/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states. +/// +/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. +/// However, if the data of the left table is derived from the work table, it will become outdated +/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. +fn reset_plan_states(plan: Arc) -> Result> { + plan.transform_up(&|plan| { + // WorkTableExec's states have already been updated correctly. + if plan.as_any().is::() { + Ok(Transformed::no(plan)) + } else { + let new_plan = plan.clone().with_new_children(plan.children())?; + Ok(Transformed::yes(new_plan)) + } + }) + .data() +} + impl Stream for RecursiveQueryStream { type Item = Result; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 7ac70949f893..c0dbf5164e19 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -406,6 +406,10 @@ impl DisplayAs for RepartitionExec { } impl ExecutionPlan for RepartitionExec { + fn name(&self) -> &'static str { + "RepartitionExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 500df6153fdb..d24bc5a670e5 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -226,6 +226,10 @@ impl DisplayAs for PartialSortExec { } impl ExecutionPlan for PartialSortExec { + fn name(&self) -> &'static str { + "PartialSortExec" + } + fn as_any(&self) -> &dyn Any { self } @@ -578,7 +582,7 @@ mod tests { #[tokio::test] async fn test_partial_sort2() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); - let source_tables = vec![ + let source_tables = [ test::build_table_scan_i32( ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]), ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]), diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index db352bb2c86f..a6f47d3d2fc9 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -733,16 +733,6 @@ pub struct SortExec { } impl SortExec { - /// Create a new sort execution plan - #[deprecated(since = "22.0.0", note = "use `new` and `with_fetch`")] - pub fn try_new( - expr: Vec, - input: Arc, - fetch: Option, - ) -> Result { - Ok(Self::new(expr, input).with_fetch(fetch)) - } - /// Create a new sort execution plan that produces a single, /// sorted output partition. pub fn new(expr: Vec, input: Arc) -> Self { @@ -758,23 +748,6 @@ impl SortExec { } } - /// Create a new sort execution plan with the option to preserve - /// the partitioning of the input plan - #[deprecated( - since = "22.0.0", - note = "use `new`, `with_fetch` and `with_preserve_partioning` instead" - )] - pub fn new_with_partitioning( - expr: Vec, - input: Arc, - preserve_partitioning: bool, - fetch: Option, - ) -> Self { - Self::new(expr, input) - .with_fetch(fetch) - .with_preserve_partitioning(preserve_partitioning) - } - /// Whether this `SortExec` preserves partitioning of the children pub fn preserve_partitioning(&self) -> bool { self.preserve_partitioning @@ -887,6 +860,10 @@ impl DisplayAs for SortExec { } impl ExecutionPlan for SortExec { + fn name(&self) -> &'static str { + "SortExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 556615f64de6..edef022b0c00 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -144,6 +144,10 @@ impl DisplayAs for SortPreservingMergeExec { } impl ExecutionPlan for SortPreservingMergeExec { + fn name(&self) -> &'static str { + "SortPreservingMergeExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 7b062ab8741f..d7e254c42fe1 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -191,6 +191,10 @@ impl DisplayAs for StreamingTableExec { #[async_trait] impl ExecutionPlan for StreamingTableExec { + fn name(&self) -> &'static str { + "StreamingTableExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 7eaac74a5449..69901aa2fa37 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -183,6 +183,10 @@ impl DisplayAs for UnionExec { } impl ExecutionPlan for UnionExec { + fn name(&self) -> &'static str { + "UnionExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -370,6 +374,10 @@ impl DisplayAs for InterleaveExec { } impl ExecutionPlan for InterleaveExec { + fn name(&self) -> &'static str { + "InterleaveExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -740,7 +748,7 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let options = SortOptions::default(); - let test_cases = vec![ + let test_cases = [ //-----------TEST CASE 1----------// ( // First child orderings diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 886b718e6efe..324e2ea2d773 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -112,6 +112,10 @@ impl DisplayAs for UnnestExec { } impl ExecutionPlan for UnnestExec { + fn name(&self) -> &'static str { + "UnnestExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 8868a59008b7..63e8c32349ab 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -154,6 +154,10 @@ impl DisplayAs for ValuesExec { } impl ExecutionPlan for ValuesExec { + fn name(&self) -> &'static str { + "ValuesExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 70b6182d81e7..75e203891cad 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -237,6 +237,10 @@ impl DisplayAs for BoundedWindowAggExec { } impl ExecutionPlan for BoundedWindowAggExec { + fn name(&self) -> &'static str { + "BoundedWindowAggExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index da2b24487d02..21f42f41fb5c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -174,20 +174,15 @@ fn create_built_in_window_expr( name: String, ignore_nulls: bool, ) -> Result> { - // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args - .iter() - .map(|arg| arg.data_type(input_schema)) - .collect::>()?; + // derive the output datatype from incoming schema + let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); - // figure out the output type - let data_type = &fun.return_type(&input_types)?; Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)), - BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)), + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)), + BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), + BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), + BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { DataFusionError::Execution( @@ -201,13 +196,13 @@ fn create_built_in_window_expr( if n.is_unsigned() { let n: u64 = n.try_into()?; - Arc::new(Ntile::new(name, n, data_type)) + Arc::new(Ntile::new(name, n, out_data_type)) } else { let n: i64 = n.try_into()?; if n <= 0 { return exec_err!("NTILE requires a positive integer"); } - Arc::new(Ntile::new(name, n as u64, data_type)) + Arc::new(Ntile::new(name, n as u64, out_data_type)) } } BuiltInWindowFunction::Lag => { @@ -216,10 +211,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lag( name, - data_type.clone(), + out_data_type.clone(), arg, shift_offset, default_value, @@ -232,10 +227,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lead( name, - data_type.clone(), + out_data_type.clone(), arg, shift_offset, default_value, @@ -252,18 +247,28 @@ fn create_built_in_window_expr( Arc::new(NthValue::nth( name, arg, - data_type.clone(), + out_data_type.clone(), n, ignore_nulls, )?) } BuiltInWindowFunction::FirstValue => { let arg = args[0].clone(); - Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls)) + Arc::new(NthValue::first( + name, + arg, + out_data_type.clone(), + ignore_nulls, + )) } BuiltInWindowFunction::LastValue => { let arg = args[0].clone(); - Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls)) + Arc::new(NthValue::last( + name, + arg, + out_data_type.clone(), + ignore_nulls, + )) } }) } diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index e300eee49d31..46ba21bd797e 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -172,6 +172,10 @@ impl DisplayAs for WindowAggExec { } impl ExecutionPlan for WindowAggExec { + fn name(&self) -> &'static str { + "WindowAggExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index f6fc0334dfc5..dfdb624a5625 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -157,6 +157,10 @@ impl DisplayAs for WorkTableExec { } impl ExecutionPlan for WorkTableExec { + fn name(&self) -> &'static str { + "WorkTableExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index f5297aefcd1c..bec2b8c53a7a 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -54,6 +54,7 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +datafusion-functions = { workspace = true, default-features = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 597094758584..e959cad2a810 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -544,17 +544,17 @@ enum ScalarFunction { unknown = 0; // 1 was Acos // 2 was Asin - Atan = 3; - Ascii = 4; + // 3 was Atan + // 4 was Ascii Ceil = 5; Cos = 6; // 7 was Digest Exp = 8; Floor = 9; - Ln = 10; + // 10 was Ln Log = 11; - Log10 = 12; - Log2 = 13; + // 12 was Log10 + // 13 was Log2 Round = 14; Signum = 15; Sin = 16; @@ -563,61 +563,61 @@ enum ScalarFunction { Trunc = 19; // 20 was Array // RegexpMatch = 21; - BitLength = 22; - Btrim = 23; - CharacterLength = 24; - Chr = 25; + // 22 was BitLength + // 23 was Btrim + // 24 was CharacterLength + // 25 was Chr Concat = 26; ConcatWithSeparator = 27; // 28 was DatePart // 29 was DateTrunc InitCap = 30; - Left = 31; - Lpad = 32; - Lower = 33; - Ltrim = 34; + // 31 was Left + // 32 was Lpad + // 33 was Lower + // 34 was Ltrim // 35 was MD5 // 36 was NullIf - OctetLength = 37; + // 37 was OctetLength Random = 38; // 39 was RegexpReplace - Repeat = 40; - Replace = 41; - Reverse = 42; - Right = 43; - Rpad = 44; - Rtrim = 45; + // 40 was Repeat + // 41 was Replace + // 42 was Reverse + // 43 was Right + // 44 was Rpad + // 45 was Rtrim // 46 was SHA224 // 47 was SHA256 // 48 was SHA384 // 49 was SHA512 - SplitPart = 50; - StartsWith = 51; - Strpos = 52; - Substr = 53; - ToHex = 54; + // 50 was SplitPart + // StartsWith = 51; + // 52 was Strpos + // 53 was Substr + // ToHex = 54; // 55 was ToTimestamp // 56 was ToTimestampMillis // 57 was ToTimestampMicros // 58 was ToTimestampSeconds // 59 was Now - Translate = 60; - Trim = 61; - Upper = 62; + // 60 was Translate + // Trim = 61; + // Upper = 62; Coalesce = 63; Power = 64; // 65 was StructFun // 66 was FromUnixtime - Atan2 = 67; + // 67 Atan2 // 68 was DateBin // 69 was ArrowTypeof // 70 was CurrentDate // 71 was CurrentTime - Uuid = 72; + // 72 was Uuid Cbrt = 73; - Acosh = 74; - Asinh = 75; - Atanh = 76; + // 74 Acosh + // 75 was Asinh + // 76 was Atanh Sinh = 77; Cosh = 78; // Tanh = 79; @@ -637,7 +637,7 @@ enum ScalarFunction { // 93 was ArrayPositions // 94 was ArrayPrepend // 95 was ArrayRemove - ArrayReplace = 96; + // 96 was ArrayReplace // 97 was ArrayToString // 98 was Cardinality // 99 was ArrayElement @@ -647,9 +647,9 @@ enum ScalarFunction { // 105 was ArrayHasAny // 106 was ArrayHasAll // 107 was ArrayRemoveN - ArrayReplaceN = 108; + // 108 was ArrayReplaceN // 109 was ArrayRemoveAll - ArrayReplaceAll = 110; + // 110 was ArrayReplaceAll Nanvl = 111; // 112 was Flatten // 113 was IsNan @@ -660,13 +660,13 @@ enum ScalarFunction { // 118 was ToTimestampNanos // 119 was ArrayIntersect // 120 was ArrayUnion - OverLay = 121; + // 121 was OverLay // 122 is Range // 123 is ArrayExcept // 124 was ArrayPopFront - Levenshtein = 125; - SubstrIndex = 126; - FindInSet = 127; + // 125 was Levenshtein + // 126 was SubstrIndex + // 127 was FindInSet // 128 was ArraySort // 129 was ArrayDistinct // 130 was ArrayResize @@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue { int64 nanos = 3; } +message UnionField { + int32 field_id = 1; + Field field = 2; +} + +message UnionValue { + // Note that a null union value must have one or more fields, so we + // encode a null UnionValue as one with value_id == 128 + int32 value_id = 1; + ScalarValue value = 2; + repeated UnionField fields = 3; + UnionMode mode = 4; +} + message ScalarFixedSizeBinary{ bytes values = 1; int32 length = 2; @@ -1042,6 +1056,7 @@ message ScalarValue{ ScalarTime64Value time64_value = 30; IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; + UnionValue union_value = 42; } } @@ -1458,6 +1473,7 @@ message PhysicalExprNode { message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; + optional bytes fun_definition = 3; ArrowType return_type = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cb9633338e8f..d900d0031df3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20391,6 +20391,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.return_type.is_some() { len += 1; } @@ -20401,6 +20404,10 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.return_type.as_ref() { struct_ser.serialize_field("returnType", v)?; } @@ -20416,6 +20423,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { const FIELDS: &[&str] = &[ "name", "args", + "fun_definition", + "funDefinition", "return_type", "returnType", ]; @@ -20424,6 +20433,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { enum GeneratedField { Name, Args, + FunDefinition, ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -20448,6 +20458,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { match value { "name" => Ok(GeneratedField::Name), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -20470,6 +20481,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { let mut name__ = None; let mut args__ = None; + let mut fun_definition__ = None; let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -20485,6 +20497,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); @@ -20496,6 +20516,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { Ok(PhysicalScalarUdfNode { name: name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, return_type: return_type__, }) } @@ -22893,56 +22914,23 @@ impl serde::Serialize for ScalarFunction { { let variant = match self { Self::Unknown => "unknown", - Self::Atan => "Atan", - Self::Ascii => "Ascii", Self::Ceil => "Ceil", Self::Cos => "Cos", Self::Exp => "Exp", Self::Floor => "Floor", - Self::Ln => "Ln", Self::Log => "Log", - Self::Log10 => "Log10", - Self::Log2 => "Log2", Self::Round => "Round", Self::Signum => "Signum", Self::Sin => "Sin", Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", - Self::BitLength => "BitLength", - Self::Btrim => "Btrim", - Self::CharacterLength => "CharacterLength", - Self::Chr => "Chr", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", - Self::Left => "Left", - Self::Lpad => "Lpad", - Self::Lower => "Lower", - Self::Ltrim => "Ltrim", - Self::OctetLength => "OctetLength", Self::Random => "Random", - Self::Repeat => "Repeat", - Self::Replace => "Replace", - Self::Reverse => "Reverse", - Self::Right => "Right", - Self::Rpad => "Rpad", - Self::Rtrim => "Rtrim", - Self::SplitPart => "SplitPart", - Self::StartsWith => "StartsWith", - Self::Strpos => "Strpos", - Self::Substr => "Substr", - Self::ToHex => "ToHex", - Self::Translate => "Translate", - Self::Trim => "Trim", - Self::Upper => "Upper", Self::Coalesce => "Coalesce", Self::Power => "Power", - Self::Atan2 => "Atan2", - Self::Uuid => "Uuid", Self::Cbrt => "Cbrt", - Self::Acosh => "Acosh", - Self::Asinh => "Asinh", - Self::Atanh => "Atanh", Self::Sinh => "Sinh", Self::Cosh => "Cosh", Self::Pi => "Pi", @@ -22951,16 +22939,9 @@ impl serde::Serialize for ScalarFunction { Self::Factorial => "Factorial", Self::Lcm => "Lcm", Self::Gcd => "Gcd", - Self::ArrayReplace => "ArrayReplace", Self::Cot => "Cot", - Self::ArrayReplaceN => "ArrayReplaceN", - Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", - Self::OverLay => "OverLay", - Self::Levenshtein => "Levenshtein", - Self::SubstrIndex => "SubstrIndex", - Self::FindInSet => "FindInSet", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22974,56 +22955,23 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { const FIELDS: &[&str] = &[ "unknown", - "Atan", - "Ascii", "Ceil", "Cos", "Exp", "Floor", - "Ln", "Log", - "Log10", - "Log2", "Round", "Signum", "Sin", "Sqrt", "Trunc", - "BitLength", - "Btrim", - "CharacterLength", - "Chr", "Concat", "ConcatWithSeparator", "InitCap", - "Left", - "Lpad", - "Lower", - "Ltrim", - "OctetLength", "Random", - "Repeat", - "Replace", - "Reverse", - "Right", - "Rpad", - "Rtrim", - "SplitPart", - "StartsWith", - "Strpos", - "Substr", - "ToHex", - "Translate", - "Trim", - "Upper", "Coalesce", "Power", - "Atan2", - "Uuid", "Cbrt", - "Acosh", - "Asinh", - "Atanh", "Sinh", "Cosh", "Pi", @@ -23032,16 +22980,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial", "Lcm", "Gcd", - "ArrayReplace", "Cot", - "ArrayReplaceN", - "ArrayReplaceAll", "Nanvl", "Iszero", - "OverLay", - "Levenshtein", - "SubstrIndex", - "FindInSet", "EndsWith", ]; @@ -23084,56 +23025,23 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { match value { "unknown" => Ok(ScalarFunction::Unknown), - "Atan" => Ok(ScalarFunction::Atan), - "Ascii" => Ok(ScalarFunction::Ascii), "Ceil" => Ok(ScalarFunction::Ceil), "Cos" => Ok(ScalarFunction::Cos), "Exp" => Ok(ScalarFunction::Exp), "Floor" => Ok(ScalarFunction::Floor), - "Ln" => Ok(ScalarFunction::Ln), "Log" => Ok(ScalarFunction::Log), - "Log10" => Ok(ScalarFunction::Log10), - "Log2" => Ok(ScalarFunction::Log2), "Round" => Ok(ScalarFunction::Round), "Signum" => Ok(ScalarFunction::Signum), "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), - "BitLength" => Ok(ScalarFunction::BitLength), - "Btrim" => Ok(ScalarFunction::Btrim), - "CharacterLength" => Ok(ScalarFunction::CharacterLength), - "Chr" => Ok(ScalarFunction::Chr), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), - "Left" => Ok(ScalarFunction::Left), - "Lpad" => Ok(ScalarFunction::Lpad), - "Lower" => Ok(ScalarFunction::Lower), - "Ltrim" => Ok(ScalarFunction::Ltrim), - "OctetLength" => Ok(ScalarFunction::OctetLength), "Random" => Ok(ScalarFunction::Random), - "Repeat" => Ok(ScalarFunction::Repeat), - "Replace" => Ok(ScalarFunction::Replace), - "Reverse" => Ok(ScalarFunction::Reverse), - "Right" => Ok(ScalarFunction::Right), - "Rpad" => Ok(ScalarFunction::Rpad), - "Rtrim" => Ok(ScalarFunction::Rtrim), - "SplitPart" => Ok(ScalarFunction::SplitPart), - "StartsWith" => Ok(ScalarFunction::StartsWith), - "Strpos" => Ok(ScalarFunction::Strpos), - "Substr" => Ok(ScalarFunction::Substr), - "ToHex" => Ok(ScalarFunction::ToHex), - "Translate" => Ok(ScalarFunction::Translate), - "Trim" => Ok(ScalarFunction::Trim), - "Upper" => Ok(ScalarFunction::Upper), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), - "Atan2" => Ok(ScalarFunction::Atan2), - "Uuid" => Ok(ScalarFunction::Uuid), "Cbrt" => Ok(ScalarFunction::Cbrt), - "Acosh" => Ok(ScalarFunction::Acosh), - "Asinh" => Ok(ScalarFunction::Asinh), - "Atanh" => Ok(ScalarFunction::Atanh), "Sinh" => Ok(ScalarFunction::Sinh), "Cosh" => Ok(ScalarFunction::Cosh), "Pi" => Ok(ScalarFunction::Pi), @@ -23142,16 +23050,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial" => Ok(ScalarFunction::Factorial), "Lcm" => Ok(ScalarFunction::Lcm), "Gcd" => Ok(ScalarFunction::Gcd), - "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), "Cot" => Ok(ScalarFunction::Cot), - "ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN), - "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), - "OverLay" => Ok(ScalarFunction::OverLay), - "Levenshtein" => Ok(ScalarFunction::Levenshtein), - "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), - "FindInSet" => Ok(ScalarFunction::FindInSet), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } @@ -24041,6 +23942,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::FixedSizeBinaryValue(v) => { struct_ser.serialize_field("fixedSizeBinaryValue", v)?; } + scalar_value::Value::UnionValue(v) => { + struct_ser.serialize_field("unionValue", v)?; + } } } struct_ser.end() @@ -24125,6 +24029,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano", "fixed_size_binary_value", "fixedSizeBinaryValue", + "union_value", + "unionValue", ]; #[allow(clippy::enum_variant_names)] @@ -24165,6 +24071,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Time64Value, IntervalMonthDayNano, FixedSizeBinaryValue, + UnionValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24222,6 +24129,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), + "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24471,6 +24379,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) +; + } + GeneratedField::UnionValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("unionValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) ; } } @@ -26930,6 +26845,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_id != 0 { + len += 1; + } + if self.field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionField", len)?; + if self.field_id != 0 { + struct_ser.serialize_field("fieldId", &self.field_id)?; + } + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_id", + "fieldId", + "field", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldId, + Field, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldId" | "field_id" => Ok(GeneratedField::FieldId), + "field" => Ok(GeneratedField::Field), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_id__ = None; + let mut field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldId => { + if field_id__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldId")); + } + field_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + } + } + Ok(UnionField { + field_id: field_id__.unwrap_or_default(), + field: field__, + }) + } + } + deserializer.deserialize_struct("datafusion.UnionField", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UnionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -27092,6 +27118,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode { deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value_id != 0 { + len += 1; + } + if self.value.is_some() { + len += 1; + } + if !self.fields.is_empty() { + len += 1; + } + if self.mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionValue", len)?; + if self.value_id != 0 { + struct_ser.serialize_field("valueId", &self.value_id)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } + if self.mode != 0 { + let v = UnionMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + struct_ser.serialize_field("mode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value_id", + "valueId", + "value", + "fields", + "mode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ValueId, + Value, + Fields, + Mode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "valueId" | "value_id" => Ok(GeneratedField::ValueId), + "value" => Ok(GeneratedField::Value), + "fields" => Ok(GeneratedField::Fields), + "mode" => Ok(GeneratedField::Mode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value_id__ = None; + let mut value__ = None; + let mut fields__ = None; + let mut mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ValueId => { + if value_id__.is_some() { + return Err(serde::de::Error::duplicate_field("valueId")); + } + value_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + GeneratedField::Mode => { + if mode__.is_some() { + return Err(serde::de::Error::duplicate_field("mode")); + } + mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(UnionValue { + value_id: value_id__.unwrap_or_default(), + value: value__, + fields: fields__.unwrap_or_default(), + mode: mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UniqueConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f5ef6c1f74f0..753abb4e2756 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1225,6 +1225,28 @@ pub struct IntervalMonthDayNanoValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -1236,7 +1258,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -1320,6 +1342,8 @@ pub mod scalar_value { IntervalMonthDayNano(super::IntervalMonthDayNanoValue), #[prost(message, tag = "34")] FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2092,6 +2116,8 @@ pub struct PhysicalScalarUdfNode { pub name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(message, optional, tag = "4")] pub return_type: ::core::option::Option, } @@ -2817,17 +2843,17 @@ pub enum ScalarFunction { Unknown = 0, /// 1 was Acos /// 2 was Asin - Atan = 3, - Ascii = 4, + /// 3 was Atan + /// 4 was Ascii Ceil = 5, Cos = 6, /// 7 was Digest Exp = 8, Floor = 9, - Ln = 10, + /// 10 was Ln Log = 11, - Log10 = 12, - Log2 = 13, + /// 12 was Log10 + /// 13 was Log2 Round = 14, Signum = 15, Sin = 16, @@ -2836,61 +2862,61 @@ pub enum ScalarFunction { Trunc = 19, /// 20 was Array /// RegexpMatch = 21; - BitLength = 22, - Btrim = 23, - CharacterLength = 24, - Chr = 25, + /// 22 was BitLength + /// 23 was Btrim + /// 24 was CharacterLength + /// 25 was Chr Concat = 26, ConcatWithSeparator = 27, /// 28 was DatePart /// 29 was DateTrunc InitCap = 30, - Left = 31, - Lpad = 32, - Lower = 33, - Ltrim = 34, + /// 31 was Left + /// 32 was Lpad + /// 33 was Lower + /// 34 was Ltrim /// 35 was MD5 /// 36 was NullIf - OctetLength = 37, + /// 37 was OctetLength Random = 38, /// 39 was RegexpReplace - Repeat = 40, - Replace = 41, - Reverse = 42, - Right = 43, - Rpad = 44, - Rtrim = 45, + /// 40 was Repeat + /// 41 was Replace + /// 42 was Reverse + /// 43 was Right + /// 44 was Rpad + /// 45 was Rtrim /// 46 was SHA224 /// 47 was SHA256 /// 48 was SHA384 /// 49 was SHA512 - SplitPart = 50, - StartsWith = 51, - Strpos = 52, - Substr = 53, - ToHex = 54, + /// 50 was SplitPart + /// StartsWith = 51; + /// 52 was Strpos + /// 53 was Substr + /// ToHex = 54; /// 55 was ToTimestamp /// 56 was ToTimestampMillis /// 57 was ToTimestampMicros /// 58 was ToTimestampSeconds /// 59 was Now - Translate = 60, - Trim = 61, - Upper = 62, + /// 60 was Translate + /// Trim = 61; + /// Upper = 62; Coalesce = 63, Power = 64, /// 65 was StructFun /// 66 was FromUnixtime - Atan2 = 67, + /// 67 Atan2 /// 68 was DateBin /// 69 was ArrowTypeof /// 70 was CurrentDate /// 71 was CurrentTime - Uuid = 72, + /// 72 was Uuid Cbrt = 73, - Acosh = 74, - Asinh = 75, - Atanh = 76, + /// 74 Acosh + /// 75 was Asinh + /// 76 was Atanh Sinh = 77, Cosh = 78, /// Tanh = 79; @@ -2910,7 +2936,7 @@ pub enum ScalarFunction { /// 93 was ArrayPositions /// 94 was ArrayPrepend /// 95 was ArrayRemove - ArrayReplace = 96, + /// 96 was ArrayReplace /// 97 was ArrayToString /// 98 was Cardinality /// 99 was ArrayElement @@ -2920,9 +2946,9 @@ pub enum ScalarFunction { /// 105 was ArrayHasAny /// 106 was ArrayHasAll /// 107 was ArrayRemoveN - ArrayReplaceN = 108, + /// 108 was ArrayReplaceN /// 109 was ArrayRemoveAll - ArrayReplaceAll = 110, + /// 110 was ArrayReplaceAll Nanvl = 111, /// 112 was Flatten /// 113 was IsNan @@ -2933,13 +2959,13 @@ pub enum ScalarFunction { /// 118 was ToTimestampNanos /// 119 was ArrayIntersect /// 120 was ArrayUnion - OverLay = 121, + /// 121 was OverLay /// 122 is Range /// 123 is ArrayExcept /// 124 was ArrayPopFront - Levenshtein = 125, - SubstrIndex = 126, - FindInSet = 127, + /// 125 was Levenshtein + /// 126 was SubstrIndex + /// 127 was FindInSet /// 128 was ArraySort /// 129 was ArrayDistinct /// 130 was ArrayResize @@ -2961,56 +2987,23 @@ impl ScalarFunction { pub fn as_str_name(&self) -> &'static str { match self { ScalarFunction::Unknown => "unknown", - ScalarFunction::Atan => "Atan", - ScalarFunction::Ascii => "Ascii", ScalarFunction::Ceil => "Ceil", ScalarFunction::Cos => "Cos", ScalarFunction::Exp => "Exp", ScalarFunction::Floor => "Floor", - ScalarFunction::Ln => "Ln", ScalarFunction::Log => "Log", - ScalarFunction::Log10 => "Log10", - ScalarFunction::Log2 => "Log2", ScalarFunction::Round => "Round", ScalarFunction::Signum => "Signum", ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", - ScalarFunction::BitLength => "BitLength", - ScalarFunction::Btrim => "Btrim", - ScalarFunction::CharacterLength => "CharacterLength", - ScalarFunction::Chr => "Chr", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", - ScalarFunction::Left => "Left", - ScalarFunction::Lpad => "Lpad", - ScalarFunction::Lower => "Lower", - ScalarFunction::Ltrim => "Ltrim", - ScalarFunction::OctetLength => "OctetLength", ScalarFunction::Random => "Random", - ScalarFunction::Repeat => "Repeat", - ScalarFunction::Replace => "Replace", - ScalarFunction::Reverse => "Reverse", - ScalarFunction::Right => "Right", - ScalarFunction::Rpad => "Rpad", - ScalarFunction::Rtrim => "Rtrim", - ScalarFunction::SplitPart => "SplitPart", - ScalarFunction::StartsWith => "StartsWith", - ScalarFunction::Strpos => "Strpos", - ScalarFunction::Substr => "Substr", - ScalarFunction::ToHex => "ToHex", - ScalarFunction::Translate => "Translate", - ScalarFunction::Trim => "Trim", - ScalarFunction::Upper => "Upper", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", - ScalarFunction::Atan2 => "Atan2", - ScalarFunction::Uuid => "Uuid", ScalarFunction::Cbrt => "Cbrt", - ScalarFunction::Acosh => "Acosh", - ScalarFunction::Asinh => "Asinh", - ScalarFunction::Atanh => "Atanh", ScalarFunction::Sinh => "Sinh", ScalarFunction::Cosh => "Cosh", ScalarFunction::Pi => "Pi", @@ -3019,16 +3012,9 @@ impl ScalarFunction { ScalarFunction::Factorial => "Factorial", ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", - ScalarFunction::ArrayReplace => "ArrayReplace", ScalarFunction::Cot => "Cot", - ScalarFunction::ArrayReplaceN => "ArrayReplaceN", - ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", - ScalarFunction::OverLay => "OverLay", - ScalarFunction::Levenshtein => "Levenshtein", - ScalarFunction::SubstrIndex => "SubstrIndex", - ScalarFunction::FindInSet => "FindInSet", ScalarFunction::EndsWith => "EndsWith", } } @@ -3036,56 +3022,23 @@ impl ScalarFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "unknown" => Some(Self::Unknown), - "Atan" => Some(Self::Atan), - "Ascii" => Some(Self::Ascii), "Ceil" => Some(Self::Ceil), "Cos" => Some(Self::Cos), "Exp" => Some(Self::Exp), "Floor" => Some(Self::Floor), - "Ln" => Some(Self::Ln), "Log" => Some(Self::Log), - "Log10" => Some(Self::Log10), - "Log2" => Some(Self::Log2), "Round" => Some(Self::Round), "Signum" => Some(Self::Signum), "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), - "BitLength" => Some(Self::BitLength), - "Btrim" => Some(Self::Btrim), - "CharacterLength" => Some(Self::CharacterLength), - "Chr" => Some(Self::Chr), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), - "Left" => Some(Self::Left), - "Lpad" => Some(Self::Lpad), - "Lower" => Some(Self::Lower), - "Ltrim" => Some(Self::Ltrim), - "OctetLength" => Some(Self::OctetLength), "Random" => Some(Self::Random), - "Repeat" => Some(Self::Repeat), - "Replace" => Some(Self::Replace), - "Reverse" => Some(Self::Reverse), - "Right" => Some(Self::Right), - "Rpad" => Some(Self::Rpad), - "Rtrim" => Some(Self::Rtrim), - "SplitPart" => Some(Self::SplitPart), - "StartsWith" => Some(Self::StartsWith), - "Strpos" => Some(Self::Strpos), - "Substr" => Some(Self::Substr), - "ToHex" => Some(Self::ToHex), - "Translate" => Some(Self::Translate), - "Trim" => Some(Self::Trim), - "Upper" => Some(Self::Upper), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), - "Atan2" => Some(Self::Atan2), - "Uuid" => Some(Self::Uuid), "Cbrt" => Some(Self::Cbrt), - "Acosh" => Some(Self::Acosh), - "Asinh" => Some(Self::Asinh), - "Atanh" => Some(Self::Atanh), "Sinh" => Some(Self::Sinh), "Cosh" => Some(Self::Cosh), "Pi" => Some(Self::Pi), @@ -3094,16 +3047,9 @@ impl ScalarFunction { "Factorial" => Some(Self::Factorial), "Lcm" => Some(Self::Lcm), "Gcd" => Some(Self::Gcd), - "ArrayReplace" => Some(Self::ArrayReplace), "Cot" => Some(Self::Cot), - "ArrayReplaceN" => Some(Self::ArrayReplaceN), - "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), - "OverLay" => Some(Self::OverLay), - "Levenshtein" => Some(Self::Levenshtein), - "SubstrIndex" => Some(Self::SubstrIndex), - "FindInSet" => Some(Self::FindInSet), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3822b74bc18c..f9e2dc5596ac 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,18 +17,6 @@ use std::sync::Arc; -use crate::protobuf::{ - self, - plan_type::PlanTypeEnum::{ - AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, - InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, - OptimizedPhysicalPlan, - }, - AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, - OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, -}; - use arrow::{ array::AsArray, buffer::Buffer, @@ -38,6 +26,7 @@ use arrow::{ }, ipc::{reader::read_record_batch, root_as_message}, }; + use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, @@ -48,24 +37,31 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, array_replace, array_replace_all, array_replace_n, ascii, asinh, atan, atan2, - atanh, bit_length, btrim, cbrt, ceil, character_length, chr, coalesce, concat_expr, - concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, + ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, - log10, log2, + factorial, floor, gcd, initcap, iszero, lcm, log, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, nanvl, octet_length, overlay, pi, power, radians, random, repeat, - replace, reverse, right, round, rpad, rtrim, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substr_index, substring, to_hex, translate, trim, trunc, - upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, trunc, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use crate::protobuf::{ + self, + plan_type::PlanTypeEnum::{ + AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, + OptimizedPhysicalPlan, + }, + AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, + OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, +}; + use super::LogicalExtensionCodec; #[derive(Debug)] @@ -327,11 +323,7 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { DataType::FixedSizeList(Arc::new(list_type), list_size) } arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( - strct - .sub_field_types - .iter() - .map(Field::try_from) - .collect::>()?, + parse_proto_fields_to_fields(&strct.sub_field_types)?.into(), ), arrow_type::ArrowTypeEnum::Union(union) => { let union_mode = protobuf::UnionMode::try_from(union.union_mode) @@ -340,11 +332,7 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { protobuf::UnionMode::Dense => UnionMode::Dense, protobuf::UnionMode::Sparse => UnionMode::Sparse, }; - let union_fields = union - .union_types - .iter() - .map(TryInto::try_into) - .collect::, _>>()?; + let union_fields = parse_proto_fields_to_fields(&union.union_types)?; // Default to index based type ids if not provided let type_ids: Vec<_> = match union.type_ids.is_empty() { @@ -440,16 +428,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sin => Self::Sin, ScalarFunction::Cos => Self::Cos, ScalarFunction::Cot => Self::Cot, - ScalarFunction::Atan => Self::Atan, ScalarFunction::Sinh => Self::Sinh, ScalarFunction::Cosh => Self::Cosh, - ScalarFunction::Asinh => Self::Asinh, - ScalarFunction::Acosh => Self::Acosh, - ScalarFunction::Atanh => Self::Atanh, ScalarFunction::Exp => Self::Exp, ScalarFunction::Log => Self::Log, - ScalarFunction::Ln => Self::Ln, - ScalarFunction::Log10 => Self::Log10, ScalarFunction::Degrees => Self::Degrees, ScalarFunction::Radians => Self::Radians, ScalarFunction::Factorial => Self::Factorial, @@ -459,51 +441,17 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Ceil => Self::Ceil, ScalarFunction::Round => Self::Round, ScalarFunction::Trunc => Self::Trunc, - ScalarFunction::OctetLength => Self::OctetLength, ScalarFunction::Concat => Self::Concat, - ScalarFunction::Lower => Self::Lower, - ScalarFunction::Upper => Self::Upper, - ScalarFunction::Trim => Self::Trim, - ScalarFunction::Ltrim => Self::Ltrim, - ScalarFunction::Rtrim => Self::Rtrim, - ScalarFunction::ArrayReplace => Self::ArrayReplace, - ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, - ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, - ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, - ScalarFunction::Ascii => Self::Ascii, - ScalarFunction::BitLength => Self::BitLength, - ScalarFunction::Btrim => Self::Btrim, - ScalarFunction::CharacterLength => Self::CharacterLength, - ScalarFunction::Chr => Self::Chr, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, - ScalarFunction::Left => Self::Left, - ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, - ScalarFunction::Repeat => Self::Repeat, - ScalarFunction::Replace => Self::Replace, - ScalarFunction::Reverse => Self::Reverse, - ScalarFunction::Right => Self::Right, - ScalarFunction::Rpad => Self::Rpad, - ScalarFunction::SplitPart => Self::SplitPart, - ScalarFunction::StartsWith => Self::StartsWith, - ScalarFunction::Strpos => Self::Strpos, - ScalarFunction::Substr => Self::Substr, - ScalarFunction::ToHex => Self::ToHex, - ScalarFunction::Uuid => Self::Uuid, - ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, - ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, - ScalarFunction::OverLay => Self::OverLay, - ScalarFunction::Levenshtein => Self::Levenshtein, - ScalarFunction::SubstrIndex => Self::SubstrIndex, - ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -771,6 +719,38 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), )), + Value::UnionValue(val) => { + let mode = match val.mode { + 0 => UnionMode::Sparse, + 1 => UnionMode::Dense, + id => Err(Error::unknown("UnionMode", id))?, + }; + let ids = val + .fields + .iter() + .map(|f| f.field_id as i8) + .collect::>(); + let fields = val + .fields + .iter() + .map(|f| f.field.clone()) + .collect::>>(); + let fields = fields.ok_or_else(|| Error::required("UnionField"))?; + let fields = parse_proto_fields_to_fields(&fields)?; + let fields = UnionFields::new(ids, fields); + let v_id = val.value_id as i8; + let val = match &val.value { + None => None, + Some(val) => { + let val: ScalarValue = val + .as_ref() + .try_into() + .map_err(|_| Error::General("Invalid Scalar".to_string()))?; + Some((v_id, Box::new(val))) + } + }; + Self::Union(val, fields, mode) + } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) } @@ -927,11 +907,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = binary_expr - .operands - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?; + let operands = parse_exprs(&binary_expr.operands, registry, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -1015,16 +991,8 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = expr - .partition_by - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?; - let mut order_by = expr - .order_by - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?; + let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; + let mut order_by = parse_exprs(&expr.order_by, registry, codec)?; let window_frame = expr .window_frame .as_ref() @@ -1120,10 +1088,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, - expr.expr - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?, + parse_exprs(&expr.expr, registry, codec)?, expr.distinct, parse_optional_expr(expr.filter.as_deref(), registry, codec)? .map(Box::new), @@ -1321,11 +1286,7 @@ pub fn parse_expr( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let exprs = unnest - .exprs - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?; + let exprs = parse_exprs(&unnest.exprs, registry, codec)?; Ok(Expr::Unnest(Unnest { exprs })) } ExprType::InList(in_list) => Ok(Expr::InList(InList::new( @@ -1335,11 +1296,7 @@ pub fn parse_expr( "expr", codec, )?), - in_list - .list - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, + parse_exprs(&in_list.list, registry, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { @@ -1356,38 +1313,12 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Asinh => { - Ok(asinh(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Acosh => { - Ok(acosh(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::ArrayReplace => Ok(array_replace( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::ArrayReplaceN => Ok(array_replace_n( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - parse_expr(&args[3], registry, codec)?, - )), - ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Atanh => { - Ok(atanh(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Degrees => { Ok(degrees(parse_expr(&args[0], registry, codec)?)) @@ -1395,11 +1326,6 @@ pub fn parse_expr( ScalarFunction::Radians => { Ok(radians(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Log10 => { - Ok(log10(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Floor => { Ok(floor(parse_expr(&args[0], registry, codec)?)) } @@ -1407,47 +1333,11 @@ pub fn parse_expr( Ok(factorial(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Round => Ok(round( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Trunc => Ok(trunc( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), + ScalarFunction::Round => Ok(round(parse_exprs(args, registry, codec)?)), + ScalarFunction::Trunc => Ok(trunc(parse_exprs(args, registry, codec)?)), ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::OctetLength => { - Ok(octet_length(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Lower => { - Ok(lower(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Upper => { - Ok(upper(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Ltrim => { - Ok(ltrim(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Rtrim => { - Ok(rtrim(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Ascii => { - Ok(ascii(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::BitLength => { - Ok(bit_length(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry, codec)?)), ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } @@ -1459,108 +1349,20 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Random => Ok(random()), - ScalarFunction::Uuid => Ok(uuid()), - ScalarFunction::Repeat => Ok(repeat( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Replace => Ok(replace( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::Reverse => { - Ok(reverse(parse_expr(&args[0], registry, codec)?)) + ScalarFunction::Concat => { + Ok(concat_expr(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Concat => Ok(concat_expr( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Lpad => Ok(lpad( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Rpad => Ok(rpad( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Btrim => Ok(btrim( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::SplitPart => Ok(split_part( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::EndsWith => Ok(ends_with( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Substr => { - if args.len() > 2 { - assert_eq!(args.len(), 3); - Ok(substring( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )) - } else { - Ok(substr( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )) - } + ScalarFunction::ConcatWithSeparator => { + Ok(concat_ws_expr(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Levenshtein => Ok(levenshtein( + ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ToHex => { - Ok(to_hex(parse_expr(&args[0], registry, codec)?)) + ScalarFunction::Coalesce => { + Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::Coalesce => Ok(coalesce( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( parse_expr(&args[0], registry, codec)?, @@ -1570,10 +1372,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Atan2 => Ok(atan2( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( parse_expr(&args[0], registry, codec)?, @@ -1582,21 +1380,6 @@ pub fn parse_expr( ScalarFunction::Iszero => { Ok(iszero(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::OverLay => Ok(overlay( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::SubstrIndex => Ok(substr_index( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::FindInSet => Ok(find_in_set( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { @@ -1610,9 +1393,7 @@ pub fn parse_expr( }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - args.iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, + parse_exprs(args, registry, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { @@ -1620,10 +1401,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - pb.args - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, + parse_exprs(&pb.args, registry, codec)?, false, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, @@ -1633,28 +1411,16 @@ pub fn parse_expr( ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| { - expr_list - .expr - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>() - }) + .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - expr.iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, + parse_exprs(expr, registry, codec)?, ))), - ExprType::Rollup(RollupNode { expr }) => { - Ok(Expr::GroupingSet(GroupingSet::Rollup( - expr.iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, - ))) - } + ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( + GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + )), ExprType::Placeholder(PlaceholderNode { id, data_type }) => match data_type { None => Ok(Expr::Placeholder(Placeholder::new(id.clone(), None))), Some(data_type) => Ok(Expr::Placeholder(Placeholder::new( @@ -1665,6 +1431,24 @@ pub fn parse_expr( } } +/// Parse a vector of `protobuf::LogicalExprNode`s. +pub fn parse_exprs<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + let res = protos + .into_iter() + .map(|elem| { + parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + }) + .collect::>>()?; + Ok(res) +} + /// Parse an optional escape_char for Like, ILike, SimilarTo fn parse_escape_char(s: &str) -> Result> { match s.len() { @@ -1721,12 +1505,7 @@ fn parse_vec_expr( registry: &dyn FunctionRegistry, codec: &dyn LogicalExtensionCodec, ) -> Result>, Error> { - let res = p - .iter() - .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) - }) - .collect::>>()?; + let res = parse_exprs(p, registry, codec)?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) } @@ -1757,3 +1536,16 @@ fn parse_required_expr( fn proto_error>(message: S) -> Error { Error::General(message.into()) } + +/// Converts a vector of `protobuf::Field`s to `Arc`s. +fn parse_proto_fields_to_fields<'a, I>( + fields: I, +) -> std::result::Result, Error> +where + I: IntoIterator, +{ + fields + .into_iter() + .map(Field::try_from) + .collect::>() +} diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7a17d2a2b405..3ee69066e1aa 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -19,6 +19,8 @@ //! DataFusion logical plans to be serialized and transmitted between //! processes. +use std::sync::Arc; + use crate::protobuf::{ self, arrow_type::ArrowTypeEnum, @@ -30,6 +32,7 @@ use crate::protobuf::{ }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, + UnionField, UnionValue, }; use arrow::{ @@ -185,10 +188,7 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { field_type: Some(Box::new(item_type.as_ref().try_into()?)), })), DataType::Struct(struct_fields) => Self::Struct(protobuf::Struct { - sub_field_types: struct_fields - .iter() - .map(|field| field.as_ref().try_into()) - .collect::, Error>>()?, + sub_field_types: convert_arc_fields_to_proto_fields(struct_fields)?, }), DataType::Union(fields, union_mode) => { let union_mode = match union_mode { @@ -196,10 +196,7 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { UnionMode::Dense => protobuf::UnionMode::Dense, }; Self::Union(protobuf::Union { - union_types: fields - .iter() - .map(|(_, field)| field.as_ref().try_into()) - .collect::, Error>>()?, + union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, union_mode: union_mode.into(), type_ids: fields.iter().map(|(x, _)| x as i32).collect(), }) @@ -230,6 +227,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() )) } + DataType::Utf8View | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { + return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + } }; Ok(res) @@ -258,11 +258,7 @@ impl TryFrom<&Schema> for protobuf::Schema { fn try_from(schema: &Schema) -> Result { Ok(Self { - columns: schema - .fields() - .iter() - .map(|f| f.as_ref().try_into()) - .collect::, Error>>()?, + columns: convert_arc_fields_to_proto_fields(schema.fields())?, metadata: schema.metadata.clone(), }) } @@ -273,11 +269,7 @@ impl TryFrom for protobuf::Schema { fn try_from(schema: SchemaRef) -> Result { Ok(Self { - columns: schema - .fields() - .iter() - .map(|f| f.as_ref().try_into()) - .collect::, Error>>()?, + columns: convert_arc_fields_to_proto_fields(schema.fields())?, metadata: schema.metadata.clone(), }) } @@ -482,6 +474,19 @@ impl TryFrom<&WindowFrame> for protobuf::WindowFrame { } } +pub fn serialize_exprs<'a, I>( + exprs: I, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + exprs + .into_iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>() +} + pub fn serialize_expr( expr: &Expr, codec: &dyn LogicalExtensionCodec, @@ -539,11 +544,7 @@ pub fn serialize_expr( // We need to reverse exprs since operands are expected to be // linearized from left innermost to right outermost (but while // traversing the chain we do the exact opposite). - operands: exprs - .into_iter() - .rev() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + operands: serialize_exprs(exprs.into_iter().rev(), codec)?, op: format!("{op:?}"), }; protobuf::LogicalExprNode { @@ -635,14 +636,8 @@ pub fn serialize_expr( } else { None }; - let partition_by = partition_by - .iter() - .map(|e| serialize_expr(e, codec)) - .collect::, _>>()?; - let order_by = order_by - .iter() - .map(|e| serialize_expr(e, codec)) - .collect::, _>>()?; + let partition_by = serialize_exprs(partition_by, codec)?; + let order_by = serialize_exprs(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); @@ -740,20 +735,14 @@ pub fn serialize_expr( let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| serialize_expr(v, codec)) - .collect::, _>>()?, + expr: serialize_exprs(args, codec)?, distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e, codec)?)), None => None, }, order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, _>>()?, + Some(e) => serialize_exprs(e, codec)?, None => vec![], }, }; @@ -765,19 +754,13 @@ pub fn serialize_expr( expr_type: Some(ExprType::AggregateUdfExpr(Box::new( protobuf::AggregateUdfExprNode { fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + args: serialize_exprs(args, codec)?, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, }, order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, _>>()?, + Some(e) => serialize_exprs(e, codec)?, None => vec![], }, }, @@ -797,10 +780,7 @@ pub fn serialize_expr( )) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let args = args - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?; + let args = serialize_exprs(args, codec)?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { let fun: protobuf::ScalarFunction = fun.try_into()?; @@ -993,10 +973,7 @@ pub fn serialize_expr( } Expr::Unnest(Unnest { exprs }) => { let expr = protobuf::Unnest { - exprs: exprs - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + exprs: serialize_exprs(exprs, codec)?, }; protobuf::LogicalExprNode { expr_type: Some(ExprType::Unnest(expr)), @@ -1009,10 +986,7 @@ pub fn serialize_expr( }) => { let expr = Box::new(protobuf::InListNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - list: list - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + list: serialize_exprs(list, codec)?, negated: *negated, }); protobuf::LogicalExprNode { @@ -1073,18 +1047,12 @@ pub fn serialize_expr( Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { - expr: exprs - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + expr: serialize_exprs(exprs, codec)?, })), }, Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Rollup(RollupNode { - expr: exprs - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + expr: serialize_exprs(exprs, codec)?, })), }, Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { @@ -1094,10 +1062,7 @@ pub fn serialize_expr( .iter() .map(|expr_list| { Ok(LogicalExprList { - expr: expr_list - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + expr: serialize_exprs(expr_list, codec)?, }) }) .collect::, Error>>()?, @@ -1402,6 +1367,34 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }; Ok(protobuf::ScalarValue { value: Some(value) }) } + + ScalarValue::Union(val, df_fields, mode) => { + let mut fields = Vec::::with_capacity(df_fields.len()); + for (id, field) in df_fields.iter() { + let field_id = id as i32; + let field = Some(field.as_ref().try_into()?); + let field = UnionField { field_id, field }; + fields.push(field); + } + let mode = match mode { + UnionMode::Sparse => 0, + UnionMode::Dense => 1, + }; + let value = match val { + None => None, + Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), + }; + let val = UnionValue { + value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), + value, + fields, + mode, + }; + let val = Value::UnionValue(Box::new(val)); + let val = protobuf::ScalarValue { value: Some(val) }; + Ok(val) + } + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { @@ -1429,68 +1422,28 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, - BuiltinScalarFunction::Atan => Self::Atan, - BuiltinScalarFunction::Asinh => Self::Asinh, - BuiltinScalarFunction::Acosh => Self::Acosh, - BuiltinScalarFunction::Atanh => Self::Atanh, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, BuiltinScalarFunction::Gcd => Self::Gcd, BuiltinScalarFunction::Lcm => Self::Lcm, BuiltinScalarFunction::Log => Self::Log, - BuiltinScalarFunction::Ln => Self::Ln, - BuiltinScalarFunction::Log10 => Self::Log10, BuiltinScalarFunction::Degrees => Self::Degrees, BuiltinScalarFunction::Radians => Self::Radians, BuiltinScalarFunction::Floor => Self::Floor, BuiltinScalarFunction::Ceil => Self::Ceil, BuiltinScalarFunction::Round => Self::Round, BuiltinScalarFunction::Trunc => Self::Trunc, - BuiltinScalarFunction::OctetLength => Self::OctetLength, BuiltinScalarFunction::Concat => Self::Concat, - BuiltinScalarFunction::Lower => Self::Lower, - BuiltinScalarFunction::Upper => Self::Upper, - BuiltinScalarFunction::Trim => Self::Trim, - BuiltinScalarFunction::Ltrim => Self::Ltrim, - BuiltinScalarFunction::Rtrim => Self::Rtrim, - BuiltinScalarFunction::ArrayReplace => Self::ArrayReplace, - BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, - BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, - BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, - BuiltinScalarFunction::Ascii => Self::Ascii, - BuiltinScalarFunction::BitLength => Self::BitLength, - BuiltinScalarFunction::Btrim => Self::Btrim, - BuiltinScalarFunction::CharacterLength => Self::CharacterLength, - BuiltinScalarFunction::Chr => Self::Chr, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, - BuiltinScalarFunction::Left => Self::Left, - BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Uuid => Self::Uuid, - BuiltinScalarFunction::Repeat => Self::Repeat, - BuiltinScalarFunction::Replace => Self::Replace, - BuiltinScalarFunction::Reverse => Self::Reverse, - BuiltinScalarFunction::Right => Self::Right, - BuiltinScalarFunction::Rpad => Self::Rpad, - BuiltinScalarFunction::SplitPart => Self::SplitPart, - BuiltinScalarFunction::StartsWith => Self::StartsWith, - BuiltinScalarFunction::Strpos => Self::Strpos, - BuiltinScalarFunction::Substr => Self::Substr, - BuiltinScalarFunction::ToHex => Self::ToHex, - BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, - BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, - BuiltinScalarFunction::OverLay => Self::OverLay, - BuiltinScalarFunction::Levenshtein => Self::Levenshtein, - BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, - BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) @@ -1669,3 +1622,16 @@ fn encode_scalar_nested_value( _ => unreachable!(), } } + +/// Converts a vector of `Arc`s to `protobuf::Field`s +fn convert_arc_fields_to_proto_fields<'a, I>( + fields: I, +) -> Result, Error> +where + I: IntoIterator>, +{ + fields + .into_iter() + .map(|field| field.as_ref().try_into()) + .collect::, Error>>() +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 184c048c1bdd..aaca4dc48236 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -59,9 +59,12 @@ use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; use chrono::{TimeZone, Utc}; +use datafusion_expr::ScalarFunctionDefinition; use object_store::path::Path; use object_store::ObjectMeta; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { Column::new(&c.name, c.index as usize) @@ -80,9 +83,10 @@ pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -105,20 +109,12 @@ pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { proto .iter() .map(|sort_expr| { - if let Some(expr) = &sort_expr.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; - let options = SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }; - Ok(PhysicalSortExpr { expr, options }) - } else { - Err(proto_error("Unexpected empty physical expression")) - } + parse_physical_sort_expr(sort_expr, registry, input_schema, codec) }) .collect::>>() } @@ -137,23 +133,15 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { - let window_node_expr = proto - .args - .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) - .collect::>>()?; + let codec = DefaultPhysicalExtensionCodec {}; + let window_node_expr = + parse_physical_exprs(&proto.args, registry, input_schema, &codec)?; - let partition_by = proto - .partition_by - .iter() - .map(|p| parse_physical_expr(p, registry, input_schema)) - .collect::>>()?; + let partition_by = + parse_physical_exprs(&proto.partition_by, registry, input_schema, &codec)?; - let order_by = proto - .order_by - .iter() - .map(|o| parse_physical_sort_expr(o, registry, input_schema)) - .collect::>>()?; + let order_by = + parse_physical_sort_exprs(&proto.order_by, registry, input_schema, &codec)?; let window_frame = proto .window_frame @@ -179,6 +167,21 @@ pub fn parse_physical_window_expr( ) } +pub fn parse_physical_exprs<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, +) -> Result>> +where + I: IntoIterator, +{ + protos + .into_iter() + .map(|p| parse_physical_expr(p, registry, input_schema, codec)) + .collect::>>() +} + /// Parses a physical expression from a protobuf. /// /// # Arguments @@ -191,6 +194,7 @@ pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { let expr_type = proto .expr_type @@ -268,17 +272,14 @@ pub fn parse_physical_expr( "expr", input_schema, )?, - e.list - .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) - .collect::, _>>()?, + parse_physical_exprs(&e.list, registry, input_schema, codec)?, &e.negated, input_schema, )?, ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -301,7 +302,7 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -331,11 +332,7 @@ pub fn parse_physical_expr( ) })?; - let args = e - .args - .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) - .collect::, _>>()?; + let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; // TODO Do not create new the ExecutionProps let execution_props = ExecutionProps::new(); @@ -348,19 +345,18 @@ pub fn parse_physical_expr( )? } ExprType::ScalarUdf(e) => { - let udf = registry.udf(e.name.as_str())?; + let udf = match &e.fun_definition { + Some(buf) => codec.try_decode_udf(&e.name, buf)?, + None => registry.udf(e.name.as_str())?, + }; let signature = udf.signature(); - let scalar_fun = udf.fun().clone(); + let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); - let args = e - .args - .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) - .collect::, _>>()?; + let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; Arc::new(ScalarFunctionExpr::new( e.name.as_str(), - scalar_fun, + scalar_fun_def, args, convert_required!(e.return_type)?, None, @@ -394,7 +390,8 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, ) -> Result> { - expr.map(|e| parse_physical_expr(e, registry, input_schema)) + let codec = DefaultPhysicalExtensionCodec {}; + expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -439,11 +436,13 @@ pub fn parse_protobuf_hash_partitioning( ) -> Result> { match partitioning { Some(hash_part) => { - let expr = hash_part - .hash_expr - .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) - .collect::>, _>>()?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + registry, + input_schema, + &codec, + )?; Ok(Some(Partitioning::Hash( expr, @@ -503,24 +502,13 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { - let sort_expr = node_collection - .physical_sort_expr_nodes - .iter() - .map(|node| { - let expr = node - .expr - .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, &schema)) - .unwrap()?; - Ok(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !node.asc, - nulls_first: node.nulls_first, - }, - }) - }) - .collect::>>()?; + let codec = DefaultPhysicalExtensionCodec {}; + let sort_expr = parse_physical_sort_exprs( + &node_collection.physical_sort_expr_nodes, + registry, + &schema, + &codec, + )?; output_ordering.push(sort_expr); } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 004948da938f..00dacffe06c2 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -20,6 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::from_proto::parse_physical_window_expr; +use self::to_proto::serialize_physical_expr; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; @@ -47,7 +48,7 @@ use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; -use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; @@ -138,7 +139,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr(expr, registry, input.schema().as_ref())?, + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + )?, name.to_string(), )) }) @@ -156,7 +162,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .transpose()? .ok_or_else(|| { @@ -208,6 +219,7 @@ impl AsExecutionPlan for PhysicalPlanNode { expr, registry, base_config.file_schema.as_ref(), + extension_codec, ) }) .transpose()?; @@ -254,7 +266,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .hash_expr .iter() .map(|e| { - parse_physical_expr(e, registry, input.schema().as_ref()) + parse_physical_expr( + e, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>, _>>()?; @@ -329,7 +346,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>>>()?; @@ -396,8 +418,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -406,8 +433,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -434,7 +466,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| { expr.expr .as_ref() - .map(|e| parse_physical_expr(e, registry, &physical_schema)) + .map(|e| { + parse_physical_expr( + e, + registry, + &physical_schema, + extension_codec, + ) + }) .transpose() }) .collect::, _>>()?; @@ -451,9 +490,9 @@ impl AsExecutionPlan for PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); let ordering_req: Vec = agg_node.ordering_req.iter() - .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { match func { AggregateFunction::AggrFunction(i) => { @@ -524,11 +563,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -555,6 +596,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -635,11 +677,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -666,6 +710,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -691,6 +736,7 @@ impl AsExecutionPlan for PhysicalPlanNode { &sym_join.left_sort_exprs, registry, &left_schema, + extension_codec, )?; let left_sort_exprs = if left_sort_exprs.is_empty() { None @@ -702,6 +748,7 @@ impl AsExecutionPlan for PhysicalPlanNode { &sym_join.right_sort_exprs, registry, &right_schema, + extension_codec, )?; let right_sort_exprs = if right_sort_exprs.is_empty() { None @@ -805,7 +852,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -852,7 +899,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -916,6 +963,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -972,14 +1020,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .sort_order .as_ref() .map(|collection| { - collection - .physical_sort_expr_nodes - .iter() - .map(|proto| { - parse_physical_sort_expr(proto, registry, &sink_schema) - .map(Into::into) - }) - .collect::>>() + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) }) .transpose()?; Ok(Arc::new(FileSinkExec::new( @@ -1003,14 +1050,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .sort_order .as_ref() .map(|collection| { - collection - .physical_sort_expr_nodes - .iter() - .map(|proto| { - parse_physical_sort_expr(proto, registry, &sink_schema) - .map(Into::into) - }) - .collect::>>() + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) }) .transpose()?; Ok(Arc::new(FileSinkExec::new( @@ -1034,14 +1080,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .sort_order .as_ref() .map(|collection| { - collection - .physical_sort_expr_nodes - .iter() - .map(|proto| { - parse_physical_sort_expr(proto, registry, &sink_schema) - .map(Into::into) - }) - .collect::>>() + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) }) .transpose()?; Ok(Arc::new(FileSinkExec::new( @@ -1088,7 +1133,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| expr.0.clone().try_into()) + .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1128,7 +1173,10 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(exec.predicate().clone().try_into()?), + expr: Some(serialize_physical_expr( + exec.predicate().clone(), + extension_codec, + )?), default_filter_selectivity: exec.default_selectivity() as u32, }, ))), @@ -1183,8 +1231,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1196,7 +1244,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1254,8 +1305,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1267,7 +1318,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1304,7 +1358,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1321,7 +1378,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1423,14 +1483,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1512,7 +1572,7 @@ impl AsExecutionPlan for PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| pred.clone().try_into()) + .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1559,7 +1619,9 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.clone().try_into()) + .map(|expr| { + serialize_physical_expr(expr.clone(), extension_codec) + }) .collect::>>()?, partition_count: *partition_count as u64, }) @@ -1592,7 +1654,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1658,7 +1723,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1695,7 +1763,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1743,7 +1814,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1773,7 +1844,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1816,7 +1887,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ba77b30b7f8d..e1574f48fb8e 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, @@ -31,13 +30,10 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::datasource::{ - file_format::csv::CsvSink, - file_format::json::JsonSink, - listing::{FileRange, PartitionedFile}, - physical_plan::FileScanConfig, - physical_plan::FileSinkConfig, -}; + +use datafusion_expr::ScalarFunctionDefinition; + +use crate::logical_plan::csv_writer_options_to_proto; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; @@ -46,16 +42,24 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, - Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, - Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, - TryCastExpr, Variance, VariancePop, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, + Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::{ + datasource::{ + file_format::{csv::CsvSink, json::JsonSink}, + listing::{FileRange, PartitionedFile}, + physical_plan::{FileScanConfig, FileSinkConfig}, + }, + physical_plan::expressions::LikeExpr, +}; use datafusion_common::config::{ ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, TableParquetOptions, @@ -68,22 +72,17 @@ use datafusion_common::{ DataFusionError, JoinSide, Result, }; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { - let expressions: Vec = a - .expressions() - .iter() - .map(|e| e.clone().try_into()) - .collect::>>()?; + let codec = DefaultPhysicalExtensionCodec {}; + let expressions = serialize_physical_exprs(a.expressions(), &codec)?; - let ordering_req: Vec = a - .order_bys() - .unwrap_or(&[]) - .iter() - .map(|e| e.clone().try_into()) - .collect::>>()?; + let ordering_req = a.order_bys().unwrap_or(&[]).to_vec(); + let ordering_req = serialize_physical_sort_exprs(ordering_req, &codec)?; if let Some(a) = a.as_any().downcast_ref::() { let name = a.fun().name().to_string(); @@ -237,23 +236,13 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; + let codec = DefaultPhysicalExtensionCodec {}; + let args = serialize_physical_exprs(args, &codec)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by().to_vec(), &codec)?; - let args = args - .into_iter() - .map(|e| e.try_into()) - .collect::>>()?; - - let partition_by = window_expr - .partition_by() - .iter() - .map(|p| p.clone().try_into()) - .collect::>>()?; - - let order_by = window_expr - .order_by() - .iter() - .map(|o| o.clone().try_into()) - .collect::>>()?; + let order_by = + serialize_physical_sort_exprs(window_expr.order_by().to_vec(), &codec)?; let window_frame: protobuf::WindowFrame = window_frame .as_ref() @@ -374,195 +363,274 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -impl TryFrom> for protobuf::PhysicalExprNode { - type Error = DataFusionError; +pub fn serialize_physical_sort_exprs( + sort_exprs: I, + codec: &dyn PhysicalExtensionCodec, +) -> Result, DataFusionError> +where + I: IntoIterator, +{ + sort_exprs + .into_iter() + .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec)) + .collect() +} - fn try_from(value: Arc) -> Result { - let expr = value.as_any(); +pub fn serialize_physical_sort_expr( + sort_expr: PhysicalSortExpr, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let PhysicalSortExpr { expr, options } = sort_expr; + let expr = serialize_physical_expr(expr, codec)?; + Ok(PhysicalSortExprNode { + expr: Some(Box::new(expr)), + asc: !options.descending, + nulls_first: options.nulls_first, + }) +} - if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Column( - protobuf::PhysicalColumn { - name: expr.name().to_string(), - index: expr.index() as u32, - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(expr.left().to_owned().try_into()?)), - r: Some(Box::new(expr.right().to_owned().try_into()?)), - op: format!("{:?}", expr.op()), - }); +pub fn serialize_physical_exprs( + values: I, + codec: &dyn PhysicalExtensionCodec, +) -> Result, DataFusionError> +where + I: IntoIterator>, +{ + values + .into_iter() + .map(|value| serialize_physical_expr(value, codec)) + .collect() +} - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( - binary_expr, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::Case( - Box::new( - protobuf::PhysicalCaseNode { - expr: expr - .expr() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, - Self::Error, - >>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ), - ), - ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( - Box::new(protobuf::PhysicalNot { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::InList( - Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - list: expr - .list() - .iter() - .map(|a| a.clone().try_into()) - .collect::, - Self::Error, +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]) +pub fn serialize_physical_expr( + value: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expr = value.as_any(); + + if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, + }, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { + l: Some(Box::new(serialize_physical_expr( + expr.left().clone(), + codec, + )?)), + r: Some(Box::new(serialize_physical_expr( + expr.right().clone(), + codec, + )?)), + op: format!("{:?}", expr.op()), + }); + + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( + binary_expr, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .map(|exp| { + serialize_physical_expr(exp.clone(), codec) + .map(Box::new) + }) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr, codec) + }) + .collect::, + DataFusionError, >>()?, - negated: expr.negated(), - }, - ), + else_expr: expr + .else_expr() + .map(|a| { + serialize_physical_expr(a.clone(), codec) + .map(Box::new) + }) + .transpose()?, + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( - Box::new(protobuf::PhysicalNegativeNode { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, - )), - }) - } else if let Some(cast) = expr.downcast_ref::() { + ), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( + protobuf::PhysicalNot { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + list: serialize_physical_exprs(expr.list().to_vec(), codec)?, + negated: expr.negated(), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( + protobuf::PhysicalNegativeNode { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(lit) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( + lit.value().try_into()?, + )), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let args = serialize_physical_exprs(expr.args().to_vec(), codec)?; + if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { + let fun: protobuf::ScalarFunction = (&fun).try_into()?; + Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( - protobuf::PhysicalCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), + fun: fun.into(), + args, + return_type: Some(expr.return_type().try_into()?), }, - ))), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( - Box::new(protobuf::PhysicalTryCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), - }), )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let args: Vec = expr - .args() - .iter() - .map(|e| e.to_owned().try_into()) - .collect::, _>>()?; - if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { - let fun: protobuf::ScalarFunction = (&fun).try_into()?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::ScalarFunction( - protobuf::PhysicalScalarFunctionNode { - name: expr.name().to_string(), - fun: fun.into(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - ), - ), - }) - } else { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( - protobuf::PhysicalScalarUdfNode { - name: expr.name().to_string(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - )), - }) + } else { + let mut buf = Vec::new(); + match expr.fun() { + ScalarFunctionDefinition::UDF(udf) => { + codec.try_encode_udf(udf, &mut buf)?; + } + _ => { + return not_impl_err!( + "Proto serialization error: Trying to serialize a unresolved function" + ); + } } - } else if let Some(expr) = expr.downcast_ref::() { + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr( - Box::new(protobuf::PhysicalLikeExprNode { - negated: expr.negated(), - case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - pattern: Some(Box::new(expr.pattern().to_owned().try_into()?)), - }), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( + protobuf::PhysicalScalarUdfNode { + name: expr.name().to_string(), + args, + fun_definition, + return_type: Some(expr.return_type().try_into()?), + }, )), }) - } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern().to_owned(), + codec, + )?)), + }, + ))), + }) + } else { + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, + codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(when_expr.clone().try_into()?), - then_expr: Some(then_expr.clone().try_into()?), + when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), + then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), }) } @@ -683,6 +751,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { fn try_from( conf: &FileScanConfig, ) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let file_groups = conf .file_groups .iter() @@ -691,18 +760,8 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let mut output_orderings = vec![]; for order in &conf.output_ordering { - let expr_node_vec = order - .iter() - .map(|sort_expr| { - let expr = sort_expr.expr.clone().try_into()?; - Ok(PhysicalSortExprNode { - expr: Some(Box::new(expr)), - asc: !sort_expr.options.descending, - nulls_first: sort_expr.options.nulls_first, - }) - }) - .collect::>>()?; - output_orderings.push(expr_node_vec) + let ordering = serialize_physical_sort_exprs(order.to_vec(), &codec)?; + output_orderings.push(ordering) } // Fields must be added to the schema so that they can persist in the protobuf @@ -757,10 +816,11 @@ impl TryFrom>> for protobuf::MaybeFilter { type Error = DataFusionError; fn try_from(expr: Option>) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(expr.try_into()?), + expr: Some(serialize_physical_expr(expr, &codec)?), }), } } @@ -786,8 +846,9 @@ impl TryFrom for protobuf::PhysicalSortExprNode { type Error = DataFusionError; fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { + let codec = DefaultPhysicalExtensionCodec {}; Ok(PhysicalSortExprNode { - expr: Some(Box::new(sort_expr.expr.try_into()?)), + expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, &codec)?)), asc: !sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, }) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 479f80fbdddf..3a47f556c0f3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,8 +34,8 @@ use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DFField, DFSchema, + DFSchemaRef, DataFusionError, FileType, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -44,8 +44,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - col, create_udaf, lit, Accumulator, AggregateFunction, - BuiltinScalarFunction::{Sqrt, Substr}, + col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, ColumnarValue, Expr, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, @@ -60,6 +59,7 @@ use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; +use datafusion::execution::FunctionRegistry; use prost::Message; #[cfg(feature = "json")] @@ -314,10 +314,9 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new(); let input = create_csv_scan(&ctx).await?; - - let mut table_options = - TableOptions::default_from_session_config(ctx.state().config_options()); - table_options.set("csv.delimiter", ";")?; + let mut table_options = ctx.copied_table_options(); + table_options.set_file_format(FileType::CSV); + table_options.set("format.delimiter", ";")?; let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -605,6 +604,14 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(3), lit(3), lit(2), lit(3), lit(1)]), lit(3), ), + array_replace(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + array_replace_n( + make_array(vec![lit(1), lit(2), lit(3)]), + lit(2), + lit(4), + lit(1), + ), + array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), ]; // ensure expressions created with the expr api can be round tripped @@ -1856,17 +1863,28 @@ fn roundtrip_cube() { #[test] fn roundtrip_substr() { + let ctx = SessionContext::new(); + + let fun = ctx + .state() + .udf("substr") + .map_err(|e| { + internal_datafusion_err!("Unable to find expected 'substr' function: {e:?}") + }) + .unwrap(); + // substr(string, position) - let test_expr = - Expr::ScalarFunction(ScalarFunction::new(Substr, vec![col("col"), lit(1_i64)])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + fun.clone(), + vec![col("col"), lit(1_i64)], + )); // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new( - Substr, + let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new_udf( + fun, vec![col("col"), lit(1_i64), lit(1_i64)], )); - let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx.clone()); roundtrip_expr_test(test_expr_with_count, ctx); } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f0c6286a19d..4924128ae190 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -32,7 +33,7 @@ use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; -use datafusion::execution::context::ExecutionProps; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; @@ -49,7 +50,6 @@ use datafusion::physical_plan::expressions::{ NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -73,13 +73,19 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, - SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -603,14 +609,11 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let execution_props = ExecutionProps::new(); - - let fun_expr = - functions::create_physical_fun(&BuiltinScalarFunction::Sin, &execution_props)?; + let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Sin); let expr = ScalarFunctionExpr::new( "sin", - fun_expr, + fun_def, vec![col("a", &schema)?], DataType::Float64, None, @@ -646,9 +649,11 @@ fn roundtrip_scalar_udf() -> Result<()> { scalar_fn.clone(), ); + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let expr = ScalarFunctionExpr::new( "dummy", - scalar_fn, + fun_def, vec![col("a", &schema)?], DataType::Int64, None, @@ -665,6 +670,134 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), ctx) } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + #[derive(Debug)] + struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, + } + + impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } + } + + /// Implement the ScalarUDFImpl trait for MyRegexUdf + impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } + + #[derive(Debug)] + pub struct ScalarUDFExtensionCodec {} + + impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode regex_udf: {}", + err + )) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + } + Ok(()) + } + } + + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = ScalarFunctionExpr::new( + udf.name(), + ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + vec![], + DataType::Int32, + None, + false, + ); + let fmt_expr = format!("{test_expr:?}"); + let ctx = SessionContext::new(); + + ctx.register_udf(udf.clone()); + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::PhysicalExprNode = + match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { + Ok(proto) => proto, + Err(e) => panic!("failed to serialize expr: {e:?}"), + }; + let field_a = Field::new("a", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![field_a])); + let round_trip = + parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); + assert_eq!(fmt_expr, format!("{round_trip:?}")); +} #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d4a1ab44a6ea..972382b841d5 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -260,10 +260,7 @@ fn test_expression_serialization_roundtrip() { let lit = Expr::Literal(ScalarValue::Utf8(None)); for builtin_fun in BuiltinScalarFunction::iter() { // default to 4 args (though some exprs like substr have error checking) - let num_args = match builtin_fun { - BuiltinScalarFunction::Substr => 3, - _ => 4, - }; + let num_args = 4; let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index ca2c1a240c21..b9f6dc259eb7 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -49,6 +49,7 @@ strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } +datafusion-functions = { workspace = true, default-features = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ffc951a6fa66..582404b29749 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -34,8 +34,6 @@ use sqlparser::ast::{ use std::str::FromStr; use strum::IntoEnumIterator; -use super::arrow_cast::ARROW_CAST_NAME; - /// Suggest a valid function based on an invalid input function name pub fn suggest_valid_function( input_function_name: &str, @@ -249,12 +247,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { null_treatment, ))); }; - - // Special case arrow_cast (as its type is dependent on its argument value) - if name == ARROW_CAST_NAME { - let args = self.function_args_to_expr(args, schema, planner_context)?; - return super::arrow_cast::create_arrow_cast(args, schema); - } } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a6f1c78c7250..064578ad51d6 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,21 +15,11 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod arrow_cast; -mod binary_op; -mod function; -mod grouping_set; -mod identifier; -mod json_access; -mod order_by; -mod subquery; -mod substring; -mod unary_op; -mod value; - -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; use arrow_schema::TimeUnit; +use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; +use sqlparser::parser::ParserError::ParserError; + use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, DFSchema, Result, ScalarValue, @@ -39,10 +29,22 @@ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, - Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, + Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Literal, Operator, + TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; -use sqlparser::parser::ParserError::ParserError; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + +mod binary_op; +mod function; +mod grouping_set; +mod identifier; +mod json_access; +mod order_by; +mod subquery; +mod substring; +mod unary_op; +mod value; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn sql_expr_to_logical_expr( @@ -603,18 +605,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let args = values .into_iter() - .map(|value| { - self.sql_expr_to_logical_expr(value, input_schema, planner_context) + .enumerate() + .map(|(i, value)| { + let args = if let SQLExpr::Named { expr, name } = value { + [ + name.value.lit(), + self.sql_expr_to_logical_expr( + *expr, + input_schema, + planner_context, + )?, + ] + } else { + [ + format!("c{i}").lit(), + self.sql_expr_to_logical_expr( + value, + input_schema, + planner_context, + )?, + ] + }; + + Ok(args) }) - .collect::>>()?; - let struct_func = self + .collect::>>()? + .into_iter() + .flatten() + .collect(); + + let named_struct_func = self .context_provider - .get_function_meta("struct") + .get_function_meta("named_struct") .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'struct' function") - })?; + internal_datafusion_err!("Unable to find expected 'named_struct' function") + })?; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - struct_func, + named_struct_func, args, ))) } @@ -744,13 +772,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = match trim_where { - Some(TrimWhereField::Leading) => BuiltinScalarFunction::Ltrim, - Some(TrimWhereField::Trailing) => BuiltinScalarFunction::Rtrim, - Some(TrimWhereField::Both) => BuiltinScalarFunction::Btrim, - None => BuiltinScalarFunction::Trim, - }; - let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let args = match (trim_what, trim_characters) { (Some(to_trim), None) => { @@ -775,7 +796,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } (None, None) => Ok(vec![arg]), }?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + + let fun_name = match trim_where { + Some(TrimWhereField::Leading) => "ltrim", + Some(TrimWhereField::Trailing) => "rtrim", + Some(TrimWhereField::Both) => "btrim", + None => "trim", + }; + let fun = self + .context_provider + .get_function_meta(fun_name) + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected '{fun_name}' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_overlay_to_expr( @@ -787,7 +822,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::OverLay; + let fun = self + .context_provider + .get_function_meta("overlay") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'overlay' function") + })?; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let what_arg = self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; @@ -801,7 +841,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_position_to_expr( &self, @@ -810,12 +850,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::Strpos; + let fun = self + .context_provider + .get_function_meta("strpos") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'strpos' function") + })?; let substr = self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_agg_with_filter_to_expr( &self, diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index a5d1abf0f265..f58c6f3b94d0 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,10 +16,10 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::plan_err; +use datafusion_common::{internal_datafusion_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{BuiltinScalarFunction, Expr}; +use datafusion_expr::Expr; use sqlparser::ast::Expr as SQLExpr; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -68,9 +68,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Substr, - args, - ))) + let fun = self + .context_provider + .get_function_meta("substr") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'substr' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index e8e07eebe22d..12d6a4669634 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -42,5 +42,4 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; -pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index effc1d096cfd..67fa1325eea7 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -17,21 +17,20 @@ //! [`DFParser`]: DataFusion SQL Parser based on [`sqlparser`] +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::str::FromStr; + use datafusion_common::parsers::CompressionTypeVariant; -use sqlparser::ast::{OrderByExpr, Query, Value}; -use sqlparser::tokenizer::Word; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, ObjectName, Statement as SQLStatement, - TableConstraint, + ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, + Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, parser::{Parser, ParserError}, - tokenizer::{Token, TokenWithLocation, Tokenizer}, + tokenizer::{Token, TokenWithLocation, Tokenizer, Word}, }; -use std::collections::VecDeque; -use std::fmt; -use std::{collections::HashMap, str::FromStr}; // Use `Parser::expected` instead, if possible macro_rules! parser_err { @@ -102,6 +101,12 @@ pub struct CopyToStatement { pub source: CopyToSource, /// The URL to where the data is heading pub target: String, + /// Partition keys + pub partitioned_by: Vec, + /// Indicates whether there is a header row (e.g. CSV) + pub has_header: bool, + /// File type (Parquet, NDJSON, CSV etc.) + pub stored_as: Option, /// Target specific options pub options: Vec<(String, Value)>, } @@ -111,15 +116,27 @@ impl fmt::Display for CopyToStatement { let Self { source, target, + partitioned_by, + stored_as, options, + .. } = self; write!(f, "COPY {source} TO {target}")?; + if let Some(file_type) = stored_as { + write!(f, " STORED AS {}", file_type)?; + } + if !partitioned_by.is_empty() { + write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; + } + + if self.has_header { + write!(f, " WITH HEADER ROW")?; + } if !options.is_empty() { let opts: Vec<_> = options.iter().map(|(k, v)| format!("{k} {v}")).collect(); - // print them in sorted order - write!(f, " ({})", opts.join(", "))?; + write!(f, " OPTIONS ({})", opts.join(", "))?; } Ok(()) @@ -158,7 +175,7 @@ pub(crate) type LexOrdering = Vec; /// [ WITH HEADER ROW ] /// [ DELIMITER ] /// [ COMPRESSION TYPE ] -/// [ PARTITIONED BY () ] +/// [ PARTITIONED BY ( | ) ] /// [ WITH ORDER () /// [ OPTIONS () ] /// LOCATION @@ -243,6 +260,15 @@ impl fmt::Display for Statement { } } +fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { + if field.is_some() { + return Err(ParserError::ParserError(format!( + "{name} specified more than once", + ))); + } + Ok(()) +} + /// Datafusion SQL Parser based on [`sqlparser`] /// /// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. @@ -252,7 +278,7 @@ impl fmt::Display for Statement { /// `CREATE EXTERNAL TABLE` have special syntax in DataFusion. See /// [`Statement`] for a list of this special syntax pub struct DFParser<'a> { - parser: Parser<'a>, + pub parser: Parser<'a>, } impl<'a> DFParser<'a> { @@ -370,21 +396,79 @@ impl<'a> DFParser<'a> { CopyToSource::Relation(table_name) }; - self.parser.expect_keyword(Keyword::TO)?; + #[derive(Default)] + struct Builder { + stored_as: Option, + target: Option, + partitioned_by: Option>, + has_header: Option, + options: Option>, + } - let target = self.parser.parse_literal_string()?; + let mut builder = Builder::default(); - // check for options in parens - let options = if self.parser.peek_token().token == Token::LParen { - self.parse_value_options()? - } else { - vec![] + loop { + if let Some(keyword) = self.parser.parse_one_of_keywords(&[ + Keyword::STORED, + Keyword::TO, + Keyword::PARTITIONED, + Keyword::OPTIONS, + Keyword::WITH, + ]) { + match keyword { + Keyword::STORED => { + self.parser.expect_keyword(Keyword::AS)?; + ensure_not_set(&builder.stored_as, "STORED AS")?; + builder.stored_as = Some(self.parse_file_format()?); + } + Keyword::TO => { + ensure_not_set(&builder.target, "TO")?; + builder.target = Some(self.parser.parse_literal_string()?); + } + Keyword::WITH => { + self.parser.expect_keyword(Keyword::HEADER)?; + self.parser.expect_keyword(Keyword::ROW)?; + ensure_not_set(&builder.has_header, "WITH HEADER ROW")?; + builder.has_header = Some(true); + } + Keyword::PARTITIONED => { + self.parser.expect_keyword(Keyword::BY)?; + ensure_not_set(&builder.partitioned_by, "PARTITIONED BY")?; + builder.partitioned_by = Some(self.parse_partitions()?); + } + Keyword::OPTIONS => { + ensure_not_set(&builder.options, "OPTIONS")?; + builder.options = Some(self.parse_value_options()?); + } + _ => { + unreachable!() + } + } + } else { + let token = self.parser.next_token(); + if token == Token::EOF || token == Token::SemiColon { + break; + } else { + return Err(ParserError::ParserError(format!( + "Unexpected token {token}" + ))); + } + } + } + + let Some(target) = builder.target else { + return Err(ParserError::ParserError( + "Missing TO clause in COPY statement".into(), + )); }; Ok(Statement::CopyTo(CopyToStatement { source, target, - options, + partitioned_by: builder.partitioned_by.unwrap_or(vec![]), + has_header: builder.has_header.unwrap_or(false), + stored_as: builder.stored_as, + options: builder.options.unwrap_or(vec![]), })) } @@ -609,7 +693,7 @@ impl<'a> DFParser<'a> { self.parser .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parser.parse_object_name(true)?; - let (columns, constraints) = self.parse_columns()?; + let (mut columns, constraints) = self.parse_columns()?; #[derive(Default)] struct Builder { @@ -624,15 +708,6 @@ impl<'a> DFParser<'a> { } let mut builder = Builder::default(); - fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { - if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); - } - Ok(()) - } - loop { if let Some(keyword) = self.parser.parse_one_of_keywords(&[ Keyword::STORED, @@ -679,7 +754,30 @@ impl<'a> DFParser<'a> { Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; ensure_not_set(&builder.table_partition_cols, "PARTITIONED BY")?; - builder.table_partition_cols = Some(self.parse_partitions()?); + // Expects either list of column names (col_name [, col_name]*) + // or list of column definitions (col_name datatype [, col_name datatype]* ) + // use the token after the name to decide which parsing rule to use + // Note that mixing both names and definitions is not allowed + let peeked = self.parser.peek_nth_token(2); + if peeked == Token::Comma || peeked == Token::RParen { + // list of column names + builder.table_partition_cols = Some(self.parse_partitions()?) + } else { + // list of column defs + let (cols, cons) = self.parse_columns()?; + builder.table_partition_cols = Some( + cols.iter().map(|col| col.name.to_string()).collect(), + ); + + columns.extend(cols); + + if !cons.is_empty() { + return Err(ParserError::ParserError( + "Constraints on Partition Columns are not supported" + .to_string(), + )); + } + } } Keyword::OPTIONS => { ensure_not_set(&builder.options, "OPTIONS")?; @@ -1092,9 +1190,37 @@ mod tests { }); expect_parse_ok(sql, expected)?; - // Error cases: partition column does not support type + // positive case: column definiton allowed in 'partition by' clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: "t".into(), + columns: vec![ + make_column_def("c1", DataType::Int(None)), + make_column_def("p1", DataType::Int(None)), + ], + file_type: "CSV".to_string(), + has_header: false, + delimiter: ',', + location: "foo.csv".into(), + table_partition_cols: vec!["p1".to_string()], + order_exprs: vec![], + if_not_exists: false, + file_compression_type: UNCOMPRESSED, + unbounded: false, + options: HashMap::new(), + constraints: vec![], + }); + expect_parse_ok(sql, expected)?; + + // negative case: mixed column defs and column names in `PARTITIONED BY` clause + let sql = + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; + expect_parse_error(sql, "sql parser error: Expected a data type name, found: )"); + + // negative case: mixed column defs and column names in `PARTITIONED BY` clause + let sql = + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 int) LOCATION 'foo.csv'"; expect_parse_error(sql, "sql parser error: Expected ',' or ')' after partition definition, found: int"); // positive case: additional options (one entry) can be specified @@ -1321,10 +1447,13 @@ mod tests { #[test] fn copy_to_table_to_table() -> Result<(), ParserError> { // positive case - let sql = "COPY foo TO bar"; + let sql = "COPY foo TO bar STORED AS CSV"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("CSV".to_owned()), options: vec![], }); @@ -1335,10 +1464,22 @@ mod tests { #[test] fn explain_copy_to_table_to_table() -> Result<(), ParserError> { let cases = vec![ - ("EXPLAIN COPY foo TO bar", false, false), - ("EXPLAIN ANALYZE COPY foo TO bar", true, false), - ("EXPLAIN VERBOSE COPY foo TO bar", false, true), - ("EXPLAIN ANALYZE VERBOSE COPY foo TO bar", true, true), + ("EXPLAIN COPY foo TO bar STORED AS PARQUET", false, false), + ( + "EXPLAIN ANALYZE COPY foo TO bar STORED AS PARQUET", + true, + false, + ), + ( + "EXPLAIN VERBOSE COPY foo TO bar STORED AS PARQUET", + false, + true, + ), + ( + "EXPLAIN ANALYZE VERBOSE COPY foo TO bar STORED AS PARQUET", + true, + true, + ), ]; for (sql, analyze, verbose) in cases { println!("sql: {sql}, analyze: {analyze}, verbose: {verbose}"); @@ -1346,6 +1487,9 @@ mod tests { let expected_copy = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("PARQUET".to_owned()), options: vec![], }); let expected = Statement::Explain(ExplainStatement { @@ -1375,10 +1519,13 @@ mod tests { panic!("Expected query, got {statement:?}"); }; - let sql = "COPY (SELECT 1) TO bar"; + let sql = "COPY (SELECT 1) TO bar STORED AS CSV WITH HEADER ROW"; let expected = Statement::CopyTo(CopyToStatement { source: CopyToSource::Query(query), target: "bar".to_string(), + partitioned_by: vec![], + has_header: true, + stored_as: Some("CSV".to_owned()), options: vec![], }); assert_eq!(verified_stmt(sql), expected); @@ -1387,10 +1534,31 @@ mod tests { #[test] fn copy_to_options() -> Result<(), ParserError> { - let sql = "COPY foo TO bar (row_group_size 55)"; + let sql = "COPY foo TO bar STORED AS CSV OPTIONS (row_group_size 55)"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("CSV".to_owned()), + options: vec![( + "row_group_size".to_string(), + Value::Number("55".to_string(), false), + )], + }); + assert_eq!(verified_stmt(sql), expected); + Ok(()) + } + + #[test] + fn copy_to_partitioned_by() -> Result<(), ParserError> { + let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS (row_group_size 55)"; + let expected = Statement::CopyTo(CopyToStatement { + source: object_name("foo"), + target: "bar".to_string(), + partitioned_by: vec!["a".to_string()], + has_header: false, + stored_as: Some("CSV".to_owned()), options: vec![( "row_group_size".to_string(), Value::Number("55".to_string(), false), @@ -1404,24 +1572,24 @@ mod tests { fn copy_to_multi_options() -> Result<(), ParserError> { // order of options is preserved let sql = - "COPY foo TO bar (format parquet, row_group_size 55, compression snappy)"; + "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy)"; let expected_options = vec![ ( - "format".to_string(), - Value::UnQuotedString("parquet".to_string()), - ), - ( - "row_group_size".to_string(), + "format.row_group_size".to_string(), Value::Number("55".to_string(), false), ), ( - "compression".to_string(), + "format.compression".to_string(), Value::UnQuotedString("snappy".to_string()), ), ]; - let options = if let Statement::CopyTo(copy_to) = verified_stmt(sql) { + let mut statements = DFParser::parse_sql(sql).unwrap(); + assert_eq!(statements.len(), 1); + let only_statement = statements.pop_front().unwrap(); + + let options = if let Statement::CopyTo(copy_to) = only_statement { copy_to.options } else { panic!("Expected copy"); @@ -1460,7 +1628,10 @@ mod tests { } let only_statement = statements.pop_front().unwrap(); - assert_eq!(canonical, only_statement.to_string()); + assert_eq!( + canonical.to_uppercase(), + only_statement.to_string().to_uppercase() + ); only_statement } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ea8edd0771c8..eda8398c432b 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -25,6 +25,7 @@ use datafusion_common::{ }; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, + Operator, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, @@ -221,37 +222,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let skip = match skip { - Some(skip_expr) => match self.sql_to_expr( - skip_expr.value, - input.schema(), - &mut PlannerContext::new(), - )? { - Expr::Literal(ScalarValue::Int64(Some(s))) => { - if s < 0 { - return plan_err!("Offset must be >= 0, '{s}' was provided."); - } - Ok(s as usize) - } - _ => plan_err!("Unexpected expression in OFFSET clause"), - }?, - _ => 0, - }; + Some(skip_expr) => { + let expr = self.sql_to_expr( + skip_expr.value, + input.schema(), + &mut PlannerContext::new(), + )?; + let n = get_constant_result(&expr, "OFFSET")?; + convert_usize_with_check(n, "OFFSET") + } + _ => Ok(0), + }?; let fetch = match fetch { Some(limit_expr) if limit_expr != sqlparser::ast::Expr::Value(Value::Null) => { - let n = match self.sql_to_expr( + let expr = self.sql_to_expr( limit_expr, input.schema(), &mut PlannerContext::new(), - )? { - Expr::Literal(ScalarValue::Int64(Some(n))) if n >= 0 => { - Ok(n as usize) - } - _ => plan_err!("LIMIT must not be negative"), - }?; - Some(n) + )?; + let n = get_constant_result(&expr, "LIMIT")?; + Some(convert_usize_with_check(n, "LIMIT")?) } _ => None, }; @@ -283,3 +276,47 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } } + +/// Retrieves the constant result of an expression, evaluating it if possible. +/// +/// This function takes an expression and an argument name as input and returns +/// a `Result` indicating either the constant result of the expression or an +/// error if the expression cannot be evaluated. +/// +/// # Arguments +/// +/// * `expr` - An `Expr` representing the expression to evaluate. +/// * `arg_name` - The name of the argument for error messages. +/// +/// # Returns +/// +/// * `Result` - An `Ok` variant containing the constant result if evaluation is successful, +/// or an `Err` variant containing an error message if evaluation fails. +/// +/// tracks a more general solution +fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { + match expr { + Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), + Expr::BinaryExpr(binary_expr) => { + let lhs = get_constant_result(&binary_expr.left, arg_name)?; + let rhs = get_constant_result(&binary_expr.right, arg_name)?; + let res = match binary_expr.op { + Operator::Plus => lhs + rhs, + Operator::Minus => lhs - rhs, + Operator::Multiply => lhs * rhs, + _ => return plan_err!("Unsupported operator for {arg_name} clause"), + }; + Ok(res) + } + _ => plan_err!("Unexpected expression in {arg_name} clause"), + } +} + +/// Converts an `i64` to `usize`, performing a boundary check. +fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { + if n < 0 { + plan_err!("{arg_name} must be >= 0, '{n}' was provided.") + } else { + Ok(n as usize) + } +} diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 412c3b753ed5..7717f75d16b8 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -813,20 +813,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn copy_to_plan(&self, statement: CopyToStatement) -> Result { // determine if source is table or query and handle accordingly let copy_source = statement.source; - let input = match copy_source { + let (input, input_schema, table_ref) = match copy_source { CopyToSource::Relation(object_name) => { - let table_ref = - self.object_name_to_table_reference(object_name.clone())?; - let table_source = self.context_provider.get_table_source(table_ref)?; - LogicalPlanBuilder::scan( - object_name_to_string(&object_name), - table_source, - None, - )? - .build()? + let table_name = object_name_to_string(&object_name); + let table_ref = self.object_name_to_table_reference(object_name)?; + let table_source = + self.context_provider.get_table_source(table_ref.clone())?; + let plan = + LogicalPlanBuilder::scan(table_name, table_source, None)?.build()?; + let input_schema = plan.schema().clone(); + (plan, input_schema, Some(table_ref)) } CopyToSource::Query(query) => { - self.query_to_plan(query, &mut PlannerContext::new())? + let plan = self.query_to_plan(query, &mut PlannerContext::new())?; + let input_schema = plan.schema().clone(); + (plan, input_schema, None) } }; @@ -849,11 +850,57 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return plan_err!("Unsupported Value in COPY statement {}", value); } }; - options.insert(key.to_lowercase(), value_string.to_lowercase()); + if !(key.contains('.') || key == "format") { + // If config does not belong to any namespace, assume it is + // a format option and apply the format prefix for backwards + // compatibility. + + let renamed_key = format!("format.{}", key); + options.insert(renamed_key.to_lowercase(), value_string.to_lowercase()); + } else { + options.insert(key.to_lowercase(), value_string.to_lowercase()); + } } - let file_type = try_infer_file_type(&mut options, &statement.target)?; - let partition_by = take_partition_by(&mut options); + let file_type = if let Some(file_type) = statement.stored_as { + FileType::from_str(&file_type).map_err(|_| { + DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) + })? + } else if let Some(format) = options.remove("format") { + // try to infer file format from the "format" key in options + FileType::from_str(&format) + .map_err(|e| DataFusionError::Configuration(format!("{}", e)))? + } else { + let e = || { + DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) + }; + // try to infer file format from file extension + let extension: &str = &Path::new(&statement.target) + .extension() + .ok_or_else(e)? + .to_str() + .ok_or_else(e)? + .to_lowercase(); + + FileType::from_str(extension).map_err(|e| { + DataFusionError::Configuration(format!( + "{}. Use STORED AS to define file format.", + e + )) + })? + }; + + let partition_by = statement + .partitioned_by + .iter() + .map(|col| input_schema.field_with_name(table_ref.as_ref(), col)) + .collect::>>()? + .into_iter() + .map(|f| f.name().to_owned()) + .collect(); Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -1469,82 +1516,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .is_ok() } } - -/// Infers the file type for a given target based on provided options or file extension. -/// -/// This function tries to determine the file type based on the 'format' option present -/// in the provided options hashmap. If 'format' is not explicitly set, the function attempts -/// to infer the file type from the file extension of the target. It returns an error if neither -/// the format option is set nor the file extension can be determined or parsed. -/// -/// # Arguments -/// -/// * `options` - A mutable reference to a HashMap containing options where the file format -/// might be specified under the 'format' key. -/// * `target` - A string slice representing the path to the file for which the file type needs to be inferred. -/// -/// # Returns -/// -/// Returns `Result` which is Ok if the file type could be successfully inferred, -/// otherwise returns an error in case of failure to determine or parse the file format or extension. -/// -/// # Errors -/// -/// This function returns an error in two cases: -/// - If the 'format' option is not set and the file extension cannot be retrieved from `target`. -/// - If the file extension is found but cannot be converted into a valid string. -/// -pub fn try_infer_file_type( - options: &mut HashMap, - target: &str, -) -> Result { - let explicit_format = options.remove("format"); - let format = match explicit_format { - Some(s) => FileType::from_str(&s), - None => { - // try to infer file format from file extension - let extension: &str = &Path::new(target) - .extension() - .ok_or(DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension!" - .to_string(), - ))? - .to_str() - .ok_or(DataFusionError::Configuration( - "Format not explicitly set and failed to parse file extension!" - .to_string(), - ))? - .to_lowercase(); - - FileType::from_str(extension) - } - }?; - - Ok(format) -} - -/// Extracts and parses the 'partition_by' option from a provided options hashmap. -/// -/// This function looks for a 'partition_by' key in the options hashmap. If found, -/// it splits the value by commas, trims each resulting string, and replaces double -/// single quotes with a single quote. It returns a vector of partition column names. -/// -/// # Arguments -/// -/// * `options` - A mutable reference to a HashMap containing options where 'partition_by' -/// might be specified. -/// -/// # Returns -/// -/// Returns a `Vec` containing partition column names. If the 'partition_by' option -/// is not present, returns an empty vector. -pub fn take_partition_by(options: &mut HashMap) -> Vec { - let partition_by = options.remove("partition_by"); - match partition_by { - Some(part_cols) => part_cols - .split(',') - .map(|s| s.trim().replace("''", "'")) - .collect::>(), - None => vec![], - } -} diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9680177d736f..a29b5014b1ce 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -24,7 +24,9 @@ use datafusion_expr::{ expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, }; -use sqlparser::ast::{self, Function, FunctionArg, Ident}; +use sqlparser::ast::{ + self, Expr as AstExpr, Function, FunctionArg, Ident, UnaryOperator, +}; use super::Unparser; @@ -52,21 +54,64 @@ impl Unparser<'_> { match expr { Expr::InList(InList { expr, - list: _, - negated: _, + list, + negated, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let list_expr = list + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + Ok(ast::Expr::InList { + expr: Box::new(self.expr_to_sql(expr)?), + list: list_expr, + negated: *negated, + }) } - Expr::ScalarFunction(ScalarFunction { .. }) => { - not_impl_err!("Unsupported expression: {expr:?}") + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let func_name = func_def.name(); + + let args = args + .iter() + .map(|e| { + if matches!(e, Expr::Wildcard { qualifier: None }) { + Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) + } else { + self.expr_to_sql(e).map(|e| { + FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) + }) + } + }) + .collect::>>()?; + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args, + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: vec![], + })) } Expr::Between(Between { expr, - negated: _, - low: _, - high: _, + negated, + low, + high, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_low = self.expr_to_sql(low)?; + let sql_high = self.expr_to_sql(high)?; + Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( + sql_parser_expr, + *negated, + sql_low, + sql_high, + )))) } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { @@ -78,10 +123,38 @@ impl Unparser<'_> { } Expr::Case(Case { expr, - when_then_expr: _, - else_expr: _, + when_then_expr, + else_expr, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let conditions = when_then_expr + .iter() + .map(|(w, _)| self.expr_to_sql(w)) + .collect::>>()?; + let results = when_then_expr + .iter() + .map(|(_, t)| self.expr_to_sql(t)) + .collect::>>()?; + let operand = match expr.as_ref() { + Some(e) => match self.expr_to_sql(e) { + Ok(sql_expr) => Some(Box::new(sql_expr)), + Err(_) => None, + }, + None => None, + }; + let else_result = match else_expr.as_ref() { + Some(e) => match self.expr_to_sql(e) { + Ok(sql_expr) => Some(Box::new(sql_expr)), + Err(_) => None, + }, + None => None, + }; + + Ok(ast::Expr::Case { + operand, + conditions, + results, + else_result, + }) } Expr::Cast(Cast { expr, data_type }) => { let inner_expr = self.expr_to_sql(expr)?; @@ -104,14 +177,17 @@ impl Unparser<'_> { not_impl_err!("Unsupported expression: {expr:?}") } Expr::Like(Like { - negated: _, + negated, expr, - pattern: _, - escape_char: _, + pattern, + escape_char, case_insensitive: _, - }) => { - not_impl_err!("Unsupported expression: {expr:?}") - } + }) => Ok(ast::Expr::Like { + negated: *negated, + expr: Box::new(self.expr_to_sql(expr)?), + pattern: Box::new(self.expr_to_sql(pattern)?), + escape_char: *escape_char, + }), Expr::AggregateFunction(agg) => { let func_name = if let AggregateFunctionDefinition::BuiltIn(built_in) = &agg.func_def @@ -181,6 +257,25 @@ impl Unparser<'_> { negated: insubq.negated, }) } + Expr::IsNotNull(expr) => { + Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + } + Expr::IsTrue(expr) => { + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) + } + Expr::IsFalse(expr) => { + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) + } + Expr::IsUnknown(expr) => { + Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) + } + Expr::Not(expr) => { + let sql_parser_expr = self.expr_to_sql(expr)?; + Ok(AstExpr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(sql_parser_expr), + }) + } _ => not_impl_err!("Unsupported expression: {expr:?}"), } } @@ -216,6 +311,21 @@ impl Unparser<'_> { } } + pub(super) fn between_op_to_sql( + &self, + expr: ast::Expr, + negated: bool, + low: ast::Expr, + high: ast::Expr, + ) -> ast::Expr { + ast::Expr::Between { + expr: Box::new(expr), + negated, + low: Box::new(low), + high: Box::new(high), + } + } + fn op_to_sql(&self, op: &Operator) -> Result { match op { Operator::Eq => Ok(ast::BinaryOperator::Eq), @@ -456,6 +566,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } @@ -491,11 +602,15 @@ impl Unparser<'_> { DataType::Binary => todo!(), DataType::FixedSizeBinary(_) => todo!(), DataType::LargeBinary => todo!(), + DataType::BinaryView => todo!(), DataType::Utf8 => Ok(ast::DataType::Varchar(None)), DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8View => todo!(), DataType::List(_) => todo!(), DataType::FixedSizeList(_, _) => todo!(), DataType::LargeList(_) => todo!(), + DataType::ListView(_) => todo!(), + DataType::LargeListView(_) => todo!(), DataType::Struct(_) => todo!(), DataType::Union(_, _) => todo!(), DataType::Dictionary(_, _) => todo!(), @@ -509,13 +624,53 @@ impl Unparser<'_> { #[cfg(test)] mod tests { + use std::any::Any; + use datafusion_common::TableReference; - use datafusion_expr::{col, expr::AggregateFunction, lit}; + use datafusion_expr::{ + case, col, expr::AggregateFunction, lit, not, ColumnarValue, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, + }; use crate::unparser::dialect::CustomDialect; use super::*; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } // See sql::tests for E2E tests. #[test] @@ -530,6 +685,13 @@ mod tests { .gt(lit(4)), r#"("a"."b"."c" > 4)"#, ), + ( + case(col("a")) + .when(lit(1), lit(true)) + .when(lit(0), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + r#"CASE "a" WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, + ), ( Expr::Cast(Cast { expr: Box::new(col("a")), @@ -544,6 +706,28 @@ mod tests { }), r#"CAST("a" AS INTEGER UNSIGNED)"#, ), + ( + col("a").in_list(vec![lit(1), lit(2), lit(3)], false), + r#""a" IN (1, 2, 3)"#, + ), + ( + col("a").in_list(vec![lit(1), lit(2), lit(3)], true), + r#""a" NOT IN (1, 2, 3)"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), + r#"dummy_udf("a", "b")"#, + ), + ( + Expr::Like(Like { + negated: true, + expr: Box::new(col("a")), + pattern: Box::new(lit("foo")), + escape_char: Some('o'), + case_insensitive: true, + }), + r#""a" NOT LIKE 'foo' ESCAPE 'o'"#, + ), ( Expr::Literal(ScalarValue::Date64(Some(0))), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, @@ -594,6 +778,24 @@ mod tests { }), "COUNT(DISTINCT *)", ), + (col("a").is_not_null(), r#""a" IS NOT NULL"#), + ( + (col("a") + col("b")).gt(lit(4)).is_true(), + r#"(("a" + "b") > 4) IS TRUE"#, + ), + ( + (col("a") + col("b")).gt(lit(4)).is_false(), + r#"(("a" + "b") > 4) IS FALSE"#, + ), + ( + (col("a") + col("b")).gt(lit(4)).is_unknown(), + r#"(("a" + "b") > 4) IS UNKNOWN"#, + ), + (not(col("a")), r#"NOT "a""#), + ( + Expr::between(col("a"), lit(1), lit(7)), + r#"("a" BETWEEN 1 AND 7)"#, + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e1f5135efda9..c9b0a8a04c7e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -124,9 +124,12 @@ impl Unparser<'_> { match plan { LogicalPlan::TableScan(scan) => { let mut builder = TableRelationBuilder::default(); - builder.name(ast::ObjectName(vec![ - self.new_ident(scan.table_name.table().to_string()) - ])); + let mut table_parts = vec![]; + if let Some(schema_name) = scan.table_name.schema() { + table_parts.push(self.new_ident(schema_name.to_string())); + } + table_parts.push(self.new_ident(scan.table_name.table().to_string())); + builder.name(ast::ObjectName(table_parts)); relation.table(builder); Ok(()) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index c9c2bdd694b5..101c31039c7e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,27 +22,36 @@ use std::{sync::Arc, vec}; use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; -use datafusion_sql::planner::PlannerContext; -use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; -use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; - +use datafusion_common::config::ConfigOptions; use datafusion_common::{ - config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, + assert_contains, plan_err, DFSchema, DataFusionError, ParamValues, Result, + ScalarValue, TableReference, }; -use datafusion_common::{plan_err, DFSchema, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, Volatility, WindowUDF, }; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; use datafusion_sql::{ parser::DFParser, - planner::{ContextProvider, ParserOptions, SqlToRel}, + planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; +use datafusion_functions::unicode; use rstest::rstest; +use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; +#[test] +fn test_schema_support() { + quick_test( + "SELECT * FROM s1.test", + "Projection: s1.test.t_date32, s1.test.t_date64\ + \n TableScan: s1.test", + ); +} + #[test] fn parse_decimals() { let test_data = [ @@ -80,7 +89,7 @@ fn parse_decimals() { fn parse_ident_normalization() { let test_data = [ ( - "SELECT LENGTH('str')", + "SELECT CHARACTER_LENGTH('str')", "Ok(Projection: character_length(Utf8(\"str\"))\n EmptyRelation)", false, ), @@ -389,7 +398,7 @@ fn plan_rollback_transaction_chained() { #[test] fn plan_copy_to() { - let sql = "COPY test_decimal to 'output.csv'"; + let sql = "COPY test_decimal to 'output.csv' STORED AS CSV"; let plan = r#" CopyTo: format=csv output_url=output.csv options: () TableScan: test_decimal @@ -410,6 +419,18 @@ Explain quick_test(sql, plan); } +#[test] +fn plan_explain_copy_to_format() { + let sql = "EXPLAIN COPY test_decimal to 'output.tbl' STORED AS CSV"; + let plan = r#" +Explain + CopyTo: format=csv output_url=output.tbl options: () + TableScan: test_decimal + "# + .trim(); + quick_test(sql, plan); +} + #[test] fn plan_copy_to_query() { let sql = "COPY (select * from test_decimal limit 10) to 'output.csv'"; @@ -423,6 +444,18 @@ CopyTo: format=csv output_url=output.csv options: () quick_test(sql, plan); } +#[test] +fn plan_copy_stored_as_priority() { + let sql = "COPY (select * from (values (1))) to 'output/' STORED AS CSV OPTIONS (format json)"; + let plan = r#" +CopyTo: format=csv output_url=output/ options: (format json) + Projection: column1 + Values: (Int64(1)) + "# + .trim(); + quick_test(sql, plan); +} + #[test] fn plan_insert() { let sql = @@ -2566,15 +2599,6 @@ fn approx_median_window() { quick_test(sql, expected); } -#[test] -fn select_arrow_cast() { - let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; - let expected = "\ - Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ - \n EmptyRelation"; - quick_test(sql, expected); -} - #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -2665,11 +2689,17 @@ fn logical_plan_with_dialect_and_options( options: ParserOptions, ) -> Result { let context = MockContextProvider::default() + .with_udf(unicode::character_length().as_ref().clone()) .with_udf(make_udf( "nullif", vec![DataType::Int32, DataType::Int32], DataType::Int32, )) + .with_udf(make_udf( + "arrow_cast", + vec![DataType::Int64, DataType::Utf8], + DataType::Float64, + )) .with_udf(make_udf( "date_trunc", vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)], @@ -4337,6 +4367,40 @@ fn test_prepare_statement_to_plan_value_list() { prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_unknown_list_param() { + let sql = "SELECT id from person where id = $2"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with id $2" + ); +} + +#[test] +fn test_prepare_statement_unknown_hash_param() { + let sql = "SELECT id from person where id = $bar"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::Map(HashMap::new()); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with name $bar" + ); +} + +#[test] +fn test_prepare_statement_bad_list_idx() { + let sql = "SELECT id from person where id = $foo"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); +} + #[test] fn test_table_alias() { let sql = "select * from (\ @@ -4446,26 +4510,27 @@ fn test_field_not_found_window_function() { #[test] fn test_parse_escaped_string_literal_value() { - let sql = r"SELECT length('\r\n') AS len"; + let sql = r"SELECT character_length('\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\\r\\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\r\n') AS len"; + let sql = r"SELECT character_length(E'\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\r\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; + let sql = + r"SELECT character_length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; let expected = "Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"\u{004b}\") AS hex, Utf8(\"\u{0001}\") AS unicode\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\000') AS len"; + let sql = r"SELECT character_length(E'\000') AS len"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 15\")" + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 25\")" ) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 19bcf6024b50..4929ab485d6d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3376,3 +3376,55 @@ SELECT FIRST_VALUE(column1 ORDER BY column2) IGNORE NULLS FROM t; statement ok DROP TABLE t; + +# Test for ignore null in LAST_VALUE +statement ok +CREATE TABLE t AS VALUES (3), (4), (null::bigint); + +query I +SELECT LAST_VALUE(column1) FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1) RESPECT NULLS FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1) IGNORE NULLS FROM t; +---- +4 + +statement ok +DROP TABLE t; + +# Test for ignore null with ORDER BY in LAST_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 3), (4, 4), (null::bigint, 1), (null::bigint, 2); + +query I +SELECT column1 FROM t ORDER BY column2 DESC; +---- +4 +3 +NULL +NULL + +query I +SELECT LAST_VALUE(column1 ORDER BY column2 DESC) FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1 ORDER BY column2 DESC) RESPECT NULLS FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t; +---- +3 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ad979a316709..3456963aacfc 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6116,7 +6116,7 @@ from fixed_size_flatten_table; [1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] [1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] -## empty +## empty (aliases: `array_empty`, `list_empty`) # empty scalar function #1 query B select empty(make_array(1)); @@ -6207,6 +6207,75 @@ NULL false false +## array_empty (aliases: `empty`, `list_empty`) +# array_empty scalar function #1 +query B +select array_empty(make_array(1)); +---- +false + +query B +select array_empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + +# array_empty scalar function #2 +query B +select array_empty(make_array()); +---- +true + +query B +select array_empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + +# array_empty scalar function #3 +query B +select array_empty(make_array(NULL)); +---- +false + +query B +select array_empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + +## list_empty (aliases: `empty`, `array_empty`) +# list_empty scalar function #1 +query B +select list_empty(make_array(1)); +---- +false + +query B +select list_empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + +# list_empty scalar function #2 +query B +select list_empty(make_array()); +---- +true + +query B +select list_empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + +# list_empty scalar function #3 +query B +select list_empty(make_array(NULL)); +---- +false + +query B +select list_empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + +# string_to_array scalar function query ? SELECT string_to_array('abcxxxdef', 'xxx') ---- diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt new file mode 100644 index 000000000000..24c99fc849b6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## Tests for basic array queries + +# Make a table with multiple input partitions +statement ok +CREATE TABLE data AS + SELECT * FROM (VALUES + ([1,2,3], [4,5], 1) + ) + UNION ALL + SELECT * FROM (VALUES + ([2,3], [2,3], 1), + ([1,2,3], NULL, 1) + ) +; + +query ??I rowsort +SELECT * FROM data; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +########### +# Filtering +########### + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 = [1,2,3]; + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 = column2 + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 != [1,2,3]; + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 != column2 + +########### +# Aggregates +########### + +query error Internal error: Min/Max accumulator not implemented for type List +SELECT min(column1) FROM data; + +query error Internal error: Min/Max accumulator not implemented for type List +SELECT max(column1) FROM data; + +query I +SELECT count(column1) FROM data; +---- +3 + +# note single count distincts are rewritten to use a group by +query I +SELECT count(distinct column1) FROM data; +---- +2 + +query I +SELECT count(distinct column2) FROM data; +---- +2 + + +# note multiple count distincts are not rewritten +query II +SELECT count(distinct column1), count(distinct column2) FROM data; +---- +2 2 + + +########### +# GROUP BY +########### + + +query I +SELECT count(column1) FROM data GROUP BY column3; +---- +3 + +# note single count distincts are rewritten to use a group by +query I +SELECT count(distinct column1) FROM data GROUP BY column3; +---- +2 + +query I +SELECT count(distinct column2) FROM data GROUP BY column3; +---- +2 + +# note multiple count distincts are not rewritten +query II +SELECT count(distinct column1), count(distinct column2) FROM data GROUP BY column3; +---- +2 2 + + +########### +# ORDER BY +########### + +query ??I +SELECT * FROM data ORDER BY column2; +---- +[2, 3] [2, 3] 1 +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 + +query ??I +SELECT * FROM data ORDER BY column2 DESC; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data ORDER BY column2 DESC NULLS LAST; +---- +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 +[1, 2, 3] NULL 1 + +# multi column +query ??I +SELECT * FROM data ORDER BY column1, column2; +---- +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data ORDER BY column1, column3, column2; +---- +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 +[2, 3] [2, 3] 1 + + +statement ok +drop table data diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 8b3bd7eac95d..3e8694f3b2c2 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -92,10 +92,11 @@ SELECT arrow_cast('1', 'Int16') 1 # Basic error test -query error Error during planning: arrow_cast needs 2 arguments, 1 provided +query error DataFusion error: Error during planning: No function matches the given name and argument types 'arrow_cast\(Utf8\)'. You might need to add explicit type casts. SELECT arrow_cast('1') -query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) + +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown @@ -315,7 +316,7 @@ select arrow_cast(interval '30 minutes', 'Duration(Second)'); ---- 0 days 0 hours 30 mins 0 secs -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) +query error DataFusion error: This feature is not implemented: Unsupported CAST from Utf8 to Duration\(Second\) select arrow_cast('30 minutes', 'Duration(Second)'); @@ -336,7 +337,7 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( ---- 2000-01-01T00:00:00+08:00 -statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index df23a993ebce..fca892dfcdad 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -21,13 +21,13 @@ create table source_table(col1 integer, col2 varchar) as values (1, 'Foo'), (2, # Copy to directory as multiple files query IT -COPY source_table TO 'test_files/scratch/copy/table/' (format parquet, 'parquet.compression' 'zstd(10)'); +COPY source_table TO 'test_files/scratch/copy/table/' STORED AS parquet OPTIONS ('format.compression' 'zstd(10)'); ---- 2 # Copy to directory as partitioned files query IT -COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' (format parquet, 'parquet.compression' 'zstd(10)', partition_by 'col2'); +COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' STORED AS parquet PARTITIONED BY (col2) OPTIONS ('format.compression' 'zstd(10)'); ---- 2 @@ -54,8 +54,8 @@ select * from validate_partitioned_parquet_bar order by col1; # Copy to directory as partitioned files query ITT -COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' -(format parquet, partition_by 'column2, column3', 'parquet.compression' 'zstd(10)'); +COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' STORED AS parquet PARTITIONED BY (column2, column3) +OPTIONS ('format.compression' 'zstd(10)'); ---- 3 @@ -82,8 +82,8 @@ select * from validate_partitioned_parquet_a_x order by column1; # Copy to directory as partitioned files query TTT -COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' -(format parquet, 'parquet.compression' 'zstd(10)', partition_by 'column1, column3'); +COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' STORED AS parquet PARTITIONED BY (column1, column3) +OPTIONS ('format.compression' 'zstd(10)'); ---- 3 @@ -111,49 +111,53 @@ a statement ok create table test ("'test'" varchar, "'test2'" varchar, "'test3'" varchar); -query TTT -insert into test VALUES ('a', 'x', 'aa'), ('b','y', 'bb'), ('c', 'z', 'cc') ----- -3 - -query T -select "'test'" from test ----- -a -b -c - -# Note to place a single ' inside of a literal string escape by putting two '' -query TTT -copy test to 'test_files/scratch/copy/escape_quote' (format csv, partition_by '''test2'',''test3''') ----- -3 - -statement ok -CREATE EXTERNAL TABLE validate_partitioned_escape_quote STORED AS CSV -LOCATION 'test_files/scratch/copy/escape_quote/' PARTITIONED BY ("'test2'", "'test3'"); - +# https://github.com/apache/arrow-datafusion/issues/9714 +## Until the partition by parsing uses ColumnDef, this test is meaningless since it becomes an overfit. Even in +## CREATE EXTERNAL TABLE, there is a schema mismatch, this should be an issue. +# +#query TTT +#insert into test VALUES ('a', 'x', 'aa'), ('b','y', 'bb'), ('c', 'z', 'cc') +#---- +#3 +# +#query T +#select "'test'" from test +#---- +#a +#b +#c +# +# # Note to place a single ' inside of a literal string escape by putting two '' +#query TTT +#copy test to 'test_files/scratch/copy/escape_quote' STORED AS CSV; +#---- +#3 +# +#statement ok +#CREATE EXTERNAL TABLE validate_partitioned_escape_quote STORED AS CSV +#LOCATION 'test_files/scratch/copy/escape_quote/' PARTITIONED BY ("'test2'", "'test3'"); +# # This triggers a panic (index out of bounds) # https://github.com/apache/arrow-datafusion/issues/9269 #query #select * from validate_partitioned_escape_quote; query TT -EXPLAIN COPY source_table TO 'test_files/scratch/copy/table/' (format parquet, 'parquet.compression' 'zstd(10)'); +EXPLAIN COPY source_table TO 'test_files/scratch/copy/table/' STORED AS PARQUET OPTIONS ('format.compression' 'zstd(10)'); ---- logical_plan -CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: (parquet.compression zstd(10)) +CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: (format.compression zstd(10)) --TableScan: source_table projection=[col1, col2] physical_plan FileSinkExec: sink=ParquetSink(file_groups=[]) --MemoryExec: partitions=1, partition_sizes=[1] # Error case -query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! +query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! Use STORED AS to define file format. EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' query TT -EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' (format parquet) +EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' STORED AS PARQUET ---- logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: () @@ -164,7 +168,7 @@ FileSinkExec: sink=ParquetSink(file_groups=[]) # Copy more files to directory via query query IT -COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' (format parquet); +COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' STORED AS PARQUET; ---- 4 @@ -185,7 +189,7 @@ select * from validate_parquet; query ? copy (values (struct(timestamp '2021-01-01 01:00:01', 1)), (struct(timestamp '2022-01-01 01:00:01', 2)), (struct(timestamp '2023-01-03 01:00:01', 3)), (struct(timestamp '2024-01-01 01:00:01', 4))) -to 'test_files/scratch/copy/table_nested2/' (format parquet); +to 'test_files/scratch/copy/table_nested2/' STORED AS PARQUET; ---- 4 @@ -204,7 +208,7 @@ query ?? COPY (values (struct ('foo', (struct ('foo', make_array(struct('a',1), struct('b',2))))), make_array(timestamp '2023-01-01 01:00:01',timestamp '2023-01-01 01:00:01')), (struct('bar', (struct ('foo', make_array(struct('aa',10), struct('bb',20))))), make_array(timestamp '2024-01-01 01:00:01', timestamp '2024-01-01 01:00:01'))) -to 'test_files/scratch/copy/table_nested/' (format parquet); +to 'test_files/scratch/copy/table_nested/' STORED AS PARQUET; ---- 2 @@ -221,7 +225,7 @@ select * from validate_parquet_nested; query ? copy (values ([struct('foo', 1), struct('bar', 2)])) to 'test_files/scratch/copy/array_of_struct/' -(format parquet); +STORED AS PARQUET; ---- 1 @@ -236,8 +240,7 @@ select * from validate_array_of_struct; query ? copy (values (struct('foo', [1,2,3], struct('bar', [2,3,4])))) -to 'test_files/scratch/copy/struct_with_array/' -(format parquet); +to 'test_files/scratch/copy/struct_with_array/' STORED AS PARQUET; ---- 1 @@ -255,31 +258,32 @@ select * from validate_struct_with_array; query IT COPY source_table TO 'test_files/scratch/copy/table_with_options/' -(format parquet, -'parquet.compression' snappy, -'parquet.compression::col1' 'zstd(5)', -'parquet.compression::col2' snappy, -'parquet.max_row_group_size' 12345, -'parquet.data_pagesize_limit' 1234, -'parquet.write_batch_size' 1234, -'parquet.writer_version' 2.0, -'parquet.dictionary_page_size_limit' 123, -'parquet.created_by' 'DF copy.slt', -'parquet.column_index_truncate_length' 123, -'parquet.data_page_row_count_limit' 1234, -'parquet.bloom_filter_enabled' true, -'parquet.bloom_filter_enabled::col1' false, -'parquet.bloom_filter_fpp::col2' 0.456, -'parquet.bloom_filter_ndv::col2' 456, -'parquet.encoding' plain, -'parquet.encoding::col1' DELTA_BINARY_PACKED, -'parquet.dictionary_enabled::col2' true, -'parquet.dictionary_enabled' false, -'parquet.statistics_enabled' page, -'parquet.statistics_enabled::col2' none, -'parquet.max_statistics_size' 123, -'parquet.bloom_filter_fpp' 0.001, -'parquet.bloom_filter_ndv' 100 +STORED AS PARQUET +OPTIONS ( +'format.compression' snappy, +'format.compression::col1' 'zstd(5)', +'format.compression::col2' snappy, +'format.max_row_group_size' 12345, +'format.data_pagesize_limit' 1234, +'format.write_batch_size' 1234, +'format.writer_version' 2.0, +'format.dictionary_page_size_limit' 123, +'format.created_by' 'DF copy.slt', +'format.column_index_truncate_length' 123, +'format.data_page_row_count_limit' 1234, +'format.bloom_filter_enabled' true, +'format.bloom_filter_enabled::col1' false, +'format.bloom_filter_fpp::col2' 0.456, +'format.bloom_filter_ndv::col2' 456, +'format.encoding' plain, +'format.encoding::col1' DELTA_BINARY_PACKED, +'format.dictionary_enabled::col2' true, +'format.dictionary_enabled' false, +'format.statistics_enabled' page, +'format.statistics_enabled::col2' none, +'format.max_statistics_size' 123, +'format.bloom_filter_fpp' 0.001, +'format.bloom_filter_ndv' 100 ) ---- 2 @@ -312,7 +316,7 @@ select * from validate_parquet_single; # copy from table to folder of compressed json files query IT -COPY source_table to 'test_files/scratch/copy/table_json_gz' (format json, 'json.compression' gzip); +COPY source_table to 'test_files/scratch/copy/table_json_gz' STORED AS JSON OPTIONS ('format.compression' gzip); ---- 2 @@ -328,7 +332,7 @@ select * from validate_json_gz; # copy from table to folder of compressed csv files query IT -COPY source_table to 'test_files/scratch/copy/table_csv' (format csv, 'csv.has_header' false, 'csv.compression' gzip); +COPY source_table to 'test_files/scratch/copy/table_csv' STORED AS CSV OPTIONS ('format.has_header' false, 'format.compression' gzip); ---- 2 @@ -360,7 +364,7 @@ select * from validate_single_csv; # Copy from table to folder of json query IT -COPY source_table to 'test_files/scratch/copy/table_json' (format json); +COPY source_table to 'test_files/scratch/copy/table_json' STORED AS JSON; ---- 2 @@ -376,7 +380,7 @@ select * from validate_json; # Copy from table to single json file query IT -COPY source_table to 'test_files/scratch/copy/table.json'; +COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON ; ---- 2 @@ -394,12 +398,12 @@ select * from validate_single_json; query IT COPY source_table to 'test_files/scratch/copy/table_csv_with_options' -(format csv, -'csv.has_header' false, -'csv.compression' uncompressed, -'csv.datetime_format' '%FT%H:%M:%S.%9f', -'csv.delimiter' ';', -'csv.null_value' 'NULLVAL'); +STORED AS CSV OPTIONS ( +'format.has_header' false, +'format.compression' uncompressed, +'format.datetime_format' '%FT%H:%M:%S.%9f', +'format.delimiter' ';', +'format.null_value' 'NULLVAL'); ---- 2 @@ -417,7 +421,7 @@ select * from validate_csv_with_options; # Copy from table to single arrow file query IT -COPY source_table to 'test_files/scratch/copy/table.arrow'; +COPY source_table to 'test_files/scratch/copy/table.arrow' STORED AS ARROW; ---- 2 @@ -437,7 +441,7 @@ select * from validate_arrow_file; query T? COPY (values ('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) -to 'test_files/scratch/copy/table_dict.arrow'; +to 'test_files/scratch/copy/table_dict.arrow' STORED AS ARROW; ---- 2 @@ -456,7 +460,7 @@ d bar # Copy from table to folder of json query IT -COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow); +COPY source_table to 'test_files/scratch/copy/table_arrow' STORED AS ARROW; ---- 2 @@ -470,17 +474,94 @@ select * from validate_arrow; 1 Foo 2 Bar +# Format Options Support without the 'format.' prefix + +# Copy with format options for Parquet without the 'format.' prefix +query IT +COPY source_table TO 'test_files/scratch/copy/format_table.parquet' +OPTIONS ( + compression snappy, + 'compression::col1' 'zstd(5)' +); +---- +2 + +# Copy with format options for JSON without the 'format.' prefix +query IT +COPY source_table to 'test_files/scratch/copy/format_table' +STORED AS JSON OPTIONS (compression gzip); +---- +2 + +# Copy with format options for CSV without the 'format.' prefix +query IT +COPY source_table to 'test_files/scratch/copy/format_table.csv' +OPTIONS ( + has_header false, + compression xz, + datetime_format '%FT%H:%M:%S.%9f', + delimiter ';', + null_value 'NULLVAL' +); +---- +2 + +# Copy with unknown format options without the 'format.' prefix to ensure error is sensible +query error DataFusion error: Invalid or Unsupported Configuration: Config value "unknown_option" not found on CsvOptions +COPY source_table to 'test_files/scratch/copy/format_table2.csv' +OPTIONS ( + unknown_option false, +); + + +# Format Options Support with format in OPTIONS i.e. COPY { table_name | query } TO 'file_name' OPTIONS (format , ...) + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format parquet); +---- +1 + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format parquet, compression 'zstd(10)'); +---- +1 + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format json, compression gzip); +---- +1 + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS ( + format csv, + has_header false, + compression xz, + datetime_format '%FT%H:%M:%S.%9f', + delimiter ';', + null_value 'NULLVAL' +); +---- +1 + +query error DataFusion error: Invalid or Unsupported Configuration: This feature is not implemented: Unknown FileType: NOTVALIDFORMAT +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format notvalidformat, compression 'zstd(5)'); + # Error cases: # Copy from table with options query error DataFusion error: Invalid or Unsupported Configuration: Config value "row_group_size" not found on JsonOptions -COPY source_table to 'test_files/scratch/copy/table.json' ('json.row_group_size' 55); +COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON OPTIONS ('format.row_group_size' 55); # Incomplete statement query error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) COPY (select col2, sum(col1) from source_table # Copy from table with non literal -query error DataFusion error: SQL error: ParserError\("Expected ',' or '\)' after option definition, found: \+"\) +query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 3b85dd9e986f..a200217af6e1 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -100,9 +100,17 @@ CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOBAR"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; +# Missing partition column +statement error DataFusion error: Arrow error: Schema error: Unable to get field named "c2". Valid fields: \["c1"\] +create EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c2) LOCATION 'foo.csv' + +# Duplicate Column in `PARTITIONED BY` clause +statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name c1 +create EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV PARTITIONED BY (c1 int) LOCATION 'foo.csv' + # Conflicting options -statement error DataFusion error: Invalid or Unsupported Configuration: Key "parquet.column_index_truncate_length" is not applicable for CSV format +statement error DataFusion error: Invalid or Unsupported Configuration: Config value "column_index_truncate_length" not found on CsvOptions CREATE EXTERNAL TABLE csv_table (column1 int) STORED AS CSV LOCATION 'foo.csv' -OPTIONS ('csv.delimiter' ';', 'parquet.column_index_truncate_length' '123') +OPTIONS ('format.delimiter' ';', 'format.column_index_truncate_length' '123') \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/create_function.slt b/datafusion/sqllogictest/test_files/create_function.slt index baa40ac64afc..4f0c53c36ca1 100644 --- a/datafusion/sqllogictest/test_files/create_function.slt +++ b/datafusion/sqllogictest/test_files/create_function.slt @@ -47,7 +47,7 @@ select abs(-1); statement ok DROP FUNCTION abs; -# now the the query errors +# now the query errors query error Invalid function 'abs'. select abs(-1); diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 7b299c0cf143..ab6847afb6a5 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -23,7 +23,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.quote' '~') +OPTIONS ('format.quote' '~') LOCATION '../core/tests/data/quote.csv'; statement ok @@ -33,7 +33,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.escape' '\') +OPTIONS ('format.escape' '\') LOCATION '../core/tests/data/escape.csv'; query TT @@ -71,7 +71,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.escape' '"') +OPTIONS ('format.escape' '"') LOCATION '../core/tests/data/escape.csv'; # TODO: Validate this with better data. @@ -117,14 +117,14 @@ CREATE TABLE src_table_2 ( query ITII COPY src_table_1 TO 'test_files/scratch/csv_files/csv_partitions/1.csv' -(FORMAT CSV); +STORED AS CSV; ---- 4 query ITII COPY src_table_2 TO 'test_files/scratch/csv_files/csv_partitions/2.csv' -(FORMAT CSV); +STORED AS CSV; ---- 4 diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 67d2beb622d3..4a83ebf348d8 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -40,11 +40,6 @@ ProjectionExec: expr=[1 as a, 2 as b, 3 as c] --PlaceholderRowExec - -# enable recursive CTEs -statement ok -set datafusion.execution.enable_recursive_ctes = true; - # trivial recursive CTE works query I rowsort WITH RECURSIVE nodes AS ( @@ -652,3 +647,101 @@ WITH RECURSIVE my_cte AS ( WHERE my_cte.a<5 ) SELECT a FROM my_cte; + + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9680 +query I +WITH RECURSIVE recursive_cte AS ( + SELECT 1 as val + UNION ALL + ( + WITH sub_cte AS ( + SELECT 2 as val + ) + SELECT + 2 as val + FROM recursive_cte + CROSS JOIN sub_cte + WHERE recursive_cte.val < 2 + ) +) +SELECT * FROM recursive_cte; +---- +1 +2 + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9680 +# 'recursive_cte' should be on the left of the cross join, as this is the test purpose of the above query. +query TT +explain WITH RECURSIVE recursive_cte AS ( + SELECT 1 as val + UNION ALL + ( + WITH sub_cte AS ( + SELECT 2 as val + ) + SELECT + 2 as val + FROM recursive_cte + CROSS JOIN sub_cte + WHERE recursive_cte.val < 2 + ) +) +SELECT * FROM recursive_cte; +---- +logical_plan +Projection: recursive_cte.val +--SubqueryAlias: recursive_cte +----RecursiveQuery: is_distinct=false +------Projection: Int64(1) AS val +--------EmptyRelation +------Projection: Int64(2) AS val +--------CrossJoin: +----------Filter: recursive_cte.val < Int64(2) +------------TableScan: recursive_cte +----------SubqueryAlias: sub_cte +------------Projection: Int64(2) AS val +--------------EmptyRelation +physical_plan +RecursiveQueryExec: name=recursive_cte, is_distinct=false +--ProjectionExec: expr=[1 as val] +----PlaceholderRowExec +--ProjectionExec: expr=[2 as val] +----CrossJoinExec +------CoalescePartitionsExec +--------CoalesceBatchesExec: target_batch_size=8182 +----------FilterExec: val@0 < 2 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------WorkTableExec: name=recursive_cte +------ProjectionExec: expr=[2 as val] +--------PlaceholderRowExec + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9794 +# Non-recursive term and recursive term have different types +query IT +WITH RECURSIVE my_cte AS( + SELECT 1::int AS a + UNION ALL + SELECT a::bigint+2 FROM my_cte WHERE a<3 +) SELECT *, arrow_typeof(a) FROM my_cte; +---- +1 Int32 +3 Int32 + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9794 +# Non-recursive term and recursive term have different number of columns +query error DataFusion error: Error during planning: Non\-recursive term and recursive term must have the same number of columns \(1 != 3\) +WITH RECURSIVE my_cte AS ( + SELECT 1::bigint AS a + UNION ALL + SELECT a+2, 'a','c' FROM my_cte WHERE a<3 +) SELECT * FROM my_cte; + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9794 +# Non-recursive term and recursive term have different types, and cannot be casted +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'abc' to value of Int64 type +WITH RECURSIVE my_cte AS ( + SELECT 1 AS a + UNION ALL + SELECT 'abc' FROM my_cte WHERE CAST(a AS text) !='abc' +) SELECT * FROM my_cte; diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 002aade2528e..af7bf5cb16e8 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -280,3 +280,70 @@ ORDER BY 2023-12-20T01:20:00 1000 f2 foo 2023-12-20T01:30:00 1000 f1 32.0 2023-12-20T01:30:00 1000 f2 foo + +# Cleanup +statement ok +drop view m1; + +statement ok +drop view m2; + +###### +# Create a table using UNION ALL to get 2 partitions (very important) +###### +statement ok +create table m3_source as + select * from (values('foo', 'bar', 1)) + UNION ALL + select * from (values('foo', 'baz', 1)); + +###### +# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code +###### +statement ok +create table m3 as + select + column1, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2", + column3 +from m3_source; + +# there are two values in column2 +query T?I rowsort +SELECT * +FROM m3; +---- +foo bar 1 +foo baz 1 + +# There is 1 distinct value in column1 +query I +SELECT count(distinct column1) +FROM m3 +GROUP BY column3; +---- +1 + +# There are 2 distinct values in column2 +query I +SELECT count(distinct column2) +FROM m3 +GROUP BY column3; +---- +2 + +# Should still get the same results when querying in the same query +query II +SELECT count(distinct column1), count(distinct column2) +FROM m3 +GROUP BY column3; +---- +1 2 + + +# Cleanup +statement ok +drop table m3; + +statement ok +drop table m3_source; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b7ad36dace16..4653250cf93f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -390,8 +390,8 @@ query TT explain select struct(1, 2.3, 'abc'); ---- logical_plan -Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc")) +Projection: Struct({c0:1,c1:2.3,c2:abc}) AS named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc")) --EmptyRelation physical_plan -ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))] +ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))] --PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 73fb5eec97d5..2e0cbf50cab9 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -335,6 +335,21 @@ SELECT ascii(NULL) ---- NULL +query I +SELECT ascii('ésoj') +---- +233 + +query I +SELECT ascii('💯') +---- +128175 + +query I +SELECT ascii('💯a') +---- +128175 + query I SELECT bit_length('') ---- @@ -400,6 +415,12 @@ SELECT chr(CAST(NULL AS int)) ---- NULL +statement error DataFusion error: Execution error: null character not permitted. +SELECT chr(CAST(0 AS int)) + +statement error DataFusion error: Execution error: requested character too large for encoding. +SELECT chr(CAST(9223372036854775807 AS bigint)) + query T SELECT concat('a','b','c') ---- @@ -939,6 +960,293 @@ SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') ---- 12123456780 +# test_date_part_time + +## time32 seconds +query R +SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query R +SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +query R +SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +## time32 milliseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50.123 + +query R +SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50.123 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +## time64 microseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50.123456 + +query R +SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50.123456 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123.456 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123.456 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +## time64 nanoseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50.123456789 + +query R +SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50.123456789 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123.456789 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123.456789 + +# just some floating point stuff happening in the result here +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456789 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456789 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + # test_extract_epoch query R @@ -1918,3 +2226,102 @@ false true false true NULL NULL NULL NULL false false true true false false true false + + +############# +## Common Subexpr Eliminate Tests +############# + +statement ok +CREATE TABLE doubles ( + f64 DOUBLE +) as VALUES + (10.1) +; + +# common subexpr with alias +query RRR rowsort +select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; +---- +10.1 0 1.570796326795 + +# common subexpr with coalesce (short-circuited) +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with coalesce (short-circuited) and alias +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with case (short-circuited) +query RRR rowsort +select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; +---- +10.1 0.09900990099 1.471623942989 + + +statement ok +CREATE TABLE t1( + time TIMESTAMP, + load1 DOUBLE, + load2 DOUBLE, + host VARCHAR +) AS VALUES + (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'), + (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'), + (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'), + (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL) +; + +# struct scalar function with columns +query ? +select struct(time,load1,load2,host) from t1; +---- +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1} +{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2} +{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3} +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: } + +# can have an aggregate function with an inner coalesce +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 +host2 2.2 +host3 3.3 + +# can have an aggregate function with an inner CASE WHEN +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 101 +host2 202 +host3 303 + +# can have 2 projections with aggr(short_circuited), with different short-circuited expr +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 diff --git a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt new file mode 100644 index 000000000000..05e622db8a02 --- /dev/null +++ b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# prepare table +statement ok +CREATE UNBOUNDED EXTERNAL TABLE data ( + "date" VARCHAR, + "ticker" VARCHAR, + "time" VARCHAR, +) STORED AS CSV +WITH ORDER ("date", "ticker", "time") +LOCATION './a.parquet'; + + +# query +query TT +explain SELECT * FROM data +WHERE ticker = 'A' +ORDER BY "date", "time"; +---- +logical_plan +Sort: data.date ASC NULLS LAST, data.time ASC NULLS LAST +--Filter: data.ticker = Utf8("A") +----TableScan: data projection=[date, ticker, time] +physical_plan +SortPreservingMergeExec: [date@0 ASC NULLS LAST,time@2 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: ticker@1 = A +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------StreamingTableExec: partition_sizes=1, projection=[date, ticker, time], infinite_source=true, output_ordering=[date@0 ASC NULLS LAST, ticker@1 ASC NULLS LAST, time@2 ASC NULLS LAST] + +# query +query TT +explain SELECT * FROM data +WHERE date = 'A' +ORDER BY "ticker", "time"; +---- +logical_plan +Sort: data.ticker ASC NULLS LAST, data.time ASC NULLS LAST +--Filter: data.date = Utf8("A") +----TableScan: data projection=[date, ticker, time] +physical_plan +SortPreservingMergeExec: [ticker@1 ASC NULLS LAST,time@2 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: date@0 = A +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------StreamingTableExec: partition_sizes=1, projection=[date, ticker, time], infinite_source=true, output_ordering=[date@0 ASC NULLS LAST, ticker@1 ASC NULLS LAST, time@2 ASC NULLS LAST] diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 080f7c209634..40617ebf8bd9 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4506,28 +4506,28 @@ CREATE TABLE src_table ( query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/0.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/1.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/2.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/3.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index dc60bafaa8db..4b9af3bdeafb 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -42,7 +42,7 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' statement ok -create table dictionary_encoded_values as values +create table dictionary_encoded_values as values ('a', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('b', arrow_cast('bar', 'Dictionary(Int32, Utf8)')); query TTT @@ -55,13 +55,13 @@ statement ok CREATE EXTERNAL TABLE dictionary_encoded_parquet_partitioned( a varchar, b varchar, -) +) STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b); query TT -insert into dictionary_encoded_parquet_partitioned +insert into dictionary_encoded_parquet_partitioned select * from dictionary_encoded_values ---- 2 @@ -76,13 +76,13 @@ statement ok CREATE EXTERNAL TABLE dictionary_encoded_arrow_partitioned( a varchar, b varchar, -) +) STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' PARTITIONED BY (b); query TT -insert into dictionary_encoded_arrow_partitioned +insert into dictionary_encoded_arrow_partitioned select * from dictionary_encoded_values ---- 2 @@ -90,7 +90,7 @@ select * from dictionary_encoded_values statement ok CREATE EXTERNAL TABLE dictionary_encoded_arrow_test_readback( a varchar, -) +) STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/b=bar/'; @@ -185,6 +185,30 @@ select * from partitioned_insert_test_verify; 1 2 +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_hive(c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned' +PARTITIONED BY (a string, b string); + +query ITT +INSERT INTO partitioned_insert_test_hive VALUES (3,30,300); +---- +1 + +query ITT +SELECT * FROM partitioned_insert_test_hive order by a,b,c; +---- +1 10 100 +1 10 200 +1 20 100 +2 20 100 +1 20 200 +2 20 200 +3 30 300 + + statement ok CREATE EXTERNAL TABLE partitioned_insert_test_json(a string, b string) diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 92093ba13eba..0d98c41d0028 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -320,7 +320,7 @@ SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); 0 # The aggregate does not need to be computed because the input statistics are exact and -# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET). +# the number of rows is less than or equal to the "fetch+skip" value (LIMIT+OFFSET). query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); ---- diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index b7cd1243cb0f..3cc52666d533 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -45,7 +45,7 @@ CREATE TABLE src_table ( query ITID COPY (SELECT * FROM src_table LIMIT 3) TO 'test_files/scratch/parquet/test_table/0.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -53,7 +53,7 @@ TO 'test_files/scratch/parquet/test_table/0.parquet' query ITID COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) TO 'test_files/scratch/parquet/test_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -128,7 +128,7 @@ SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] query ITID COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -281,7 +281,7 @@ LIMIT 10; query ITID COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 4c9254beef6b..33c9ff7c3eed 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -781,4 +781,4 @@ logical_plan EmptyRelation physical_plan EmptyExec statement ok -drop table t; +drop table t; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 391a6739b060..594c52f12d75 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -25,7 +25,7 @@ set datafusion.execution.target_partitions = 4; statement ok COPY (VALUES (1, 2), (2, 5), (3, 2), (4, 5), (5, 0)) TO 'test_files/scratch/repartition/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE parquet_table(column1 int, column2 int) diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 15fe670a454c..f9699a5fda8f 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -35,7 +35,7 @@ set datafusion.optimizer.repartition_file_min_size = 1; # Note filename 2.parquet to test sorting (on local file systems it is often listed before 1.parquet) statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE parquet_table(column1 int) @@ -61,7 +61,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -86,7 +86,7 @@ set datafusion.optimizer.enable_round_robin_repartition = true; # create a second parquet file statement ok COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ## Still expect to see the scan read the file as "4" groups with even sizes. One group should read ## parts of both files. @@ -102,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --SortExec: expr=[column1@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=8192 ------FilterExec: column1@0 != 42 ---------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # Cleanup statement ok @@ -158,7 +158,7 @@ DROP TABLE parquet_table_with_order; # create a single csv file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' -(FORMAT csv, 'csv.has_header' true); +STORED AS CSV WITH HEADER ROW; statement ok CREATE EXTERNAL TABLE csv_table(column1 int) @@ -202,7 +202,7 @@ DROP TABLE csv_table; # create a single json file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/json_table/1.json' -(FORMAT json); +STORED AS JSON; statement ok CREATE EXTERNAL TABLE json_table (column1 int) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a77a2bf4059c..20c8b3d25fdd 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -2087,7 +2087,7 @@ select position('' in '') 1 -query error DataFusion error: Error during planning: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. +query error DataFusion error: Execution error: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. select position(1 in 1) diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt index aee0e97edc1e..5572c4a5ffef 100644 --- a/datafusion/sqllogictest/test_files/schema_evolution.slt +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -31,7 +31,7 @@ COPY ( SELECT column1 as a, column2 as b FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File2 has only b @@ -40,7 +40,7 @@ COPY ( SELECT column1 as b FROM ( VALUES (10) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File3 has a column from 'z' which does not appear in the table # but also values from a which do appear in the table @@ -49,7 +49,7 @@ COPY ( SELECT column1 as z, column2 as a FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File4 has data for b and a (reversed) and d statement ok @@ -57,7 +57,7 @@ COPY ( SELECT column1 as b, column2 as a, column3 as c FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # The logical distribution of `a`, `b` and `c` in the files is like this: # diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 3d3e73e81637..ad4b0df1a546 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -550,9 +550,31 @@ select * from (select 1 a union all select 2) b order by a limit 1; 1 # select limit clause invalid -statement error DataFusion error: Error during planning: LIMIT must not be negative +statement error DataFusion error: Error during planning: LIMIT must be >= 0, '\-1' was provided\. select * from (select 1 a union all select 2) b order by a limit -1; +# select limit with basic arithmetic +query I +select * from (select 1 a union all select 2) b order by a limit 1+1; +---- +1 +2 + +# select limit with basic arithmetic +query I +select * from (values (1)) LIMIT 10*100; +---- +1 + +# More complex expressions in the limit is not supported yet. +# See issue: https://github.com/apache/arrow-datafusion/issues/9821 +statement error DataFusion error: Error during planning: Unsupported operator for LIMIT clause +select * from (values (1)) LIMIT 100/10; + +# More complex expressions in the limit is not supported yet. +statement error DataFusion error: Error during planning: Unexpected expression in LIMIT clause +select * from (values (1)) LIMIT cast(column1 as tinyint); + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit null; @@ -1364,6 +1386,27 @@ AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true +# FilterExec can track equality of non-column expressions. +# plan below shouldn't have a SortExec because given column 'a' is ordered. +# 'CAST(ROUND(b) as INT)' is also ordered. After filter is applied. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE CAST(ROUND(b) as INT) = a +ORDER BY CAST(ROUND(b) as INT); +---- +logical_plan +Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST +--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a] +physical_plan +SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + + statement ok drop table annotated_data_finite2; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 936dedcc896e..2e0b699f6dd6 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -23,11 +23,12 @@ statement ok CREATE TABLE values( a INT, b FLOAT, - c VARCHAR + c VARCHAR, + n VARCHAR, ) AS VALUES - (1, 1.1, 'a'), - (2, 2.2, 'b'), - (3, 3.3, 'c') + (1, 1.1, 'a', NULL), + (2, 2.2, 'b', NULL), + (3, 3.3, 'c', NULL) ; # struct[i] @@ -50,6 +51,18 @@ select struct(1, 3.14, 'e'); ---- {c0: 1, c1: 3.14, c2: e} +# struct scalar function with named values +query ? +select struct(1 as "name0", 3.14 as name1, 'e', true as 'name3'); +---- +{name0: 1, name1: 3.14, c2: e, name3: true} + +# struct scalar function with mixed named and unnamed values +query ? +select struct(1, 3.14 as name1, 'e', true); +---- +{c0: 1, name1: 3.14, c2: e, c3: true} + # struct scalar function with columns #1 query ? select struct(a, b, c) from values; @@ -58,16 +71,112 @@ select struct(a, b, c) from values; {c0: 2, c1: 2.2, c2: b} {c0: 3, c1: 3.3, c2: c} +# struct scalar function with columns and scalars +query ? +select struct(a, 'foo') from values; +---- +{c0: 1, c1: foo} +{c0: 2, c1: foo} +{c0: 3, c1: foo} + + # explain struct scalar function with columns #1 query TT explain select struct(a, b, c) from values; ---- logical_plan -Projection: struct(values.a, values.b, values.c) +Projection: named_struct(Utf8("c0"), values.a, Utf8("c1"), values.b, Utf8("c2"), values.c) --TableScan: values projection=[a, b, c] physical_plan -ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)] +ProjectionExec: expr=[named_struct(c0, a@0, c1, b@1, c2, c@2) as named_struct(Utf8("c0"),values.a,Utf8("c1"),values.b,Utf8("c2"),values.c)] --MemoryExec: partitions=1, partition_sizes=[1] +# error on 0 arguments +query error DataFusion error: Error during planning: No function matches the given name and argument types 'named_struct\(\)'. You might need to add explicit type casts. +select named_struct(); + +# error on odd number of arguments #1 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead +select named_struct('a'); + +# error on odd number of arguments #2 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead +select named_struct(1); + +# error on odd number of arguments #3 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead +select named_struct(values.a) from values; + +# error on odd number of arguments #4 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 3 instead +select named_struct('a', 1, 'b'); + +# error on even argument not a string literal #1 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(1\) instead at position 0 +select named_struct(1, 'a'); + +# error on even argument not a string literal #2 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(0\) instead at position 2 +select named_struct('corret', 1, 0, 'wrong'); + +# error on even argument not a string literal #3 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.a instead at position 0 +select named_struct(values.a, 'a') from values; + +# error on even argument not a string literal #4 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.c instead at position 0 +select named_struct(values.c, 'c') from values; + +# named_struct with mixed scalar and array values #1 +query ? +select named_struct('scalar', 27, 'array', values.a, 'null', NULL) from values; +---- +{scalar: 27, array: 1, null: } +{scalar: 27, array: 2, null: } +{scalar: 27, array: 3, null: } + +# named_struct with mixed scalar and array values #2 +query ? +select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values; +---- +{array: 1, scalar: 27, null: } +{array: 2, scalar: 27, null: } +{array: 3, scalar: 27, null: } + +# named_struct with mixed scalar and array values #3 +query ? +select named_struct('null', NULL, 'array', values.a, 'scalar', 27) from values; +---- +{null: , array: 1, scalar: 27} +{null: , array: 2, scalar: 27} +{null: , array: 3, scalar: 27} + +# named_struct with mixed scalar and array values #4 +query ? +select named_struct('null_array', values.n, 'array', values.a, 'scalar', 27, 'null', NULL) from values; +---- +{null_array: , array: 1, scalar: 27, null: } +{null_array: , array: 2, scalar: 27, null: } +{null_array: , array: 3, scalar: 27, null: } + +# named_struct arrays only +query ? +select named_struct('field_a', a, 'field_b', b) from values; +---- +{field_a: 1, field_b: 1.1} +{field_a: 2, field_b: 2.2} +{field_a: 3, field_b: 3.3} + +# named_struct scalars only +query ? +select named_struct('field_a', 1, 'field_b', 2); +---- +{field_a: 1, field_b: 2} + statement ok drop table values; + +query T +select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); +---- +Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f718bbf14cbc..f0e04b522a78 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2661,7 +2661,7 @@ PT123456S query T select to_char(arrow_cast(123456, 'Duration(Second)'), null); ---- -PT123456S +NULL query error DataFusion error: Execution error: Cast error: Format error SELECT to_char(timestamps, '%X%K') from formats; @@ -2672,14 +2672,22 @@ SELECT to_char('2000-02-03'::date, '%X%K'); query T SELECT to_char(timestamps, null) from formats; ---- -2024-01-01T06:00:00Z -2025-01-01T23:59:58Z +NULL +NULL query T SELECT to_char(null, '%d-%m-%Y'); ---- (empty) +query T +SELECT to_char(column1, column2) +FROM +(VALUES ('2024-01-01 06:00:00'::timestamp, null), ('2025-01-01 23:59:58'::timestamp, '%d:%m:%Y %H-%M-%S')); +---- +NULL +01:01:2025 23-59-58 + statement ok drop table formats; diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 7475dfc1e37b..cc79685c9429 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -36,7 +36,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.27.0" +substrait = "0.28.0" [dev-dependencies] tokio = { workspace = true } diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index e8698253edb5..6b81e33dfc37 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,6 +27,7 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; +#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?; diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 5163c99bd5ac..7d324d074c9d 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -731,13 +731,13 @@ } }, "node_modules/body-parser": { - "version": "1.20.1", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz", - "integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==", + "version": "1.20.2", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", + "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", "dev": true, "dependencies": { "bytes": "3.1.2", - "content-type": "~1.0.4", + "content-type": "~1.0.5", "debug": "2.6.9", "depd": "2.0.0", "destroy": "1.2.0", @@ -745,7 +745,7 @@ "iconv-lite": "0.4.24", "on-finished": "2.4.1", "qs": "6.11.0", - "raw-body": "2.5.1", + "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" }, @@ -892,13 +892,19 @@ } }, "node_modules/call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1109,9 +1115,9 @@ } }, "node_modules/cookie": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", - "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", "dev": true, "engines": { "node": ">= 0.6" @@ -1204,6 +1210,23 @@ "node": ">= 10" } }, + "node_modules/define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/define-lazy-prop": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", @@ -1323,6 +1346,27 @@ "node": ">=4" } }, + "node_modules/es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.2.4" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/es-module-lexer": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", @@ -1435,17 +1479,17 @@ } }, "node_modules/express": { - "version": "4.18.2", - "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", - "integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==", + "version": "4.19.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz", + "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==", "dev": true, "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "1.20.1", + "body-parser": "1.20.2", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.5.0", + "cookie": "0.6.0", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", @@ -1666,9 +1710,9 @@ } }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -1742,21 +1786,28 @@ } }, "node_modules/function-bind": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", - "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", - "dev": true + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } }, "node_modules/get-intrinsic": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", - "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "has": "^1.0.3", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", "has-proto": "^1.0.1", - "has-symbols": "^1.0.3" + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1832,6 +1883,18 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.1.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/graceful-fs": { "version": "4.2.11", "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", @@ -1865,10 +1928,22 @@ "node": ">=8" } }, + "node_modules/has-property-descriptors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", + "dev": true, + "dependencies": { + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/has-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.1.tgz", - "integrity": "sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==", + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", "dev": true, "engines": { "node": ">= 0.4" @@ -1889,6 +1964,18 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/hpack.js": { "version": "2.1.6", "resolved": "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz", @@ -2648,9 +2735,9 @@ } }, "node_modules/object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.1", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz", + "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==", "dev": true, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -2987,9 +3074,9 @@ } }, "node_modules/raw-body": { - "version": "2.5.1", - "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", - "integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==", + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", + "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", "dev": true, "dependencies": { "bytes": "3.1.2", @@ -3357,6 +3444,23 @@ "node": ">= 0.8.0" } }, + "node_modules/set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -3406,14 +3510,18 @@ } }, "node_modules/side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "dependencies": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -3945,9 +4053,9 @@ } }, "node_modules/webpack-dev-middleware": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", - "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "version": "5.3.4", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", + "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", "dev": true, "dependencies": { "colorette": "^2.0.10", @@ -4868,13 +4976,13 @@ "dev": true }, "body-parser": { - "version": "1.20.1", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz", - "integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==", + "version": "1.20.2", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", + "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", "dev": true, "requires": { "bytes": "3.1.2", - "content-type": "~1.0.4", + "content-type": "~1.0.5", "debug": "2.6.9", "depd": "2.0.0", "destroy": "1.2.0", @@ -4882,7 +4990,7 @@ "iconv-lite": "0.4.24", "on-finished": "2.4.1", "qs": "6.11.0", - "raw-body": "2.5.1", + "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" }, @@ -4992,13 +5100,16 @@ } }, "call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" } }, "caniuse-lite": { @@ -5144,9 +5255,9 @@ "dev": true }, "cookie": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", - "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", "dev": true }, "cookie-signature": { @@ -5220,6 +5331,17 @@ "execa": "^5.0.0" } }, + "define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "requires": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + } + }, "define-lazy-prop": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", @@ -5308,6 +5430,21 @@ "integrity": "sha512-ZtUjZO6l5mwTHvc1L9+1q5p/R3wTopcfqMW8r5t8SJSKqeVI/LtajORwRFEKpEFuekjD0VBjwu1HMxL4UalIRw==", "dev": true }, + "es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "requires": { + "get-intrinsic": "^1.2.4" + } + }, + "es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true + }, "es-module-lexer": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", @@ -5395,17 +5532,17 @@ } }, "express": { - "version": "4.18.2", - "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", - "integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==", + "version": "4.19.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz", + "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==", "dev": true, "requires": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "1.20.1", + "body-parser": "1.20.2", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.5.0", + "cookie": "0.6.0", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", @@ -5580,9 +5717,9 @@ } }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "forwarded": { @@ -5626,21 +5763,22 @@ "optional": true }, "function-bind": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", - "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", "dev": true }, "get-intrinsic": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", - "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "has": "^1.0.3", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", "has-proto": "^1.0.1", - "has-symbols": "^1.0.3" + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" } }, "get-stream": { @@ -5692,6 +5830,15 @@ "slash": "^3.0.0" } }, + "gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "requires": { + "get-intrinsic": "^1.1.3" + } + }, "graceful-fs": { "version": "4.2.11", "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", @@ -5719,10 +5866,19 @@ "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", "dev": true }, + "has-property-descriptors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", + "dev": true, + "requires": { + "es-define-property": "^1.0.0" + } + }, "has-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.1.tgz", - "integrity": "sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==", + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", "dev": true }, "has-symbols": { @@ -5731,6 +5887,15 @@ "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", "dev": true }, + "hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "requires": { + "function-bind": "^1.1.2" + } + }, "hpack.js": { "version": "2.1.6", "resolved": "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz", @@ -6284,9 +6449,9 @@ } }, "object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.1", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz", + "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==", "dev": true }, "obuf": { @@ -6523,9 +6688,9 @@ "dev": true }, "raw-body": { - "version": "2.5.1", - "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", - "integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==", + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", + "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", "dev": true, "requires": { "bytes": "3.1.2", @@ -6813,6 +6978,20 @@ "send": "0.18.0" } }, + "set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "requires": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + } + }, "setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -6850,14 +7029,15 @@ "dev": true }, "side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "requires": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" } }, "signal-exit": { @@ -7247,9 +7427,9 @@ } }, "webpack-dev-middleware": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", - "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "version": "5.3.4", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", + "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", "dev": true, "requires": { "colorette": "^2.0.10", diff --git a/dev/changelog/13.0.0.md b/dev/changelog/13.0.0.md index 0f35903e2600..14b42a052ef9 100644 --- a/dev/changelog/13.0.0.md +++ b/dev/changelog/13.0.0.md @@ -87,7 +87,7 @@ - Optimizer rule 'projection_push_down' failed due to unexpected error: Error during planning: Aggregate schema has wrong number of fields. Expected 3 got 8 [\#3704](https://github.com/apache/arrow-datafusion/issues/3704) - Optimizer regressions in `unwrap_cast_in_comparison` [\#3690](https://github.com/apache/arrow-datafusion/issues/3690) - Internal error when evaluating a predicate = "The type of Dictionary\(Int16, Utf8\) = Int64 of binary physical should be same" [\#3685](https://github.com/apache/arrow-datafusion/issues/3685) -- Specialized regexp_replace should early-abort when the the input arrays are empty [\#3647](https://github.com/apache/arrow-datafusion/issues/3647) +- Specialized regexp_replace should early-abort when the input arrays are empty [\#3647](https://github.com/apache/arrow-datafusion/issues/3647) - Internal error: Failed to coerce types Decimal128\(10, 2\) and Boolean in BETWEEN expression [\#3646](https://github.com/apache/arrow-datafusion/issues/3646) - Internal error: Failed to coerce types Decimal128\(10, 2\) and Boolean in BETWEEN expression [\#3645](https://github.com/apache/arrow-datafusion/issues/3645) - Type coercion error: The type of Boolean AND Decimal128\(10, 2\) of binary physical should be same [\#3644](https://github.com/apache/arrow-datafusion/issues/3644) diff --git a/dev/changelog/37.0.0.md b/dev/changelog/37.0.0.md new file mode 100644 index 000000000000..b1fcd5fdf008 --- /dev/null +++ b/dev/changelog/37.0.0.md @@ -0,0 +1,347 @@ + + +## [37.0.0](https://github.com/apache/arrow-datafusion/tree/37.0.0) (2024-03-28) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/36.0.0...37.0.0) + +**Breaking changes:** + +- refactor: Change `SchemaProvider::table` to return `Result` rather than `Option<..>` [#9307](https://github.com/apache/arrow-datafusion/pull/9307) (crepererum) +- feat: issue_9285: port builtin reg function into datafusion-function-\* (1/3 regexpmatch) [#9329](https://github.com/apache/arrow-datafusion/pull/9329) (Lordworms) +- Cache common plan properties to eliminate recursive calls in physical plan [#9346](https://github.com/apache/arrow-datafusion/pull/9346) (mustafasrepo) +- Consolidate `TreeNode` transform and rewrite APIs [#8891](https://github.com/apache/arrow-datafusion/pull/8891) (peter-toth) +- Extend argument types for udf `return_type_from_exprs` [#9522](https://github.com/apache/arrow-datafusion/pull/9522) (jayzhan211) +- Systematic Configuration in 'Create External Table' and 'Copy To' Options [#9382](https://github.com/apache/arrow-datafusion/pull/9382) (metesynnada) +- Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent [#9730](https://github.com/apache/arrow-datafusion/pull/9730) (Omega359) + +**Performance related:** + +- perf: improve to_field performance [#9722](https://github.com/apache/arrow-datafusion/pull/9722) (haohuaijin) + +**Implemented enhancements:** + +- feat: support for defining ARRAY columns in `CREATE TABLE` [#9381](https://github.com/apache/arrow-datafusion/pull/9381) (jonahgao) +- feat: support `unnest` in FROM clause [#9355](https://github.com/apache/arrow-datafusion/pull/9355) (jonahgao) +- feat: support nvl2 function [#9364](https://github.com/apache/arrow-datafusion/pull/9364) (guojidan) +- feat: issue #9224 substitute tlide in table path [#9259](https://github.com/apache/arrow-datafusion/pull/9259) (Lordworms) +- feat: replace std Instant with wasm-compatible wrapper [#9189](https://github.com/apache/arrow-datafusion/pull/9189) (waynexia) +- feat: support `unnest` with additional columns [#9400](https://github.com/apache/arrow-datafusion/pull/9400) (jonahgao) +- feat: Support `EscapedStringLiteral`, update sqlparser to `0.44.0` [#9268](https://github.com/apache/arrow-datafusion/pull/9268) (JasonLi-cn) +- feat: add support for fixed list wildcard in type signature [#9312](https://github.com/apache/arrow-datafusion/pull/9312) (universalmind303) +- feat: Add projection to HashJoinExec. [#9236](https://github.com/apache/arrow-datafusion/pull/9236) (my-vegetable-has-exploded) +- feat: function name hints for UDFs [#9407](https://github.com/apache/arrow-datafusion/pull/9407) (SteveLauC) +- feat: Introduce convert Expr to SQL string API and basic feature [#9517](https://github.com/apache/arrow-datafusion/pull/9517) (backkem) +- feat: implement more expr_to_sql functionality [#9578](https://github.com/apache/arrow-datafusion/pull/9578) (devinjdangelo) +- feat: implement aggregation and subquery plans to SQL [#9606](https://github.com/apache/arrow-datafusion/pull/9606) (devinjdangelo) +- feat: track memory usage for recursive CTE, enable recursive CTEs by default [#9619](https://github.com/apache/arrow-datafusion/pull/9619) (jonahgao) +- feat: Between expr to sql string [#9803](https://github.com/apache/arrow-datafusion/pull/9803) (sebastian2296) +- feat: Expose `array_empty` and `list_empty` functions as alias of `empty` function [#9807](https://github.com/apache/arrow-datafusion/pull/9807) (erenavsarogullari) +- feat: Not expr to string [#9802](https://github.com/apache/arrow-datafusion/pull/9802) (sebastian2296) +- feat: pass SessionState not SessionConfig to FunctionFactory::create [#9837](https://github.com/apache/arrow-datafusion/pull/9837) (tshauck) + +**Fixed bugs:** + +- fix: use `JoinSet` to make spawned tasks cancel-safe [#9318](https://github.com/apache/arrow-datafusion/pull/9318) (DDtKey) +- fix: nvl function's return type [#9357](https://github.com/apache/arrow-datafusion/pull/9357) (guojidan) +- fix: panic in isnan() when no args are given [#9377](https://github.com/apache/arrow-datafusion/pull/9377) (SteveLauC) +- fix: using test data sample for catalog example [#9372](https://github.com/apache/arrow-datafusion/pull/9372) (korowa) +- fix: sort_batch function unsupported mixed types with list [#9410](https://github.com/apache/arrow-datafusion/pull/9410) (JasonLi-cn) +- fix: casting to ARRAY types failed [#9441](https://github.com/apache/arrow-datafusion/pull/9441) (jonahgao) +- fix: reading from partitioned `json` & `arrow` tables [#9431](https://github.com/apache/arrow-datafusion/pull/9431) (korowa) +- fix: coalesce function should return correct data type [#9459](https://github.com/apache/arrow-datafusion/pull/9459) (viirya) +- fix: `generate_series` and `range` panic on edge cases [#9503](https://github.com/apache/arrow-datafusion/pull/9503) (jonahgao) +- fix: `substr_index` not handling negative occurrence correctly [#9475](https://github.com/apache/arrow-datafusion/pull/9475) (jonahgao) +- fix: support two argument TRIM [#9521](https://github.com/apache/arrow-datafusion/pull/9521) (tshauck) +- fix: incorrect null handling in `range` and `generate_series` [#9574](https://github.com/apache/arrow-datafusion/pull/9574) (jonahgao) +- fix: recursive cte hangs on joins [#9687](https://github.com/apache/arrow-datafusion/pull/9687) (jonahgao) +- fix: parallel parquet can underflow when max_record_batch_rows < execution.batch_size [#9737](https://github.com/apache/arrow-datafusion/pull/9737) (devinjdangelo) +- fix: change placeholder errors from Internal to Plan [#9745](https://github.com/apache/arrow-datafusion/pull/9745) (erratic-pattern) +- fix: ensure mutual compatibility of the two input schemas from recursive CTEs [#9795](https://github.com/apache/arrow-datafusion/pull/9795) (jonahgao) + +**Documentation updates:** + +- docs: put flatten in top fn list [#9376](https://github.com/apache/arrow-datafusion/pull/9376) (SteveLauC) +- Update documentation so list_to_string alias to point to array_to_string [#9374](https://github.com/apache/arrow-datafusion/pull/9374) (monkwire) +- Uplift keys/dependencies to use more workspace inheritance [#9293](https://github.com/apache/arrow-datafusion/pull/9293) (Jefffrey) +- docs: update contributor guide (migration to sqllogictest is done) [#9408](https://github.com/apache/arrow-datafusion/pull/9408) (SteveLauC) +- Move the to_timestamp\* functions to datafusion-functions [#9388](https://github.com/apache/arrow-datafusion/pull/9388) (Omega359) +- NEW Logo [#9385](https://github.com/apache/arrow-datafusion/pull/9385) (pinarbayata) +- Minor: docs: rm duplicate words. [#9449](https://github.com/apache/arrow-datafusion/pull/9449) (my-vegetable-has-exploded) +- Update contributor guide with updated scalar function howto [#9438](https://github.com/apache/arrow-datafusion/pull/9438) (Omega359) +- docs: fix extraneous char in array functions table of contents [#9560](https://github.com/apache/arrow-datafusion/pull/9560) (tshauck) +- doc: Add missing doc link [#9631](https://github.com/apache/arrow-datafusion/pull/9631) (Weijun-H) +- chore: remove repetitive word `the the` --> `the` in docs / comments [#9673](https://github.com/apache/arrow-datafusion/pull/9673) (InventiveCoder) +- Update example-usage.md to remove reference to simd and rust nightly. [#9677](https://github.com/apache/arrow-datafusion/pull/9677) (Omega359) +- Minor: Improve documentation for `LogicalPlan::expressions` [#9698](https://github.com/apache/arrow-datafusion/pull/9698) (alamb) +- Add Minimum Supported Rust Version policy to docs [#9681](https://github.com/apache/arrow-datafusion/pull/9681) (alamb) +- doc: Updated known users list and usage dependency description [#9718](https://github.com/apache/arrow-datafusion/pull/9718) (comphead) + +**Merged pull requests:** + +- refactor: Change `SchemaProvider::table` to return `Result` rather than `Option<..>` [#9307](https://github.com/apache/arrow-datafusion/pull/9307) (crepererum) +- fix write_partitioned_parquet_results test case bug [#9360](https://github.com/apache/arrow-datafusion/pull/9360) (guojidan) +- fix: use `JoinSet` to make spawned tasks cancel-safe [#9318](https://github.com/apache/arrow-datafusion/pull/9318) (DDtKey) +- Update nix requirement from 0.27.1 to 0.28.0 [#9344](https://github.com/apache/arrow-datafusion/pull/9344) (dependabot[bot]) +- Replace usages of internal_err with exec_err where appropriate [#9241](https://github.com/apache/arrow-datafusion/pull/9241) (Omega359) +- feat : Support for deregistering user defined functions [#9239](https://github.com/apache/arrow-datafusion/pull/9239) (mobley-trent) +- fix: nvl function's return type [#9357](https://github.com/apache/arrow-datafusion/pull/9357) (guojidan) +- refactor: move acos() to function crate [#9297](https://github.com/apache/arrow-datafusion/pull/9297) (SteveLauC) +- docs: put flatten in top fn list [#9376](https://github.com/apache/arrow-datafusion/pull/9376) (SteveLauC) +- Update documentation so list_to_string alias to point to array_to_string [#9374](https://github.com/apache/arrow-datafusion/pull/9374) (monkwire) +- feat: issue_9285: port builtin reg function into datafusion-function-\* (1/3 regexpmatch) [#9329](https://github.com/apache/arrow-datafusion/pull/9329) (Lordworms) +- Add test to verify issue #9161 [#9265](https://github.com/apache/arrow-datafusion/pull/9265) (jonahgao) +- refactor: fix error macros hygiene (always import `DataFusionError`) [#9366](https://github.com/apache/arrow-datafusion/pull/9366) (crepererum) +- feat: support for defining ARRAY columns in `CREATE TABLE` [#9381](https://github.com/apache/arrow-datafusion/pull/9381) (jonahgao) +- fix: panic in isnan() when no args are given [#9377](https://github.com/apache/arrow-datafusion/pull/9377) (SteveLauC) +- feat: support `unnest` in FROM clause [#9355](https://github.com/apache/arrow-datafusion/pull/9355) (jonahgao) +- feat: support nvl2 function [#9364](https://github.com/apache/arrow-datafusion/pull/9364) (guojidan) +- refactor: move asin() to function crate [#9379](https://github.com/apache/arrow-datafusion/pull/9379) (SteveLauC) +- fix: using test data sample for catalog example [#9372](https://github.com/apache/arrow-datafusion/pull/9372) (korowa) +- delete tail space, fix `error: unused import: DataFusionError` [#9386](https://github.com/apache/arrow-datafusion/pull/9386) (Tangruilin) +- Run cargo-fmt on `datafusion-functions/core` [#9367](https://github.com/apache/arrow-datafusion/pull/9367) (alamb) +- Cache common plan properties to eliminate recursive calls in physical plan [#9346](https://github.com/apache/arrow-datafusion/pull/9346) (mustafasrepo) +- Run cargo-fmt on all of `datafusion-functions` [#9390](https://github.com/apache/arrow-datafusion/pull/9390) (alamb) +- feat: issue #9224 substitute tlide in table path [#9259](https://github.com/apache/arrow-datafusion/pull/9259) (Lordworms) +- port range function and change gen_series logic [#9352](https://github.com/apache/arrow-datafusion/pull/9352) (Lordworms) +- [MINOR]: Generate physical plan, instead of logical plan in the bench test [#9383](https://github.com/apache/arrow-datafusion/pull/9383) (mustafasrepo) +- Add `to_date` function [#9019](https://github.com/apache/arrow-datafusion/pull/9019) (Tangruilin) +- Minor: clarify performance in docs for `ScalarUDF`, `ScalarUDAF` and `ScalarUDWF` [#9384](https://github.com/apache/arrow-datafusion/pull/9384) (alamb) +- feat: replace std Instant with wasm-compatible wrapper [#9189](https://github.com/apache/arrow-datafusion/pull/9189) (waynexia) +- Uplift keys/dependencies to use more workspace inheritance [#9293](https://github.com/apache/arrow-datafusion/pull/9293) (Jefffrey) +- Improve documentation for ExecutionPlanProperties, use consistent field name [#9389](https://github.com/apache/arrow-datafusion/pull/9389) (alamb) +- Doc: Workaround for Running cargo test locally without signficant memory [#9402](https://github.com/apache/arrow-datafusion/pull/9402) (devinjdangelo) +- feat: support `unnest` with additional columns [#9400](https://github.com/apache/arrow-datafusion/pull/9400) (jonahgao) +- Minor: improve the display name of `unnest` expressions [#9412](https://github.com/apache/arrow-datafusion/pull/9412) (jonahgao) +- Minor: Move function signature check to planning stage [#9401](https://github.com/apache/arrow-datafusion/pull/9401) (2010YOUY01) +- chore(deps): update substrait requirement from 0.24.0 to 0.25.1 [#9406](https://github.com/apache/arrow-datafusion/pull/9406) (dependabot[bot]) +- docs: update contributor guide (migration to sqllogictest is done) [#9408](https://github.com/apache/arrow-datafusion/pull/9408) (SteveLauC) +- Move the to_timestamp\* functions to datafusion-functions [#9388](https://github.com/apache/arrow-datafusion/pull/9388) (Omega359) +- Minor: Support LargeList List Range indexing and fix large list handling in ConstEvaluator [#9393](https://github.com/apache/arrow-datafusion/pull/9393) (jayzhan211) +- NEW Logo [#9385](https://github.com/apache/arrow-datafusion/pull/9385) (pinarbayata) +- Handle serde for ScalarUDF [#9395](https://github.com/apache/arrow-datafusion/pull/9395) (yyy1000) +- Minior: Add tests with `sqrt` with negative argument [#9426](https://github.com/apache/arrow-datafusion/pull/9426) (caicancai) +- Move SpawnedTask from datafusion_physical_plan to new `datafusion_common_runtime` crate [#9414](https://github.com/apache/arrow-datafusion/pull/9414) (mustafasrepo) +- Re-export datafusion-functions-array [#9433](https://github.com/apache/arrow-datafusion/pull/9433) (andygrove) +- Minor: Support LargeList for ListIndex [#9424](https://github.com/apache/arrow-datafusion/pull/9424) (PsiACE) +- move ArrayDims, ArrayNdims and Cardinality to datafusion-function-crate [#9425](https://github.com/apache/arrow-datafusion/pull/9425) (Weijun-H) +- refactor: make instr() an alias of strpos() [#9396](https://github.com/apache/arrow-datafusion/pull/9396) (SteveLauC) +- Add test case for invalid tz in timestamp literal [#9429](https://github.com/apache/arrow-datafusion/pull/9429) (MohamedAbdeen21) +- Minor: simplify call [#9434](https://github.com/apache/arrow-datafusion/pull/9434) (alamb) +- Support IGNORE NULLS for LEAD window function [#9419](https://github.com/apache/arrow-datafusion/pull/9419) (comphead) +- fix sqllogicaltest result [#9444](https://github.com/apache/arrow-datafusion/pull/9444) (jackwener) +- Minor: docs: rm duplicate words. [#9449](https://github.com/apache/arrow-datafusion/pull/9449) (my-vegetable-has-exploded) +- minor: fix cargo clippy some warning [#9442](https://github.com/apache/arrow-datafusion/pull/9442) (jackwener) +- port regexp_like function and port related tests [#9397](https://github.com/apache/arrow-datafusion/pull/9397) (Lordworms) +- fix: sort_batch function unsupported mixed types with list [#9410](https://github.com/apache/arrow-datafusion/pull/9410) (JasonLi-cn) +- refactor: add `join_unwind` to `SpawnedTask` [#9422](https://github.com/apache/arrow-datafusion/pull/9422) (DDtKey) +- Ignore null LEAD support for small batch sizes. [#9445](https://github.com/apache/arrow-datafusion/pull/9445) (mustafasrepo) +- fix: casting to ARRAY types failed [#9441](https://github.com/apache/arrow-datafusion/pull/9441) (jonahgao) +- fix: reading from partitioned `json` & `arrow` tables [#9431](https://github.com/apache/arrow-datafusion/pull/9431) (korowa) +- feat: Support `EscapedStringLiteral`, update sqlparser to `0.44.0` [#9268](https://github.com/apache/arrow-datafusion/pull/9268) (JasonLi-cn) +- Minor: fix LEAD test description [#9451](https://github.com/apache/arrow-datafusion/pull/9451) (comphead) +- Consolidate `TreeNode` transform and rewrite APIs [#8891](https://github.com/apache/arrow-datafusion/pull/8891) (peter-toth) +- Support `Date32` arguments for `generate_series` [#9420](https://github.com/apache/arrow-datafusion/pull/9420) (Lordworms) +- Minor: change doc for range [#9455](https://github.com/apache/arrow-datafusion/pull/9455) (Lordworms) +- doc: add missing function index in scalar_expression.md [#9462](https://github.com/apache/arrow-datafusion/pull/9462) (Weijun-H) +- build: Update bigdecimal version in `Cargo.toml` [#9471](https://github.com/apache/arrow-datafusion/pull/9471) (comphead) +- chore(deps): update base64 requirement from 0.21 to 0.22 [#9446](https://github.com/apache/arrow-datafusion/pull/9446) (dependabot[bot]) +- Port regexp_replace functions and related tests [#9454](https://github.com/apache/arrow-datafusion/pull/9454) (Lordworms) +- Update contributor guide with updated scalar function howto [#9438](https://github.com/apache/arrow-datafusion/pull/9438) (Omega359) +- feat: add support for fixed list wildcard in type signature [#9312](https://github.com/apache/arrow-datafusion/pull/9312) (universalmind303) +- Add a `ScalarUDFImpl::simplfy()` API, move `SimplifyInfo` et al to datafusion_expr [#9304](https://github.com/apache/arrow-datafusion/pull/9304) (jayzhan211) +- Implement IGNORE NULLS for FIRST_VALUE [#9411](https://github.com/apache/arrow-datafusion/pull/9411) (huaxingao) +- Add plugable handler for `CREATE FUNCTION` [#9333](https://github.com/apache/arrow-datafusion/pull/9333) (milenkovicm) +- Enable configurable display of partition sizes in the explain statement [#9474](https://github.com/apache/arrow-datafusion/pull/9474) (jayzhan211) +- Reduce casts for LEAD/LAG [#9468](https://github.com/apache/arrow-datafusion/pull/9468) (comphead) +- [CI build] fix chrono suggestions [#9486](https://github.com/apache/arrow-datafusion/pull/9486) (comphead) +- Make regex dependency optional in datafusion-functions, add CI checks for function packages [#9473](https://github.com/apache/arrow-datafusion/pull/9473) (alamb) +- fix: coalesce function should return correct data type [#9459](https://github.com/apache/arrow-datafusion/pull/9459) (viirya) +- LEAD/LAG calculate default value once [#9485](https://github.com/apache/arrow-datafusion/pull/9485) (comphead) +- chore: simplify the return type of `validate_data_types()` [#9491](https://github.com/apache/arrow-datafusion/pull/9491) (waynexia) +- minor: use arrow-rs casting from Float to Timestamp [#9500](https://github.com/apache/arrow-datafusion/pull/9500) (comphead) +- chore(deps): update substrait requirement from 0.25.1 to 0.27.0 [#9502](https://github.com/apache/arrow-datafusion/pull/9502) (dependabot[bot]) +- fix: `generate_series` and `range` panic on edge cases [#9503](https://github.com/apache/arrow-datafusion/pull/9503) (jonahgao) +- Fix undeterministic behaviour of schema nullability of lag window query [#9508](https://github.com/apache/arrow-datafusion/pull/9508) (mustafasrepo) +- Add `to_unixtime` function [#9077](https://github.com/apache/arrow-datafusion/pull/9077) (Tangruilin) +- Minor: fixed transformed state in UDF Simplify [#9484](https://github.com/apache/arrow-datafusion/pull/9484) (alamb) +- test: port strpos test in physical_expr/src/functions to sqllogictest [#9439](https://github.com/apache/arrow-datafusion/pull/9439) (SteveLauC) +- Port ArrayHas family to `functions-array` [#9496](https://github.com/apache/arrow-datafusion/pull/9496) (jayzhan211) +- port array_empty and array_length to datafusion-function-array crate [#9510](https://github.com/apache/arrow-datafusion/pull/9510) (Weijun-H) +- fix: `substr_index` not handling negative occurrence correctly [#9475](https://github.com/apache/arrow-datafusion/pull/9475) (jonahgao) +- [minor] extract collect file statistics method and add doc [#9490](https://github.com/apache/arrow-datafusion/pull/9490) (Ted-Jiang) +- test: sqllogictests for multiple tables join [#9480](https://github.com/apache/arrow-datafusion/pull/9480) (korowa) +- Add support for ignore nulls for LEAD, LAG in WindowAggExec [#9498](https://github.com/apache/arrow-datafusion/pull/9498) (Lordworms) +- Minior: Improve log expr description [#9516](https://github.com/apache/arrow-datafusion/pull/9516) (caicancai) +- port flatten to datafusion-function-array [#9523](https://github.com/apache/arrow-datafusion/pull/9523) (Weijun-H) +- feat: Add projection to HashJoinExec. [#9236](https://github.com/apache/arrow-datafusion/pull/9236) (my-vegetable-has-exploded) +- Add example for `FunctionFactory` [#9482](https://github.com/apache/arrow-datafusion/pull/9482) (milenkovicm) +- Move date_part, date_trunc, date_bin functions to datafusion-functions [#9435](https://github.com/apache/arrow-datafusion/pull/9435) (Omega359) +- fix: support two argument TRIM [#9521](https://github.com/apache/arrow-datafusion/pull/9521) (tshauck) +- Remove physical expr of ListIndex and ListRange, convert to `array_element` and `array_slice` functions [#9492](https://github.com/apache/arrow-datafusion/pull/9492) (jayzhan211) +- feat: function name hints for UDFs [#9407](https://github.com/apache/arrow-datafusion/pull/9407) (SteveLauC) +- Minor: Improve documentation for registering `AnalyzerRule` [#9520](https://github.com/apache/arrow-datafusion/pull/9520) (alamb) +- Extend argument types for udf `return_type_from_exprs` [#9522](https://github.com/apache/arrow-datafusion/pull/9522) (jayzhan211) +- move make_array array_append array_prepend array_concat function to datafusion-functions-array crate [#9504](https://github.com/apache/arrow-datafusion/pull/9504) (guojidan) +- Port `StringToArray` to `function-arrays` subcrate [#9543](https://github.com/apache/arrow-datafusion/pull/9543) (erenavsarogullari) +- Minor: remove `..` pattern matching in sql planner [#9531](https://github.com/apache/arrow-datafusion/pull/9531) (alamb) +- Minor: Fix document Interval syntax [#9542](https://github.com/apache/arrow-datafusion/pull/9542) (yyy1000) +- Port `struct` to datafusion-functions [#9546](https://github.com/apache/arrow-datafusion/pull/9546) (yyy1000) +- UDAF and UDWF support aliases [#9489](https://github.com/apache/arrow-datafusion/pull/9489) (lewiszlw) +- docs: fix extraneous char in array functions table of contents [#9560](https://github.com/apache/arrow-datafusion/pull/9560) (tshauck) +- [MINOR]: Fix undeterministic test [#9559](https://github.com/apache/arrow-datafusion/pull/9559) (mustafasrepo) +- Port `arrow_typeof` to datafusion-function [#9524](https://github.com/apache/arrow-datafusion/pull/9524) (yyy1000) +- feat: Introduce convert Expr to SQL string API and basic feature [#9517](https://github.com/apache/arrow-datafusion/pull/9517) (backkem) +- Port `ArraySort` to `function-arrays` subcrate [#9551](https://github.com/apache/arrow-datafusion/pull/9551) (erenavsarogullari) +- refactor: unify some plan optimization in CommonSubexprEliminate [#9556](https://github.com/apache/arrow-datafusion/pull/9556) (jackwener) +- Port `ArrayDistinct` to `functions-array` subcrate [#9549](https://github.com/apache/arrow-datafusion/pull/9549) (erenavsarogullari) +- Minor: add a sql_planner benchmarks to reflecte select many field on a huge table [#9536](https://github.com/apache/arrow-datafusion/pull/9536) (haohuaijin) +- Support IGNORE NULLS for FIRST/LAST window function [#9470](https://github.com/apache/arrow-datafusion/pull/9470) (huaxingao) +- Systematic Configuration in 'Create External Table' and 'Copy To' Options [#9382](https://github.com/apache/arrow-datafusion/pull/9382) (metesynnada) +- fix: incorrect null handling in `range` and `generate_series` [#9574](https://github.com/apache/arrow-datafusion/pull/9574) (jonahgao) +- Update README.md [#9572](https://github.com/apache/arrow-datafusion/pull/9572) (Abdullahsab3) +- Port tan, tanh to datafusion-functions [#9535](https://github.com/apache/arrow-datafusion/pull/9535) (ongchi) +- feat(9493): provide access to FileMetaData for files written with ParquetSink [#9548](https://github.com/apache/arrow-datafusion/pull/9548) (wiedld) +- Export datafusion-functions UDFs publically [#9585](https://github.com/apache/arrow-datafusion/pull/9585) (alamb) +- Update the comment and Add a check [#9571](https://github.com/apache/arrow-datafusion/pull/9571) (colommar) +- Port `ArrayRepeat` to `functions-array` subcrate [#9568](https://github.com/apache/arrow-datafusion/pull/9568) (erenavsarogullari) +- Fix ApproxPercentileAccumulator on zero values [#9582](https://github.com/apache/arrow-datafusion/pull/9582) (Dandandan) +- Add `FunctionRewrite` API, Move Array specific rewrites to `datafusion_functions_array` [#9583](https://github.com/apache/arrow-datafusion/pull/9583) (alamb) +- Move from_unixtime, now, current_date, current_time functions to datafusion-functions [#9537](https://github.com/apache/arrow-datafusion/pull/9537) (Omega359) +- minor: update Debug trait impl for WindowsFrame [#9587](https://github.com/apache/arrow-datafusion/pull/9587) (comphead) +- Initial support LogicalPlan to SQL String [#9596](https://github.com/apache/arrow-datafusion/pull/9596) (backkem) +- refactor: use a common macro to define math UDFs [#9598](https://github.com/apache/arrow-datafusion/pull/9598) (jonahgao) +- Move all `crypto` related functions to `datafusion-functions` [#9590](https://github.com/apache/arrow-datafusion/pull/9590) (Lordworms) +- Remove physical expr of NamedStructField, convert to `get_field` function call [#9563](https://github.com/apache/arrow-datafusion/pull/9563) (yyy1000) +- Add `/benchmark` github command to comparison benchmark between base and pr commit [#9461](https://github.com/apache/arrow-datafusion/pull/9461) (gruuya) +- support unnest as subexpression [#9592](https://github.com/apache/arrow-datafusion/pull/9592) (YjyJeff) +- feat: implement more expr_to_sql functionality [#9578](https://github.com/apache/arrow-datafusion/pull/9578) (devinjdangelo) +- Port `ArrayResize` to `functions-array` subcrate [#9570](https://github.com/apache/arrow-datafusion/pull/9570) (erenavsarogullari) +- Move make_date, to_char to datafusion-functions [#9601](https://github.com/apache/arrow-datafusion/pull/9601) (Omega359) +- Fix to_timestamp benchmark [#9608](https://github.com/apache/arrow-datafusion/pull/9608) (Omega359) +- feat: implement aggregation and subquery plans to SQL [#9606](https://github.com/apache/arrow-datafusion/pull/9606) (devinjdangelo) +- Port ArrayElem/Slice/PopFront/Back into `functions-array` [#9615](https://github.com/apache/arrow-datafusion/pull/9615) (jayzhan211) +- Minor: Remove datafusion-functions-array dependency from datafusion-optimizer [#9621](https://github.com/apache/arrow-datafusion/pull/9621) (alamb) +- Enable TTY during bench data generation [#9626](https://github.com/apache/arrow-datafusion/pull/9626) (gruuya) +- Remove constant expressions from SortExprs in the SortExec [#9618](https://github.com/apache/arrow-datafusion/pull/9618) (mustafasrepo) +- Try fixing missing results name in the benchmark step [#9632](https://github.com/apache/arrow-datafusion/pull/9632) (gruuya) +- feat: track memory usage for recursive CTE, enable recursive CTEs by default [#9619](https://github.com/apache/arrow-datafusion/pull/9619) (jonahgao) +- doc: Add missing doc link [#9631](https://github.com/apache/arrow-datafusion/pull/9631) (Weijun-H) +- Add explicit move of PR bench results if they were placed in HEAD dir [#9636](https://github.com/apache/arrow-datafusion/pull/9636) (gruuya) +- Add `array_reverse` function to datafusion-function-\* crate [#9630](https://github.com/apache/arrow-datafusion/pull/9630) (Weijun-H) +- Move parts of `InListSimplifier` simplify rules to `Simplifier` [#9628](https://github.com/apache/arrow-datafusion/pull/9628) (jayzhan211) +- Port Array Union and Intersect to `functions-array` [#9629](https://github.com/apache/arrow-datafusion/pull/9629) (jayzhan211) +- Port `ArrayPosition` and `ArrayPositions` to `functions-array` subcrate [#9617](https://github.com/apache/arrow-datafusion/pull/9617) (erenavsarogullari) +- Optimize make_date (#9089) [#9600](https://github.com/apache/arrow-datafusion/pull/9600) (vojtechtoman) +- Support AT TIME ZONE clause [#9647](https://github.com/apache/arrow-datafusion/pull/9647) (tinfoil-knight) +- Window Linear Mode use smaller buffers [#9597](https://github.com/apache/arrow-datafusion/pull/9597) (mustafasrepo) +- Port `ArrayExcept` to `functions-array` subcrate [#9634](https://github.com/apache/arrow-datafusion/pull/9634) (erenavsarogullari) +- chore: improve array expression doc and clean up array_expression.rs [#9650](https://github.com/apache/arrow-datafusion/pull/9650) (Weijun-H) +- Minor: remove clone in `exprlist_to_fields` [#9657](https://github.com/apache/arrow-datafusion/pull/9657) (jayzhan211) +- Port `ArrayRemove`, `ArrayRemoveN`, `ArrayRemoveAll` to `functions-array` subcrate [#9656](https://github.com/apache/arrow-datafusion/pull/9656) (erenavsarogullari) +- Minor: Remove redundant dependencies from `datafusion-functions/Cargo.toml` [#9622](https://github.com/apache/arrow-datafusion/pull/9622) (alamb) +- Support IGNORE NULLS for NTH_VALUE window function [#9625](https://github.com/apache/arrow-datafusion/pull/9625) (huaxingao) +- Improve Robustness of Unparser Testing and Implementation [#9623](https://github.com/apache/arrow-datafusion/pull/9623) (devinjdangelo) +- Adding Constant Check for FilterExec [#9649](https://github.com/apache/arrow-datafusion/pull/9649) (Lordworms) +- chore(deps-dev): bump follow-redirects from 1.15.4 to 1.15.6 in /datafusion/wasmtest/datafusion-wasm-app [#9609](https://github.com/apache/arrow-datafusion/pull/9609) (dependabot[bot]) +- move array_replace family functions to datafusion-function-array crate [#9651](https://github.com/apache/arrow-datafusion/pull/9651) (Weijun-H) +- chore: remove repetitive word `the the` --> `the` in docs / comments [#9673](https://github.com/apache/arrow-datafusion/pull/9673) (InventiveCoder) +- Update example-usage.md to remove reference to simd and rust nightly. [#9677](https://github.com/apache/arrow-datafusion/pull/9677) (Omega359) +- [MINOR]: Remove some `.unwrap`s from nth_value.rs file [#9674](https://github.com/apache/arrow-datafusion/pull/9674) (mustafasrepo) +- minor: Remove deprecated methods [#9627](https://github.com/apache/arrow-datafusion/pull/9627) (comphead) +- Migrate `arrow_cast` to a UDF [#9610](https://github.com/apache/arrow-datafusion/pull/9610) (alamb) +- parquet: Add row*groups_matched*{statistics,bloom_filter} statistics [#9640](https://github.com/apache/arrow-datafusion/pull/9640) (progval) +- Make COPY TO align with CREATE EXTERNAL TABLE [#9604](https://github.com/apache/arrow-datafusion/pull/9604) (metesynnada) +- Support "A column is known to be entirely NULL" in `PruningPredicate` [#9223](https://github.com/apache/arrow-datafusion/pull/9223) (appletreeisyellow) +- Suppress self update for windows CI runner [#9661](https://github.com/apache/arrow-datafusion/pull/9661) (jayzhan211) +- add schema to SQL ast builder [#9624](https://github.com/apache/arrow-datafusion/pull/9624) (sardination) +- core/tests/parquet/row_group_pruning.rs: Add tests for strings [#9642](https://github.com/apache/arrow-datafusion/pull/9642) (progval) +- Fix incorrect results with multiple `COUNT(DISTINCT..)` aggregates on dictionaries [#9679](https://github.com/apache/arrow-datafusion/pull/9679) (alamb) +- parquet: Add support for Bloom filters on binary columns [#9644](https://github.com/apache/arrow-datafusion/pull/9644) (progval) +- Update Arrow/Parquet to `51.0.0`, tonic to `0.11` [#9613](https://github.com/apache/arrow-datafusion/pull/9613) (tustvold) +- Move inlist rule to expr_simplifier [#9692](https://github.com/apache/arrow-datafusion/pull/9692) (jayzhan211) +- Support Serde for ScalarUDF in Physical Expressions [#9436](https://github.com/apache/arrow-datafusion/pull/9436) (yyy1000) +- Support Union types in `ScalarValue` [#9683](https://github.com/apache/arrow-datafusion/pull/9683) (avantgardnerio) +- parquet: Add support for row group pruning on FixedSizeBinary [#9646](https://github.com/apache/arrow-datafusion/pull/9646) (progval) +- Minor: Improve documentation for `LogicalPlan::expressions` [#9698](https://github.com/apache/arrow-datafusion/pull/9698) (alamb) +- Make builtin window function output datatype to be derived from schema [#9686](https://github.com/apache/arrow-datafusion/pull/9686) (comphead) +- refactor: Extract `array_to_string` and `string_to_array` from `functions-array` subcrate' s `kernels` and `udf` containers [#9704](https://github.com/apache/arrow-datafusion/pull/9704) (erenavsarogullari) +- Add Minimum Supported Rust Version policy to docs [#9681](https://github.com/apache/arrow-datafusion/pull/9681) (alamb) +- doc: Add DataFusion profiling documentation for MacOS [#9711](https://github.com/apache/arrow-datafusion/pull/9711) (comphead) +- Minor: add ticket reference to commented out test [#9715](https://github.com/apache/arrow-datafusion/pull/9715) (alamb) +- Minor: Rename path from `common_runtime` to `common-runtime` [#9717](https://github.com/apache/arrow-datafusion/pull/9717) (alamb) +- Use object_store:BufWriter to replace put_multipart [#9648](https://github.com/apache/arrow-datafusion/pull/9648) (yyy1000) +- Fix COPY TO failing on passing format options through CLI [#9709](https://github.com/apache/arrow-datafusion/pull/9709) (tinfoil-knight) +- fix: recursive cte hangs on joins [#9687](https://github.com/apache/arrow-datafusion/pull/9687) (jonahgao) +- Move `starts_with`, `to_hex`,` trim`, `upper` to datafusion-functions (and add string_expressions) [#9541](https://github.com/apache/arrow-datafusion/pull/9541) (Tangruilin) +- Support for `extract(x from time)` / `date_part` from time types [#8693](https://github.com/apache/arrow-datafusion/pull/8693) (Jefffrey) +- doc: Updated known users list and usage dependency description [#9718](https://github.com/apache/arrow-datafusion/pull/9718) (comphead) +- Minor: improve documentation for `CommonSubexprEliminate` [#9700](https://github.com/apache/arrow-datafusion/pull/9700) (alamb) +- build: modify code to comply with latest clippy requirement [#9725](https://github.com/apache/arrow-datafusion/pull/9725) (comphead) +- Minor: return internal error rather than panic on unexpected error in COUNT DISTINCT [#9712](https://github.com/apache/arrow-datafusion/pull/9712) (alamb) +- fix(9678): short circuiting prevented population of visited stack, for common subexpr elimination optimization [#9685](https://github.com/apache/arrow-datafusion/pull/9685) (wiedld) +- perf: improve to_field performance [#9722](https://github.com/apache/arrow-datafusion/pull/9722) (haohuaijin) +- Minor: Run ScalarValue size test on aarch again [#9728](https://github.com/apache/arrow-datafusion/pull/9728) (alamb) +- Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent [#9730](https://github.com/apache/arrow-datafusion/pull/9730) (Omega359) +- make format prefix optional for format options in COPY [#9723](https://github.com/apache/arrow-datafusion/pull/9723) (tinfoil-knight) +- refactor: Extract `range` and `gen_series` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9720](https://github.com/apache/arrow-datafusion/pull/9720) (erenavsarogullari) +- Move ascii function to datafusion_functions [#9740](https://github.com/apache/arrow-datafusion/pull/9740) (PsiACE) +- adding expr to string for IsNotNull IsTrue IsFalse and IsUnkown [#9739](https://github.com/apache/arrow-datafusion/pull/9739) (Lordworms) +- fix: parallel parquet can underflow when max_record_batch_rows < execution.batch_size [#9737](https://github.com/apache/arrow-datafusion/pull/9737) (devinjdangelo) +- support format in options of COPY command [#9744](https://github.com/apache/arrow-datafusion/pull/9744) (tinfoil-knight) +- Move lower, octet_length to datafusion-functions [#9747](https://github.com/apache/arrow-datafusion/pull/9747) (Omega359) +- Fixed missing trim() in rust api [#9749](https://github.com/apache/arrow-datafusion/pull/9749) (Omega359) +- refactor: Extract `array_length`, `array_reverse` and `array_sort` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9751](https://github.com/apache/arrow-datafusion/pull/9751) (erenavsarogullari) +- refactor: Extract `array_empty` and `array_repeat` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9762](https://github.com/apache/arrow-datafusion/pull/9762) (erenavsarogullari) +- Minor: remove an outdated TODO in `TypeCoercion` [#9752](https://github.com/apache/arrow-datafusion/pull/9752) (jonahgao) +- refactor: Extract `array_resize` and `cardinality` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9766](https://github.com/apache/arrow-datafusion/pull/9766) (erenavsarogullari) +- fix: change placeholder errors from Internal to Plan [#9745](https://github.com/apache/arrow-datafusion/pull/9745) (erratic-pattern) +- Move levenshtein, uuid, overlay to datafusion-functions [#9760](https://github.com/apache/arrow-datafusion/pull/9760) (Omega359) +- improve null handling for to_char [#9689](https://github.com/apache/arrow-datafusion/pull/9689) (tinfoil-knight) +- Add Expr->String for ScalarFunction and InList [#9759](https://github.com/apache/arrow-datafusion/pull/9759) (yyy1000) +- Move repeat, replace, split_part to datafusion_functions [#9784](https://github.com/apache/arrow-datafusion/pull/9784) (Omega359) +- refactor: Extract `array_dims`, `array_ndims` and `flatten` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9786](https://github.com/apache/arrow-datafusion/pull/9786) (erenavsarogullari) +- Minor: Improve documentation about `ColumnarValues::values_to_array` [#9774](https://github.com/apache/arrow-datafusion/pull/9774) (alamb) +- Fix panic in `struct` function with mixed scalar/array arguments [#9775](https://github.com/apache/arrow-datafusion/pull/9775) (alamb) +- refactor: Apply minor refactorings to `functions-array` crate [#9788](https://github.com/apache/arrow-datafusion/pull/9788) (erenavsarogullari) +- Move bit_length and chr functions to datafusion_functions [#9782](https://github.com/apache/arrow-datafusion/pull/9782) (PsiACE) +- Support tencent cloud COS storage in `datafusion-cli` [#9734](https://github.com/apache/arrow-datafusion/pull/9734) (harveyyue) +- Make it easier to register configuration extension ... [#9781](https://github.com/apache/arrow-datafusion/pull/9781) (milenkovicm) +- Expr to Sql : Case [#9798](https://github.com/apache/arrow-datafusion/pull/9798) (yyy1000) +- feat: Between expr to sql string [#9803](https://github.com/apache/arrow-datafusion/pull/9803) (sebastian2296) +- feat: Expose `array_empty` and `list_empty` functions as alias of `empty` function [#9807](https://github.com/apache/arrow-datafusion/pull/9807) (erenavsarogullari) +- Support Expr `Like` to sql [#9805](https://github.com/apache/arrow-datafusion/pull/9805) (Weijun-H) +- feat: Not expr to string [#9802](https://github.com/apache/arrow-datafusion/pull/9802) (sebastian2296) +- [Minor]: Move some repetitive codes to functions(proto) [#9811](https://github.com/apache/arrow-datafusion/pull/9811) (mustafasrepo) +- Implement IGNORE NULLS for LAST_VALUE [#9801](https://github.com/apache/arrow-datafusion/pull/9801) (huaxingao) +- [MINOR]: Move some repetitive codes to functions [#9810](https://github.com/apache/arrow-datafusion/pull/9810) (mustafasrepo) +- fix: ensure mutual compatibility of the two input schemas from recursive CTEs [#9795](https://github.com/apache/arrow-datafusion/pull/9795) (jonahgao) +- Add support for constant expression evaluation in limit [#9790](https://github.com/apache/arrow-datafusion/pull/9790) (mustafasrepo) +- Projection Pushdown through user defined LogicalPlan nodes. [#9690](https://github.com/apache/arrow-datafusion/pull/9690) (mustafasrepo) +- chore(deps): update substrait requirement from 0.27.0 to 0.28.0 [#9809](https://github.com/apache/arrow-datafusion/pull/9809) (dependabot[bot]) +- Run TPC-H SF10 during PR benchmarks [#9822](https://github.com/apache/arrow-datafusion/pull/9822) (gruuya) +- Expose `parser` on DFParser to enable user controlled parsing [#9729](https://github.com/apache/arrow-datafusion/pull/9729) (tshauck) +- Disable parallel reading for gziped ndjson file [#9799](https://github.com/apache/arrow-datafusion/pull/9799) (Lordworms) +- Optimize to_timestamp (with format) (#9090) [#9833](https://github.com/apache/arrow-datafusion/pull/9833) (vojtechtoman) +- Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function [#9825](https://github.com/apache/arrow-datafusion/pull/9825) (Omega359) +- [Minor] Update TCPDS tests, remove some #[ignore]d tests [#9829](https://github.com/apache/arrow-datafusion/pull/9829) (Dandandan) +- doc: Adding baseline benchmark example [#9827](https://github.com/apache/arrow-datafusion/pull/9827) (comphead) +- Add name method to execution plan [#9793](https://github.com/apache/arrow-datafusion/pull/9793) (matthewmturner) +- chore(deps-dev): bump express from 4.18.2 to 4.19.2 in /datafusion/wasmtest/datafusion-wasm-app [#9826](https://github.com/apache/arrow-datafusion/pull/9826) (dependabot[bot]) +- feat: pass SessionState not SessionConfig to FunctionFactory::create [#9837](https://github.com/apache/arrow-datafusion/pull/9837) (tshauck) diff --git a/dev/changelog/7.0.0.md b/dev/changelog/7.0.0.md index e63c2a4455c9..4d2606d7bfbe 100644 --- a/dev/changelog/7.0.0.md +++ b/dev/changelog/7.0.0.md @@ -56,7 +56,7 @@ - Keep all datafusion's packages up to date with Dependabot [\#1472](https://github.com/apache/arrow-datafusion/issues/1472) - ExecutionContext support init ExecutionContextState with `new(state: Arc>)` method [\#1439](https://github.com/apache/arrow-datafusion/issues/1439) - support the decimal scalar value [\#1393](https://github.com/apache/arrow-datafusion/issues/1393) -- Documentation for using scalar functions with the the DataFrame API [\#1364](https://github.com/apache/arrow-datafusion/issues/1364) +- Documentation for using scalar functions with the DataFrame API [\#1364](https://github.com/apache/arrow-datafusion/issues/1364) - Support `boolean == boolean` and `boolean != boolean` operators [\#1159](https://github.com/apache/arrow-datafusion/issues/1159) - Support DataType::Decimal\(15, 2\) in TPC-H benchmark [\#174](https://github.com/apache/arrow-datafusion/issues/174) - Make `MemoryStream` public [\#150](https://github.com/apache/arrow-datafusion/issues/150) diff --git a/conbench/benchmarks.py b/dev/depcheck/Cargo.toml similarity index 60% rename from conbench/benchmarks.py rename to dev/depcheck/Cargo.toml index f80b3add90f9..cb4e77eabb22 100644 --- a/conbench/benchmarks.py +++ b/dev/depcheck/Cargo.toml @@ -15,27 +15,11 @@ # specific language governing permissions and limitations # under the License. -import conbench.runner +# Circular dependency checker for DataFusion +[package] +name = "depcheck" -import _criterion +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -@conbench.runner.register_benchmark -class TestBenchmark(conbench.runner.Benchmark): - name = "test" - - def run(self, **kwargs): - yield self.conbench.benchmark( - self._f(), - self.name, - options=kwargs, - ) - - def _f(self): - return lambda: 1 + 1 - - -@conbench.runner.register_benchmark -class CargoBenchmarks(_criterion.CriterionBenchmark): - name = "datafusion" - description = "Run Arrow DataFusion micro benchmarks." +[dependencies] +cargo = "0.78.1" diff --git a/dev/depcheck/README.md b/dev/depcheck/README.md new file mode 100644 index 000000000000..4a628cdd88e9 --- /dev/null +++ b/dev/depcheck/README.md @@ -0,0 +1,26 @@ + + +This directory contains a tool that ensures there are no circular dependencies +in the DataFusion codebase. + +Specifically, it checks that no create's tests depend on another crate which +depends on the first, which prevents publishing to crates.io, for example + +[issue 9272]: https://github.com/apache/arrow-datafusion/issues/9277: diff --git a/datafusion/core/tests/depcheck.rs b/dev/depcheck/src/main.rs similarity index 75% rename from datafusion/core/tests/depcheck.rs rename to dev/depcheck/src/main.rs index 94448818691e..b52074c9b1d3 100644 --- a/datafusion/core/tests/depcheck.rs +++ b/dev/depcheck/src/main.rs @@ -15,18 +15,38 @@ // specific language governing permissions and limitations // under the License. +extern crate cargo; + +use cargo::CargoResult; /// Check for circular dependencies between DataFusion crates use std::collections::{HashMap, HashSet}; use std::env; use std::path::Path; use cargo::util::config::Config; -#[test] -fn test_deps() -> Result<(), Box> { + +/// Verifies that there are no circular dependencies between DataFusion crates +/// (which prevents publishing on crates.io) by parsing the Cargo.toml files and +/// checking the dependency graph. +/// +/// See https://github.com/apache/arrow-datafusion/issues/9278 for more details +fn main() -> CargoResult<()> { let config = Config::default()?; + // This is the path for the depcheck binary let path = env::var("CARGO_MANIFEST_DIR").unwrap(); - let dir = Path::new(&path); - let root_cargo_toml = dir.join("Cargo.toml"); + let root_cargo_toml = Path::new(&path) + // dev directory + .parent() + .expect("Can not find dev directory") + // project root directory + .parent() + .expect("Can not find project root directory") + .join("Cargo.toml"); + + println!( + "Checking for circular dependencies in {}", + root_cargo_toml.display() + ); let workspace = cargo::core::Workspace::new(&root_cargo_toml, &config)?; let (_, resolve) = cargo::ops::resolve_ws(&workspace)?; @@ -50,7 +70,7 @@ fn test_deps() -> Result<(), Box> { check_circular_deps(root_package, dep, &package_deps, &mut seen); } } - + println!("No circular dependencies found"); Ok(()) } diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 8678aa534baf..7b5e71bc3a1c 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -44,7 +44,7 @@ request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https ## Mailing list We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other -than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. +than the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 9d3a177be6bd..eadf4147c57e 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -237,6 +237,25 @@ If the environment variable `PARQUET_FILE` is set, the benchmark will run querie The benchmark will automatically remove any generated parquet file on exit, however, if interrupted (e.g. by CTRL+C) it will not. This can be useful for analysing the particular file after the fact, or preserving it to use with `PARQUET_FILE` in subsequent runs. +### Comparing Baselines + +By default, Criterion.rs will compare the measurements against the previous run (if any). Sometimes it's useful to keep a set of measurements around for several runs. For example, you might want to make multiple changes to the code while comparing against the master branch. For this situation, Criterion.rs supports custom baselines. + +``` + git checkout main + cargo bench --bench sql_planner -- --save-baseline main + git checkout YOUR_BRANCH + cargo bench --bench sql_planner -- --baseline main +``` + +Note: For MacOS it may be required to run `cargo bench` with `sudo` + +``` +sudo cargo bench ... +``` + +More information on [Baselines](https://bheisler.github.io/criterion.rs/book/user_guide/command_line_options.html#baselines) + ### Upstream Benchmark Suites Instructions and tooling for running upstream benchmark suites against DataFusion can be found in [benchmarks](https://github.com/apache/arrow-datafusion/tree/main/benchmarks). diff --git a/docs/source/index.rst b/docs/source/index.rst index f7c0873f3a5f..919a7ad7036f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -79,7 +79,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for .. toctree:: :maxdepth: 1 :caption: Library User Guide - + library-user-guide/index library-user-guide/using-the-sql-api library-user-guide/working-with-exprs @@ -89,6 +89,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for library-user-guide/adding-udfs library-user-guide/custom-table-providers library-user-guide/extending-operators + library-user-guide/profiling .. _toc.contributor-guide: diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index f433e026e0a2..ad210724103d 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -204,7 +204,7 @@ let df = ctx.sql(&sql).await.unwrap(); ## Adding a Window UDF -Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. +Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the proximal rows is helpful, but adds some complexity to the implementation. For example, we will declare a user defined window function that computes a moving average. diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 06cd2765d161..d30e26f1964a 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -19,7 +19,7 @@ # Catalogs, Schemas, and Tables -This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/external_dependency/catalog.rs). +This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/catalog.rs). ## General Concepts diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 9da207da68f3..11024f77e0d0 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -46,6 +46,10 @@ struct CustomExec { } impl ExecutionPlan for CustomExec { + fn name(&self) { + "CustomExec" + } + fn execute( &self, _partition: usize, diff --git a/docs/source/library-user-guide/profiling.md b/docs/source/library-user-guide/profiling.md new file mode 100644 index 000000000000..a20489496f0c --- /dev/null +++ b/docs/source/library-user-guide/profiling.md @@ -0,0 +1,63 @@ + + +# Profiling Cookbook + +The section contains examples how to perform CPU profiling for Apache Arrow DataFusion on different operating systems. + +## MacOS + +### Building a flamegraph + +- [cargo-flamegraph](https://github.com/flamegraph-rs/flamegraph) + +Test: + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --unit-test datafusion -- dataframe::tests::test_array_agg +``` + +Benchmark: + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --bench sql_planner -- --bench +``` + +Open `flamegraph.svg` file with the browser + +- dtrace with DataFusion CLI + +```bash +git clone https://github.com/brendangregg/FlameGraph.git /tmp/fg +cd datafusion-cli +CARGO_PROFILE_RELEASE_DEBUG=true cargo build --release +echo "select * from table;" >> test.sql +sudo dtrace -c './target/debug/datafusion-cli -f test.sql' -o out.stacks -n 'profile-997 /execname == "datafusion-cli"/ { @[ustack(100)] = count(); }' +/tmp/fg/FlameGraph/stackcollapse.pl out.stacks | /tmp/fg/FlameGraph/flamegraph.pl > flamegraph.svg +``` + +Open `flamegraph.svg` file with the browser + +### CPU profiling with XCode Instruments + +[Video: how to CPU profile DataFusion with XCode Instruments](https://youtu.be/P3dXH61Kr5U) + +## Linux + +## Windows diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index a94e2427eaa2..da4c9870545a 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -312,9 +312,9 @@ select count(*) from hits; CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS( - 'access_key_id' '******', - 'secret_access_key' '******', - 'region' 'us-east-2' + 'aws.access_key_id' '******', + 'aws.secret_access_key' '******', + 'aws.region' 'us-east-2' ) LOCATION 's3://bucket/path/file.parquet'; ``` @@ -365,9 +365,9 @@ Details of the environment variables that can be used are: CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS( - 'access_key_id' '******', - 'secret_access_key' '******', - 'endpoint' 'https://bucket.oss-cn-hangzhou.aliyuncs.com' + 'aws.access_key_id' '******', + 'aws.secret_access_key' '******', + 'aws.oss.endpoint' 'https://bucket.oss-cn-hangzhou.aliyuncs.com' ) LOCATION 'oss://bucket/path/file.parquet'; ``` @@ -380,6 +380,29 @@ The supported OPTIONS are: Note that the `endpoint` format of oss needs to be: `https://{bucket}.{oss-region-endpoint}` +## Registering COS Data Sources + +[Tencent cloud COS](https://cloud.tencent.com/product/cos) data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. + +```sql +CREATE EXTERNAL TABLE test +STORED AS PARQUET +OPTIONS( + 'aws.access_key_id' '******', + 'aws.secret_access_key' '******', + 'aws.cos.endpoint' 'https://cos.ap-singapore.myqcloud.com' +) +LOCATION 'cos://bucket/path/file.parquet'; +``` + +The supported OPTIONS are: + +- access_key_id +- secret_access_key +- endpoint + +Note that the `endpoint` format of urls must be: `https://cos.{cos-region-endpoint}` + ## Registering GCS Data Sources [Google Cloud Storage](https://cloud.google.com/storage) data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. @@ -388,7 +411,7 @@ Note that the `endpoint` format of oss needs to be: `https://{bucket}.{oss-regio CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS( - 'service_account_path' '/tmp/gcs.json', + 'gcp.service_account_path' '/tmp/gcs.json', ) LOCATION 'gs://bucket/path/file.parquet'; ``` diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 492be93caf0c..a95f2f802dfb 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -64,7 +64,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 36.0.0 | Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 37.0.0 | Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 1c5c8f49a16a..31b599ac3308 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -23,20 +23,20 @@ In this example some simple processing is performed on the [`example.csv`](https Even [`more code examples`](https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples) attached to the project. -## Add DataFusion as a dependency +## Add published DataFusion dependency Find latest available Datafusion version on [DataFusion's crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml -datafusion = "31" +datafusion = "latest_version" tokio = "1.0" ``` -## Add DataFusion latest codebase as a dependency +## Add latest non published DataFusion dependency -Cargo supports adding dependency directly from Github which allows testing out latest DataFusion codebase without waiting the code to be released to crates.io -according to the [DataFusion release schedule](https://github.com/apache/arrow-datafusion/blob/main/dev/release/README.md#release-process) +DataFusion changes are published to `crates.io` according to [release schedule](https://github.com/apache/arrow-datafusion/blob/main/dev/release/README.md#release-process) +In case if it is required to test out DataFusion changes which are merged but yet to be published, Cargo supports adding dependency directly to Github branch ```toml datafusion = { git = "https://github.com/apache/arrow-datafusion", branch = "main"} @@ -240,17 +240,11 @@ async fn main() -> datafusion::error::Result<()> { } ``` -Finally, in order to build with the `simd` optimization `cargo nightly` is required. - -```shell -rustup toolchain install nightly -``` - Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally with `native` or at least `avx2`. ```shell -RUSTFLAGS='-C target-cpu=native' cargo +nightly run --release +RUSTFLAGS='-C target-cpu=native' cargo run --release ``` ## Enable backtraces diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 005d2ec94229..a5fc13491677 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -217,6 +217,7 @@ select log(-1), log(0), sqrt(-1); | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | | array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | +| empty(array) | Returns true for an empty array or false for a non-empty array. `empty([1]) -> false` | | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index ae2684699726..be15848407a2 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -96,6 +96,7 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine +- [Comet](https://github.com/apache/arrow-datafusion-comet) Apache Spark native query execution plugin - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python @@ -115,6 +116,7 @@ Here are some active projects using DataFusion: - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await - [ROAPI](https://github.com/roapi/roapi) - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database +- [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine - [Synnada](https://synnada.ai/) Streaming-first framework for data products - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar - [ZincObserve](https://github.com/zinclabs/zincobserve) Distributed cloud native observability platform @@ -139,12 +141,13 @@ Here are some less active projects that used DataFusion: [kamu]: https://github.com/kamu-data/kamu-cli [greptime db]: https://github.com/GreptimeTeam/greptimedb [horaedb]: https://github.com/apache/incubator-horaedb -[influxdb iox]: https://github.com/influxdata/influxdb_iox +[influxdb]: https://github.com/influxdata/influxdb [parseable]: https://github.com/parseablehq/parseable [prql-query]: https://github.com/prql/prql-query [qv]: https://github.com/timvw/qv [roapi]: https://github.com/roapi/roapi [seafowl]: https://github.com/splitgraph/seafowl +[spice.ai]: https://github.com/spiceai/spiceai [synnada]: https://synnada.ai/ [tensorbase]: https://github.com/tensorbase/tensorbase [vegafusion]: https://vegafusion.io/ @@ -166,6 +169,7 @@ provide integrations with other systems, some of which are described below: - [datafusion-bigtable](https://github.com/datafusion-contrib/datafusion-bigtable) - [datafusion-catalogprovider-glue](https://github.com/datafusion-contrib/datafusion-catalogprovider-glue) +- [datafusion-federation](https://github.com/datafusion-contrib/datafusion-federation) ## Why DataFusion? diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 405e77a21b26..79c36092fd3d 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -25,11 +25,14 @@ and modifying data in tables. ## COPY Copies the contents of a table or query to file(s). Supported file -formats are `parquet`, `csv`, and `json` and can be inferred based on -filename if writing to a single file. +formats are `parquet`, `csv`, `json`, and `arrow`.
-COPY { table_name | query } TO 'file_name' [ ( option [, ... ] ) ]
+COPY { table_name | query } 
+TO 'file_name'
+[ STORED AS format ]
+[ PARTITIONED BY column_name [, ...] ]
+[ OPTIONS( option [, ... ] ) ]
 
For a detailed list of valid OPTIONS, see [Write Options](write_options). @@ -49,7 +52,7 @@ Copy the contents of `source_table` to one or more Parquet formatted files in the `dir_name` directory: ```sql -> COPY source_table TO 'dir_name' (FORMAT parquet); +> COPY source_table TO 'dir_name' STORED AS PARQUET; +-------+ | count | +-------+ @@ -61,7 +64,7 @@ Copy the contents of `source_table` to multiple directories of hive-style partitioned parquet files: ```sql -> COPY source_table TO 'dir_name' (FORMAT parquet, partition_by 'column1, column2'); +> COPY source_table TO 'dir_name' STORED AS parquet, PARTITIONED BY (column1, column2); +-------+ | count | +-------+ @@ -74,7 +77,7 @@ results (maintaining the order) to a parquet file named `output.parquet` with a maximum parquet row group size of 10MB: ```sql -> COPY (SELECT * from source ORDER BY time) TO 'output.parquet' (ROW_GROUP_LIMIT_BYTES 10000000); +> COPY (SELECT * from source ORDER BY time) TO 'output.parquet' OPTIONS (MAX_ROW_GROUP_SIZE 10000000); +-------+ | count | +-------+ @@ -82,6 +85,12 @@ results (maintaining the order) to a parquet file named +-------+ ``` +The output format is determined by the first match of the following rules: + +1. Value of `STORED AS` +2. Value of the `OPTION (FORMAT ..)` +3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) + ## INSERT Insert values into a table. diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index b63fa9950ae0..e2e129a2e2d1 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -731,12 +731,15 @@ btrim(str[, trim_str]) Can be a constant, column, or function, and any combination of string operators. - **trim_str**: String expression to trim from the beginning and end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters_. + _Default is whitespace characters._ **Related functions**: [ltrim](#ltrim), -[rtrim](#rtrim), -[trim](#trim) +[rtrim](#rtrim) + +#### Aliases + +- trim ### `char_length` @@ -919,26 +922,25 @@ lpad(str, n[, padding_str]) ### `ltrim` -Removes leading spaces from a string. +Trims the specified trim string from the beginning of a string. +If no trim string is provided, all whitespace is removed from the start +of the input string. ``` -ltrim(str) +ltrim(str[, trim_str]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **trim_str**: String expression to trim from the beginning of the input string. + Can be a constant, column, or function, and any combination of arithmetic operators. + _Default is whitespace characters._ **Related functions**: [btrim](#btrim), -[rtrim](#rtrim), -[trim](#trim) - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +[rtrim](#rtrim) ### `octet_length` @@ -1040,21 +1042,25 @@ rpad(str, n[, padding_str]) ### `rtrim` -Removes trailing spaces from a string. +Trims the specified trim string from the end of a string. +If no trim string is provided, all whitespace is removed from the end +of the input string. ``` -rtrim(str) +rtrim(str[, trim_str]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **trim_str**: String expression to trim from the end of the input string. + Can be a constant, column, or function, and any combination of arithmetic operators. + _Default is whitespace characters._ **Related functions**: [btrim](#btrim), -[ltrim](#ltrim), -[trim](#trim) +[ltrim](#ltrim) ### `split_part` @@ -1154,21 +1160,7 @@ to_hex(int) ### `trim` -Removes leading and trailing spaces from a string. - -``` -trim(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[btrim](#btrim), -[ltrim](#ltrim), -[rtrim](#rtrim) +_Alias of [btrim](#btrim)._ ### `upper` @@ -1624,34 +1616,19 @@ _Alias of [date_part](#date_part)._ ### `extract` Returns a sub-field from a time value as an integer. -Similar to `date_part`, but with different arguments. ``` extract(field FROM source) ``` -#### Arguments - -- **field**: Part or field of the date to return. - The following date fields are supported: +Equivalent to calling `date_part('field', source)`. For example, these are equivalent: - - year - - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - - month - - week _(week of the year)_ - - day _(day of the month)_ - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow _(day of the week)_ - - doy _(day of the year)_ - - epoch _(seconds since Unix epoch)_ +```sql +extract(day FROM '2024-04-13'::date) +date_part('day', '2024-04-13'::date) +``` -- **source**: Source time expression to operate on. - Can be a constant, column, or function. +See [date_part](#date_part). ### `make_date` @@ -1954,6 +1931,7 @@ from_unixtime(expression) - [array_has_all](#array_has_all) - [array_has_any](#array_has_any) - [array_element](#array_element) +- [array_empty](#array_empty) - [array_except](#array_except) - [array_extract](#array_extract) - [array_fill](#array_fill) @@ -3032,6 +3010,11 @@ empty(array) +------------------+ ``` +#### Aliases + +- array_empty, +- list_empty + ### `generate_series` Similar to the range function, but it includes the upper bound. @@ -3061,10 +3044,6 @@ generate_series(start, stop, step) _Alias of [array_append](#array_append)._ -### `list_sort` - -_Alias of [array_sort](#array_sort)._ - ### `list_cat` _Alias of [array_concat](#array_concat)._ @@ -3085,6 +3064,10 @@ _Alias of [array_dims](#array_distinct)._ _Alias of [array_element](#array_element)._ +### `list_empty` + +_Alias of [empty](#empty)._ + ### `list_except` _Alias of [array_element](#array_except)._ @@ -3193,13 +3176,17 @@ _Alias of [array_reverse](#array_reverse)._ _Alias of [array_slice](#array_slice)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_to_string` _Alias of [array_to_string](#array_to_string)._ ### `list_union` -_Alias of [array_to_string](#array_union)._ +_Alias of [array_union](#array_union)._ ### `make_array` @@ -3209,6 +3196,10 @@ Returns an Arrow array using the specified input expressions. make_array(expression1[, ..., expression_n]) ``` +### `array_empty` + +_Alias of [empty](#empty)._ + #### Arguments - **expression_n**: Expression to include in the output array. @@ -3321,11 +3312,12 @@ are not allowed ## Struct Functions - [struct](#struct) +- [named_struct](#named_struct) ### `struct` -Returns an Arrow struct using the specified input expressions. -Fields in the returned struct use the `cN` naming convention. +Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. For example: `c0`, `c1`, `c2`, etc. ``` @@ -3333,7 +3325,7 @@ struct(expression1[, ..., expression_n]) ``` For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `c0` and `c1`: +a struct type of fields `field_a` and `c1`: ``` select * from t; @@ -3344,18 +3336,55 @@ select * from t; | 3 | 4 | +---+---+ -select struct(a, b) from t; -+-----------------+ -| struct(t.a,t.b) | -+-----------------+ -| {c0: 1, c1: 2} | -| {c0: 3, c1: 4} | -+-----------------+ +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ ``` #### Arguments - **expression_n**: Expression to include in the output struct. + Can be a constant, column, or function, any combination of arithmetic or + string operators, or a named expression of previous listed . + +### `named_struct` + +Returns an Arrow struct using the specified name and input expressions pairs. + +``` +named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) +``` + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: + +``` +select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +``` + +#### Arguments + +- **expression_n_name**: Name of the column field. + Must be a constant string. +- **expression_n_input**: Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators.