Skip to content

Commit

Permalink
improved unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Sep 19, 2019
1 parent 3d7407b commit eb28a08
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 58 deletions.
89 changes: 35 additions & 54 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,71 +546,66 @@ impl SchemaProvider for ExecutionContextSchemaProvider {
mod tests {

use super::*;
use crate::test;
use std::fs::File;
use std::io::prelude::*;
use tempdir::TempDir;

#[test]
fn parallel_projection() -> Result<()> {
let mut ctx = ExecutionContext::new();

// define schema for data source (csv file)
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false),
]));

let tmp_dir = TempDir::new("parallel_projection")?;

// generate a partitioned file
let partition_count = 4;
for partition in 0..partition_count {
let filename = format!("partition-{}.csv", partition);
let file_path = tmp_dir.path().join(&filename);
let mut file = File::create(file_path)?;

// generate some data
for i in 0..=10 {
let data = format!("{},{}\n", partition, i);
file.write_all(data.as_bytes())?;
}
}

// register csv file with the execution context
ctx.register_csv("test", tmp_dir.path().to_str().unwrap(), &schema, true);

let logical_plan = ctx.create_logical_plan("SELECT c1, c2 FROM test")?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

let results = ctx.collect(physical_plan.as_ref())?;
let results = execute("SELECT c1, c2 FROM test", partition_count)?;

// there should be one batch per partition
assert_eq!(partition_count, results.len());
assert_eq!(results.len(), partition_count);

// each batch should contain 2 columns and 10 rows
for batch in &results {
assert_eq!(2, batch.num_columns());
assert_eq!(10, batch.num_rows());
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 10);
}

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
fn aggregate_() -> Result<()> {
let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

/// Execute SQL and return results
fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new("execute")?;
let mut ctx = create_ctx(&tmp_dir, partition_count)?;

let logical_plan = ctx.create_logical_plan(sql)?;
let logical_plan = ctx.optimize(&logical_plan)?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

ctx.collect(physical_plan.as_ref())
}

/// Generate a partitioned CSV file and register it with an execution context
fn create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext> {
let mut ctx = ExecutionContext::new();

// define schema for data source (csv file)
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
]));

let tmp_dir = TempDir::new("aggregate")?;

// generate a partitioned file
let partition_count = 4;
for partition in 0..partition_count {
let filename = format!("partition-{}.csv", partition);
let file_path = tmp_dir.path().join(&filename);
Expand All @@ -626,21 +621,7 @@ mod tests {
// register csv file with the execution context
ctx.register_csv("test", tmp_dir.path().to_str().unwrap(), &schema, true);

let logical_plan =
ctx.create_logical_plan("SELECT c1, SUM(c2) FROM test GROUP BY c1")?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;

let results = ctx.collect(physical_plan.as_ref())?;

assert_eq!(results.len(), 1);

let batch = &results[0];

assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 4);

Ok(())
Ok(ctx)
}

}
4 changes: 0 additions & 4 deletions rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,6 @@ impl BatchIterator for HashAggregateIterator {

// iterate over input and perform aggregation
while let Some(batch) = input.next()? {
//TODO add optimization for case where there are no grouping expressions and
// we can just perform vectorized operations on columns in the batch rather
// than iterating over each row

// evaluate the grouping expressions for this batch
let group_values = self
.group_expr
Expand Down
84 changes: 84 additions & 0 deletions rust/datafusion/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use crate::error::Result;
use crate::execution::context::ExecutionContext;
use crate::execution::physical_plan::ExecutionPlan;
use arrow::array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use std::env;
Expand Down Expand Up @@ -102,3 +103,86 @@ pub fn aggr_test_schema() -> Arc<Schema> {
Field::new("c13", DataType::Utf8, false),
]))
}

/// Format a batch as csv
pub fn format_batch(batch: &RecordBatch) -> Vec<String> {
let mut rows = vec![];
for row_index in 0..batch.num_rows() {
let mut s = String::new();
for column_index in 0..batch.num_columns() {
if column_index > 0 {
s.push(',');
}
let array = batch.column(column_index);
match array.data_type() {
DataType::Int8 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Int8Array>()
.unwrap()
.value(row_index)
)),
DataType::Int16 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Int16Array>()
.unwrap()
.value(row_index)
)),
DataType::Int32 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Int32Array>()
.unwrap()
.value(row_index)
)),
DataType::Int64 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Int64Array>()
.unwrap()
.value(row_index)
)),
DataType::UInt8 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::UInt8Array>()
.unwrap()
.value(row_index)
)),
DataType::UInt16 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::UInt16Array>()
.unwrap()
.value(row_index)
)),
DataType::UInt32 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::UInt32Array>()
.unwrap()
.value(row_index)
)),
DataType::UInt64 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::UInt64Array>()
.unwrap()
.value(row_index)
)),
_ => s.push('?'),
}
}
rows.push(s);
}
rows
}

0 comments on commit eb28a08

Please sign in to comment.