Skip to content

Commit

Permalink
ARROW-10354: [Rust][DataFusion] regexp_extract function to select reg…
Browse files Browse the repository at this point in the history
…ex groups from strings

Adds a regexp_extract compute kernel to select a substring based on a regular expression.

Some things I did that I may be doing wrong:

* I exposed `GenericStringBuilder`
* I build the resulting Array using a builder - this looks quite different from e.g. the substring kernel. Should I change it accordingly, e.g. because of performance considerations?
* In order to apply the new function in datafusion, I did not see a better solution than to handle the pattern string as `StringArray` and take the first record to compile the regex pattern from it and apply it to all values. Is there a way to define that an argument has to be a literal/scalar and cannot be filled by e.g. another column? I consider my current implementation quite error prone and would like to make this a bit more robust.

Closes #9428 from sweb/ARROW-10354/regexp_extract

Authored-by: Florian Müller <[email protected]>
Signed-off-by: Andrew Lamb <[email protected]>
  • Loading branch information
sweb authored and alamb committed Apr 1, 2021
1 parent ec50bc1 commit 59bb69b
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 7 deletions.
1 change: 1 addition & 0 deletions rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ unary_scalar_expr!(Lpad, lpad);
unary_scalar_expr!(Ltrim, ltrim);
unary_scalar_expr!(MD5, md5);
unary_scalar_expr!(OctetLength, octet_length);
unary_scalar_expr!(RegexpMatch, regexp_match);
unary_scalar_expr!(RegexpReplace, regexp_replace);
unary_scalar_expr!(Replace, replace);
unary_scalar_expr!(Repeat, repeat);
Expand Down
8 changes: 4 additions & 4 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ pub use expr::{
ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count,
count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad,
rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with,
strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr,
ExprRewriter, ExpressionVisitor, Literal, Recursion,
octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when,
Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
122 changes: 120 additions & 2 deletions rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ pub enum BuiltinScalarFunction {
Trim,
/// upper
Upper,
/// regexp_match
RegexpMatch,
}

impl fmt::Display for BuiltinScalarFunction {
Expand Down Expand Up @@ -271,7 +273,7 @@ impl FromStr for BuiltinScalarFunction {
"translate" => BuiltinScalarFunction::Translate,
"trim" => BuiltinScalarFunction::Trim,
"upper" => BuiltinScalarFunction::Upper,

"regexp_match" => BuiltinScalarFunction::RegexpMatch,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -607,6 +609,20 @@ pub fn return_type(
));
}
}),
BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] {
DataType::LargeUtf8 => {
DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true)))
}
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The regexp_extract function can only accept strings.".to_string(),
));
}
}),

BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
Expand Down Expand Up @@ -853,6 +869,28 @@ pub fn create_physical_expr(
_ => unreachable!(),
},
},
BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_match,
i32,
"regexp_match"
);
make_scalar_function(func)(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_match,
i64,
"regexp_match"
);
make_scalar_function(func)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function regexp_match",
other
))),
},
BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
Expand Down Expand Up @@ -1229,6 +1267,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::NullIf => {
Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec())
}
BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![
Signature::Exact(vec![DataType::Utf8, DataType::Utf8]),
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
]),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
Expand Down Expand Up @@ -1386,7 +1430,7 @@ mod tests {
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array,
Int32Array, StringArray, UInt32Array, UInt64Array,
Int32Array, ListArray, StringArray, UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
Expand Down Expand Up @@ -3646,4 +3690,78 @@ mod tests {
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)
}

#[test]
#[cfg(feature = "regex_expressions")]
fn test_regexp_match() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

// concat(value, value)
let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"]));
let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
let columns: Vec<ArrayRef> = vec![col_value];
let expr = create_physical_expr(
&BuiltinScalarFunction::RegexpMatch,
&[col("a"), pattern],
&schema,
)?;

// type is correct
assert_eq!(
expr.data_type(&schema)?,
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
);

// evaluate works
let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());

// downcast works
let result = result.as_any().downcast_ref::<ListArray>().unwrap();
let first_row = result.value(0);
let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();

// value is correct
let expected = "555".to_string();
assert_eq!(first_row.value(0), expected);

Ok(())
}

#[test]
#[cfg(feature = "regex_expressions")]
fn test_regexp_match_all_literals() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

// concat(value, value)
let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string())));
let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
let expr = create_physical_expr(
&BuiltinScalarFunction::RegexpMatch,
&[col_value, pattern],
&schema,
)?;

// type is correct
assert_eq!(
expr.data_type(&schema)?,
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
);

// evaluate works
let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());

// downcast works
let result = result.as_any().downcast_ref::<ListArray>().unwrap();
let first_row = result.value(0);
let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();

// value is correct
let expected = "555".to_string();
assert_eq!(first_row.value(0), expected);

Ok(())
}
}
15 changes: 15 additions & 0 deletions rust/datafusion/src/physical_plan/regex_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::sync::Arc;

use crate::error::{DataFusionError, Result};
use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait};
use arrow::compute;
use hashbrown::HashMap;
use regex::Regex;

Expand All @@ -43,6 +44,20 @@ macro_rules! downcast_string_arg {
}};
}

/// extract a specific group from a string column, using a regular expression
pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None)
.map_err(DataFusionError::ArrowError),
3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T)))
.map_err(DataFusionError::ArrowError),
other => Err(DataFusionError::Internal(format!(
"regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
other
))),
}
}

/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1})
/// used by regexp_replace
fn regex_replace_posix_groups(replacement: &str) -> String {
Expand Down
6 changes: 5 additions & 1 deletion rust/datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ macro_rules! build_list {
for scalar_value in values {
match scalar_value {
ScalarValue::$SCALAR_TY(Some(v)) => {
builder.values().append_value(*v).unwrap()
builder.values().append_value(v.clone()).unwrap()
}
ScalarValue::$SCALAR_TY(None) => {
builder.values().append_null().unwrap();
Expand Down Expand Up @@ -335,6 +335,10 @@ impl ScalarValue {
DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size),
DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size),
DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size),
DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size),
DataType::LargeUtf8 => {
build_list!(LargeStringBuilder, LargeUtf8, values, size)
}
_ => panic!("Unexpected DataType for list"),
}),
ScalarValue::Date32(e) => match e {
Expand Down
11 changes: 11 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2560,6 +2560,17 @@ async fn test_in_list_scalar() -> Result<()> {
test_expression!("'2' IN ('a','b',NULL,1)", "NULL");
test_expression!("'1' NOT IN ('a','b',NULL,1)", "false");
test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL");
test_expression!("regexp_match('foobarbequebaz', '')", "[]");
test_expression!(
"regexp_match('foobarbequebaz', '(bar)(beque)')",
"[bar, beque]"
);
test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL");
test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]");
test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]");
test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL");
test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL");
test_expression!("regexp_match('aaa-0', NULL)", "NULL");
Ok(())
}

Expand Down

0 comments on commit 59bb69b

Please sign in to comment.