Skip to content

Commit

Permalink
Use enum for value (#8)
Browse files Browse the repository at this point in the history
* Use enum for arm IDs

* update

* clippy

* msrv 1.65

* clippy
  • Loading branch information
vbkaisetsu authored Sep 17, 2023
1 parent b54f6ff commit b1b0969
Showing 4 changed files with 103 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ jobs:
strategy:
matrix:
rust:
- 1.60.0 # MSRV
- 1.65.0 # MSRV
- stable
- nightly
steps:
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
name = "trie-match"
version = "0.1.1"
edition = "2021"
rust-version = "1.60"
rust-version = "1.65"
authors = [
"Koichi Akabe <vbkaisetsu@gmail.com>",
]
110 changes: 72 additions & 38 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -41,9 +41,10 @@ extern crate proc_macro;
use std::collections::HashMap;

use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, spanned::Spanned, Arm, Error, ExprLit, ExprMatch, Lit, Pat, PatOr, PatWild,
parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr,
PatWild,
};

use crate::trie::Sparse;
@@ -107,13 +108,16 @@ fn retrieve_match_patterns(pat: &Pat) -> Result<Vec<Option<String>>, Error> {
Ok(pats)
}

fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
let ExprMatch {
attrs, expr, arms, ..
} = input;
let mut map = HashMap::new();
struct MatchInfo {
bodies: Vec<Expr>,
pattern_map: HashMap<String, usize>,
wildcard_idx: usize,
}

