From 59bb69b6a37fc01701fae10cab799bc9516a8c9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Thu, 1 Apr 2021 07:58:56 -0400 Subject: [PATCH] ARROW-10354: [Rust][DataFusion] regexp_extract function to select regex groups from strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a regexp_extract compute kernel to select a substring based on a regular expression. Some things I did that I may be doing wrong: * I exposed `GenericStringBuilder` * I build the resulting Array using a builder - this looks quite different from e.g. the substring kernel. Should I change it accordingly, e.g. because of performance considerations? * In order to apply the new function in datafusion, I did not see a better solution than to handle the pattern string as `StringArray` and take the first record to compile the regex pattern from it and apply it to all values. Is there a way to define that an argument has to be a literal/scalar and cannot be filled by e.g. another column? I consider my current implementation quite error prone and would like to make this a bit more robust. Closes #9428 from sweb/ARROW-10354/regexp_extract Authored-by: Florian Müller Signed-off-by: Andrew Lamb --- rust/datafusion/src/logical_plan/expr.rs | 1 + rust/datafusion/src/logical_plan/mod.rs | 8 +- .../datafusion/src/physical_plan/functions.rs | 122 +++++++++++++++++- .../src/physical_plan/regex_expressions.rs | 15 +++ rust/datafusion/src/scalar.rs | 6 +- rust/datafusion/tests/sql.rs | 11 ++ 6 files changed, 156 insertions(+), 7 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 314f5d477b3c..991b16058b1d 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1090,6 +1090,7 @@ unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(RegexpMatch, regexp_match); unary_scalar_expr!(RegexpReplace, regexp_replace); unary_scalar_expr!(Replace, replace); unary_scalar_expr!(Repeat, repeat); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0e7e61981b12..f9be1ff98300 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,10 +37,10 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, - strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, + octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, + starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, + Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 9dc54a4113fa..56365fec1dc8 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -198,6 +198,8 @@ pub enum BuiltinScalarFunction { Trim, /// upper Upper, + /// regexp_match + RegexpMatch, } impl fmt::Display for BuiltinScalarFunction { @@ -271,7 +273,7 @@ impl FromStr for BuiltinScalarFunction { "translate" => BuiltinScalarFunction::Translate, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, - + "regexp_match" => BuiltinScalarFunction::RegexpMatch, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -607,6 +609,20 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] { + DataType::LargeUtf8 => { + DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true))) + } + DataType::Utf8 => { + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + } + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_extract function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos @@ -853,6 +869,28 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other + ))), + }, BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_regex_expressions_feature_flag!( @@ -1229,6 +1267,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } + BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), + ]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -1386,7 +1430,7 @@ mod tests { use arrow::{ array::{ Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array, - Int32Array, StringArray, UInt32Array, UInt64Array, + Int32Array, ListArray, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -3646,4 +3690,78 @@ mod tests { "PrimitiveArray\n[\n 1,\n 1,\n]", ) } + + #[test] + #[cfg(feature = "regex_expressions")] + fn test_regexp_match() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + // concat(value, value) + let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"])); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let columns: Vec = vec![col_value]; + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpMatch, + &[col("a"), pattern], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + ); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + let first_row = result.value(0); + let first_row = first_row.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(first_row.value(0), expected); + + Ok(()) + } + + #[test] + #[cfg(feature = "regex_expressions")] + fn test_regexp_match_all_literals() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + // concat(value, value) + let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpMatch, + &[col_value, pattern], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + ); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + let first_row = result.value(0); + let first_row = first_row.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(first_row.value(0), expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs index 6482424e1053..b526e7259ef6 100644 --- a/rust/datafusion/src/physical_plan/regex_expressions.rs +++ b/rust/datafusion/src/physical_plan/regex_expressions.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait}; +use arrow::compute; use hashbrown::HashMap; use regex::Regex; @@ -43,6 +44,20 @@ macro_rules! downcast_string_arg { }}; } +/// extract a specific group from a string column, using a regular expression +pub fn regexp_match(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None) + .map_err(DataFusionError::ArrowError), + 3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T))) + .map_err(DataFusionError::ArrowError), + other => Err(DataFusionError::Internal(format!( + "regexp_match was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index f0c7acfce3eb..b2367758493e 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -115,7 +115,7 @@ macro_rules! build_list { for scalar_value in values { match scalar_value { ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(*v).unwrap() + builder.values().append_value(v.clone()).unwrap() } ScalarValue::$SCALAR_TY(None) => { builder.values().append_null().unwrap(); @@ -335,6 +335,10 @@ impl ScalarValue { DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), + DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), + DataType::LargeUtf8 => { + build_list!(LargeStringBuilder, LargeUtf8, values, size) + } _ => panic!("Unexpected DataType for list"), }), ScalarValue::Date32(e) => match e { diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 15234a812463..8c2c35ef6f09 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2560,6 +2560,17 @@ async fn test_in_list_scalar() -> Result<()> { test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); + test_expression!("regexp_match('foobarbequebaz', '')", "[]"); + test_expression!( + "regexp_match('foobarbequebaz', '(bar)(beque)')", + "[bar, beque]" + ); + test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL"); + test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]"); + test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]"); + test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL"); + test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL"); + test_expression!("regexp_match('aaa-0', NULL)", "NULL"); Ok(()) }