Skip to content

Commit

Permalink
Implement HashAggregate exection plan and SUM aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Sep 19, 2019
1 parent 211e240 commit 07b2489
Show file tree
Hide file tree
Showing 7 changed files with 712 additions and 13 deletions.
10 changes: 5 additions & 5 deletions rust/datafusion/src/datasource/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ pub struct CsvBatchIterator {

impl CsvBatchIterator {
#[allow(missing_docs)]
pub fn new(
pub fn try_new(
filename: &str,
schema: Arc<Schema>,
has_header: bool,
projection: &Option<Vec<usize>>,
batch_size: usize,
) -> Self {
let file = File::open(filename).unwrap();
) -> Result<Self> {
let file = File::open(filename)?;
let reader = csv::Reader::new(
file,
schema.clone(),
Expand All @@ -110,10 +110,10 @@ impl CsvBatchIterator {
None => schema,
};

Self {
Ok(Self {
schema: projected_schema,
reader,
}
})
}
}

Expand Down
3 changes: 2 additions & 1 deletion rust/datafusion/src/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,8 @@ mod tests {
}

fn load_csv(filename: &str, schema: &Arc<Schema>) -> Rc<RefCell<dyn Relation>> {
let ds = CsvBatchIterator::new(filename, schema.clone(), true, &None, 10);
let ds =
CsvBatchIterator::try_new(filename, schema.clone(), true, &None, 10).unwrap();
Rc::new(RefCell::new(DataSourceRelation::new(Arc::new(Mutex::new(
ds,
)))))
Expand Down
95 changes: 94 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::Column;
use crate::execution::physical_plan::merge::MergeExec;
use crate::execution::physical_plan::projection::ProjectionExec;
use crate::execution::physical_plan::{ExecutionPlan, PhysicalExpr};
use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr};
use crate::execution::projection::ProjectRelation;
use crate::execution::relation::{DataSourceRelation, Relation};
use crate::execution::scalar_relation::ScalarRelation;
Expand Down Expand Up @@ -256,6 +256,29 @@ impl ExecutionContext {
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?))
}
LogicalPlan::Aggregate {
input,
group_expr,
aggr_expr,
schema,
} => {
let input = self.create_physical_plan(input, batch_size)?;
let input_schema = input.as_ref().schema().clone();
let group_expr = group_expr
.iter()
.map(|e| self.create_physical_expr(e, &input_schema))
.collect::<Result<Vec<_>>>()?;
let aggr_expr = aggr_expr
.iter()
.map(|e| self.create_aggregate_expr(e, &input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(HashAggregateExec::try_new(
group_expr,
aggr_expr,
input,
schema.clone(),
)?))
}
_ => Err(ExecutionError::General(
"Unsupported logical plan variant".to_string(),
)),
Expand All @@ -276,6 +299,30 @@ impl ExecutionContext {
}
}

/// Create an aggregate expression from a logical expression
pub fn create_aggregate_expr(
&self,
e: &Expr,
input_schema: &Schema,
) -> Result<Arc<dyn AggregateExpr>> {
match e {
Expr::AggregateFunction { name, args, .. } => {
match name.to_lowercase().as_ref() {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
other => Err(ExecutionError::NotImplemented(format!(
"Unsupported aggregate function '{}'",
other
))),
}
}
_ => Err(ExecutionError::NotImplemented(
"Unsupported aggregate expression".to_string(),
)),
}
}

/// Execute a physical plan and collect the results in memory
pub fn collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>> {
let partitions = plan.partitions()?;
Expand Down Expand Up @@ -550,4 +597,50 @@ mod tests {
Ok(())
}

#[test]
fn aggregate() -> Result<()> {
let mut ctx = ExecutionContext::new();

// define schema for data source (csv file)
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false),
]));

let tmp_dir = TempDir::new("aggregate")?;

// generate a partitioned file
let partition_count = 4;
for partition in 0..partition_count {
let filename = format!("partition-{}.csv", partition);
let file_path = tmp_dir.path().join(&filename);
let mut file = File::create(file_path)?;

// generate some data
for i in 0..=10 {
let data = format!("{},{}\n", partition, i);
file.write_all(data.as_bytes())?;
}
}

// register csv file with the execution context
ctx.register_csv("test", tmp_dir.path().to_str().unwrap(), &schema, true);

let logical_plan =
ctx.create_logical_plan("SELECT c1, SUM(c2) FROM test GROUP BY c1")?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

let results = ctx.collect(physical_plan.as_ref())?;

assert_eq!(results.len(), 1);

let batch = &results[0];

assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 4);

Ok(())
}

}
11 changes: 7 additions & 4 deletions rust/datafusion/src/execution/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ pub(super) fn compile_aggregate_expr(

macro_rules! binary_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap();
let rr = $RIGHT.as_any().downcast_ref::<$DT>().unwrap();
let ll = $LEFT.as_any().downcast_ref::<$DT>().expect("x");
let rr = $RIGHT.as_any().downcast_ref::<$DT>().expect("x");
Ok(Arc::new(compute::$OP(&ll, &rr)?))
}};
}
Expand Down Expand Up @@ -237,11 +237,14 @@ macro_rules! boolean_ops {
let left_values = $LEFT.invoke($BATCH)?;
let right_values = $RIGHT.invoke($BATCH)?;
Ok(Arc::new(compute::$OP(
left_values.as_any().downcast_ref::<BooleanArray>().unwrap(),
left_values
.as_any()
.downcast_ref::<BooleanArray>()
.expect("y"),
right_values
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap(),
.expect("z"),
)?))
}};
}
Expand Down
Loading

0 comments on commit 07b2489

Please sign in to comment.