Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into alamb/user_access_plan
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 9, 2024
2 parents 7799d65 + ad0dc2f commit 69509a4
Show file tree
Hide file tree
Showing 86 changed files with 1,832 additions and 1,939 deletions.
44 changes: 38 additions & 6 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::functions_aggregate::first_last::first_value_udaf;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries};
use datafusion::prelude::*;
Expand All @@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
///
Expand All @@ -44,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
/// also comes with APIs for evaluation, simplification, and analysis.
///
/// The code in this example shows how to:
/// 1. Create [`Exprs`] using different APIs: [`main`]`
/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`]
/// 3. Simplify expressions: [`simplify_demo`]
/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`]
/// 5. Get the types of the expressions: [`expression_type_demo`]
/// 1. Create [`Expr`]s using different APIs: [`main`]`
/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`]
/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`]
/// 4. Simplify expressions: [`simplify_demo`]
/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`]
/// 6. Get the types of the expressions: [`expression_type_demo`]
#[tokio::main]
async fn main() -> Result<()> {
// The easiest way to do create expressions is to use the
Expand All @@ -63,6 +65,9 @@ async fn main() -> Result<()> {
));
assert_eq!(expr, expr2);

// See how to build aggregate functions with the expr_fn API
expr_fn_demo()?;

// See how to evaluate expressions
evaluate_demo()?;

Expand All @@ -78,6 +83,33 @@ async fn main() -> Result<()> {
Ok(())
}

/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the
/// full range of expression types such as aggregates and window functions.
fn expr_fn_demo() -> Result<()> {
// Let's say you want to call the "first_value" aggregate function
let first_value = first_value_udaf();

// For example, to create the expression `FIRST_VALUE(price)`
// These expressions can be passed to `DataFrame::aggregate` and other
// APIs that take aggregate expressions.
let agg = first_value.call(vec![col("price")]);
assert_eq!(agg.to_string(), "first_value(price)");

// You can use the AggregateExt trait to create more complex aggregates
// such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
let agg = first_value
.call(vec![col("price")])
.order_by(vec![col("ts").sort(false, false)])
.filter(col("quantity").gt(lit(100)))
.build()?; // build the aggregate
assert_eq!(
agg.to_string(),
"first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST]"
);

Ok(())
}

/// DataFusion can also evaluate arbitrary expressions on Arrow arrays.
fn evaluate_demo() -> Result<()> {
// For example, let's say you have some integers in an array
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::median;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_aggregate::expr_fn::{median, stddev};

use async_trait::async_trait;

Expand Down Expand Up @@ -1820,7 +1820,7 @@ mod tests {

assert_batches_sorted_eq!(
["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
Expand Down Expand Up @@ -2395,7 +2395,7 @@ mod tests {
assert_batches_sorted_eq!(
[
"+----+-----------------------------+",
"| c1 | SUM(aggregate_test_100.c12) |",
"| c1 | sum(aggregate_test_100.c12) |",
"+----+-----------------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
Expand All @@ -2411,7 +2411,7 @@ mod tests {
assert_batches_sorted_eq!(
[
"+----+---------------------+",
"| c1 | SUM(test_table.c12) |",
"| c1 | sum(test_table.c12) |",
"+----+---------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ mod tests {

#[rustfmt::skip]
let expected = ["+--------------+",
"| SUM(aggr.c2) |",
"| sum(aggr.c2) |",
"+--------------+",
"| 285 |",
"+--------------+"];
Expand Down Expand Up @@ -956,7 +956,7 @@ mod tests {

#[rustfmt::skip]
let expected = ["+--------------+",
"| SUM(aggr.c3) |",
"| sum(aggr.c3) |",
"+--------------+",
"| 781 |",
"+--------------+"];
Expand Down Expand Up @@ -1122,7 +1122,7 @@ mod tests {

#[rustfmt::skip]
let expected = ["+---------------------+",
"| SUM(empty.column_1) |",
"| sum(empty.column_1) |",
"+---------------------+",
"| 10 |",
"+---------------------+"];
Expand Down Expand Up @@ -1161,7 +1161,7 @@ mod tests {

#[rustfmt::skip]
let expected = ["+-----------------------+",
"| SUM(one_col.column_1) |",
"| sum(one_col.column_1) |",
"+-----------------------+",
"| 50 |",
"+-----------------------+"];
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,15 @@ mod tests {
ctx.register_json("json_parallel", table_path, options)
.await?;

let query = "SELECT SUM(a) FROM json_parallel;";
let query = "SELECT sum(a) FROM json_parallel;";

let result = ctx.sql(query).await?.collect().await?;
let actual_partitions = count_num_partitions(&ctx, query).await?;

#[rustfmt::skip]
let expected = [
"+----------------------+",
"| SUM(json_parallel.a) |",
"| sum(json_parallel.a) |",
"+----------------------+",
"| -7 |",
"+----------------------+"
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ mod tests {
)
.await?;
let results =
plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?;
plan_and_collect(&ctx, "SELECT sum(c1), sum(c2), COUNT(*) FROM test").await?;

assert_eq!(results.len(), 1);
let expected = [
"+--------------+--------------+----------+",
"| SUM(test.c1) | SUM(test.c2) | COUNT(*) |",
"| sum(test.c1) | sum(test.c2) | COUNT(*) |",
"+--------------+--------------+----------+",
"| 10 | 110 | 20 |",
"+--------------+--------------+----------+",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ mod tests {
use crate::physical_plan::{displayable, Partitioning};

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_physical_expr::expressions::{col, Count, Sum};
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{col, Count};
use datafusion_physical_plan::udaf::create_aggregate_expr;

/// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected
macro_rules! assert_optimized {
Expand Down Expand Up @@ -391,12 +393,17 @@ mod tests {
#[test]
fn aggregations_with_group_combined() -> Result<()> {
let schema = schema();
let aggr_expr = vec![Arc::new(Sum::new(
col("b", &schema)?,
"Sum(b)".to_string(),
DataType::Int64,
)) as _];

let aggr_expr = vec![create_aggregate_expr(
&sum_udaf(),
&[col("b", &schema)?],
&[],
&[],
&schema,
"Sum(b)",
false,
false,
)?];
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("c", &schema)?, "c".to_string())];

Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2586,7 +2586,7 @@ mod tests {
.downcast_ref::<AggregateExec>()
.expect("hash aggregate");
assert_eq!(
"SUM(aggregate_test_100.c2)",
"sum(aggregate_test_100.c2)",
final_hash_agg.schema().field(1).name()
);
// we need access to the input to the partial aggregate so that other projects can
Expand Down Expand Up @@ -2614,7 +2614,7 @@ mod tests {
.downcast_ref::<AggregateExec>()
.expect("hash aggregate");
assert_eq!(
"SUM(aggregate_test_100.c3)",
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(2).name()
);
// we need access to the input to the partial aggregate so that other projects can
Expand Down
Loading

0 comments on commit 69509a4

Please sign in to comment.