Skip to content

Commit

Permalink
feat: Support ArrayIndex (GetIndexedExpr) on dynamic key expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed May 11, 2022
1 parent 3dcf38b commit 7b7901f
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 86 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl ExprRewritable for Expr {
}
Expr::GetIndexedField { expr, key } => Expr::GetIndexedField {
expr: rewrite_boxed(expr, rewriter)?,
key,
key: rewrite_boxed(key, rewriter)?,
},
};

Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ impl ExprVisitable for Expr {
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
| Expr::Sort { expr, .. } => expr.accept(visitor),
Expr::GetIndexedField { expr, key } => {
let visitor = expr.accept(visitor)?;
key.accept(visitor)
}
Expr::Column(_)
| Expr::OuterColumn(_, _)
| Expr::ScalarVariable(_, _)
Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,10 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
| Expr::Alias(expr, ..)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
| Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
Expr::GetIndexedField { expr, key } => {
Ok(vec![expr.as_ref().to_owned(), key.as_ref().to_owned()])
}
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::TableUDF { args, .. }
Expand Down Expand Up @@ -547,9 +549,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
"QualifiedWildcard expressions are not valid in a logical query plan"
.to_owned(),
)),
Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField {
Expr::GetIndexedField { .. } => Ok(Expr::GetIndexedField {
expr: Box::new(expressions[0].clone()),
key: key.clone(),
key: Box::new(expressions[1].clone()),
}),
}
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
Expr::GetIndexedField { expr, key } => {
let expr = create_physical_name(expr, false)?;
let key = create_physical_name(key, false)?;
Ok(format!("{}[{}]", expr, key))
}
Expr::ScalarFunction { fun, args, .. } => {
Expand Down Expand Up @@ -1093,7 +1094,7 @@ pub fn create_physical_expr(
)?),
Expr::GetIndexedField { expr, key } => Ok(Arc::new(GetIndexedFieldExpr::new(
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,
key.clone(),
create_physical_expr(key, input_dfschema, input_schema, execution_props)?,
))),

Expr::ScalarFunction { fun, args } => {
Expand Down
50 changes: 9 additions & 41 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,7 @@ impl SqlToRelContext {
}
}

fn plan_key(key: SQLExpr) -> Result<ScalarValue> {
let scalar = match key {
SQLExpr::Value(Value::Number(s, _)) => {
ScalarValue::Int64(Some(s.parse().unwrap()))
}
SQLExpr::Value(Value::SingleQuotedString(s)) => ScalarValue::Utf8(Some(s)),
SQLExpr::Identifier(ident) => ScalarValue::Utf8(Some(ident.value)),
_ => {
return Err(DataFusionError::SQL(ParserError(format!(
"Unsuported index key expression: {:?}",
key
))))
}
};

Ok(scalar)
}

fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {
fn plan_indexed(expr: Expr, mut keys: Vec<Expr>) -> Result<Expr> {
let key = keys.pop().ok_or_else(|| {
DataFusionError::SQL(ParserError(
"Internal error: Missing index key expression".to_string(),
Expand All @@ -166,7 +148,7 @@ fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {

Ok(Expr::GetIndexedField {
expr: Box::new(expr),
key: plan_key(key)?,
key: Box::new(key),
})
}

Expand Down Expand Up @@ -1704,26 +1686,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

SQLExpr::MapAccess { ref column, keys } => {
if let SQLExpr::Identifier(ref id) = column.as_ref() {
plan_indexed(col(&id.value), keys)
} else {
Err(DataFusionError::NotImplemented(format!(
"map access requires an identifier, found column {} instead",
column
)))
}
}

SQLExpr::ArrayIndex { obj, indexs } => {
if let SQLExpr::Identifier(ref id) = obj.as_ref() {
plan_indexed(col(&id.value), indexs)
} else {
Err(DataFusionError::NotImplemented(format!(
"array index access requires an identifier, found column {} instead",
obj
)))
}
let expr = self.sql_expr_to_logical_expr(*obj, schema)?;

plan_indexed(expr, indexs.into_iter()
.map(|e| self.sql_expr_to_logical_expr(e, schema))
.collect::<Result<Vec<_>>>()?)
}

SQLExpr::CompoundIdentifier(ids) => {
Expand Down Expand Up @@ -1754,7 +1722,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Access to a field of a column which is a structure, example: SELECT my_struct.key
Ok(Expr::GetIndexedField {
expr: Box::new(Expr::Column(field.qualified_column())),
key: ScalarValue::Utf8(Some(name)),
key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(name)))),
})
} else {
// table.column identifier
Expand Down Expand Up @@ -2104,7 +2072,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::DotExpr { expr, field } => {
Ok(Expr::GetIndexedField {
expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
key: ScalarValue::Utf8(Some(field.value)),
key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(field.value)))),
})
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ where
Expr::QualifiedWildcard { .. } => Ok(expr.clone()),
Expr::GetIndexedField { expr, key } => Ok(Expr::GetIndexedField {
expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?),
key: key.clone(),
key: Box::new(clone_with_replacement(key.as_ref(), replacement_fn)?),
}),
},
}
Expand Down
84 changes: 82 additions & 2 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,92 @@ async fn query_nested_get_indexed_field() -> Result<()> {
"+----------+",
];
assert_batches_eq!(expected, &actual);

// nested with scalar values
let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+",
];
assert_batches_eq!(expected, &actual);

// nested with dynamic expr in key
assert_batches_eq!(expected, &actual);
let sql = "SELECT some_list[1 - 1][1 - 1] as i0 FROM ints LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}

#[tokio::test]
async fn query_get_indexed_array_dynamic_key() -> Result<()> {
let ctx = SessionContext::new();

let list_dt = Box::new(Field::new("item", DataType::Int64, true));
let schema = Arc::new(Schema::new(vec![
Field::new("arr", DataType::List(list_dt), false),
Field::new("key", DataType::Int64, false),
]));

let array_ints_builder = PrimitiveBuilder::<Int64Type>::new(3);
let mut arr_builder = ListBuilder::new(array_ints_builder);
let mut key_builder = PrimitiveBuilder::<Int64Type>::new(3);

for (int_vec, key) in vec![
(vec![0, 1, 2, 3], 1),
(vec![4, 5, 6, 7], 2),
(vec![8, 9, 10, 11], 3),
] {
for n in int_vec {
arr_builder.values().append_value(n)?;
}

key_builder.append_value(key)?;
arr_builder.append(true)?;
}

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arr_builder.finish()),
Arc::new(key_builder.finish()),
],
)?;
let table = MemTable::try_new(schema, vec![vec![data]])?;
let table_a = Arc::new(table);

ctx.register_table("array_and_keys", table_a)?;

let sql = "SELECT arr[key], key FROM array_and_keys";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------------------+-----+",
"| array_and_keys.arr[array_and_keys.key] | key |",
"+----------------------------------------+-----+",
"| 0 | 1 |",
"| 4 | 2 |",
"| 8 | 3 |",
"+----------------------------------------+-----+",
];
assert_batches_eq!(expected, &actual);

// All dynamic
let sql = "SELECT r.value[r.key] FROM (SELECT array[1,2,3] as value, 1 as key UNION ALL SELECT array[4,5,6] as value, 2 as key) as r";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------+",
"| r.value[r.key] |",
"+----------------+",
"| 1 |",
"| 5 |",
"+----------------+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}

Expand Down Expand Up @@ -634,7 +714,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
ctx.register_table("structs", table_a)?;

// Original column is micros, convert to millis and check timestamp
let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
let sql = "SELECT some_struct['bar'] as l0 FROM structs LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------+",
Expand All @@ -661,7 +741,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
let sql = "SELECT some_struct['bar'][0] as i0 FROM structs LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pub enum Expr {
/// the expression to take the field from
expr: Box<Expr>,
/// The name of the field to take
key: ScalarValue,
key: Box<Expr>,
},
/// Whether an expression is between a given range.
Between {
Expand Down
Loading

0 comments on commit 7b7901f

Please sign in to comment.