Skip to content

Commit

Permalink
fix accepting state at EOS
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 29, 2024
1 parent 0acf01b commit 69788eb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
9 changes: 8 additions & 1 deletion controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ def character_maker2(lm, id, description, valid_weapons):

grm = "123" + gen(name="numbers", regex=r"\d*233", max_tokens=5)

grm = greedy_grammar(body=lexeme("[0-9]+"),skip_regex=r"\s*") + "x"

grm = "Here: 2 + 2 = " + guidance.json(name="num", schema={"type": "integer"})
# grm = guidance.json(name="num", schema={"type": "integer"})
# m = grm.match("123<s>")
# print(m)
# assert m["num"] == "123"

# grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3)

Expand All @@ -225,7 +232,7 @@ def character_maker2(lm, id, description, valid_weapons):
# body = lexeme("[0-9]+")
# )

max_tokens = 30
max_tokens = 7

serialized = grm.ll_serialize()

Expand Down
13 changes: 10 additions & 3 deletions controllers/llguidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ pub struct Parser {
byte_idx: usize,
options: GenGrammarOptions,
trie_gen_grammar: Option<CSymIdx>,
trie_gen_grammar_accepting: bool,
}

impl Scratch {
Expand Down Expand Up @@ -345,6 +346,7 @@ impl Parser {
byte_idx: 0,
options,
trie_gen_grammar: None,
trie_gen_grammar_accepting: false,
lexer_stack: vec![LexerState {
row_idx: 0,
lexer_state,
Expand Down Expand Up @@ -372,11 +374,15 @@ impl Parser {
Ok(r)
}

pub fn compute_bias_after_gen_grammar(&mut self, trie: &TokTrie, symidx: CSymIdx) -> SimpleVob {
pub fn compute_bias_after_gen_grammar(
&mut self,
trie: &TokTrie,
symidx: CSymIdx,
) -> (bool, SimpleVob) {
self.trie_gen_grammar = Some(symidx);
let r = self.compute_bias(trie, &[]);
assert!(self.trie_gen_grammar.is_none());
r
(self.trie_gen_grammar_accepting, r)
}

pub fn compute_bias(&mut self, trie: &TokTrie, start: &[u8]) -> SimpleVob {
Expand Down Expand Up @@ -931,7 +937,8 @@ impl Parser {

fn flush_gen_grammar(&mut self) {
if let Some(idx) = self.trie_gen_grammar.take() {
self.scan_gen_grammar_inner(idx, vec![]);
let r = self.scan_gen_grammar_inner(idx, vec![]);
self.trie_gen_grammar_accepting = r && self.row_is_accepting();
}
}

Expand Down
16 changes: 9 additions & 7 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ impl TokenParser {
None
};

infoln!(self, "\n");

let r = self.mid_process_inner(arg);

if self.test_trace {
Expand Down Expand Up @@ -244,7 +246,6 @@ impl TokenParser {

self.mid_process_was_accepting = false;

infoln!(self, "\n");
let trie = self.token_env.tok_trie();

infoln!(
Expand Down Expand Up @@ -439,7 +440,7 @@ impl TokenParser {
{
if self.parser_stack.is_empty() {
self.mid_process_was_accepting = inner_accepting;
infoln!(self, "only eos token allowed, stopping");
infoln!(self, "only eos token allowed, stopping; accepting: {}", inner_accepting);
return MidProcessResult::stop();
} else {
infoln!(self, "pop_parser; tokens left {}", self.max_tokens_parser);
Expand All @@ -460,11 +461,12 @@ impl TokenParser {
for pentry in self.parser_stack.iter_mut() {
if pentry.mask.is_none() {
assert!(token_prefix.is_empty());
let mask = pentry
let (is_accepting, mask) = pentry
.parser
.compute_bias_after_gen_grammar(trie, pentry.symidx);
infoln!(self, "bias for upper parser: {}", trie.token_set_dbg(&mask));
pentry.mask = Some(mask);
pentry.is_accepting = is_accepting;
}
let m = pentry.mask.as_ref().unwrap();
pop_tokens.or_minus(m, &set);
Expand All @@ -482,8 +484,9 @@ impl TokenParser {

infoln!(
self,
"bias: (pref: {:?}) {:?} {}",
"bias: (pref: {:?}; accpt: {}) {:?} {}",
String::from_utf8_lossy(&token_prefix),
self.mid_process_was_accepting,
start_time.elapsed(),
self.token_env.tok_trie().token_set_dbg(&set)
);
Expand All @@ -504,17 +507,16 @@ impl TokenParser {
let grm = Arc::clone(&self.compiled_grammars[gen_grammar.grammar.0]);
let max_tokens = self.parser.grammar().sym_data(symidx).props.max_tokens;
let parser = Parser::new(grm, gen_grammar)?;
let mut old_parser = std::mem::replace(&mut self.parser, parser);
let old_parser = std::mem::replace(&mut self.parser, parser);
self.parser.stats = old_parser.stats.clone();
let is_accepting = old_parser.is_accepting();
let mut entry = ParserStackEntry {
parser: old_parser,
parser_llm_tokens_offset: self.parser_llm_tokens_offset,
previous_grm_bytes_len: self.previous_grm_bytes.len(),
symidx,
max_tokens_offset: self.max_tokens_total.saturating_sub(self.max_tokens_parser),
mask: None,
is_accepting,
is_accepting: false, // computed with mask
};
self.max_tokens_parser = std::cmp::min(self.max_tokens_parser, max_tokens);
self.parser_llm_tokens_offset = self.llm_tokens.len();
Expand Down

0 comments on commit 69788eb

Please sign in to comment.