Skip to content

Commit

Permalink
better handling of nested grammars
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 28, 2024
1 parent 5a50242 commit 6164157
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 47 deletions.
9 changes: 9 additions & 0 deletions controllers/aici_abi/src/svob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,15 @@ impl SimpleVob {
}
}

/// self |= other & !minus
pub fn or_minus(&mut self, other: &SimpleVob, minus: &SimpleVob) {
assert_eq!(self.size, other.size);
assert_eq!(self.size, minus.size);
for ((slf, oth), mn) in self.data.iter_mut().zip(other.data.iter()).zip(minus.data.iter()) {
*slf |= *oth & !*mn;
}
}

pub fn and(&mut self, other: &SimpleVob) {
assert_eq!(self.size, other.size);
for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
Expand Down
13 changes: 13 additions & 0 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ def character_maker2(lm, id, description, valid_weapons):
)
)

grm = "6 * 7 = " + greedy_grammar(
body = lexeme("[0-9]{1,3}")
) + "\n"
# assert grm.match("6 * 7 = 42\n")

grm = (
"Dolphin name: "
+ commit_point(
'"' + byte_range(b"A", b"Z") + one_or_more(byte_range(b"a", b"z")) + '"'
)
+ ","
)


# g = zero_or_more("a") + "b"
# assert g.match("b")
Expand Down
110 changes: 70 additions & 40 deletions controllers/llguidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ pub struct Parser {
token_idx: usize,
byte_idx: usize,
options: GenGrammarOptions,
trie_gen_grammar: Option<CSymIdx>,
}

