diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index d0a7c73ae3d5..dcecdb6743ff 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -49,6 +49,7 @@ half = { version = "2.0", default-features = false } hashbrown = { version = "0.12", default-features = false } csv_crate = { version = "1.1", default-features = false, optional = true, package="csv" } regex = { version = "1.5.6", default-features = false, features = ["std", "unicode"] } +regex-syntax = { version = "0.6.27", default-features = false, features = ["unicode"] } lazy_static = { version = "1.4", default-features = false } packed_simd = { version = "0.3", default-features = false, optional = true, package = "packed_simd_2" } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 7733ce67a76e..e4187ef87155 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -35,7 +35,7 @@ use crate::datatypes::{ }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use regex::{escape, Regex}; +use regex::Regex; use std::collections::HashMap; /// Helper function to perform boolean lambda function on values from two array accessors, this @@ -169,7 +169,7 @@ where let re = if let Some(ref regex) = map.get(pat) { regex } else { - let re_pattern = escape(pat).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(pat)?; let re = op(&re_pattern)?; map.insert(pat, re); map.get(pat).unwrap() @@ -248,7 +248,9 @@ pub fn like_utf8_scalar( bit_util::set_bit(bool_slice, i); } } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use starts_with let starts_with = &right[..right.len() - 1]; @@ -266,7 +268,7 @@ pub fn like_utf8_scalar( } } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -296,6 +298,43 @@ pub fn like_utf8_scalar( Ok(BooleanArray::from(data)) } +/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does: +/// +/// 1. Replace like wildcards for regex expressions as the pattern will be evaluated using regex match: `%` => `.*` and `_` => `.` +/// 2. Escape regex meta characters to match them and not be evaluated as regex special chars. For example: `.` => `\\.` +/// 3. Replace escaped like wildcards removing the escape characters to be able to match it as a regex. For example: `\\%` => `%` +fn replace_like_wildcards(pattern: &str) -> Result { + let mut result = String::new(); + let pattern = String::from(pattern); + let mut chars_iter = pattern.chars().peekable(); + while let Some(c) = chars_iter.next() { + if c == '\\' { + let next = chars_iter.peek(); + match next { + Some(next) if is_like_pattern(*next) => { + result.push(*next); + // Skipping the next char as it is already appended + chars_iter.next(); + } + _ => { + result.push('\\'); + result.push('\\'); + } + } + } else if regex_syntax::is_meta_character(c) { + result.push('\\'); + result.push(c); + } else if c == '%' { + result.push_str(".*"); + } else if c == '_' { + result.push('.'); + } else { + result.push(c); + } + } + Ok(result) +} + /// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / /// [`LargeStringArray`]. /// @@ -330,7 +369,9 @@ pub fn nlike_utf8_scalar( for i in 0..left.len() { result.append(left.value(i) != right); } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use ends_with for i in 0..left.len() { @@ -342,7 +383,7 @@ pub fn nlike_utf8_scalar( result.append(!left.value(i).ends_with(&right[1..])); } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -403,7 +444,9 @@ pub fn ilike_utf8_scalar( for i in 0..left.len() { result.append(left.value(i) == right); } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use ends_with for i in 0..left.len() { @@ -423,7 +466,7 @@ pub fn ilike_utf8_scalar( ); } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", @@ -484,7 +527,9 @@ pub fn nilike_utf8_scalar( for i in 0..left.len() { result.append(left.value(i) != right); } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use ends_with for i in 0..left.len() { @@ -506,7 +551,7 @@ pub fn nilike_utf8_scalar( ); } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", @@ -3740,6 +3785,50 @@ mod tests { vec![false, true, false, false] ); + test_utf8_scalar!( + test_utf8_scalar_like_escape, + vec!["a%", "a\\x"], + "a\\%", + like_utf8_scalar, + vec![true, false] + ); + + test_utf8!( + test_utf8_scalar_ilike_regex, + vec!["%%%"], + vec![r#"\%_\%"#], + ilike_utf8, + vec![true] + ); + + #[test] + fn test_replace_like_wildcards() { + let a_eq = "_%"; + let expected = "..*"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_leave_like_meta_chars() { + let a_eq = "\\%\\_"; + let expected = "%_"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_with_multiple_escape_chars() { + let a_eq = "\\\\%"; + let expected = "\\\\%"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_escape_regex_meta_char() { + let a_eq = "."; + let expected = "\\."; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + test_utf8!( test_utf8_array_eq, vec!["arrow", "arrow", "arrow", "arrow"],