Skip to content

Commit

Permalink
Introduce TypePlanner for customizing type planning (#13294)
Browse files Browse the repository at this point in the history
* introduce `plan_data_type` for ExprPlanner

* implement TypePlanner trait instead of extending ExprPlanner

* enhance the document
  • Loading branch information
goldmedal authored Nov 13, 2024
1 parent cc96026 commit 4e1f839
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 10 deletions.
52 changes: 49 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1788,22 +1788,24 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {

#[cfg(test)]
mod tests {
use std::env;
use std::path::PathBuf;

use super::{super::options::CsvReadOptions, *};
use crate::assert_batches_eq;
use crate::execution::memory_pool::MemoryConsumer;
use crate::execution::runtime_env::RuntimeEnvBuilder;
use crate::test;
use crate::test_util::{plan_and_collect, populate_csv_partitions};
use arrow_schema::{DataType, TimeUnit};
use std::env;
use std::path::PathBuf;

use datafusion_common_runtime::SpawnedTask;

use crate::catalog::SchemaProvider;
use crate::execution::session_state::SessionStateBuilder;
use crate::physical_planner::PhysicalPlanner;
use async_trait::async_trait;
use datafusion_expr::planner::TypePlanner;
use sqlparser::ast;
use tempfile::TempDir;

#[tokio::test]
Expand Down Expand Up @@ -2200,6 +2202,29 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn custom_type_planner() -> Result<()> {
let state = SessionStateBuilder::new()
.with_default_features()
.with_type_planner(Arc::new(MyTypePlanner {}))
.build();
let ctx = SessionContext::new_with_state(state);
let result = ctx
.sql("SELECT DATETIME '2021-01-01 00:00:00'")
.await?
.collect()
.await?;
let expected = [
"+-----------------------------+",
"| Utf8(\"2021-01-01 00:00:00\") |",
"+-----------------------------+",
"| 2021-01-01T00:00:00 |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}

struct MyPhysicalPlanner {}

#[async_trait]
Expand Down Expand Up @@ -2260,4 +2285,25 @@ mod tests {

Ok(ctx)
}

#[derive(Debug)]
struct MyTypePlanner {}

impl TypePlanner for MyTypePlanner {
fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
match sql_type {
ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Some(3) => TimeUnit::Millisecond,
Some(6) => TimeUnit::Microsecond,
None | Some(9) => TimeUnit::Nanosecond,
_ => unreachable!(),
};
Ok(Some(DataType::Timestamp(precision, None)))
}
_ => Ok(None),
}
}
}
}
30 changes: 29 additions & 1 deletion datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::planner::{ExprPlanner, TypePlanner};
use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::var_provider::{is_system_variables, VarType};
Expand Down Expand Up @@ -128,6 +128,8 @@ pub struct SessionState {
analyzer: Analyzer,
/// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
expr_planners: Vec<Arc<dyn ExprPlanner>>,
/// Provides support for customising the SQL type planning
type_planner: Option<Arc<dyn TypePlanner>>,
/// Responsible for optimizing a logical plan
optimizer: Optimizer,
/// Responsible for optimizing a physical execution plan
Expand Down Expand Up @@ -192,6 +194,7 @@ impl Debug for SessionState {
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("expr_planners", &self.expr_planners)
.field("type_planner", &self.type_planner)
.field("query_planners", &self.query_planner)
.field("analyzer", &self.analyzer)
.field("optimizer", &self.optimizer)
Expand Down Expand Up @@ -955,6 +958,7 @@ pub struct SessionStateBuilder {
session_id: Option<String>,
analyzer: Option<Analyzer>,
expr_planners: Option<Vec<Arc<dyn ExprPlanner>>>,
type_planner: Option<Arc<dyn TypePlanner>>,
optimizer: Option<Optimizer>,
physical_optimizers: Option<PhysicalOptimizer>,
query_planner: Option<Arc<dyn QueryPlanner + Send + Sync>>,
Expand Down Expand Up @@ -984,6 +988,7 @@ impl SessionStateBuilder {
session_id: None,
analyzer: None,
expr_planners: None,
type_planner: None,
optimizer: None,
physical_optimizers: None,
query_planner: None,
Expand Down Expand Up @@ -1031,6 +1036,7 @@ impl SessionStateBuilder {
session_id: None,
analyzer: Some(existing.analyzer),
expr_planners: Some(existing.expr_planners),
type_planner: existing.type_planner,
optimizer: Some(existing.optimizer),
physical_optimizers: Some(existing.physical_optimizers),
query_planner: Some(existing.query_planner),
Expand Down Expand Up @@ -1125,6 +1131,12 @@ impl SessionStateBuilder {
self
}

/// Set the [`TypePlanner`] used to customize the behavior of the SQL planner.
pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
self.type_planner = Some(type_planner);
self
}

/// Set the [`PhysicalOptimizerRule`]s used to optimize plans.
pub fn with_physical_optimizer_rules(
mut self,
Expand Down Expand Up @@ -1318,6 +1330,7 @@ impl SessionStateBuilder {
session_id,
analyzer,
expr_planners,
type_planner,
optimizer,
physical_optimizers,
query_planner,
Expand Down Expand Up @@ -1346,6 +1359,7 @@ impl SessionStateBuilder {
session_id: session_id.unwrap_or(Uuid::new_v4().to_string()),
analyzer: analyzer.unwrap_or_default(),
expr_planners: expr_planners.unwrap_or_default(),
type_planner,
optimizer: optimizer.unwrap_or_default(),
physical_optimizers: physical_optimizers.unwrap_or_default(),
query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})),
Expand Down Expand Up @@ -1456,6 +1470,11 @@ impl SessionStateBuilder {
&mut self.expr_planners
}

/// Returns the current type_planner value
pub fn type_planner(&mut self) -> &mut Option<Arc<dyn TypePlanner>> {
&mut self.type_planner
}

/// Returns the current optimizer value
pub fn optimizer(&mut self) -> &mut Option<Optimizer> {
&mut self.optimizer
Expand Down Expand Up @@ -1578,6 +1597,7 @@ impl Debug for SessionStateBuilder {
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("expr_planners", &self.expr_planners)
.field("type_planner", &self.type_planner)
.field("query_planners", &self.query_planner)
.field("analyzer_rules", &self.analyzer_rules)
.field("analyzer", &self.analyzer)
Expand Down Expand Up @@ -1619,6 +1639,14 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
&self.state.expr_planners
}

fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
if let Some(type_planner) = &self.state.type_planner {
Some(Arc::clone(type_planner))
} else {
None
}
}

fn get_table_source(
&self,
name: TableReference,
Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
};
use sqlparser::ast;

use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};

Expand Down Expand Up @@ -66,6 +67,11 @@ pub trait ContextProvider {
&[]
}

/// Getter for the data type planner
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
None
}

/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
Expand Down Expand Up @@ -216,7 +222,7 @@ pub trait ExprPlanner: Debug + Send + Sync {
/// custom expressions.
#[derive(Debug, Clone)]
pub struct RawBinaryExpr {
pub op: sqlparser::ast::BinaryOperator,
pub op: ast::BinaryOperator,
pub left: Expr,
pub right: Expr,
}
Expand Down Expand Up @@ -249,3 +255,13 @@ pub enum PlannerResult<T> {
/// The raw expression could not be planned, and is returned unmodified
Original(T),
}

/// This trait allows users to customize the behavior of the data type planning
pub trait TypePlanner: Debug + Send + Sync {
/// Plan SQL type to DataFusion data type
///
/// Returns None if not possible
fn plan_type(&self, _sql_type: &ast::DataType) -> Result<Option<DataType>> {
Ok(None)
}
}
8 changes: 8 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
// First check if any of the registered type_planner can handle this type
if let Some(type_planner) = self.context_provider.get_type_planner() {
if let Some(data_type) = type_planner.plan_type(sql_type)? {
return Ok(data_type);
}
}

// If no type_planner can handle this type, use the default conversion
match sql_type {
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => {
// Arrays may be multi-dimensional.
Expand Down
57 changes: 53 additions & 4 deletions datafusion/sql/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
use std::fmt::Display;
use std::fmt::{Debug, Display};
use std::{sync::Arc, vec};

use arrow_schema::*;
use datafusion_common::config::ConfigOptions;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{plan_err, GetExt, Result, TableReference};
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference};
use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner};
use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
use datafusion_functions_nested::expr_fn::make_array;
use datafusion_sql::planner::ContextProvider;

struct MockCsvType {}
Expand Down Expand Up @@ -54,6 +55,7 @@ pub(crate) struct MockSessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
type_planner: Option<Arc<dyn TypePlanner>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
pub config_options: ConfigOptions,
}
Expand All @@ -64,6 +66,11 @@ impl MockSessionState {
self
}

pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
self.type_planner = Some(type_planner);
self
}

pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
self.scalar_functions
.insert(scalar_function.name().to_string(), scalar_function);
Expand Down Expand Up @@ -259,6 +266,14 @@ impl ContextProvider for MockContextProvider {
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
}

fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
if let Some(type_planner) = &self.state.type_planner {
Some(Arc::clone(type_planner))
} else {
None
}
}
}

struct EmptyTable {
Expand All @@ -280,3 +295,37 @@ impl TableSource for EmptyTable {
Arc::clone(&self.table_schema)
}
}

#[derive(Debug)]
pub struct CustomTypePlanner {}

impl TypePlanner for CustomTypePlanner {
fn plan_type(&self, sql_type: &sqlparser::ast::DataType) -> Result<Option<DataType>> {
match sql_type {
sqlparser::ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Some(3) => TimeUnit::Millisecond,
Some(6) => TimeUnit::Microsecond,
None | Some(9) => TimeUnit::Nanosecond,
_ => unreachable!(),
};
Ok(Some(DataType::Timestamp(precision, None)))
}
_ => Ok(None),
}
}
}

#[derive(Debug)]
pub struct CustomExprPlanner {}

impl ExprPlanner for CustomExprPlanner {
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
_schema: &DFSchema,
) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Planned(make_array(exprs)))
}
}
Loading

0 comments on commit 4e1f839

Please sign in to comment.