fn parse_match_arms(arms: Vec<Arm>) -> Result<MatchInfo, Error> {
let mut pattern_map = HashMap::new();
let mut wildcard_idx = None;
let mut built_arms = vec![];
let mut bodies = vec![];
for (
i,
Arm {
@@ -132,62 +136,92 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
return Err(Error::new(if_token.span(), "match guard not supported"));
}
let pat_strs = retrieve_match_patterns(&pat)?;
let i = u32::try_from(i).unwrap();
for pat_str in pat_strs {
if let Some(pat_str) = pat_str {
if map.contains_key(&pat_str) {
if pattern_map.contains_key(&pat_str) {
return Err(Error::new(pat.span(), "unreachable pattern"));
}
map.insert(pat_str, i);
pattern_map.insert(pat_str, i);
} else {
if wildcard_idx.is_some() {
return Err(Error::new(pat.span(), "unreachable pattern"));
}
wildcard_idx.replace(i);
}
}
built_arms.push(quote! { #i => #body });
bodies.push(*body);
}
if wildcard_idx.is_none() {
let Some(wildcard_idx) = wildcard_idx else {
return Err(Error::new(
Span::call_site(),
"non-exhaustive patterns: `_` not covered",
));
}
let wildcard_idx = wildcard_idx.unwrap();
};
Ok(MatchInfo {
bodies,
pattern_map,
wildcard_idx,
})
}

fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
let ExprMatch {
attrs, expr, arms, ..
} = input;

let MatchInfo {
bodies,
pattern_map,
wildcard_idx,
} = parse_match_arms(arms)?;

let mut trie = Sparse::new();
for (k, v) in map {
for (k, v) in pattern_map {
if v == wildcard_idx {
continue;
}
trie.add(k, v);
}
let (bases, out_checks) = trie.build_double_array_trie(wildcard_idx);
let (bases, checks, outs) = trie.build_double_array_trie(wildcard_idx);

let base = bases.iter();
let out_check = out_checks.iter();
let arm = built_arms.iter();
let out_check = outs.iter().zip(checks).map(|(out, check)| {
let out = format_ident!("V{out}");
quote! { (__TrieMatchValue::#out, #check) }
});
let arm = bodies.iter().enumerate().map(|(i, body)| {
let i = format_ident!("V{i}");
quote! { __TrieMatchValue::#i => #body }
});
let attr = attrs.iter();
let enumvalue = (0..bodies.len()).map(|i| format_ident!("V{i}"));
let wildcard_ident = format_ident!("V{wildcard_idx}");
Ok(quote! {
#( #attr )*
match (|query: &str| unsafe {
let bases: &'static [i32] = &[ #( #base, )* ];
let out_checks: &'static [u32] = &[ #( #out_check, )* ];
let mut pos = 0;
for &b in query.as_bytes() {
let base = *bases.get_unchecked(pos);
pos = base.wrapping_add(i32::from(b)) as usize;
if let Some(out_check) = out_checks.get(pos) {
if out_check & 0xff == u32::from(b) {
continue;
{
#[derive(Clone, Copy)]
enum __TrieMatchValue {
#( #enumvalue, )*
}
#( #attr )*
match (|query: &str| unsafe {
let bases: &'static [i32] = &[ #( #base, )* ];
let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ];
let mut pos = 0;
let mut base = bases[0];
for &b in query.as_bytes() {
pos = base.wrapping_add(i32::from(b)) as usize;
if let Some((_, check)) = out_checks.get(pos) {
if *check == b {
base = *bases.get_unchecked(pos);
continue;
}
}
return __TrieMatchValue::#wildcard_ident;
}
return #wildcard_idx;
out_checks.get_unchecked(pos).0
})( #expr ) {
#( #arm, )*
}
*out_checks.get_unchecked(pos) >> 8
})( #expr ) {
#( #arm, )*
// Safety: A query always matches one of the patterns because
// all patterns in the input match's AST are expanded. (Even
// mismatched cases are always captured by wildcard_idx.)
_ => unsafe { std::hint::unreachable_unchecked() },
}
})
}
45 changes: 29 additions & 16 deletions src/trie.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
use std::collections::{BTreeMap, HashSet};

#[derive(Default, Debug)]
struct State {
#[derive(Debug)]
struct State<T> {
edges: BTreeMap<u8, usize>,
value: Option<u32>,
value: Option<T>,
}

impl<T> Default for State<T> {
fn default() -> Self {
Self {
edges: BTreeMap::default(),
value: None,
}
}
}

/// Sparse trie.
pub struct Sparse {
states: Vec<State>,
pub struct Sparse<T> {
states: Vec<State<T>>,
}

impl Sparse {
impl<T> Sparse<T> {
pub fn new() -> Self {
Self {
states: vec![State::default()],
}
}

/// Adds a new pattern.
pub fn add(&mut self, pattern: impl AsRef<[u8]>, value: u32) {
pub fn add(&mut self, pattern: impl AsRef<[u8]>, value: T) {
let pattern = pattern.as_ref();
let mut state_idx = 0;
for &b in pattern {
@@ -35,7 +44,7 @@ impl Sparse {
fn find_base(
search_start: i32,
is_used: &[bool],
state: &State,
state: &State<T>,
used_bases: &HashSet<i32>,
) -> Option<i32> {
let (&k, _) = state.edges.iter().next()?;
@@ -68,19 +77,22 @@ impl Sparse {
///
/// # Returns
///
/// The first item is a `base` array, and the second item is `out_check` array.
pub fn build_double_array_trie(&self, wildcard_idx: u32) -> (Vec<i32>, Vec<u32>) {
/// A tuple of a base array, a check array, and a value array.
pub fn build_double_array_trie(&self, wildcard_value: T) -> (Vec<i32>, Vec<u8>, Vec<T>)
where
T: Copy,
{
let mut bases = vec![i32::MAX];
let mut out_checks = vec![wildcard_idx << 8];
let mut checks = vec![0];
let mut values = vec![wildcard_value];
let mut is_used = vec![true];
let mut stack = vec![(0, 0)];
let mut used_bases = HashSet::new();
let mut search_start = 0;
while let Some((state_id, da_pos)) = stack.pop() {
let state = &self.states[state_id];
if let Some(val) = state.value {
let check = out_checks[da_pos] & 0xff;
out_checks[da_pos] = val << 8 | check;
values[da_pos] = val;
}
for &u in &is_used[usize::try_from(search_start).unwrap()..] {
if !u {
@@ -95,15 +107,16 @@ impl Sparse {
let child_da_pos = usize::try_from(base + i32::from(k)).unwrap();
if child_da_pos >= bases.len() {
bases.resize(child_da_pos + 1, i32::MAX);
out_checks.resize(child_da_pos + 1, wildcard_idx << 8);
checks.resize(child_da_pos + 1, 0);
values.resize(child_da_pos + 1, wildcard_value);
is_used.resize(child_da_pos + 1, false);
}
out_checks[child_da_pos] = wildcard_idx << 8 | u32::from(k);
checks[child_da_pos] = k;
is_used[child_da_pos] = true;
stack.push((v, child_da_pos));
}
}
}
(bases, out_checks)
(bases, checks, values)
}
}

0 comments on commit b1b0969

Please sign in to comment.