impl Scratch {
Expand Down Expand Up @@ -342,6 +343,7 @@ impl Parser {
token_idx: 0,
byte_idx: 0,
options,
trie_gen_grammar: None,
lexer_stack: vec![LexerState {
row_idx: 0,
lexer_state,
Expand Down Expand Up @@ -369,6 +371,13 @@ impl Parser {
Ok(r)
}

pub fn compute_bias_after_gen_grammar(&mut self, trie: &TokTrie, symidx: CSymIdx) -> SimpleVob {
self.trie_gen_grammar = Some(symidx);
let r = self.compute_bias(trie, &[]);
assert!(self.trie_gen_grammar.is_none());
r
}

pub fn compute_bias(&mut self, trie: &TokTrie, start: &[u8]) -> SimpleVob {
let mut set = trie.alloc_token_set();

Expand Down Expand Up @@ -740,13 +749,34 @@ impl Parser {
}
}

pub fn is_accepting(&mut self) -> bool {
self.trie_started();
let r = self.flush_lexer() && self.row_is_accepting();
self.trie_finished();
fn trie_started_inner(&mut self) {
// debug!("trie_started: rows={} lexer={}", self.num_rows(), self.lexer_stack.len());
self.assert_definitive();
self.trie_lexer_stack = self.lexer_stack.len();
self.scratch.definitive = false;
}

fn trie_finished_inner(&mut self) {
// debug!("trie_finished: rows={} lexer={}", self.num_rows(), self.lexer_stack.len());
assert!(self.scratch.definitive == false);
assert!(self.row_infos.len() <= self.num_rows());
// clean up stack
self.pop_lexer_states(self.lexer_stack.len() - self.trie_lexer_stack);
self.scratch.definitive = true;
self.assert_definitive();
}

fn run_speculative<T>(&mut self, f: impl FnOnce(&mut Self) -> T) -> T {
self.trie_started_inner();
let r = f(self);
self.trie_finished_inner();
r
}

pub fn is_accepting(&mut self) -> bool {
self.run_speculative(|s| s.flush_lexer() && s.row_is_accepting())
}

pub fn try_push_byte_definitive(&mut self, byte: Option<u8>) -> bool {
assert!(self.scratch.definitive);

Expand Down Expand Up @@ -783,19 +813,19 @@ impl Parser {
}

pub fn model_variables(&mut self) -> Vec<ModelVariable> {
self.trie_started();
let mut vars = vec![];
if self.flush_lexer() {
for sym_data in self.after_dots_symdata() {
if let Some(ref mv) = sym_data.props.model_variable {
if !vars.contains(mv) {
vars.push(mv.clone());
self.run_speculative(|s| {
let mut vars = vec![];
if s.flush_lexer() {
for sym_data in s.after_dots_symdata() {
if let Some(ref mv) = sym_data.props.model_variable {
if !vars.contains(mv) {
vars.push(mv.clone());
}
}
}
}
}
self.trie_finished();
vars
vars
})
}

fn forced_byte(&mut self) -> Option<u8> {
Expand All @@ -806,23 +836,22 @@ impl Parser {

// self.print_row(self.num_rows() - 1);

let mut byte_sym = None;
self.trie_started();
for b in 0..=255 {
if self.try_push_byte(b) {
self.pop_bytes(1);
// debug!(" forced: {:?}", b as char);
if byte_sym.is_some() {
self.trie_finished();
// debug!(" forced multiple");
return None; // more than one option
} else {
byte_sym = Some(b);
self.run_speculative(|s| {
let mut byte_sym = None;
for b in 0..=255 {
if s.try_push_byte(b) {
s.pop_bytes(1);
// debug!(" forced: {:?}", b as char);
if byte_sym.is_some() {
// debug!(" forced multiple");
return None; // more than one option
} else {
byte_sym = Some(b);
}
}
}
}
self.trie_finished();
byte_sym
byte_sym
})
}

fn flush_lexer(&mut self) -> bool {
Expand Down Expand Up @@ -873,9 +902,18 @@ impl Parser {
Some((msg, res_idx.unwrap(), res.unwrap()))
}

fn flush_gen_grammar(&mut self) {
if let Some(idx) = self.trie_gen_grammar.take() {
self.scan_gen_grammar_inner(idx, vec![]);
}
}

pub fn scan_gen_grammar(&mut self, symidx: CSymIdx, inner_bytes: Vec<u8>) -> bool {
self.assert_definitive();
self.scan_gen_grammar_inner(symidx, inner_bytes)
}

fn scan_gen_grammar_inner(&mut self, symidx: CSymIdx, inner_bytes: Vec<u8>) -> bool {
debug!(" scan gen_grammar: {}", self.grammar.sym_name(symidx));

self.scratch.new_row(self.curr_row().last_item);
Expand Down Expand Up @@ -1424,20 +1462,12 @@ impl Recognizer for Parser {
}

fn trie_started(&mut self) {
// debug!("trie_started: rows={} lexer={}", self.num_rows(), self.lexer_stack.len());
self.assert_definitive();
self.trie_lexer_stack = self.lexer_stack.len();
self.scratch.definitive = false;
self.trie_started_inner();
self.flush_gen_grammar();
}

fn trie_finished(&mut self) {
// debug!("trie_finished: rows={} lexer={}", self.num_rows(), self.lexer_stack.len());
assert!(self.scratch.definitive == false);
assert!(self.row_infos.len() <= self.num_rows());
// clean up stack
self.pop_lexer_states(self.lexer_stack.len() - self.trie_lexer_stack);
self.scratch.definitive = true;
self.assert_definitive();
self.trie_finished_inner();
}

#[inline(always)]
Expand Down
63 changes: 57 additions & 6 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
};
use aici_abi::{MidProcessArg, MidProcessResult, TokenId, TokenizerEnv};
use anyhow::Result;
use derivre::SimpleVob;
use serde_json::json;

macro_rules! infoln {
Expand All @@ -31,6 +32,8 @@ pub struct TokenParser {
pub parser: Parser,
pub log_level: isize,
pub mid_process_start_time: std::time::Instant,
// sampling any of these will pop the parser stack:
pop_tokens: Option<SimpleVob>,
test_trace: bool,
parser_stack: Vec<ParserStackEntry>,
parser_llm_tokens_offset: usize,
Expand All @@ -56,6 +59,8 @@ struct ParserStackEntry {
previous_grm_bytes_len: usize,
symidx: CSymIdx,
max_tokens_offset: usize,
mask: Option<SimpleVob>,
is_accepting: bool,
}

impl TokenParser {
Expand All @@ -79,6 +84,7 @@ impl TokenParser {
token_env,
mid_process_start_time,
mid_process_was_accepting: false,
pop_tokens: None,
parser,
parser_llm_tokens_offset: 0,
parser_stack: Vec::new(),
Expand Down Expand Up @@ -248,6 +254,23 @@ impl TokenParser {
trie.tokens_dbg(&arg.tokens)
);

if arg.tokens.len() == 1 {
if let Some(pop) = &self.pop_tokens {
if pop.is_allowed(arg.tokens[0]) {
infoln!(self, "pop_tokens hit: {}", trie.token_set_dbg(pop));
let pentry = self.parser_stack.last().unwrap();
// if the top of parse stack allows this token, we should stop
// popping parsers in the next iteration - clear pop_tokens
if pentry.mask.as_ref().unwrap().is_allowed(arg.tokens[0]) {
self.pop_tokens = None;
}
self.pop_parser();
return self.mid_process_inner(arg);
}
}
}
self.pop_tokens = None;

let mut has_eos = false;

if arg.tokens.contains(&trie.eos_token()) {
Expand Down Expand Up @@ -389,7 +412,7 @@ impl TokenParser {
}
}

let inner_done = {
let (inner_done, inner_accepting) = {
let empty_token_prefix = token_prefix.is_empty();
let lexer_bytes = self.parser.has_pending_lexeme_bytes();
let is_accepting = self.parser.is_accepting();
Expand All @@ -402,20 +425,20 @@ impl TokenParser {
accept: {is_accepting}; \
empty_token_prefix: {empty_token_prefix}"
);
self.mid_process_was_accepting =
is_accepting && empty_token_prefix && self.parser_stack.is_empty();
inner_done
let inner_accepting = is_accepting && empty_token_prefix;
(inner_done, inner_accepting)
};

let trie = self.token_env.tok_trie();
// self.parser.print_row(self.parser.num_rows() - 1);
let set = self.parser.compute_bias(trie, &token_prefix);
let mut set = self.parser.compute_bias(trie, &token_prefix);

if inner_done
|| self.max_tokens_parser == 0
|| (set.num_set() == 1 && set.is_allowed(trie.eos_token()))
{
if self.parser_stack.is_empty() {
self.mid_process_was_accepting = inner_accepting;
infoln!(self, "only eos token allowed, stopping");
return MidProcessResult::stop();
} else {
Expand All @@ -430,6 +453,31 @@ impl TokenParser {
}
}

if inner_accepting {
let mut all_accepting = true;
let mut pop_tokens = trie.alloc_token_set();
for pentry in self.parser_stack.iter_mut() {
if pentry.mask.is_none() {
assert!(token_prefix.is_empty());
let mask = pentry
.parser
.compute_bias_after_gen_grammar(trie, pentry.symidx);
infoln!(self, "bias for upper parser: {}", trie.token_set_dbg(&mask));
pentry.mask = Some(mask);
}
let m = pentry.mask.as_ref().unwrap();
pop_tokens.or_minus(m, &set);
set.or(m);
if !pentry.is_accepting {
all_accepting = false;
break;
}
}
infoln!(self, "pop_tokens: {}", trie.token_set_dbg(&pop_tokens));
self.pop_tokens = Some(pop_tokens);
self.mid_process_was_accepting = all_accepting;
}

infoln!(
self,
"bias: (pref: {:?}) {:?} {}",
Expand All @@ -454,14 +502,17 @@ impl TokenParser {
let grm = Arc::clone(&self.compiled_grammars[gen_grammar.grammar.0]);
let max_tokens = self.parser.grammar().sym_data(symidx).props.max_tokens;
let parser = Parser::new(grm, gen_grammar)?;
let old_parser = std::mem::replace(&mut self.parser, parser);
let mut old_parser = std::mem::replace(&mut self.parser, parser);
self.parser.stats = old_parser.stats.clone();
let is_accepting = old_parser.is_accepting();
let mut entry = ParserStackEntry {
parser: old_parser,
parser_llm_tokens_offset: self.parser_llm_tokens_offset,
previous_grm_bytes_len: self.previous_grm_bytes.len(),
symidx,
max_tokens_offset: self.max_tokens_total.saturating_sub(self.max_tokens_parser),
mask: None,
is_accepting,
};
self.max_tokens_parser = std::cmp::min(self.max_tokens_parser, max_tokens);
self.parser_llm_tokens_offset = self.llm_tokens.len();
Expand Down
2 changes: 1 addition & 1 deletion py/guidance
Submodule guidance updated 1 files
+16 −1 tests/unit/test_ll.py

0 comments on commit 6164157

Please sign in to comment.