Skip to content

Commit

Permalink
Support byte string patterns (#12)
Browse files Browse the repository at this point in the history
* Support `&[u8]`

* clippy

* Support reference

* clippy

* fix README

* update doc

* Update lib.rs
  • Loading branch information
vbkaisetsu authored Sep 17, 2023
1 parent 4537eb7 commit 5b4e52f
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ to achieve efficient state-to-state traversal, and the time complexity becomes

The followings are different from the normal `match` expression:

* Only supports string comparison.
* Only supports strings, byte strings, and u8 slices as patterns.
* The wildcard is evaluated last. (The normal `match` expression does not
match patterns after the wildcard.)
* Pattern bindings are unavailable.
Expand Down
167 changes: 112 additions & 55 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
//!
//! The followings are different from the normal `match` expression:
//!
//! * Only supports string comparison.
//! * Only supports strings, byte strings, and u8 slices as patterns.
//! * The wildcard is evaluated last. (The normal `match` expression does not
//! match patterns after the wildcard.)
//! * Pattern bindings are unavailable.
Expand All @@ -44,73 +44,135 @@ use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{
parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr,
PatWild,
PatReference, PatSlice, PatWild,
};

static ERROR_UNEXPECTED_PATTERN: &str =
"`trie_match` only supports string literals, byte string literals, and u8 slices as patterns";
static ERROR_ATTRIBUTE_NOT_SUPPORTED: &str = "attribute not supported here";
static ERROR_GUARD_NOT_SUPPORTED: &str = "match guard not supported";
static ERROR_UNREACHABLE_PATTERN: &str = "unreachable pattern";
static ERROR_PATTERN_NOT_COVERED: &str = "non-exhaustive patterns: `_` not covered";
static ERROR_EXPECTED_U8_LITERAL: &str = "expected `u8` integer literal";

use crate::trie::Sparse;

/// Converts a literal pattern into a byte sequence.
fn convert_literal_pattern(pat: &ExprLit) -> Result<Option<Vec<u8>>, Error> {
let ExprLit { attrs, lit } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
match lit {
Lit::Str(s) => Ok(Some(s.value().into())),
Lit::ByteStr(s) => Ok(Some(s.value())),
_ => Err(Error::new(lit.span(), ERROR_UNEXPECTED_PATTERN)),
}
}

/// Converts a slice pattern into a byte sequence.
fn convert_slice_pattern(pat: &PatSlice) -> Result<Option<Vec<u8>>, Error> {
let PatSlice { attrs, elems, .. } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
let mut result = vec![];
for elem in elems {
match elem {
Pat::Lit(ExprLit { attrs, lit }) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
match lit {
Lit::Int(i) => {
let int_type = i.suffix();
if int_type != "u8" && !int_type.is_empty() {
return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL));
}
result.push(i.base10_parse::<u8>()?);
}
Lit::Byte(b) => {
result.push(b.value());
}
_ => {
return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
}
}
}
_ => {
return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
}
}
}
Ok(Some(result))
}

/// Checks a wildcard pattern and returns `None`.
///
/// The reason the type is `Result<Option<Vec<u8>>, Error>` instead of `Result<(), Error>` is for
/// consistency with other functions.
fn convert_wildcard_pattern(pat: &PatWild) -> Result<Option<Vec<u8>>, Error> {
let PatWild { attrs, .. } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
Ok(None)
}

/// Converts a reference pattern (e.g. `&[0, 1, ...]`) into a byte sequence.
fn convert_reference_pattern(pat: &PatReference) -> Result<Option<Vec<u8>>, Error> {
let PatReference { attrs, pat, .. } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
match &**pat {
Pat::Lit(pat) => convert_literal_pattern(pat),
Pat::Slice(pat) => convert_slice_pattern(pat),
Pat::Reference(pat) => convert_reference_pattern(pat),
_ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)),
}
}

/// Retrieves pattern strings from the given token.
///
/// None indicates a wild card pattern (`_`).
fn retrieve_match_patterns(pat: &Pat) -> Result<Vec<Option<String>>, Error> {
fn retrieve_match_patterns(pat: &Pat) -> Result<Vec<Option<Vec<u8>>>, Error> {
let mut pats = vec![];
match pat {
Pat::Lit(ExprLit {
lit: Lit::Str(s),
attrs,
}) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
}
pats.push(Some(s.value()));
}
Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?),
Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?),
Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?),
Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?),
Pat::Or(PatOr {
attrs,
leading_vert: None,
cases,
}) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
for pat in cases {
match pat {
Pat::Lit(ExprLit {
lit: Lit::Str(s),
attrs,
}) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
}
pats.push(Some(s.value()));
}
Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?),
Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?),
Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?),
Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?),
_ => {
return Err(Error::new(
pat.span(),
"`trie_match` only supports string literal patterns",
));
return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN));
}
}
}
}
Pat::Wild(PatWild { attrs, .. }) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
}
pats.push(None);
}
_ => {
return Err(Error::new(
pat.span(),
"`trie_match` only supports string literal patterns",
));
return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN));
}
}
Ok(pats)
}

