Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support SQL filter clause for aggregate expressions, add SQL dialect support #5868

Merged
merged 15 commits into from
Apr 11, 2023
Merged
4 changes: 4 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ config_namespace! {
/// When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted)
pub enable_ident_normalization: bool, default = true

/// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic,
yjshen marked this conversation as resolved.
Show resolved Hide resolved
/// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi.
pub dialect: String, default = "generic".to_string()

}
}

Expand Down
52 changes: 51 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ use datafusion_sql::{
planner::{ContextProvider, SqlToRel},
};
use parquet::file::properties::WriterProperties;
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, Dialect, GenericDialect,
HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, RedshiftSqlDialect,
SQLiteDialect, SnowflakeDialect,
};
use url::Url;

use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA};
Expand Down Expand Up @@ -1510,6 +1515,27 @@ impl SessionState {
Ok(statement)
}

/// Convert a SQL string into an AST Statement
pub fn sql_to_statement_with_dialect(
yjshen marked this conversation as resolved.
Show resolved Hide resolved
&self,
sql: &str,
dialect: &str,
) -> Result<datafusion_sql::parser::Statement> {
let dialect = create_dialect_from_str(dialect);
let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?;
if statements.len() > 1 {
return Err(DataFusionError::NotImplemented(
"The context currently only supports a single SQL statement".to_string(),
));
}
let statement = statements.pop_front().ok_or_else(|| {
DataFusionError::NotImplemented(
"The context requires a statement!".to_string(),
)
})?;
Ok(statement)
}

/// Resolve all table references in the SQL statement.
pub fn resolve_table_references(
&self,
Expand Down Expand Up @@ -1624,7 +1650,8 @@ impl SessionState {
///
/// See [`SessionContext::sql`] for a higher-level interface that also handles DDL
pub async fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
let statement = self.sql_to_statement(sql)?;
let dialect = self.config.options().sql_parser.dialect.as_str();
let statement = self.sql_to_statement_with_dialect(sql, dialect)?;
let plan = self.statement_to_plan(statement).await?;
Ok(plan)
}
Expand Down Expand Up @@ -1833,6 +1860,29 @@ impl From<&SessionState> for TaskContext {
}
}

fn create_dialect_from_str(dialect_name: &str) -> Box<dyn Dialect> {
yjshen marked this conversation as resolved.
Show resolved Hide resolved
match dialect_name.to_lowercase().as_str() {
"generic" => Box::new(GenericDialect),
"mysql" => Box::new(MySqlDialect {}),
"postgresql" | "postgres" => Box::new(PostgreSqlDialect {}),
"hive" => Box::new(HiveDialect {}),
"sqlite" => Box::new(SQLiteDialect {}),
"snowflake" => Box::new(SnowflakeDialect),
"redshift" => Box::new(RedshiftSqlDialect {}),
"mssql" => Box::new(MsSqlDialect {}),
"clickhouse" => Box::new(ClickHouseDialect {}),
"bigquery" => Box::new(BigQueryDialect),
"ansi" => Box::new(AnsiDialect {}),
_ => {
yjshen marked this conversation as resolved.
Show resolved Hide resolved
println!(
"Unsupported SQL dialect: {}. Using GenericDialect. Available dialects: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi.",
dialect_name
);
Box::new(GenericDialect)
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>>
{
if partial_agg_exec.mode() == &AggregateMode::Partial
&& partial_agg_exec.group_expr().is_empty()
&& partial_agg_exec
.aggr_expr()
.iter()
.all(|e| e.filter().is_none())
{
let stats = partial_agg_exec.input().statistics();
if stats.is_exact {
Expand Down Expand Up @@ -369,6 +373,7 @@ mod tests {
fn count_expr(&self) -> Arc<dyn AggregateExpr> {
Arc::new(Count::new(
self.column(),
None,
self.column_name(),
DataType::Int64,
))
Expand Down
43 changes: 43 additions & 0 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,28 @@ fn aggregate_expressions(
}
}

/// returns filter expressions to evaluate against a batch
/// The expressions are different depending on `mode`:
/// * Partial: AggregateExpr::filter
/// * Final | FinalPartitioned: empty
fn filter_expressions(
aggr_expr: &[Arc<dyn AggregateExpr>],
mode: &AggregateMode,
) -> Result<Vec<Option<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
let filters = aggr_expr
.iter()
.map(|agg| agg.filter())
.collect::<Vec<Option<Arc<dyn PhysicalExpr>>>>();
Ok(filters)
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
Ok(vec![None; aggr_expr.len()])
}
}
}

/// uses `state_fields` to build a vec of physical column expressions required to merge the
/// AggregateExpr' accumulator's state.
///
Expand Down Expand Up @@ -667,6 +689,20 @@ fn evaluate_many(
.collect::<Result<Vec<_>>>()
}

fn evaluate_optional(
expr: &[Option<Arc<dyn PhysicalExpr>>],
batch: &RecordBatch,
) -> Result<Vec<Option<ArrayRef>>> {
expr.iter()
.map(|expr| {
expr.as_ref()
.map(|expr| expr.evaluate(batch))
.transpose()
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
})
.collect::<Result<Vec<_>>>()
}

fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
Expand Down Expand Up @@ -792,6 +828,7 @@ mod tests {

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Count::new(
lit(1i8),
None,
"COUNT(1)".to_string(),
DataType::Int64,
))];
Expand Down Expand Up @@ -897,6 +934,7 @@ mod tests {

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &input_schema)?,
None,
"AVG(b)".to_string(),
DataType::Float64,
))];
Expand Down Expand Up @@ -1127,6 +1165,7 @@ mod tests {
// something that allocates within the aggregator
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Median::new(
col("a", &input_schema)?,
None,
"MEDIAN(a)".to_string(),
DataType::UInt32,
))];
Expand All @@ -1135,13 +1174,15 @@ mod tests {
let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(ApproxDistinct::new(
col("a", &input_schema)?,
None,
"APPROX_DISTINCT(a)".to_string(),
DataType::UInt32,
))];

