diff --git a/Cargo.toml b/Cargo.toml index ff231178a2b3..ec8e6621c2a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ exclude = ["datafusion-cli"] members = [ "datafusion/common", + "datafusion/common_runtime", "datafusion/core", "datafusion/expr", "datafusion/execution", @@ -72,6 +73,7 @@ ctor = "0.2.0" dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common_runtime", version = "36.0.0" } datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } diff --git a/README.md b/README.md index 634aa426bdff..e5ac9503be44 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ [API Docs](https://docs.rs/datafusion/latest/datafusion/) | [Chat](https://discord.com/channels/885562378132000778/885562378132000781) -logo +logo DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in [Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8ef6c6acbbcc..665c45651863 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1112,6 +1112,7 @@ dependencies = [ "chrono", "dashmap", "datafusion-common", + "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-functions", @@ -1193,6 +1194,13 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datafusion-common-runtime" +version = "36.0.0" +dependencies = [ + "tokio", +] + [[package]] name = "datafusion-execution" version = "36.0.0" @@ -1317,6 +1325,7 @@ dependencies = [ "async-trait", "chrono", "datafusion-common", + "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", diff --git a/datafusion/common_runtime/Cargo.toml b/datafusion/common_runtime/Cargo.toml new file mode 100644 index 000000000000..7ed8b2cf2975 --- /dev/null +++ b/datafusion/common_runtime/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-common-runtime" +description = "Common Runtime functionality for DataFusion query engine" +keywords = ["arrow", "query", "sql"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_common_runtime" +path = "src/lib.rs" + +[dependencies] +tokio = { workspace = true } diff --git a/datafusion/common_runtime/README.md b/datafusion/common_runtime/README.md new file mode 100644 index 000000000000..77100e52603c --- /dev/null +++ b/datafusion/common_runtime/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Common Runtime + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides common utilities. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/common_runtime/src/common.rs b/datafusion/common_runtime/src/common.rs new file mode 100644 index 000000000000..88b74448c7a8 --- /dev/null +++ b/datafusion/common_runtime/src/common.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; + +use tokio::task::{JoinError, JoinSet}; + +/// Helper that provides a simple API to spawn a single task and join it. +/// Provides guarantees of aborting on `Drop` to keep it cancel-safe. +/// +/// Technically, it's just a wrapper of `JoinSet` (with size=1). +#[derive(Debug)] +pub struct SpawnedTask { + inner: JoinSet, +} + +impl SpawnedTask { + pub fn spawn(task: T) -> Self + where + T: Future, + T: Send + 'static, + R: Send, + { + let mut inner = JoinSet::new(); + inner.spawn(task); + Self { inner } + } + + pub fn spawn_blocking(task: T) -> Self + where + T: FnOnce() -> R, + T: Send + 'static, + R: Send, + { + let mut inner = JoinSet::new(); + inner.spawn_blocking(task); + Self { inner } + } + + pub async fn join(mut self) -> Result { + self.inner + .join_next() + .await + .expect("`SpawnedTask` instance always contains exactly 1 task") + } +} diff --git a/datafusion/common_runtime/src/lib.rs b/datafusion/common_runtime/src/lib.rs new file mode 100644 index 000000000000..e8624163f224 --- /dev/null +++ b/datafusion/common_runtime/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod common; + +pub use common::SpawnedTask; diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 662d95a9323c..0c378d9d83f5 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -89,6 +89,7 @@ bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } dashmap = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index d7c31b9bd6b3..3bdf2af4552d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1510,6 +1510,7 @@ mod tests { use arrow::array::{self, Int32Array}; use arrow::datatypes::DataType; use datafusion_common::{Constraint, Constraints}; + use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, @@ -2169,15 +2170,14 @@ mod tests { } #[tokio::test] - #[allow(clippy::disallowed_methods)] async fn sendable() { let df = test_table().await.unwrap(); // dataframes should be sendable between threads/tasks - let task = tokio::task::spawn(async move { + let task = SpawnedTask::spawn(async move { df.select_columns(&["c1"]) .expect("should be usable in a task") }); - task.await.expect("task completed successfully"); + task.join().await.expect("task completed successfully"); } #[tokio::test] diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 739850115370..4ea6c2a273f1 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -40,9 +40,9 @@ use arrow::datatypes::SchemaRef; use arrow::datatypes::{Fields, Schema}; use bytes::{BufMut, BytesMut}; use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; -use datafusion_physical_plan::common::SpawnedTask; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; use object_store::path::Path; diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index d70b4811da5b..396da96332f6 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -33,7 +33,7 @@ use arrow_array::{downcast_dictionary_array, RecordBatch, StringArray, StructArr use arrow_schema::{DataType, Schema}; use datafusion_common::cast::as_string_array; use datafusion_common::{exec_datafusion_err, DataFusionError}; - +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use futures::StreamExt; @@ -41,7 +41,6 @@ use object_store::path::Path; use rand::distributions::DistString; -use datafusion_physical_plan::common::SpawnedTask; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; type RecordBatchReceiver = Receiver; diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 05406d3751c9..dd0e5ce6a40e 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -30,10 +30,10 @@ use crate::physical_plan::SendableRecordBatchStream; use arrow_array::RecordBatch; use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use bytes::Bytes; -use datafusion_physical_plan::common::SpawnedTask; use futures::try_join; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 6dc59e4a5c65..0d91b1cba34d 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -31,9 +31,9 @@ use async_trait::async_trait; use futures::StreamExt; use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{CreateExternalTable, Expr, TableType}; -use datafusion_physical_plan::common::SpawnedTask; use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 3aa4edfe3adc..2d964d29688c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2222,6 +2222,7 @@ mod tests { use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; use async_trait::async_trait; + use datafusion_common_runtime::SpawnedTask; use datafusion_expr::Expr; use std::env; use std::path::PathBuf; @@ -2321,7 +2322,6 @@ mod tests { } #[tokio::test] - #[allow(clippy::disallowed_methods)] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded // environment. Usecase is for concurrent planing. @@ -2332,7 +2332,7 @@ mod tests { let threads: Vec<_> = (0..2) .map(|_| ctx.clone()) .map(|ctx| { - tokio::spawn(async move { + SpawnedTask::spawn(async move { // Ensure we can create logical plan code on a separate thread. ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3") .await @@ -2341,7 +2341,7 @@ mod tests { .collect(); for handle in threads { - handle.await.unwrap().unwrap(); + handle.join().await.unwrap().unwrap(); } Ok(()) } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index d78d7a38a1c3..2b565ece7568 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -480,6 +480,11 @@ pub use parquet; /// re-export of [`datafusion_common`] crate pub mod common { pub use datafusion_common::*; + + /// re-export of [`datafusion_common_runtime`] crate + pub mod runtime { + pub use datafusion_common_runtime::*; + } } // Backwards compatibility @@ -524,7 +529,7 @@ pub mod functions { /// re-export of [`datafusion_functions_array`] crate, if "array_expressions" feature is enabled pub mod functions_array { #[cfg(feature = "array_expressions")] - pub use datafusion_functions::*; + pub use datafusion_functions_array::*; } #[cfg(test)] diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 1cab4d5c2f98..ee5e34bd703f 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -30,6 +30,7 @@ use datafusion::physical_plan::windows::{ use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; +use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, @@ -123,8 +124,7 @@ async fn window_bounded_window_random_comparison() -> Result<()> { for i in 0..n { let idx = i % test_cases.len(); let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - let job = tokio::spawn(run_window_test( + let job = SpawnedTask::spawn(run_window_test( make_staggered_batches::(1000, n_distinct, i as u64), i as u64, pb_cols, @@ -134,7 +134,7 @@ async fn window_bounded_window_random_comparison() -> Result<()> { handles.push(job); } for job in handles { - job.await.unwrap()?; + job.join().await.unwrap()?; } } Ok(()) diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 8039a211c9e4..f0ce61ee9bbb 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -78,10 +78,11 @@ impl GetFieldAccessSchema { Self::ListIndex{ key_dt } => { match (data_type, key_dt) { (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), - (DataType::List(_), _) => plan_err!( - "Only ints are valid as an indexed field in a list" + (DataType::LargeList(lt), DataType::Int64) => Ok(Field::new("large_list", lt.data_type().clone(), true)), + (DataType::List(_), _) | (DataType::LargeList(_), _) => plan_err!( + "Only ints are valid as an indexed field in a List/LargeList" ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), } } Self::ListRange { start_dt, stop_dt, stride_dt } => { @@ -89,7 +90,7 @@ impl GetFieldAccessSchema { (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), (DataType::LargeList(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("large_list", data_type.clone(), true)), (DataType::List(_), _, _, _) | (DataType::LargeList(_), _, _, _)=> plan_err!( - "Only ints are valid as an indexed field in a list" + "Only ints are valid as an indexed field in a List/LargeList" ), (other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 773387bf7421..c93090c4946f 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -252,14 +252,14 @@ impl PhysicalExpr for GetIndexedFieldExpr { GetFieldAccessExpr::ListIndex{key} => { let key = key.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), key.data_type()) { - (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ + (DataType::List(_), DataType::Int64) | (DataType::LargeList(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ array, key ])?)), - (DataType::List(_), key) => exec_err!( - "get indexed field is only possible on lists with int64 indexes. \ + (DataType::List(_), key) | (DataType::LargeList(_), key) => exec_err!( + "get indexed field is only possible on List/LargeList with int64 indexes. \ Tried with {key:?} index"), (dt, key) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ + "get indexed field is only possible on List/LargeList with int64 indexes or struct \ with utf8 indexes. Tried {dt:?} with {key:?} index"), } }, diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index b4621109d2b1..72ee4fb3ef7e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -43,6 +43,7 @@ arrow-schema = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } +datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 47cdf3e400e3..656bffd4a799 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -36,9 +36,8 @@ use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use futures::{Future, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; -use tokio::task::{JoinError, JoinSet}; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -172,46 +171,6 @@ pub fn compute_record_batch_statistics( } } -/// Helper that provides a simple API to spawn a single task and join it. -/// Provides guarantees of aborting on `Drop` to keep it cancel-safe. -/// -/// Technically, it's just a wrapper of `JoinSet` (with size=1). -#[derive(Debug)] -pub struct SpawnedTask { - inner: JoinSet, -} - -impl SpawnedTask { - pub fn spawn(task: T) -> Self - where - T: Future, - T: Send + 'static, - R: Send, - { - let mut inner = JoinSet::new(); - inner.spawn(task); - Self { inner } - } - - pub fn spawn_blocking(task: T) -> Self - where - T: FnOnce() -> R, - T: Send + 'static, - R: Send, - { - let mut inner = JoinSet::new(); - inner.spawn_blocking(task); - Self { inner } - } - - pub async fn join(mut self) -> Result { - self.inner - .join_next() - .await - .expect("`SpawnedTask` instance always contains exactly 1 task") - } -} - /// Transposes the given vector of vectors. pub fn transpose(original: Vec>) -> Vec> { match original.as_slice() { diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 006cd646b0ca..b527466493a8 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -264,7 +264,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// [`spawn`]: tokio::task::spawn /// [`JoinSet`]: tokio::task::JoinSet - /// [`SpawnedTask`]: crate::common::SpawnedTask + /// [`SpawnedTask`]: datafusion_common_runtime::SpawnedTask /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder /// /// # Implementation Examples diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index fe93ea131506..7ac70949f893 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -29,7 +29,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use crate::common::{transpose, SpawnedTask}; +use crate::common::transpose; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{ @@ -42,6 +42,7 @@ use arrow::array::{ArrayRef, UInt64Builder}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; @@ -946,7 +947,6 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use futures::FutureExt; - use tokio::task::JoinHandle; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1060,10 +1060,9 @@ mod tests { } #[tokio::test] - #[allow(clippy::disallowed_methods)] async fn many_to_many_round_robin_within_tokio_task() -> Result<()> { - let join_handle: JoinHandle>>> = - tokio::spawn(async move { + let handle: SpawnedTask>>> = + SpawnedTask::spawn(async move { // define input partitions let schema = test_schema(); let partition = create_vec_batches(50); @@ -1074,7 +1073,7 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await }); - let output_partitions = join_handle.await.unwrap().unwrap(); + let output_partitions = handle.join().await.unwrap().unwrap(); assert_eq!(5, output_partitions.len()); assert_eq!(30, output_partitions[0].len()); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index f46958663252..5b0f2f354824 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -27,7 +27,7 @@ use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; -use crate::common::{spawn_buffered, IPCWriter, SpawnedTask}; +use crate::common::{spawn_buffered, IPCWriter}; use crate::expressions::PhysicalSortExpr; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -46,6 +46,7 @@ use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{ human_readable_size, MemoryConsumer, MemoryReservation, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 41dc6fef1924..f17e39d02f06 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -751,6 +751,7 @@ message AggregateUDFExprNode { message ScalarUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + optional bytes fun_definition = 3; } enum BuiltInWindowFunction { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d4abb9ed9c6f..610c533d574c 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -16,6 +16,7 @@ // under the License. //! Serialization / Deserialization to Bytes +use crate::logical_plan::to_proto::serialize_expr; use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; @@ -87,8 +88,8 @@ pub trait Serializeable: Sized { impl Serializeable for Expr { fn to_bytes(&self) -> Result { let mut buffer = BytesMut::new(); - let protobuf: protobuf::LogicalExprNode = self - .try_into() + let extension_codec = DefaultLogicalExtensionCodec {}; + let protobuf: protobuf::LogicalExprNode = serialize_expr(self, &extension_codec) .map_err(|e| plan_datafusion_err!("Error encoding expr as protobuf: {e}"))?; protobuf @@ -177,7 +178,8 @@ impl Serializeable for Expr { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - logical_plan::from_proto::parse_expr(&protobuf, registry) + let extension_codec = DefaultLogicalExtensionCodec {}; + logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0def1e3b3586..83b19013c77d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -23372,6 +23372,9 @@ impl serde::Serialize for ScalarUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -23379,6 +23382,10 @@ impl serde::Serialize for ScalarUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } struct_ser.end() } } @@ -23392,12 +23399,15 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { "fun_name", "funName", "args", + "fun_definition", + "funDefinition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { FunName, Args, + FunDefinition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23421,6 +23431,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23442,6 +23453,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut fun_definition__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -23456,11 +23468,20 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(ScalarUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b24fcf6680f8..2eeee3d10e8c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -895,6 +895,8 @@ pub struct ScalarUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 66d52adce71d..cb27b63b53cc 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -76,6 +76,8 @@ use datafusion_expr::{ expr::{Alias, Placeholder}, }; +use super::LogicalExtensionCodec; + #[derive(Debug)] pub enum Error { General(String), @@ -973,6 +975,7 @@ pub fn parse_i32_to_aggregate_function(value: &i32) -> Result Result { use protobuf::{logical_expr_node::ExprType, window_expr_node, ScalarFunction}; @@ -987,7 +990,7 @@ pub fn parse_expr( let operands = binary_expr .operands .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?; if operands.len() < 2 { @@ -1006,8 +1009,12 @@ pub fn parse_expr( .expect("Binary expression could not be reduced to a single expression.")) } ExprType::GetIndexedField(get_indexed_field) => { - let expr = - parse_required_expr(get_indexed_field.expr.as_deref(), registry, "expr")?; + let expr = parse_required_expr( + get_indexed_field.expr.as_deref(), + registry, + "expr", + codec, + )?; let field = match &get_indexed_field.field { Some(protobuf::get_indexed_field::Field::NamedStructField( named_struct_field, @@ -1024,6 +1031,7 @@ pub fn parse_expr( list_index.key.as_deref(), registry, "key", + codec, )?), } } @@ -1033,16 +1041,19 @@ pub fn parse_expr( list_range.start.as_deref(), registry, "start", + codec, )?), stop: Box::new(parse_required_expr( list_range.stop.as_deref(), registry, "stop", + codec, )?), stride: Box::new(parse_required_expr( list_range.stride.as_deref(), registry, "stride", + codec, )?), } } @@ -1067,12 +1078,12 @@ pub fn parse_expr( let partition_by = expr .partition_by .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; let mut order_by = expr .order_by .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; let window_frame = expr .window_frame @@ -1100,7 +1111,7 @@ pub fn parse_expr( datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), - vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], + vec![parse_required_expr(expr.expr.as_deref(), registry, "expr", codec)?], partition_by, order_by, window_frame, @@ -1112,9 +1123,10 @@ pub fn parse_expr( .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( @@ -1129,9 +1141,10 @@ pub fn parse_expr( } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = registry.udaf(udaf_name)?; - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, @@ -1145,9 +1158,10 @@ pub fn parse_expr( } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = registry.udwf(udwf_name)?; - let args = parse_optional_expr(expr.expr.as_deref(), registry)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = + parse_optional_expr(expr.expr.as_deref(), registry, codec)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, @@ -1168,15 +1182,16 @@ pub fn parse_expr( fun, expr.expr .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?, expr.distinct, - parse_optional_expr(expr.filter.as_deref(), registry)?.map(Box::new), - parse_vec_expr(&expr.order_by, registry)?, + parse_optional_expr(expr.filter.as_deref(), registry, codec)? + .map(Box::new), + parse_vec_expr(&expr.order_by, registry, codec)?, ))) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, alias .relation .first() @@ -1188,90 +1203,118 @@ pub fn parse_expr( is_null.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr")?, + parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", + codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr")?, + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), registry, "expr", + codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), registry, "expr", + codec, )?), Box::new(parse_required_expr( between.high.as_deref(), registry, "expr", + codec, )?), ))), ExprType::Like(like) => Ok(Expr::Like(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, false, ))), ExprType::Ilike(like) => Ok(Expr::Like(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, true, ))), ExprType::SimilarTo(like) => Ok(Expr::SimilarTo(Like::new( like.negated, - Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + like.expr.as_deref(), + registry, + "expr", + codec, + )?), Box::new(parse_required_expr( like.pattern.as_deref(), registry, "pattern", + codec, )?), parse_escape_char(&like.escape_char)?, false, @@ -1281,44 +1324,66 @@ pub fn parse_expr( .when_then_expr .iter() .map(|e| { - let when_expr = - parse_required_expr(e.when_expr.as_ref(), registry, "when_expr")?; - let then_expr = - parse_required_expr(e.then_expr.as_ref(), registry, "then_expr")?; + let when_expr = parse_required_expr( + e.when_expr.as_ref(), + registry, + "when_expr", + codec, + )?; + let then_expr = parse_required_expr( + e.then_expr.as_ref(), + registry, + "then_expr", + codec, + )?; Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry)?.map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), registry, codec)? + .map(Box::new), ))) } ExprType::Cast(cast) => { - let expr = - Box::new(parse_required_expr(cast.expr.as_deref(), registry, "expr")?); + let expr = Box::new(parse_required_expr( + cast.expr.as_deref(), + registry, + "expr", + codec, + )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::Cast(Cast::new(expr, data_type))) } ExprType::TryCast(cast) => { - let expr = - Box::new(parse_required_expr(cast.expr.as_deref(), registry, "expr")?); + let expr = Box::new(parse_required_expr( + cast.expr.as_deref(), + registry, + "expr", + codec, + )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast(TryCast::new(expr, data_type))) } ExprType::Sort(sort) => Ok(Expr::Sort(Sort::new( - Box::new(parse_required_expr(sort.expr.as_deref(), registry, "expr")?), + Box::new(parse_required_expr( + sort.expr.as_deref(), + registry, + "expr", + codec, + )?), sort.asc, sort.nulls_first, ))), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr")?, + parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { let exprs = unnest .exprs .iter() - .map(|e| parse_expr(e, registry)) + .map(|e| parse_expr(e, registry, codec)) .collect::, _>>()?; Ok(Expr::Unnest(Unnest { exprs })) } @@ -1327,11 +1392,12 @@ pub fn parse_expr( in_list.expr.as_deref(), registry, "expr", + codec, )?), in_list .list .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, in_list.negated, ))), @@ -1349,317 +1415,357 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Asinh => Ok(asinh(parse_expr(&args[0], registry)?)), - ScalarFunction::Acosh => Ok(acosh(parse_expr(&args[0], registry)?)), + ScalarFunction::Asinh => { + Ok(asinh(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Acosh => { + Ok(acosh(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Array => Ok(array( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ArrayAppend => Ok(array_append( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArraySort => Ok(array_sort( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayPopFront => { - Ok(array_pop_front(parse_expr(&args[0], registry)?)) + Ok(array_pop_front(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayPopBack => { - Ok(array_pop_back(parse_expr(&args[0], registry)?)) + Ok(array_pop_back(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayPrepend => Ok(array_prepend( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayConcat => Ok(array_concat( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ArrayExcept => Ok(array_except( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHasAll => Ok(array_has_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHasAny => Ok(array_has_any( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayHas => Ok(array_has( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayIntersect => Ok(array_intersect( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayPosition => Ok(array_position( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayPositions => Ok(array_positions( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRepeat => Ok(array_repeat( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRemove => Ok(array_remove( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayRemoveN => Ok(array_remove_n( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayRemoveAll => Ok(array_remove_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayReplace => Ok(array_replace( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayReplaceN => Ok(array_replace_n( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, - parse_expr(&args[3], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, + parse_expr(&args[3], registry, codec)?, )), ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::ArrayReverse => { - Ok(array_reverse(parse_expr(&args[0], registry)?)) + Ok(array_reverse(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArraySlice => Ok(array_slice( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, - parse_expr(&args[3], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, + parse_expr(&args[3], registry, codec)?, )), ScalarFunction::Cardinality => { - Ok(cardinality(parse_expr(&args[0], registry)?)) + Ok(cardinality(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayLength => Ok(array_length( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayDims => { - Ok(array_dims(parse_expr(&args[0], registry)?)) + Ok(array_dims(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayDistinct => { - Ok(array_distinct(parse_expr(&args[0], registry)?)) + Ok(array_distinct(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayElement => Ok(array_element( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayEmpty => { - Ok(array_empty(parse_expr(&args[0], registry)?)) + Ok(array_empty(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayNdims => { - Ok(array_ndims(parse_expr(&args[0], registry)?)) + Ok(array_ndims(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::ArrayUnion => Ok(array_union( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::ArrayResize => Ok(array_resize( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), - ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), - ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), - ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry)?)), - ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry)?)), - ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry)?)), - ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry)?)), - ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry)?)), - ScalarFunction::Atanh => Ok(atanh(parse_expr(&args[0], registry)?)), - ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry)?)), - ScalarFunction::Degrees => Ok(degrees(parse_expr(&args[0], registry)?)), - ScalarFunction::Radians => Ok(radians(parse_expr(&args[0], registry)?)), - ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry)?)), - ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry)?)), - ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)), - ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)), + ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Atanh => { + Ok(atanh(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Degrees => { + Ok(degrees(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Radians => { + Ok(radians(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Log10 => { + Ok(log10(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Floor => { + Ok(floor(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Factorial => { - Ok(factorial(parse_expr(&args[0], registry)?)) + Ok(factorial(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)), + ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Round => Ok(round( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Trunc => Ok(trunc( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), - ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)), + ScalarFunction::Signum => { + Ok(signum(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::OctetLength => { - Ok(octet_length(parse_expr(&args[0], registry)?)) + Ok(octet_length(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Lower => { + Ok(lower(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Upper => { + Ok(upper(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::Ltrim => { + Ok(ltrim(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Rtrim => { + Ok(rtrim(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha224 => { + Ok(sha224(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], registry)?)), - ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], registry)?)), - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry)?)), - ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], registry)?)), - ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), - ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), - ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Sha256 => { + Ok(sha256(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha384 => { + Ok(sha384(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Sha512 => { + Ok(sha512(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Digest => Ok(digest( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::Ascii => { + Ok(ascii(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::BitLength => { - Ok(bit_length(parse_expr(&args[0], registry)?)) + Ok(bit_length(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry)?)) + Ok(character_length(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry, codec)?)), + ScalarFunction::InitCap => { + Ok(initcap(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), - ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0], registry)?)), ScalarFunction::InStr => Ok(instr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Gcd => Ok(gcd( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Lcm => Ok(lcm( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Random => Ok(random()), ScalarFunction::Uuid => Ok(uuid()), ScalarFunction::Repeat => Ok(repeat( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Replace => Ok(replace( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0], registry)?)), + ScalarFunction::Reverse => { + Ok(reverse(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Concat => Ok(concat_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Lpad => Ok(lpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Rpad => Ok(rpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::RegexpLike => Ok(regexp_like( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::RegexpReplace => Ok(regexp_replace( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Btrim => Ok(btrim( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::SplitPart => Ok(split_part( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::EndsWith => Ok(ends_with( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Substr => { if args.len() > 2 { assert_eq!(args.len(), 3); Ok(substring( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )) } else { Ok(substr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )) } } ScalarFunction::Levenshtein => Ok(levenshtein( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), + ScalarFunction::ToHex => { + Ok(to_hex(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::MakeDate => { let args: Vec<_> = args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::>()?; Ok(Expr::ScalarFunction(expr::ScalarFunction::new( BuiltinScalarFunction::MakeDate, @@ -1669,7 +1775,7 @@ pub fn parse_expr( ScalarFunction::ToChar => { let args: Vec<_> = args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::>()?; Ok(Expr::ScalarFunction(expr::ScalarFunction::new( BuiltinScalarFunction::ToChar, @@ -1678,75 +1784,86 @@ pub fn parse_expr( } ScalarFunction::Now => Ok(now()), ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::Coalesce => Ok(coalesce( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Log => Ok(log( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::FromUnixtime => { - Ok(from_unixtime(parse_expr(&args[0], registry)?)) + Ok(from_unixtime(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Atan2 => Ok(atan2( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::CurrentDate => Ok(current_date()), ScalarFunction::CurrentTime => Ok(current_time()), - ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)), + ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), + ScalarFunction::Iszero => { + Ok(iszero(parse_expr(&args[0], registry, codec)?)) + } ScalarFunction::ArrowTypeof => { - Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + Ok(arrow_typeof(parse_expr(&args[0], registry, codec)?)) + } + ScalarFunction::Flatten => { + Ok(flatten(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), ScalarFunction::StringToArray => Ok(string_to_array( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::OverLay => Ok(overlay( args.to_owned() .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), ScalarFunction::SubstrIndex => Ok(substr_index( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - parse_expr(&args[2], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, + parse_expr(&args[2], registry, codec)?, )), ScalarFunction::FindInSet => Ok(find_in_set( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, + parse_expr(&args[0], registry, codec)?, + parse_expr(&args[1], registry, codec)?, )), ScalarFunction::StructFun => { - Ok(struct_fun(parse_expr(&args[0], registry)?)) + Ok(struct_fun(parse_expr(&args[0], registry, codec)?)) } } } - ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { - let scalar_fn = registry.udf(fun_name.as_str())?; + ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name, + args, + fun_definition, + }) => { + let scalar_fn = match fun_definition { + Some(buf) => codec.try_decode_udf(fun_name, buf)?, + None => registry.udf(fun_name.as_str())?, + }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))) } @@ -1757,11 +1874,11 @@ pub fn parse_expr( agg_fn, pb.args .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, false, - parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), - parse_vec_expr(&pb.order_by, registry)?, + parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), + parse_vec_expr(&pb.order_by, registry, codec)?, ))) } @@ -1772,7 +1889,7 @@ pub fn parse_expr( expr_list .expr .iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>() }) .collect::, Error>>()?, @@ -1780,13 +1897,13 @@ pub fn parse_expr( } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( expr.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))), ExprType::Rollup(RollupNode { expr }) => { Ok(Expr::GroupingSet(GroupingSet::Rollup( expr.iter() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| parse_expr(expr, registry, codec)) .collect::, Error>>()?, ))) } @@ -1854,10 +1971,13 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_vec_expr( p: &[protobuf::LogicalExprNode], registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, ) -> Result>, Error> { let res = p .iter() - .map(|elem| parse_expr(elem, registry).map_err(|e| plan_datafusion_err!("{}", e))) + .map(|elem| { + parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + }) .collect::>>()?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) @@ -1866,9 +1986,10 @@ fn parse_vec_expr( fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry).map(Some), + Some(expr) => parse_expr(expr, registry, codec).map(Some), None => Ok(None), } } @@ -1877,9 +1998,10 @@ fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, field: impl Into, + codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry), + Some(expr) => parse_expr(expr, registry, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index f107af757a71..7c9ead27e3b5 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -17,6 +17,7 @@ use arrow::csv::WriterBuilder; use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; +use datafusion_expr::ScalarUDF; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -72,6 +73,8 @@ use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; +use self::to_proto::serialize_expr; + pub mod from_proto; pub mod to_proto; @@ -133,6 +136,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { node: Arc, buf: &mut Vec, ) -> Result<()>; + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug, Clone)] @@ -241,7 +252,9 @@ impl AsLogicalPlan for LogicalPlanNode { .chunks_exact(n_cols) .map(|r| { r.iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, from_proto::Error>>() }) .collect::, _>>() @@ -255,7 +268,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Vec = projection .expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let new_proj = project(input, expr)?; @@ -277,7 +290,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Expr = selection .expr .as_ref() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal("expression required".to_string()) @@ -291,7 +304,7 @@ impl AsLogicalPlan for LogicalPlanNode { let window_expr = window .window_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input).window(window_expr)?.build() } @@ -301,12 +314,12 @@ impl AsLogicalPlan for LogicalPlanNode { let group_expr = aggregate .group_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let aggr_expr = aggregate .aggr_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? @@ -328,7 +341,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let mut all_sort_orders = vec![]; @@ -336,7 +349,7 @@ impl AsLogicalPlan for LogicalPlanNode { let file_sort_order = order .logical_expr_nodes .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; all_sort_orders.push(file_sort_order) } @@ -436,7 +449,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let provider = extension_codec.try_decode_table_provider( &scan.custom_table_data, @@ -461,7 +474,7 @@ impl AsLogicalPlan for LogicalPlanNode { let sort_expr: Vec = sort .expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } @@ -483,7 +496,9 @@ impl AsLogicalPlan for LogicalPlanNode { }) => Partitioning::Hash( pb_hash_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, _>>()?, *partition_count as usize, ), @@ -527,7 +542,7 @@ impl AsLogicalPlan for LogicalPlanNode { let order_expr = expr .logical_expr_nodes .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; order_exprs.push(order_expr) } @@ -535,7 +550,7 @@ impl AsLogicalPlan for LogicalPlanNode { let mut column_defaults = HashMap::with_capacity(create_extern_table.column_defaults.len()); for (col_name, expr) in &create_extern_table.column_defaults { - let expr = from_proto::parse_expr(expr, ctx)?; + let expr = from_proto::parse_expr(expr, ctx, extension_codec)?; column_defaults.insert(col_name.clone(), expr); } @@ -663,12 +678,12 @@ impl AsLogicalPlan for LogicalPlanNode { let left_keys: Vec = join .left_join_key .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let right_keys: Vec = join .right_join_key .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { @@ -689,7 +704,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filter: Option = join .filter .as_ref() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; let builder = LogicalPlanBuilder::from(into_logical_plan!( @@ -769,12 +784,12 @@ impl AsLogicalPlan for LogicalPlanNode { let on_expr = distinct_on .on_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let select_expr = distinct_on .select_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; let sort_expr = match distinct_on.sort_expr.len() { 0 => None, @@ -782,7 +797,9 @@ impl AsLogicalPlan for LogicalPlanNode { distinct_on .sort_expr .iter() - .map(|expr| from_proto::parse_expr(expr, ctx)) + .map(|expr| { + from_proto::parse_expr(expr, ctx, extension_codec) + }) .collect::, _>>()?, ), }; @@ -944,7 +961,7 @@ impl AsLogicalPlan for LogicalPlanNode { let values_list = values .iter() .flatten() - .map(|v| v.try_into()) + .map(|v| serialize_expr(v, extension_codec)) .collect::, _>>()?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Values( @@ -982,7 +999,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters: Vec = filters .iter() - .map(|filter| filter.try_into()) + .map(|filter| serialize_expr(filter, extension_codec)) .collect::, _>>()?; if let Some(listing_table) = source.downcast_ref::() { @@ -1039,7 +1056,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr_vec = LogicalExprNodeCollection { logical_expr_nodes: order .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, }; exprs_vec.push(expr_vec); @@ -1120,7 +1137,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), expr: expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, optional_alias: None, }, @@ -1137,7 +1154,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some((&filter.predicate).try_into()?), + expr: Some(serialize_expr( + &filter.predicate, + extension_codec, + )?), }, ))), }) @@ -1172,7 +1192,7 @@ impl AsLogicalPlan for LogicalPlanNode { None => vec![], Some(sort_expr) => sort_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }; Ok(protobuf::LogicalPlanNode { @@ -1180,11 +1200,11 @@ impl AsLogicalPlan for LogicalPlanNode { protobuf::DistinctOnNode { on_expr: on_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, select_expr: select_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, sort_expr, input: Some(Box::new(input)), @@ -1206,7 +1226,7 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), window_expr: window_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }, ))), @@ -1229,11 +1249,11 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), group_expr: group_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, aggr_expr: aggr_expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, _>>()?, }, ))), @@ -1261,7 +1281,12 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let (left_join_key, right_join_key) = on .iter() - .map(|(l, r)| Ok((l.try_into()?, r.try_into()?))) + .map(|(l, r)| { + Ok(( + serialize_expr(l, extension_codec)?, + serialize_expr(r, extension_codec)?, + )) + }) .collect::, to_proto::Error>>()? .into_iter() .unzip(); @@ -1270,7 +1295,7 @@ impl AsLogicalPlan for LogicalPlanNode { join_constraint.to_owned().into(); let filter = filter .as_ref() - .map(|e| e.try_into()) + .map(|e| serialize_expr(e, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1329,7 +1354,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let selection_expr: Vec = expr .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( @@ -1361,7 +1386,7 @@ impl AsLogicalPlan for LogicalPlanNode { PartitionMethod::Hash(protobuf::HashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.try_into()) + .map(|expr| serialize_expr(expr, extension_codec)) .collect::, to_proto::Error>>()?, partition_count: *partition_count as u64, }) @@ -1416,9 +1441,8 @@ impl AsLogicalPlan for LogicalPlanNode { let temp = LogicalExprNodeCollection { logical_expr_nodes: order .iter() - .map(|expr| expr.try_into()) - .collect::, to_proto::Error>>( - )?, + .map(|expr| serialize_expr(expr, extension_codec)) + .collect::, to_proto::Error>>()?, }; converted_order_exprs.push(temp); } @@ -1426,7 +1450,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut converted_column_defaults = HashMap::with_capacity(column_defaults.len()); for (col_name, expr) in column_defaults { - converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + converted_column_defaults + .insert(col_name.clone(), serialize_expr(expr, extension_codec)?); } let file_compression_type = diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 591ee796173f..d875848a284c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -56,6 +56,8 @@ use datafusion_expr::{ TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use super::LogicalExtensionCodec; + #[derive(Debug)] pub enum Error { General(String), @@ -480,615 +482,612 @@ impl TryFrom<&WindowFrame> for protobuf::WindowFrame { } } -impl TryFrom<&Expr> for protobuf::LogicalExprNode { - type Error = Error; +pub fn serialize_expr( + expr: &Expr, + codec: &dyn LogicalExtensionCodec, +) -> Result { + use protobuf::logical_expr_node::ExprType; + + let expr_node = match expr { + Expr::Column(c) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Column(c.into())), + }, + Expr::Alias(Alias { + expr, + relation, + name, + }) => { + let alias = Box::new(protobuf::AliasNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), + alias: name.to_owned(), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Alias(alias)), + } + } + Expr::Literal(value) => { + let pb_value: protobuf::ScalarValue = value.try_into()?; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Literal(pb_value)), + } + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // Try to linerize a nested binary expression tree of the same operator + // into a flat vector of expressions. + let mut exprs = vec![right.as_ref()]; + let mut current_expr = left.as_ref(); + while let Expr::BinaryExpr(BinaryExpr { + left, + op: current_op, + right, + }) = current_expr + { + if current_op == op { + exprs.push(right.as_ref()); + current_expr = left.as_ref(); + } else { + break; + } + } + exprs.push(current_expr); - fn try_from(expr: &Expr) -> Result { - use protobuf::logical_expr_node::ExprType; + let binary_expr = protobuf::BinaryExprNode { + // We need to reverse exprs since operands are expected to be + // linearized from left innermost to right outermost (but while + // traversing the chain we do the exact opposite). + operands: exprs + .into_iter() + .rev() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + op: format!("{op:?}"), + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::BinaryExpr(binary_expr)), + } + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + if *case_insensitive { + let pb = Box::new(protobuf::ILikeNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), + }); - let expr_node = match expr { - Expr::Column(c) => Self { - expr_type: Some(ExprType::Column(c.into())), - }, - Expr::Alias(Alias { - expr, - relation, - name, - }) => { - let alias = Box::new(protobuf::AliasNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - relation: relation - .to_owned() - .map(|r| vec![r.into()]) - .unwrap_or(vec![]), - alias: name.to_owned(), + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Ilike(pb)), + } + } else { + let pb = Box::new(protobuf::LikeNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), }); - Self { - expr_type: Some(ExprType::Alias(alias)), + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Like(pb)), } } - Expr::Literal(value) => { - let pb_value: protobuf::ScalarValue = value.try_into()?; - Self { - expr_type: Some(ExprType::Literal(pb_value)), - } + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }) => { + let pb = Box::new(protobuf::SimilarToNode { + negated: *negated, + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + pattern: Some(Box::new(serialize_expr(pattern.as_ref(), codec)?)), + escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::SimilarTo(pb)), } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // Try to linerize a nested binary expression tree of the same operator - // into a flat vector of expressions. - let mut exprs = vec![right.as_ref()]; - let mut current_expr = left.as_ref(); - while let Expr::BinaryExpr(BinaryExpr { - left, - op: current_op, - right, - }) = current_expr - { - if current_op == op { - exprs.push(right.as_ref()); - current_expr = left.as_ref(); - } else { - break; - } + } + Expr::WindowFunction(expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }) => { + let window_function = match fun { + WindowFunctionDefinition::AggregateFunction(fun) => { + protobuf::window_expr_node::WindowFunction::AggrFunction( + protobuf::AggregateFunction::from(fun).into(), + ) } - exprs.push(current_expr); - - let binary_expr = protobuf::BinaryExprNode { - // We need to reverse exprs since operands are expected to be - // linearized from left innermost to right outermost (but while - // traversing the chain we do the exact opposite). - operands: exprs - .into_iter() - .rev() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - op: format!("{op:?}"), - }; - Self { - expr_type: Some(ExprType::BinaryExpr(binary_expr)), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + protobuf::window_expr_node::WindowFunction::BuiltInFunction( + protobuf::BuiltInWindowFunction::from(fun).into(), + ) } - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { - if *case_insensitive { - let pb = Box::new(protobuf::ILikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char - .map(|ch| ch.to_string()) - .unwrap_or_default(), - }); - - Self { - expr_type: Some(ExprType::Ilike(pb)), - } - } else { - let pb = Box::new(protobuf::LikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char - .map(|ch| ch.to_string()) - .unwrap_or_default(), - }); - - Self { - expr_type: Some(ExprType::Like(pb)), - } + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), + ) } - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { - let pb = Box::new(protobuf::SimilarToNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), - }); - Self { - expr_type: Some(ExprType::SimilarTo(pb)), + WindowFunctionDefinition::WindowUDF(window_udf) => { + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), + ) } + }; + let arg_expr: Option> = if !args.is_empty() { + let arg = &args[0]; + Some(Box::new(serialize_expr(arg, codec)?)) + } else { + None + }; + let partition_by = partition_by + .iter() + .map(|e| serialize_expr(e, codec)) + .collect::, _>>()?; + let order_by = order_by + .iter() + .map(|e| serialize_expr(e, codec)) + .collect::, _>>()?; + + let window_frame: Option = + Some(window_frame.try_into()?); + let window_expr = Box::new(protobuf::WindowExprNode { + expr: arg_expr, + window_function: Some(window_function), + partition_by, + order_by, + window_frame, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::WindowExpr(window_expr)), } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { - let window_function = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - protobuf::window_expr_node::WindowFunction::AggrFunction( - protobuf::AggregateFunction::from(fun).into(), - ) + } + Expr::AggregateFunction(expr::AggregateFunction { + ref func_def, + ref args, + ref distinct, + ref filter, + ref order_by, + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont } - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - protobuf::window_expr_node::WindowFunction::BuiltInFunction( - protobuf::BuiltInWindowFunction::from(fun).into(), - ) + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight } - WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), - ) + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop } - WindowFunctionDefinition::WindowUDF(window_udf) => { - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), - ) + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::NthValue => { + protobuf::AggregateFunction::NthValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(arg.try_into()?)) - } else { - None + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| serialize_expr(v, codec)) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e, codec)?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, _>>()?, + None => vec![], + }, }; - let partition_by = partition_by - .iter() - .map(|e| e.try_into()) - .collect::, _>>()?; - let order_by = order_by - .iter() - .map(|e| e.try_into()) - .collect::, _>>()?; - - let window_frame: Option = - Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, - window_function: Some(window_function), - partition_by, - order_by, - window_frame, - }); - Self { - expr_type: Some(ExprType::WindowExpr(window_expr)), + protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), } } - Expr::AggregateFunction(expr::AggregateFunction { - ref func_def, - ref args, - ref distinct, - ref filter, - ref order_by, - }) => { - match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg - } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args + AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + order_by: match order_by { + Some(e) => e .iter() - .map(|v| v.try_into()) + .map(|expr| serialize_expr(expr, codec)) .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new( - aggregate_expr, - ))), - } - } - AggregateFunctionDefinition::UDF(fun) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), + None => vec![], + }, }, - AggregateFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + ))), + }, + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); - } - } } + }, - Expr::ScalarVariable(_, _) => { - return Err(Error::General( - "Proto serialization error: Scalar Variable not supported" - .to_string(), - )) - } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let args = args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?; - match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, - }, - )), - } + Expr::ScalarVariable(_, _) => { + return Err(Error::General( + "Proto serialization error: Scalar Variable not supported".to_string(), + )) + } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), } - ScalarFunctionDefinition::UDF(fun) => Self { + } + ScalarFunctionDefinition::UDF(fun) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udf(fun.as_ref(), &mut buf); + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; + + protobuf::LogicalExprNode { expr_type: Some(ExprType::ScalarUdfExpr( protobuf::ScalarUdfExprNode { fun_name: fun.name().to_string(), + fun_definition, args, }, )), - }, - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + } + } + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); - } } } - Expr::Not(expr) => { - let expr = Box::new(protobuf::Not { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::NotExpr(expr)), - } + } + Expr::Not(expr) => { + let expr = Box::new(protobuf::Not { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::NotExpr(expr)), } - Expr::IsNull(expr) => { - let expr = Box::new(protobuf::IsNull { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNullExpr(expr)), - } + } + Expr::IsNull(expr) => { + let expr = Box::new(protobuf::IsNull { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNullExpr(expr)), } - Expr::IsNotNull(expr) => { - let expr = Box::new(protobuf::IsNotNull { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotNullExpr(expr)), - } + } + Expr::IsNotNull(expr) => { + let expr = Box::new(protobuf::IsNotNull { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotNullExpr(expr)), } - Expr::IsTrue(expr) => { - let expr = Box::new(protobuf::IsTrue { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsTrue(expr)), - } + } + Expr::IsTrue(expr) => { + let expr = Box::new(protobuf::IsTrue { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsTrue(expr)), } - Expr::IsFalse(expr) => { - let expr = Box::new(protobuf::IsFalse { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsFalse(expr)), - } + } + Expr::IsFalse(expr) => { + let expr = Box::new(protobuf::IsFalse { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsFalse(expr)), } - Expr::IsUnknown(expr) => { - let expr = Box::new(protobuf::IsUnknown { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsUnknown(expr)), - } + } + Expr::IsUnknown(expr) => { + let expr = Box::new(protobuf::IsUnknown { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsUnknown(expr)), } - Expr::IsNotTrue(expr) => { - let expr = Box::new(protobuf::IsNotTrue { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotTrue(expr)), - } + } + Expr::IsNotTrue(expr) => { + let expr = Box::new(protobuf::IsNotTrue { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotTrue(expr)), } - Expr::IsNotFalse(expr) => { - let expr = Box::new(protobuf::IsNotFalse { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotFalse(expr)), - } + } + Expr::IsNotFalse(expr) => { + let expr = Box::new(protobuf::IsNotFalse { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotFalse(expr)), } - Expr::IsNotUnknown(expr) => { - let expr = Box::new(protobuf::IsNotUnknown { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::IsNotUnknown(expr)), - } + } + Expr::IsNotUnknown(expr) => { + let expr = Box::new(protobuf::IsNotUnknown { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::IsNotUnknown(expr)), } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = Box::new(protobuf::BetweenNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - negated: *negated, - low: Some(Box::new(low.as_ref().try_into()?)), - high: Some(Box::new(high.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::Between(expr)), - } + } + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + let expr = Box::new(protobuf::BetweenNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + negated: *negated, + low: Some(Box::new(serialize_expr(low.as_ref(), codec)?)), + high: Some(Box::new(serialize_expr(high.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Between(expr)), } - Expr::Case(case) => { - let when_then_expr = case - .when_then_expr - .iter() - .map(|(w, t)| { - Ok(protobuf::WhenThen { - when_expr: Some(w.as_ref().try_into()?), - then_expr: Some(t.as_ref().try_into()?), - }) + } + Expr::Case(case) => { + let when_then_expr = case + .when_then_expr + .iter() + .map(|(w, t)| { + Ok(protobuf::WhenThen { + when_expr: Some(serialize_expr(w.as_ref(), codec)?), + then_expr: Some(serialize_expr(t.as_ref(), codec)?), }) - .collect::, Error>>()?; - let expr = Box::new(protobuf::CaseNode { - expr: match &case.expr { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - when_then_expr, - else_expr: match &case.else_expr { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - }); - Self { - expr_type: Some(ExprType::Case(expr)), - } + }) + .collect::, Error>>()?; + let expr = Box::new(protobuf::CaseNode { + expr: match &case.expr { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + when_then_expr, + else_expr: match &case.else_expr { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Case(expr)), } - Expr::Cast(Cast { expr, data_type }) => { - let expr = Box::new(protobuf::CastNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - arrow_type: Some(data_type.try_into()?), - }); - Self { - expr_type: Some(ExprType::Cast(expr)), - } + } + Expr::Cast(Cast { expr, data_type }) => { + let expr = Box::new(protobuf::CastNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + arrow_type: Some(data_type.try_into()?), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Cast(expr)), } - Expr::TryCast(TryCast { expr, data_type }) => { - let expr = Box::new(protobuf::TryCastNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - arrow_type: Some(data_type.try_into()?), - }); - Self { - expr_type: Some(ExprType::TryCast(expr)), - } + } + Expr::TryCast(TryCast { expr, data_type }) => { + let expr = Box::new(protobuf::TryCastNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + arrow_type: Some(data_type.try_into()?), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::TryCast(expr)), } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = Box::new(protobuf::SortExprNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - asc: *asc, - nulls_first: *nulls_first, - }); - Self { - expr_type: Some(ExprType::Sort(expr)), - } + } + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => { + let expr = Box::new(protobuf::SortExprNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + asc: *asc, + nulls_first: *nulls_first, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Sort(expr)), } - Expr::Negative(expr) => { - let expr = Box::new(protobuf::NegativeNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - }); - Self { - expr_type: Some(ExprType::Negative(expr)), - } + } + Expr::Negative(expr) => { + let expr = Box::new(protobuf::NegativeNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Negative(expr)), } - Expr::Unnest(Unnest { exprs }) => { - let expr = protobuf::Unnest { - exprs: exprs.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - }; - Self { - expr_type: Some(ExprType::Unnest(expr)), - } + } + Expr::Unnest(Unnest { exprs }) => { + let expr = protobuf::Unnest { + exprs: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Unnest(expr)), } - Expr::InList(InList { - expr, - list, - negated, - }) => { - let expr = Box::new(protobuf::InListNode { - expr: Some(Box::new(expr.as_ref().try_into()?)), - list: list - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - negated: *negated, - }); - Self { - expr_type: Some(ExprType::InList(expr)), - } + } + Expr::InList(InList { + expr, + list, + negated, + }) => { + let expr = Box::new(protobuf::InListNode { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + list: list + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + negated: *negated, + }); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::InList(expr)), } - Expr::Wildcard { qualifier } => Self { - expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone().unwrap_or("".to_string()), - })), - }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); - } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let field = match field { - GetFieldAccess::NamedStructField { name } => { - protobuf::get_indexed_field::Field::NamedStructField( - protobuf::NamedStructField { - name: Some(name.try_into()?), - }, - ) - } - GetFieldAccess::ListIndex { key } => { - protobuf::get_indexed_field::Field::ListIndex(Box::new( - protobuf::ListIndex { - key: Some(Box::new(key.as_ref().try_into()?)), - }, - )) - } - GetFieldAccess::ListRange { - start, - stop, - stride, - } => protobuf::get_indexed_field::Field::ListRange(Box::new( - protobuf::ListRange { - start: Some(Box::new(start.as_ref().try_into()?)), - stop: Some(Box::new(stop.as_ref().try_into()?)), - stride: Some(Box::new(stride.as_ref().try_into()?)), + } + Expr::Wildcard { qualifier } => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone().unwrap_or("".to_string()), + })), + }, + Expr::ScalarSubquery(_) + | Expr::InSubquery(_) + | Expr::Exists { .. } + | Expr::OuterReferenceColumn { .. } => { + // we would need to add logical plan operators to datafusion.proto to support this + // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + } + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + protobuf::get_indexed_field::Field::NamedStructField( + protobuf::NamedStructField { + name: Some(name.try_into()?), }, - )), - }; - - Self { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - expr: Some(Box::new(expr.as_ref().try_into()?)), - field: Some(field), + ) + } + GetFieldAccess::ListIndex { key } => { + protobuf::get_indexed_field::Field::ListIndex(Box::new( + protobuf::ListIndex { + key: Some(Box::new(serialize_expr(key.as_ref(), codec)?)), }, - ))), + )) } + GetFieldAccess::ListRange { + start, + stop, + stride, + } => protobuf::get_indexed_field::Field::ListRange(Box::new( + protobuf::ListRange { + start: Some(Box::new(serialize_expr(start.as_ref(), codec)?)), + stop: Some(Box::new(serialize_expr(stop.as_ref(), codec)?)), + stride: Some(Box::new(serialize_expr(stride.as_ref(), codec)?)), + }, + )), + }; + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::GetIndexedField(Box::new( + protobuf::GetIndexedField { + expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), + field: Some(field), + }, + ))), } + } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { - expr_type: Some(ExprType::Cube(CubeNode { - expr: exprs.iter().map(|expr| expr.try_into()).collect::, - Self::Error, - >>( - )?, - })), - }, - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self { - expr_type: Some(ExprType::Rollup(RollupNode { - expr: exprs.iter().map(|expr| expr.try_into()).collect::, - Self::Error, - >>( - )?, - })), - }, - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self { + Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Cube(CubeNode { + expr: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + })), + }, + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Rollup(RollupNode { + expr: exprs + .iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, + })), + }, + Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { + protobuf::LogicalExprNode { expr_type: Some(ExprType::GroupingSet(GroupingSetNode { expr: exprs .iter() @@ -1096,29 +1095,29 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Ok(LogicalExprList { expr: expr_list .iter() - .map(|expr| expr.try_into()) - .collect::, Self::Error>>()?, + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>()?, }) }) - .collect::, Self::Error>>()?, + .collect::, Error>>()?, })), - }, - Expr::Placeholder(Placeholder { id, data_type }) => { - let data_type = match data_type { - Some(data_type) => Some(data_type.try_into()?), - None => None, - }; - Self { - expr_type: Some(ExprType::Placeholder(PlaceholderNode { - id: id.clone(), - data_type, - })), - } } - }; + } + Expr::Placeholder(Placeholder { id, data_type }) => { + let data_type = match data_type { + Some(data_type) => Some(data_type.try_into()?), + None => None, + }; + protobuf::LogicalExprNode { + expr_type: Some(ExprType::Placeholder(PlaceholderNode { + id: id.clone(), + data_type, + })), + } + } + }; - Ok(expr_node) - } + Ok(expr_node) } impl TryFrom<&ScalarValue> for protobuf::ScalarValue { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d2961875d89a..a20baeb4e941 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -60,6 +60,7 @@ use datafusion::physical_plan::{ WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::ScalarUDF; use prost::bytes::BufMut; use prost::Message; @@ -1911,6 +1912,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { ) -> Result>; fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e3bd2cb1dc47..0ec44190ef7a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,6 +28,8 @@ use arrow::datatypes::{ }; use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; +use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; +use datafusion_proto::logical_plan::to_proto::serialize_expr; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -62,8 +64,8 @@ use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; -use datafusion_proto::logical_plan::from_proto; use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; #[cfg(feature = "json")] @@ -78,13 +80,15 @@ fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test // equality. -fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) -where - for<'a> &'a T: TryInto + Debug, - E: Debug, -{ - let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx).unwrap(); +fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { + let extension_codec = DefaultLogicalExtensionCodec {}; + let proto: protobuf::LogicalExprNode = + match serialize_expr(&initial_struct, &extension_codec) { + Ok(p) => p, + Err(e) => panic!("Error serializing expression: {:?}", e), + }; + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -631,6 +635,12 @@ pub mod proto { #[prost(uint64, tag = "1")] pub k: u64, } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } } #[derive(PartialEq, Eq, Hash)] @@ -707,7 +717,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { let node = TopKPlanNode::new( proto.k as usize, input.clone(), - from_proto::parse_expr(expr, ctx)?, + from_proto::parse_expr(expr, ctx, self)?, ); Ok(Extension { @@ -725,7 +735,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { if let Some(exec) = node.node.as_any().downcast_ref::() { let proto = proto::TopKPlanProto { k: exec.k as u64, - expr: Some((&exec.expr).try_into()?), + expr: Some(serialize_expr(&exec.expr, self)?), }; proto.encode(buf).map_err(|e| { @@ -756,6 +766,109 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, +} + +impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } +} + +/// Implement the ScalarUDFImpl trait for MyRegexUdf +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } +} + +#[derive(Debug)] +pub struct ScalarUDFExtensionCodec {} + +impl LogicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("No extension codec provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + internal_err!("unsupported plan type") + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + internal_err!("unsupported plan type") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {}", err)) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = proto::MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + Ok(()) + } +} + #[test] fn round_trip_scalar_values() { let should_pass: Vec = vec![ @@ -1664,6 +1777,30 @@ fn roundtrip_scalar_udf() { roundtrip_expr_test(test_expr, ctx); } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), vec![])); + + let ctx = SessionContext::new(); + ctx.register_udf(udf); + + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::LogicalExprNode = + match serialize_expr(&test_expr, &extension_codec) { + Ok(p) => p, + Err(e) => panic!("Error serializing expression: {:?}", e), + }; + let round_trip: Expr = + from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + + assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); + + roundtrip_json_test(&proto); +} + #[test] fn roundtrip_grouping_sets() { let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 7dd0333909ee..d4a1ab44a6ea 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -25,6 +25,8 @@ use datafusion::prelude::SessionContext; use datafusion_expr::{col, create_udf, lit, ColumnarValue}; use datafusion_expr::{Expr, Volatility}; use datafusion_proto::bytes::Serializeable; +use datafusion_proto::logical_plan::to_proto::serialize_expr; +use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; #[test] #[should_panic( @@ -252,7 +254,6 @@ fn test_expression_serialization_roundtrip() { use datafusion_expr::expr::ScalarFunction; use datafusion_expr::BuiltinScalarFunction; use datafusion_proto::logical_plan::from_proto::parse_expr; - use datafusion_proto::protobuf::LogicalExprNode; use strum::IntoEnumIterator; let ctx = SessionContext::new(); @@ -266,8 +267,9 @@ fn test_expression_serialization_roundtrip() { let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); - let proto = LogicalExprNode::try_from(&expr).unwrap(); - let deserialize = parse_expr(&proto, &ctx).unwrap(); + let extension_codec = DefaultLogicalExtensionCodec {}; + let proto = serialize_expr(&expr, &extension_codec).unwrap(); + let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize); diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 12c4c96d5236..c348f2cddc93 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -39,6 +39,7 @@ chrono = { workspace = true, optional = true } clap = { version = "4.4.8", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } +datafusion-common-runtime = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } itertools = { workspace = true } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 41c33deec643..268d09681c72 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -28,6 +28,7 @@ use log::info; use sqllogictest::strict_column_validator; use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_common_runtime::SpawnedTask; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; @@ -88,8 +89,7 @@ async fn run_tests() -> Result<()> { // modifying shared state like `/tmp/`) let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - tokio::task::spawn(async move { + SpawnedTask::spawn(async move { println!("Running {:?}", test_file.relative_path); if options.complete { run_complete_file(test_file).await?; @@ -100,6 +100,7 @@ async fn run_tests() -> Result<()> { } Ok(()) as Result<()> }) + .join() }) // run up to num_cpus streams in parallel .buffer_unordered(num_cpus::get()) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4e6cb4d59d14..5065d9b9a73b 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -883,6 +883,11 @@ select arrow_cast([1, 2, 3], 'LargeList(Int64)')[0:0], ---- [] [1, 2] [h, e, l, l, o] +query I +select arrow_cast([1, 2, 3], 'LargeList(Int64)')[1]; +---- +1 + # TODO: support multiple negative index # multiple index with columns #3 (negative index) # query II diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a3e97d6a7d82..48b5a0af7253 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -941,6 +941,25 @@ select round(sqrt(a), 5), round(sqrt(b), 5), round(sqrt(c), 5) from signed_integ NaN 10 NaN NaN 100 NaN +# sqrt scalar fraction +query RR rowsort +select sqrt(1.4), sqrt(2.0/3); +---- +1.18321595662 0.816496580928 + +# sqrt scalar cast +query R rowsort +select sqrt(cast(10e8 as double)); +---- +31622.776601683792 + + +# sqrt scalar negative +query R rowsort +select sqrt(-1); +---- +NaN + ## tan # tan scalar function diff --git a/docs/logos/Datafusion_Branding_Guideline.pdf b/docs/logos/Datafusion_Branding_Guideline.pdf new file mode 100644 index 000000000000..dcf0a09dba9f Binary files /dev/null and b/docs/logos/Datafusion_Branding_Guideline.pdf differ diff --git a/docs/logos/DataFUSION-Logo-Dark.svg b/docs/logos/old_logo/DataFUSION-Logo-Dark.svg similarity index 100% rename from docs/logos/DataFUSION-Logo-Dark.svg rename to docs/logos/old_logo/DataFUSION-Logo-Dark.svg diff --git a/docs/logos/DataFUSION-Logo-Dark@2x.png b/docs/logos/old_logo/DataFUSION-Logo-Dark@2x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Dark@2x.png rename to docs/logos/old_logo/DataFUSION-Logo-Dark@2x.png diff --git a/docs/logos/DataFUSION-Logo-Dark@4x.png b/docs/logos/old_logo/DataFUSION-Logo-Dark@4x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Dark@4x.png rename to docs/logos/old_logo/DataFUSION-Logo-Dark@4x.png diff --git a/docs/logos/DataFUSION-Logo-Light.svg b/docs/logos/old_logo/DataFUSION-Logo-Light.svg similarity index 100% rename from docs/logos/DataFUSION-Logo-Light.svg rename to docs/logos/old_logo/DataFUSION-Logo-Light.svg diff --git a/docs/logos/DataFUSION-Logo-Light@2x.png b/docs/logos/old_logo/DataFUSION-Logo-Light@2x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Light@2x.png rename to docs/logos/old_logo/DataFUSION-Logo-Light@2x.png diff --git a/docs/logos/DataFUSION-Logo-Light@4x.png b/docs/logos/old_logo/DataFUSION-Logo-Light@4x.png similarity index 100% rename from docs/logos/DataFUSION-Logo-Light@4x.png rename to docs/logos/old_logo/DataFUSION-Logo-Light@4x.png diff --git a/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf b/docs/logos/old_logo/DataFusion-LogoAndColorPaletteExploration_v01.pdf similarity index 100% rename from docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf rename to docs/logos/old_logo/DataFusion-LogoAndColorPaletteExploration_v01.pdf diff --git a/docs/logos/primary_mark/black.png b/docs/logos/primary_mark/black.png new file mode 100644 index 000000000000..053a798720d8 Binary files /dev/null and b/docs/logos/primary_mark/black.png differ diff --git a/docs/logos/primary_mark/black.svg b/docs/logos/primary_mark/black.svg new file mode 100644 index 000000000000..0b0a890e1eec --- /dev/null +++ b/docs/logos/primary_mark/black.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/black2x.png b/docs/logos/primary_mark/black2x.png new file mode 100644 index 000000000000..18ce390da8b2 Binary files /dev/null and b/docs/logos/primary_mark/black2x.png differ diff --git a/docs/logos/primary_mark/black4x.png b/docs/logos/primary_mark/black4x.png new file mode 100644 index 000000000000..cfcbd9c8ed59 Binary files /dev/null and b/docs/logos/primary_mark/black4x.png differ diff --git a/docs/logos/primary_mark/mixed.png b/docs/logos/primary_mark/mixed.png new file mode 100644 index 000000000000..4a24495f879a Binary files /dev/null and b/docs/logos/primary_mark/mixed.png differ diff --git a/docs/logos/primary_mark/mixed.svg b/docs/logos/primary_mark/mixed.svg new file mode 100644 index 000000000000..306450bbbf58 --- /dev/null +++ b/docs/logos/primary_mark/mixed.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/mixed2x.png b/docs/logos/primary_mark/mixed2x.png new file mode 100644 index 000000000000..16e1f5687127 Binary files /dev/null and b/docs/logos/primary_mark/mixed2x.png differ diff --git a/docs/logos/primary_mark/mixed4x.png b/docs/logos/primary_mark/mixed4x.png new file mode 100644 index 000000000000..ada80821508f Binary files /dev/null and b/docs/logos/primary_mark/mixed4x.png differ diff --git a/docs/logos/primary_mark/original.png b/docs/logos/primary_mark/original.png new file mode 100644 index 000000000000..687f946760b0 Binary files /dev/null and b/docs/logos/primary_mark/original.png differ diff --git a/docs/logos/primary_mark/original.svg b/docs/logos/primary_mark/original.svg new file mode 100644 index 000000000000..6ba0ece995a3 --- /dev/null +++ b/docs/logos/primary_mark/original.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/original2x.png b/docs/logos/primary_mark/original2x.png new file mode 100644 index 000000000000..a7402109b211 Binary files /dev/null and b/docs/logos/primary_mark/original2x.png differ diff --git a/docs/logos/primary_mark/original4x.png b/docs/logos/primary_mark/original4x.png new file mode 100644 index 000000000000..ae1000635cc6 Binary files /dev/null and b/docs/logos/primary_mark/original4x.png differ diff --git a/docs/logos/primary_mark/white.png b/docs/logos/primary_mark/white.png new file mode 100644 index 000000000000..cdb66f1f7c10 Binary files /dev/null and b/docs/logos/primary_mark/white.png differ diff --git a/docs/logos/primary_mark/white.svg b/docs/logos/primary_mark/white.svg new file mode 100644 index 000000000000..6f900590ce39 --- /dev/null +++ b/docs/logos/primary_mark/white.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/logos/primary_mark/white2x.png b/docs/logos/primary_mark/white2x.png new file mode 100644 index 000000000000..d54606e667e4 Binary files /dev/null and b/docs/logos/primary_mark/white2x.png differ diff --git a/docs/logos/primary_mark/white4x.png b/docs/logos/primary_mark/white4x.png new file mode 100644 index 000000000000..bc867fb1b92b Binary files /dev/null and b/docs/logos/primary_mark/white4x.png differ diff --git a/docs/logos/standalone_logo/logo_black.png b/docs/logos/standalone_logo/logo_black.png new file mode 100644 index 000000000000..46cfd58e0d61 Binary files /dev/null and b/docs/logos/standalone_logo/logo_black.png differ diff --git a/docs/logos/standalone_logo/logo_black.svg b/docs/logos/standalone_logo/logo_black.svg new file mode 100644 index 000000000000..f82a47e1cf6d --- /dev/null +++ b/docs/logos/standalone_logo/logo_black.svg @@ -0,0 +1,4 @@ + + + + diff --git a/docs/logos/standalone_logo/logo_black2x.png b/docs/logos/standalone_logo/logo_black2x.png new file mode 100644 index 000000000000..34731a637736 Binary files /dev/null and b/docs/logos/standalone_logo/logo_black2x.png differ diff --git a/docs/logos/standalone_logo/logo_black4x.png b/docs/logos/standalone_logo/logo_black4x.png new file mode 100644 index 000000000000..6a6ee3c06fad Binary files /dev/null and b/docs/logos/standalone_logo/logo_black4x.png differ diff --git a/docs/logos/standalone_logo/logo_original.png b/docs/logos/standalone_logo/logo_original.png new file mode 100644 index 000000000000..381265e62d7b Binary files /dev/null and b/docs/logos/standalone_logo/logo_original.png differ diff --git a/docs/logos/standalone_logo/logo_original.svg b/docs/logos/standalone_logo/logo_original.svg new file mode 100644 index 000000000000..bf174719bcf2 --- /dev/null +++ b/docs/logos/standalone_logo/logo_original.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/docs/logos/standalone_logo/logo_original2x.png b/docs/logos/standalone_logo/logo_original2x.png new file mode 100644 index 000000000000..7d5b25bd2e8b Binary files /dev/null and b/docs/logos/standalone_logo/logo_original2x.png differ diff --git a/docs/logos/standalone_logo/logo_original4x.png b/docs/logos/standalone_logo/logo_original4x.png new file mode 100644 index 000000000000..10dd50978e37 Binary files /dev/null and b/docs/logos/standalone_logo/logo_original4x.png differ diff --git a/docs/logos/standalone_logo/logo_white.png b/docs/logos/standalone_logo/logo_white.png new file mode 100644 index 000000000000..a48ef890d6f4 Binary files /dev/null and b/docs/logos/standalone_logo/logo_white.png differ diff --git a/docs/logos/standalone_logo/logo_white.svg b/docs/logos/standalone_logo/logo_white.svg new file mode 100644 index 000000000000..9f1954ed82e5 --- /dev/null +++ b/docs/logos/standalone_logo/logo_white.svg @@ -0,0 +1,4 @@ + + + + diff --git a/docs/logos/standalone_logo/logo_white2x.png b/docs/logos/standalone_logo/logo_white2x.png new file mode 100644 index 000000000000..c26de0fe5a5c Binary files /dev/null and b/docs/logos/standalone_logo/logo_white2x.png differ diff --git a/docs/logos/standalone_logo/logo_white4x.png b/docs/logos/standalone_logo/logo_white4x.png new file mode 100644 index 000000000000..22bbb4892ed7 Binary files /dev/null and b/docs/logos/standalone_logo/logo_white4x.png differ diff --git a/docs/source/_static/images/2x_bgwhite_original.png b/docs/source/_static/images/2x_bgwhite_original.png new file mode 100644 index 000000000000..abb5fca6e461 Binary files /dev/null and b/docs/source/_static/images/2x_bgwhite_original.png differ diff --git a/docs/source/_static/images/DataFusion-Logo-Background-White.png b/docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.png similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Background-White.png rename to docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.png diff --git a/docs/source/_static/images/DataFusion-Logo-Background-White.svg b/docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.svg similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Background-White.svg rename to docs/source/_static/images/old_logo/DataFusion-Logo-Background-White.svg diff --git a/docs/source/_static/images/DataFusion-Logo-Dark.png b/docs/source/_static/images/old_logo/DataFusion-Logo-Dark.png similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Dark.png rename to docs/source/_static/images/old_logo/DataFusion-Logo-Dark.png diff --git a/docs/source/_static/images/DataFusion-Logo-Dark.svg b/docs/source/_static/images/old_logo/DataFusion-Logo-Dark.svg similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Dark.svg rename to docs/source/_static/images/old_logo/DataFusion-Logo-Dark.svg diff --git a/docs/source/_static/images/DataFusion-Logo-Light.png b/docs/source/_static/images/old_logo/DataFusion-Logo-Light.png similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Light.png rename to docs/source/_static/images/old_logo/DataFusion-Logo-Light.png diff --git a/docs/source/_static/images/DataFusion-Logo-Light.svg b/docs/source/_static/images/old_logo/DataFusion-Logo-Light.svg similarity index 100% rename from docs/source/_static/images/DataFusion-Logo-Light.svg rename to docs/source/_static/images/old_logo/DataFusion-Logo-Light.svg diff --git a/docs/source/_static/images/original.png b/docs/source/_static/images/original.png new file mode 100644 index 000000000000..687f946760b0 Binary files /dev/null and b/docs/source/_static/images/original.png differ diff --git a/docs/source/_static/images/original.svg b/docs/source/_static/images/original.svg new file mode 100644 index 000000000000..6ba0ece995a3 --- /dev/null +++ b/docs/source/_static/images/original.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/original2x.png b/docs/source/_static/images/original2x.png new file mode 100644 index 000000000000..a7402109b211 Binary files /dev/null and b/docs/source/_static/images/original2x.png differ diff --git a/docs/source/_templates/docs-sidebar.html b/docs/source/_templates/docs-sidebar.html index 2b400b4dcade..7c3ecc3d802e 100644 --- a/docs/source/_templates/docs-sidebar.html +++ b/docs/source/_templates/docs-sidebar.html @@ -15,7 +15,7 @@ - + diff --git a/docs/source/conf.py b/docs/source/conf.py index becece330d1a..a203bfbb10d5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -100,7 +100,7 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] -html_logo = "_static/images/DataFusion-Logo-Background-White.png" +html_logo = "_static/images/2x_bgwhite_original.png" html_css_files = [ "theme_overrides.css" diff --git a/docs/source/index.rst b/docs/source/index.rst index 385371661716..f7c0873f3a5f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,7 +15,7 @@ .. specific language governing permissions and limitations .. under the License. -.. image:: _static/images/DataFusion-Logo-Background-White.png +.. image:: _static/images/2x_bgwhite_original.png :alt: DataFusion Logo =======================