Skip to content

Commit

Permalink
Fix concat simplifier for Utf8View types (#13346)
Browse files Browse the repository at this point in the history
* Add string view options to concat, fix simplifier for handling concat to return the same schema as without

* Set coersion ordering

* Add to simplification unit test to catch changes in type for concat

* Update coersion ordering

* Simplify computing merged type for concat
  • Loading branch information
timsaucer authored Nov 15, 2024
1 parent 7e69580 commit 1e96a0a
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 16 deletions.
24 changes: 19 additions & 5 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef {
Field::new("c2", DataType::Boolean, true),
Field::new("c3", DataType::Int64, true),
Field::new("c4", DataType::UInt32, true),
Field::new("c5", DataType::Utf8View, true),
Field::new("c1_non_null", DataType::Utf8, false),
Field::new("c2_non_null", DataType::Boolean, false),
Field::new("c3_non_null", DataType::Int64, false),
Field::new("c4_non_null", DataType::UInt32, false),
Field::new("c5_non_null", DataType::Utf8View, false),
])
.to_dfschema_ref()
.unwrap()
Expand Down Expand Up @@ -665,20 +667,32 @@ fn test_simplify_concat_ws_with_null() {
}

#[test]
fn test_simplify_concat() {
fn test_simplify_concat() -> Result<()> {
let schema = expr_test_schema();
let null = lit(ScalarValue::Utf8(None));
let expr = concat(vec![
null.clone(),
col("c0"),
col("c1"),
lit("hello "),
null.clone(),
lit("rust"),
col("c1"),
lit(ScalarValue::Utf8View(Some("!".to_string()))),
col("c2"),
lit(""),
null,
col("c5"),
]);
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
test_simplify(expr, expected)
let expr_datatype = expr.get_type(schema.as_ref())?;
let expected = concat(vec![
col("c1"),
lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))),
col("c2"),
col("c5"),
]);
let expected_datatype = expected.get_type(schema.as_ref())?;
assert_eq!(expr_datatype, expected_datatype);
test_simplify(expr, expected);
Ok(())
}
#[test]
fn test_simplify_cycles() {
Expand Down
87 changes: 76 additions & 11 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl ConcatFunc {
use DataType::*;
Self {
signature: Signature::variadic(
vec![Utf8, Utf8View, LargeUtf8],
vec![Utf8View, Utf8, LargeUtf8],
Volatility::Immutable,
),
}
Expand Down Expand Up @@ -110,8 +110,19 @@ impl ScalarUDFImpl for ConcatFunc {
if array_len.is_none() {
let mut result = String::new();
for arg in args {
if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg {
result.push_str(v);
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => {
result.push_str(v);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
other => plan_err!(
"Concat function does not support scalar type {:?}",
other
)?,
}
}

Expand Down Expand Up @@ -282,15 +293,37 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();

let return_type = {
let data_types: Vec<_> = args
.iter()
.filter_map(|expr| match expr {
Expr::Literal(l) => Some(l.data_type()),
_ => None,
})
.collect();
ConcatFunc::new().return_type(&data_types)
}?;

for arg in args.clone() {
match arg {
Expr::Literal(ScalarValue::Utf8(None)) => {}
Expr::Literal(ScalarValue::LargeUtf8(None)) => {
}
Expr::Literal(ScalarValue::Utf8View(None)) => { }

// filter out `null` args
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
// Concatenate it with the `contiguous_scalar`.
Expr::Literal(
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)),
) => contiguous_scalar += &v,
Expr::Literal(ScalarValue::Utf8(Some(v))) => {
contiguous_scalar += &v;
}
Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => {
contiguous_scalar += &v;
}
Expr::Literal(ScalarValue::Utf8View(Some(v))) => {
contiguous_scalar += &v;
}

Expr::Literal(x) => {
return internal_err!(
"The scalar {x} should be casted to string type during the type coercion."
Expand All @@ -301,7 +334,12 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
// Then pushing this arg to the `new_args`.
arg => {
if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
match return_type {
DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
_ => unreachable!(),
}
contiguous_scalar = "".to_string();
}
new_args.push(arg);
Expand All @@ -310,7 +348,16 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
}

if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
match return_type {
DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
DataType::LargeUtf8 => {
new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
}
DataType::Utf8View => {
new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
}
_ => unreachable!(),
}
}

if !args.eq(&new_args) {
Expand Down Expand Up @@ -392,6 +439,17 @@ mod tests {
LargeUtf8,
LargeStringArray
);
test_function!(
ConcatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
],
Ok(Some("aacc")),
&str,
Utf8View,
StringViewArray
);

Ok(())
}
Expand All @@ -406,11 +464,18 @@ mod tests {
None,
Some("z"),
])));
let args = &[c0, c1, c2];
let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
Some("a"),
None,
Some("b"),
])));
let args = &[c0, c1, c2, c3, c4];

let result = ConcatFunc::new().invoke_batch(args, 3)?;
let expected =
Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef;
Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
as ArrayRef;
match &result {
ColumnarValue::Array(array) => {
assert_eq!(&expected, array);
Expand Down

0 comments on commit 1e96a0a

Please sign in to comment.