// use fast-path in `row_hash.rs`.
let aggregates_v2: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &input_schema)?,
None,
"AVG(b)".to_string(),
DataType::Float64,
))];
Expand Down Expand Up @@ -1200,6 +1241,7 @@ mod tests {

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("a", &schema)?,
None,
"AVG(a)".to_string(),
DataType::Float64,
))];
Expand Down Expand Up @@ -1238,6 +1280,7 @@ mod tests {

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &schema)?,
None,
"AVG(b)".to_string(),
DataType::Float64,
))];
Expand Down
27 changes: 20 additions & 7 deletions datafusion/core/src/physical_plan/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

use crate::execution::context::TaskContext;
use crate::physical_plan::aggregates::{
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
AggregateMode,
aggregate_expressions, create_accumulators, filter_expressions, finalize_aggregation,
AccumulatorItem, AggregateMode,
};
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
Expand All @@ -33,6 +33,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};

use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use crate::physical_plan::filter::batch_filter;
use futures::stream::{Stream, StreamExt};

/// stream struct for aggregation without grouping columns
Expand All @@ -52,6 +53,7 @@ struct AggregateStreamInner {
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
accumulators: Vec<AccumulatorItem>,
reservation: MemoryReservation,
finished: bool,
Expand All @@ -69,6 +71,7 @@ impl AggregateStream {
partition: usize,
) -> Result<Self> {
let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
let filter_expressions = filter_expressions(&aggr_expr, &mode)?;
let accumulators = create_accumulators(&aggr_expr)?;

let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
Expand All @@ -80,6 +83,7 @@ impl AggregateStream {
input,
baseline_metrics,
aggregate_expressions,
filter_expressions,
accumulators,
reservation,
finished: false,
Expand All @@ -100,6 +104,7 @@ impl AggregateStream {
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
&this.filter_expressions,
);

timer.done();
Expand Down Expand Up @@ -172,26 +177,34 @@ fn aggregate_batch(
batch: &RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
filters: &[Option<Arc<dyn PhysicalExpr>>],
) -> Result<usize> {
let mut allocated = 0usize;

// 1.1 iterate accumulators and respective expressions together
// 1.2 evaluate expressions
// 1.3 update / merge accumulators with the expressions' values
// 1.2 filter the batch if necessary
yjshen marked this conversation as resolved.
Show resolved Hide resolved
// 1.3 evaluate expressions
// 1.4 update / merge accumulators with the expressions' values

// 1.1
accumulators
.iter_mut()
.zip(expressions)
.try_for_each(|(accum, expr)| {
.zip(filters)
.try_for_each(|((accum, expr), filter)| {
// 1.2
let batch = match filter {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Some(filter) => batch_filter(batch, filter)?,
None => batch.clone(),
yjshen marked this conversation as resolved.
Show resolved Hide resolved
};
// 1.3
let values = &expr
.iter()
.map(|e| e.evaluate(batch))
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

// 1.3
// 1.4
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial => accum.update_batch(values),
Expand Down
Loading