From 5b4e52feeaef59905093159c71d7829c9f929840 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 18 Sep 2023 07:10:19 +0900 Subject: [PATCH] Support byte string patterns (#12) * Support `&[u8]` * clippy * Support reference * clippy * fix README * update doc * Update lib.rs --- README.md | 2 +- src/lib.rs | 167 +++++++++++++++++++++++++++++++++---------------- tests/tests.rs | 56 +++++++++++++++++ 3 files changed, 169 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index c7c1b47..d9a8ea4 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ to achieve efficient state-to-state traversal, and the time complexity becomes The followings are different from the normal `match` expression: -* Only supports string comparison. +* Only supports strings, byte strings, and u8 slices as patterns. * The wildcard is evaluated last. (The normal `match` expression does not match patterns after the wildcard.) * Pattern bindings are unavailable. diff --git a/src/lib.rs b/src/lib.rs index 540df57..2980ef1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ //! //! The followings are different from the normal `match` expression: //! -//! * Only supports string comparison. +//! * Only supports strings, byte strings, and u8 slices as patterns. //! * The wildcard is evaluated last. (The normal `match` expression does not //! match patterns after the wildcard.) //! * Pattern bindings are unavailable. @@ -44,65 +44,127 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr, - PatWild, + PatReference, PatSlice, PatWild, }; +static ERROR_UNEXPECTED_PATTERN: &str = + "`trie_match` only supports string literals, byte string literals, and u8 slices as patterns"; +static ERROR_ATTRIBUTE_NOT_SUPPORTED: &str = "attribute not supported here"; +static ERROR_GUARD_NOT_SUPPORTED: &str = "match guard not supported"; +static ERROR_UNREACHABLE_PATTERN: &str = "unreachable pattern"; +static ERROR_PATTERN_NOT_COVERED: &str = "non-exhaustive patterns: `_` not covered"; +static ERROR_EXPECTED_U8_LITERAL: &str = "expected `u8` integer literal"; + use crate::trie::Sparse; +/// Converts a literal pattern into a byte sequence. +fn convert_literal_pattern(pat: &ExprLit) -> Result>, Error> { + let ExprLit { attrs, lit } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + match lit { + Lit::Str(s) => Ok(Some(s.value().into())), + Lit::ByteStr(s) => Ok(Some(s.value())), + _ => Err(Error::new(lit.span(), ERROR_UNEXPECTED_PATTERN)), + } +} + +/// Converts a slice pattern into a byte sequence. +fn convert_slice_pattern(pat: &PatSlice) -> Result>, Error> { + let PatSlice { attrs, elems, .. } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + let mut result = vec![]; + for elem in elems { + match elem { + Pat::Lit(ExprLit { attrs, lit }) => { + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + match lit { + Lit::Int(i) => { + let int_type = i.suffix(); + if int_type != "u8" && !int_type.is_empty() { + return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL)); + } + result.push(i.base10_parse::()?); + } + Lit::Byte(b) => { + result.push(b.value()); + } + _ => { + return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL)); + } + } + } + _ => { + return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL)); + } + } + } + Ok(Some(result)) +} + +/// Checks a wildcard pattern and returns `None`. +/// +/// The reason the type is `Result>, Error>` instead of `Result<(), Error>` is for +/// consistency with other functions. +fn convert_wildcard_pattern(pat: &PatWild) -> Result>, Error> { + let PatWild { attrs, .. } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + Ok(None) +} + +/// Converts a reference pattern (e.g. `&[0, 1, ...]`) into a byte sequence. +fn convert_reference_pattern(pat: &PatReference) -> Result>, Error> { + let PatReference { attrs, pat, .. } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + match &**pat { + Pat::Lit(pat) => convert_literal_pattern(pat), + Pat::Slice(pat) => convert_slice_pattern(pat), + Pat::Reference(pat) => convert_reference_pattern(pat), + _ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)), + } +} + /// Retrieves pattern strings from the given token. /// /// None indicates a wild card pattern (`_`). -fn retrieve_match_patterns(pat: &Pat) -> Result>, Error> { +fn retrieve_match_patterns(pat: &Pat) -> Result>>, Error> { let mut pats = vec![]; match pat { - Pat::Lit(ExprLit { - lit: Lit::Str(s), - attrs, - }) => { - if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); - } - pats.push(Some(s.value())); - } + Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?), + Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?), + Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), + Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?), Pat::Or(PatOr { attrs, leading_vert: None, cases, }) => { if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); } for pat in cases { match pat { - Pat::Lit(ExprLit { - lit: Lit::Str(s), - attrs, - }) => { - if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); - } - pats.push(Some(s.value())); - } + Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?), + Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?), + Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), + Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?), _ => { - return Err(Error::new( - pat.span(), - "`trie_match` only supports string literal patterns", - )); + return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); } } } } - Pat::Wild(PatWild { attrs, .. }) => { - if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); - } - pats.push(None); - } _ => { - return Err(Error::new( - pat.span(), - "`trie_match` only supports string literal patterns", - )); + return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); } } Ok(pats) @@ -110,7 +172,7 @@ fn retrieve_match_patterns(pat: &Pat) -> Result>, Error> { struct MatchInfo { bodies: Vec, - pattern_map: HashMap, + pattern_map: HashMap, usize>, wildcard_idx: usize, } @@ -130,21 +192,21 @@ fn parse_match_arms(arms: Vec) -> Result { ) in arms.into_iter().enumerate() { if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); } if let Some((if_token, _)) = guard { - return Err(Error::new(if_token.span(), "match guard not supported")); + return Err(Error::new(if_token.span(), ERROR_GUARD_NOT_SUPPORTED)); } - let pat_strs = retrieve_match_patterns(&pat)?; - for pat_str in pat_strs { - if let Some(pat_str) = pat_str { - if pattern_map.contains_key(&pat_str) { - return Err(Error::new(pat.span(), "unreachable pattern")); + let pat_bytes_set = retrieve_match_patterns(&pat)?; + for pat_bytes in pat_bytes_set { + if let Some(pat_bytes) = pat_bytes { + if pattern_map.contains_key(&pat_bytes) { + return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN)); } - pattern_map.insert(pat_str, i); + pattern_map.insert(pat_bytes, i); } else { if wildcard_idx.is_some() { - return Err(Error::new(pat.span(), "unreachable pattern")); + return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN)); } wildcard_idx.replace(i); } @@ -152,10 +214,7 @@ fn parse_match_arms(arms: Vec) -> Result { bodies.push(*body); } let Some(wildcard_idx) = wildcard_idx else { - return Err(Error::new( - Span::call_site(), - "non-exhaustive patterns: `_` not covered", - )); + return Err(Error::new(Span::call_site(), ERROR_PATTERN_NOT_COVERED)); }; Ok(MatchInfo { bodies, @@ -168,13 +227,11 @@ fn trie_match_inner(input: ExprMatch) -> Result { let ExprMatch { attrs, expr, arms, .. } = input; - let MatchInfo { bodies, pattern_map, wildcard_idx, } = parse_match_arms(arms)?; - let mut trie = Sparse::new(); for (k, v) in pattern_map { if v == wildcard_idx { @@ -203,12 +260,12 @@ fn trie_match_inner(input: ExprMatch) -> Result { #( #enumvalue, )* } #( #attr )* - match (|query: &str| unsafe { + match (|query: &[u8]| unsafe { let bases: &'static [i32] = &[ #( #base, )* ]; let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ]; let mut pos = 0; let mut base = bases[0]; - for &b in query.as_bytes() { + for &b in query { pos = base.wrapping_add(i32::from(b)) as usize; if let Some((_, check)) = out_checks.get(pos) { if *check == b { @@ -219,7 +276,7 @@ fn trie_match_inner(input: ExprMatch) -> Result { return __TrieMatchValue::#wildcard_ident; } out_checks.get_unchecked(pos).0 - })( #expr ) { + })( ::std::convert::AsRef::<[u8]>::as_ref( #expr ) ) { #( #arm, )* } } diff --git a/tests/tests.rs b/tests/tests.rs index 261ce8c..2b0378c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -144,3 +144,59 @@ fn test_invalid_root_check_of_zero() { }; assert_eq!(f("\u{0}\u{1}"), 0); } + +#[test] +fn test_bytes_literal() { + let f = |text: &[u8]| { + trie_match! { + match text { + b"abc" => 0, + _ => 1, + } + } + }; + assert_eq!(f(b"abc"), 0); + assert_eq!(f(b"ab"), 1); +} + +#[test] +fn test_slice_byte_literal() { + let f = |text: &[u8]| { + trie_match! { + match text { + [b'a', b'b', b'c'] => 0, + _ => 1, + } + } + }; + assert_eq!(f(b"abc"), 0); + assert_eq!(f(b"ab"), 1); +} + +#[test] +fn test_slice_numbers() { + let f = |text: &[u8]| { + trie_match! { + match text { + [0, 1, 2] => 0, + _ => 1, + } + } + }; + assert_eq!(f(&[0, 1, 2]), 0); + assert_eq!(f(&[0, 1]), 1); +} + +#[test] +fn test_slice_ref_numbers() { + let f = |text: &[u8]| { + trie_match! { + match text { + &[0, 1, 2] => 0, + _ => 1, + } + } + }; + assert_eq!(f(&[0, 1, 2]), 0); + assert_eq!(f(&[0, 1]), 1); +}