Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-6090: [Rust] [DataFusion] Physical plan for HashAggregate #5191

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions rust/datafusion/src/datasource/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ pub struct CsvBatchIterator {

impl CsvBatchIterator {
#[allow(missing_docs)]
pub fn new(
pub fn try_new(
filename: &str,
schema: Arc<Schema>,
has_header: bool,
projection: &Option<Vec<usize>>,
batch_size: usize,
) -> Self {
let file = File::open(filename).unwrap();
) -> Result<Self> {
let file = File::open(filename)?;
let reader = csv::Reader::new(
file,
schema.clone(),
Expand All @@ -110,10 +110,10 @@ impl CsvBatchIterator {
None => schema,
};

Self {
Ok(Self {
schema: projected_schema,
reader,
}
})
}
}

Expand Down
3 changes: 2 additions & 1 deletion rust/datafusion/src/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,8 @@ mod tests {
}

fn load_csv(filename: &str, schema: &Arc<Schema>) -> Rc<RefCell<dyn Relation>> {
let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 10);
let ds =
CsvBatchIterator::try_new(filename, schema.clone(), true, &None, 10).unwrap();
Copy link
Contributor

@paddyhoran paddyhoran Sep 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you propagate this error instead of using unwrap, this is the error from File::open so it's reasonable that load functions would have to deal with this, including load_csv.

If there is a strong reason why you don't want to maybe just switch to expect.

Also, why are you not exposing has_header, etc.? In particular, why is 10 used as the batch size?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is test code to support the unit tests and not part of the actual product ... but you are right, it would be better to have this method return a Result. I will fix that tonight.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is test code to support the unit tests and not part of the actual product

Ahh, ok sorry. It's hard to see the context while reviewing on github. No need to change if it's test code.

Rc::new(RefCell::new(DataSourceRelation::new(Arc::new(Mutex::new(
ds,
)))))
Expand Down
130 changes: 109 additions & 21 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::Column;
use crate::execution::physical_plan::merge::MergeExec;
use crate::execution::physical_plan::projection::ProjectionExec;
use crate::execution::physical_plan::{ExecutionPlan, PhysicalExpr};
use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr};
use crate::execution::projection::ProjectRelation;
use crate::execution::relation::{DataSourceRelation, Relation};
use crate::execution::scalar_relation::ScalarRelation;
Expand Down Expand Up @@ -256,6 +256,29 @@ impl ExecutionContext {
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?))
}
LogicalPlan::Aggregate {
input,
group_expr,
aggr_expr,
schema,
} => {
let input = self.create_physical_plan(input, batch_size)?;
let input_schema = input.as_ref().schema().clone();
let group_expr = group_expr
.iter()
.map(|e| self.create_physical_expr(e, &input_schema))
.collect::<Result<Vec<_>>>()?;
let aggr_expr = aggr_expr
.iter()
.map(|e| self.create_aggregate_expr(e, &input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(HashAggregateExec::try_new(
group_expr,
aggr_expr,
input,
schema.clone(),
)?))
}
_ => Err(ExecutionError::General(
"Unsupported logical plan variant".to_string(),
)),
Expand All @@ -276,6 +299,30 @@ impl ExecutionContext {
}
}

/// Create an aggregate expression from a logical expression
pub fn create_aggregate_expr(
&self,
e: &Expr,
input_schema: &Schema,
) -> Result<Arc<dyn AggregateExpr>> {
match e {
Expr::AggregateFunction { name, args, .. } => {
match name.to_lowercase().as_ref() {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
other => Err(ExecutionError::NotImplemented(format!(
"Unsupported aggregate function '{}'",
other
))),
}
}
_ => Err(ExecutionError::NotImplemented(
"Unsupported aggregate expression".to_string(),
)),
}
}

/// Execute a physical plan and collect the results in memory
pub fn collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>> {
let partitions = plan.partitions()?;
Expand Down Expand Up @@ -499,24 +546,80 @@ impl SchemaProvider for ExecutionContextSchemaProvider {
mod tests {

use super::*;
use crate::test;
use std::fs::File;
use std::io::prelude::*;
use tempdir::TempDir;

#[test]
fn parallel_projection() -> Result<()> {
let partition_count = 4;
let results = execute("SELECT c1, c2 FROM test", partition_count)?;

// there should be one batch per partition
assert_eq!(results.len(), partition_count);

// each batch should contain 2 columns and 10 rows
for batch in &results {
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 10);
}

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["60,220"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn aggregate_grouped() -> Result<()> {
let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

/// Execute SQL and return results
fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new("execute")?;
let mut ctx = create_ctx(&tmp_dir, partition_count)?;

let logical_plan = ctx.create_logical_plan(sql)?;
let logical_plan = ctx.optimize(&logical_plan)?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

ctx.collect(physical_plan.as_ref())
}

/// Generate a partitioned CSV file and register it with an execution context
fn create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext> {
let mut ctx = ExecutionContext::new();

// define schema for data source (csv file)
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
]));

let tmp_dir = TempDir::new("parallel_projection")?;

// generate a partitioned file
let partition_count = 4;
for partition in 0..partition_count {
let filename = format!("partition-{}.csv", partition);
let file_path = tmp_dir.path().join(&filename);
Expand All @@ -532,22 +635,7 @@ mod tests {
// register csv file with the execution context
ctx.register_csv("test", tmp_dir.path().to_str().unwrap(), &schema, true);

let logical_plan = ctx.create_logical_plan("SELECT c1, c2 FROM test")?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

let results = ctx.collect(physical_plan.as_ref())?;

// there should be one batch per partition
assert_eq!(partition_count, results.len());

// each batch should contain 2 columns and 10 rows
for batch in &results {
assert_eq!(2, batch.num_columns());
assert_eq!(10, batch.num_rows());
}

Ok(())
Ok(ctx)
}

}
Loading