struct MatchInfo {
bodies: Vec<Expr>,
pattern_map: HashMap<String, usize>,
pattern_map: HashMap<Vec<u8>, usize>,
wildcard_idx: usize,
}

Expand All @@ -130,32 +192,29 @@ fn parse_match_arms(arms: Vec<Arm>) -> Result<MatchInfo, Error> {
) in arms.into_iter().enumerate()
{
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
if let Some((if_token, _)) = guard {
return Err(Error::new(if_token.span(), "match guard not supported"));
return Err(Error::new(if_token.span(), ERROR_GUARD_NOT_SUPPORTED));
}
let pat_strs = retrieve_match_patterns(&pat)?;
for pat_str in pat_strs {
if let Some(pat_str) = pat_str {
if pattern_map.contains_key(&pat_str) {
return Err(Error::new(pat.span(), "unreachable pattern"));
let pat_bytes_set = retrieve_match_patterns(&pat)?;
for pat_bytes in pat_bytes_set {
if let Some(pat_bytes) = pat_bytes {
if pattern_map.contains_key(&pat_bytes) {
return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
}
pattern_map.insert(pat_str, i);
pattern_map.insert(pat_bytes, i);
} else {
if wildcard_idx.is_some() {
return Err(Error::new(pat.span(), "unreachable pattern"));
return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
}
wildcard_idx.replace(i);
}
}
bodies.push(*body);
}
let Some(wildcard_idx) = wildcard_idx else {
return Err(Error::new(
Span::call_site(),
"non-exhaustive patterns: `_` not covered",
));
return Err(Error::new(Span::call_site(), ERROR_PATTERN_NOT_COVERED));
};
Ok(MatchInfo {
bodies,
Expand All @@ -168,13 +227,11 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
let ExprMatch {
attrs, expr, arms, ..
} = input;

let MatchInfo {
bodies,
pattern_map,
wildcard_idx,
} = parse_match_arms(arms)?;

let mut trie = Sparse::new();
for (k, v) in pattern_map {
if v == wildcard_idx {
Expand Down Expand Up @@ -203,12 +260,12 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
#( #enumvalue, )*
}
#( #attr )*
match (|query: &str| unsafe {
match (|query: &[u8]| unsafe {
let bases: &'static [i32] = &[ #( #base, )* ];
let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ];
let mut pos = 0;
let mut base = bases[0];
for &b in query.as_bytes() {
for &b in query {
pos = base.wrapping_add(i32::from(b)) as usize;
if let Some((_, check)) = out_checks.get(pos) {
if *check == b {
Expand All @@ -219,7 +276,7 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
return __TrieMatchValue::#wildcard_ident;
}
out_checks.get_unchecked(pos).0
})( #expr ) {
})( ::std::convert::AsRef::<[u8]>::as_ref( #expr ) ) {
#( #arm, )*
}
}
Expand Down
56 changes: 56 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,59 @@ fn test_invalid_root_check_of_zero() {
};
assert_eq!(f("\u{0}\u{1}"), 0);
}

#[test]
fn test_bytes_literal() {
let f = |text: &[u8]| {
trie_match! {
match text {
b"abc" => 0,
_ => 1,
}
}
};
assert_eq!(f(b"abc"), 0);
assert_eq!(f(b"ab"), 1);
}

#[test]
fn test_slice_byte_literal() {
let f = |text: &[u8]| {
trie_match! {
match text {
[b'a', b'b', b'c'] => 0,
_ => 1,
}
}
};
assert_eq!(f(b"abc"), 0);
assert_eq!(f(b"ab"), 1);
}

#[test]
fn test_slice_numbers() {
let f = |text: &[u8]| {
trie_match! {
match text {
[0, 1, 2] => 0,
_ => 1,
}
}
};
assert_eq!(f(&[0, 1, 2]), 0);
assert_eq!(f(&[0, 1]), 1);
}

#[test]
fn test_slice_ref_numbers() {
let f = |text: &[u8]| {
trie_match! {
match text {
&[0, 1, 2] => 0,
_ => 1,
}
}
};
assert_eq!(f(&[0, 1, 2]), 0);
assert_eq!(f(&[0, 1]), 1);
}

0 comments on commit 5b4e52f

Please sign in to comment.