Skip to content

Commit

Permalink
Embed wildcard index to the double-array (#6)
Browse files Browse the repository at this point in the history
* Embed wildcard index to the double-array

* clippy
  • Loading branch information
vbkaisetsu authored Sep 15, 2023
1 parent 647f55e commit b54f6ff
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
17 changes: 9 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ use syn::{

use crate::trie::Sparse;

/// Retrieves pattern strings from the given token.
///
/// None indicates a wild card pattern (`_`).
fn retrieve_match_patterns(pat: &Pat) -> Result<Vec<Option<String>>, Error> {
let mut pats = vec![];
match pat {
Expand Down Expand Up @@ -156,7 +159,7 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
for (k, v) in map {
trie.add(k, v);
}
let (bases, out_checks) = trie.build_double_array_trie();
let (bases, out_checks) = trie.build_double_array_trie(wildcard_idx);

let base = bases.iter();
let out_check = out_checks.iter();
Expand All @@ -178,15 +181,13 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
}
return #wildcard_idx;
}
let out = *out_checks.get_unchecked(pos) >> 8;
if out != 0xffffff {
out
} else {
#wildcard_idx
}
*out_checks.get_unchecked(pos) >> 8
})( #expr ) {
#( #arm, )*
_ => unreachable!(),
// Safety: A query always matches one of the patterns because
// all patterns in the input match's AST are expanded. (Even
// mismatched cases are always captured by wildcard_idx.)
_ => unsafe { std::hint::unreachable_unchecked() },
}
})
}
Expand Down
23 changes: 18 additions & 5 deletions src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ struct State {
value: Option<u32>,
}

/// Sparse trie.
pub struct Sparse {
states: Vec<State>,
}
Expand All @@ -17,6 +18,7 @@ impl Sparse {
}
}

/// Adds a new pattern.
pub fn add(&mut self, pattern: impl AsRef<[u8]>, value: u32) {
let pattern = pattern.as_ref();
let mut state_idx = 0;
Expand Down Expand Up @@ -57,17 +59,28 @@ impl Sparse {
Some(base_cand)
}

pub fn build_double_array_trie(&self) -> (Vec<i32>, Vec<u32>) {
/// Builds a compact double-array.
///
/// # Arguments
///
/// * `wildcard_idx` - A wild card index that is used for invalid state. This value is returned
/// if the query matches no pattern.
///
/// # Returns
///
/// The first item is a `base` array, and the second item is `out_check` array.
pub fn build_double_array_trie(&self, wildcard_idx: u32) -> (Vec<i32>, Vec<u32>) {
let mut bases = vec![i32::MAX];
let mut out_checks = vec![u32::MAX];
let mut out_checks = vec![wildcard_idx << 8];
let mut is_used = vec![true];
let mut stack = vec![(0, 0)];
let mut used_bases = HashSet::new();
let mut search_start = 0;
while let Some((state_id, da_pos)) = stack.pop() {
let state = &self.states[state_id];
if let Some(val) = state.value {
out_checks[da_pos] &= val << 8 | 0xff;
let check = out_checks[da_pos] & 0xff;
out_checks[da_pos] = val << 8 | check;
}
for &u in &is_used[usize::try_from(search_start).unwrap()..] {
if !u {
Expand All @@ -82,10 +95,10 @@ impl Sparse {
let child_da_pos = usize::try_from(base + i32::from(k)).unwrap();
if child_da_pos >= bases.len() {
bases.resize(child_da_pos + 1, i32::MAX);
out_checks.resize(child_da_pos + 1, u32::MAX);
out_checks.resize(child_da_pos + 1, wildcard_idx << 8);
is_used.resize(child_da_pos + 1, false);
}
out_checks[child_da_pos] = u32::MAX << 8 | u32::from(k);
out_checks[child_da_pos] = wildcard_idx << 8 | u32::from(k);
is_used[child_da_pos] = true;
stack.push((v, child_da_pos));
}
Expand Down

0 comments on commit b54f6ff

Please sign in to comment.