From ee14adf0992146c0341dcf193d57c5dbffd1ced7 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 22 Aug 2024 14:14:20 +0800 Subject: [PATCH] Add null tests --- datafusion/functions/src/regex/regexpcount.rs | 13 +++++- datafusion/sqllogictest/test_files/regexp.slt | 40 +++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8236fd673550..cc86c66c4792 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -131,6 +131,11 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { 2..=4 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; + + if values.is_empty() || regex.is_empty() { + return Ok(Arc::new(Int64Array::new_null(0))); + } + let regex_datum: &dyn Datum = if regex.len() != 1 { regex } else { @@ -139,7 +144,9 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { let start_scalar: Scalar<&Int64Array>; let start_array_datum: Option<&dyn Datum> = if arg_len > 2 { let start_array = as_primitive_array::(&args[2])?; - if start_array.len() != 1 { + if start_array.is_empty() { + None + } else if start_array.len() != 1 { Some(start_array as &dyn Datum) } else { start_scalar = Scalar::new(start_array); @@ -152,7 +159,9 @@ pub fn regexp_count(args: &[ArrayRef]) -> Result { let flags_scalar: Scalar<&GenericStringArray>; let flags_array_datum: Option<&dyn Datum> = if arg_len > 3 { let flags_array = as_generic_string_array::(&args[3])?; - if flags_array.len() != 1 { + if flags_array.is_empty() { + None + } else if flags_array.len() != 1 { Some(flags_array as &dyn Datum) } else { flags_scalar = Scalar::new(flags_array); diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 1d2cc90d9bd0..1c86ea5e6be5 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -614,6 +614,46 @@ SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), 1 1 +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + statement ok drop table t;