Skip to content

Commit

Permalink
ARROW-6090: [Rust] [DataFusion] Physical plan for HashAggregate
Browse files Browse the repository at this point in the history
This PR implements the HashAggregate execution plan.

Closes #5191 from andygrove/ARROW-6090 and squashes the following commits:

6004d08 <Andy Grove> Fix conflict from rebase
29f1197 <Andy Grove> manual rebase on ARROW-6563
8bcc84b <Andy Grove> Add support for aggregate queries without a GROUP BY clause
eb28a08 <Andy Grove> improved unit test
3d7407b <Andy Grove> Code cleanup and rebase
07b2489 <Andy Grove> Implement HashAggregate exection plan and SUM aggregate function

Authored-by: Andy Grove <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
andygrove committed Sep 20, 2019
1 parent 00a3c47 commit f19ee70
Show file tree
Hide file tree
Showing 7 changed files with 939 additions and 30 deletions.
10 changes: 5 additions & 5 deletions rust/datafusion/src/datasource/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ pub struct CsvBatchIterator {

impl CsvBatchIterator {
#[allow(missing_docs)]
pub fn new(
pub fn try_new(
filename: &str,
schema: Arc<Schema>,
has_header: bool,
projection: &Option<Vec<usize>>,
batch_size: usize,
) -> Self {
let file = File::open(filename).unwrap();
) -> Result<Self> {
let file = File::open(filename)?;
let reader = csv::Reader::new(
file,
schema.clone(),
Expand All @@ -110,10 +110,10 @@ impl CsvBatchIterator {
None => schema,
};

Self {
Ok(Self {
schema: projected_schema,
reader,
}
})
}
}

Expand Down
3 changes: 2 additions & 1 deletion rust/datafusion/src/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,8 @@ mod tests {
}

fn load_csv(filename: &str, schema: &Arc<Schema>) -> Rc<RefCell<dyn Relation>> {
let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 10);
let ds =
CsvBatchIterator::try_new(filename, schema.clone(), true, &None, 10).unwrap();
Rc::new(RefCell::new(DataSourceRelation::new(Arc::new(Mutex::new(
ds,
)))))
Expand Down
133 changes: 111 additions & 22 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ use crate::execution::filter::FilterRelation;
use crate::execution::limit::LimitRelation;
use crate::execution::physical_plan::common;
use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::Column;
use crate::execution::physical_plan::expressions::{Column, Sum};
use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
use crate::execution::physical_plan::merge::MergeExec;
use crate::execution::physical_plan::projection::ProjectionExec;
use crate::execution::physical_plan::{ExecutionPlan, PhysicalExpr};
use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr};
use crate::execution::projection::ProjectRelation;
use crate::execution::relation::{DataSourceRelation, Relation};
use crate::execution::scalar_relation::ScalarRelation;
Expand Down Expand Up @@ -256,6 +257,29 @@ impl ExecutionContext {
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?))
}
LogicalPlan::Aggregate {
input,
group_expr,
aggr_expr,
schema,
} => {
let input = self.create_physical_plan(input, batch_size)?;
let input_schema = input.as_ref().schema().clone();
let group_expr = group_expr
.iter()
.map(|e| self.create_physical_expr(e, &input_schema))
.collect::<Result<Vec<_>>>()?;
let aggr_expr = aggr_expr
.iter()
.map(|e| self.create_aggregate_expr(e, &input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(HashAggregateExec::try_new(
group_expr,
aggr_expr,
input,
schema.clone(),
)?))
}
_ => Err(ExecutionError::General(
"Unsupported logical plan variant".to_string(),
)),
Expand All @@ -276,6 +300,30 @@ impl ExecutionContext {
}
}

/// Create an aggregate expression from a logical expression
pub fn create_aggregate_expr(
&self,
e: &Expr,
input_schema: &Schema,
) -> Result<Arc<dyn AggregateExpr>> {
match e {
Expr::AggregateFunction { name, args, .. } => {
match name.to_lowercase().as_ref() {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
other => Err(ExecutionError::NotImplemented(format!(
"Unsupported aggregate function '{}'",
other
))),
}
}
_ => Err(ExecutionError::NotImplemented(
"Unsupported aggregate expression".to_string(),
)),
}
}

/// Execute a physical plan and collect the results in memory
pub fn collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>> {
let partitions = plan.partitions()?;
Expand Down Expand Up @@ -499,24 +547,80 @@ impl SchemaProvider for ExecutionContextSchemaProvider {
mod tests {

use super::*;
use crate::test;
use std::fs::File;
use std::io::prelude::*;
use tempdir::TempDir;

#[test]
fn parallel_projection() -> Result<()> {
let partition_count = 4;
let results = execute("SELECT c1, c2 FROM test", partition_count)?;

// there should be one batch per partition
assert_eq!(results.len(), partition_count);

// each batch should contain 2 columns and 10 rows
for batch in &results {
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 10);
}

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["60,220"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn aggregate_grouped() -> Result<()> {
let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

/// Execute SQL and return results
fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new("execute")?;
let mut ctx = create_ctx(&tmp_dir, partition_count)?;

let logical_plan = ctx.create_logical_plan(sql)?;
let logical_plan = ctx.optimize(&logical_plan)?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

ctx.collect(physical_plan.as_ref())
}

/// Generate a partitioned CSV file and register it with an execution context
fn create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext> {
let mut ctx = ExecutionContext::new();

// define schema for data source (csv file)
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
]));

let tmp_dir = TempDir::new("parallel_projection")?;

// generate a partitioned file
let partition_count = 4;
for partition in 0..partition_count {
let filename = format!("partition-{}.csv", partition);
let file_path = tmp_dir.path().join(&filename);
Expand All @@ -532,22 +636,7 @@ mod tests {
// register csv file with the execution context
ctx.register_csv("test", tmp_dir.path().to_str().unwrap(), &schema, true);

let logical_plan = ctx.create_logical_plan("SELECT c1, c2 FROM test")?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

let results = ctx.collect(physical_plan.as_ref())?;

// there should be one batch per partition
assert_eq!(partition_count, results.len());

// each batch should contain 2 columns and 10 rows
for batch in &results {
assert_eq!(2, batch.num_columns());
assert_eq!(10, batch.num_rows());
}

Ok(())
Ok(ctx)
}

}
Loading

0 comments on commit f19ee70

Please sign in to comment.