Skip to content

Commit

Permalink
fix: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChieloNewctle committed May 9, 2024
1 parent 3f2439e commit 568be06
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
10 changes: 9 additions & 1 deletion python/mtc_token_healing/mtc_token_healing.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Sequence

TokenId = int

class BestChoice: ...
class CountInfo: ...
class InferRequest: ...
class InferResponse: ...
class Prediction: ...
class VocabPrefixAutomaton: ...

class VocabPrefixAutomaton:
def __init__(self, vocab: Sequence[str]) -> None: ...
def get_order(self) -> Sequence[int]: ...
@property
def vocab_size(self) -> int: ...

class ReorderedTokenId: ...
class SearchTree: ...
15 changes: 15 additions & 0 deletions python/tests/test_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from mtc_token_healing import VocabPrefixAutomaton


def test_vocab_simple():
vocab = ["bcd", "abc", "cc", "hello", "world", " ", "yes", "no", "."]
order = [5, 8, 1, 0, 2, 3, 7, 4, 6]

assert len(vocab) == len(order)

automaton = VocabPrefixAutomaton(vocab)

assert automaton.vocab_size == len(vocab)
assert automaton.get_order() == order

assert all(vocab[order[i]] < vocab[order[i + 1]] for i in range(len(order) - 1))
2 changes: 1 addition & 1 deletion src/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ mod _pyo3 {
Self(Arc::new(VocabPrefixAutomaton::new(vocab)))
}

#[pyo3(name = "vocab_size")]
#[getter("vocab_size")]
fn vocab_size_py(&self) -> usize {
self.vocab.len()
}
Expand Down

0 comments on commit 568be06

Please sign in to comment.