Skip to content

Commit

Permalink
Add --tag-scores option to show tag scores (#107)
Browse files Browse the repository at this point in the history
* Add --tag-scores option

* Update README

* fix
  • Loading branch information
vbkaisetsu authored Sep 5, 2023
1 parent 5135720 commit dbbe794
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 33 deletions.
17 changes: 14 additions & 3 deletions README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ Vaporetto は2種類のコーパス(フルアノテーションコーパスと
9:交代 -5794
```

### 品詞推定
### タグ予測

Vaporettoは実験的にタグ推定(品詞推定や読み推定)に対応しています。
Vaporettoは実験的にタグ予測(品詞予測や読み予測)に対応しています。

タグを学習するには、以下のように、データセットの各トークンに続けてスラッシュとタグを追加します。

Expand All @@ -226,7 +226,18 @@ Vaporettoは実験的にタグ推定(品詞推定や読み推定)に対応

データセットにタグが含まれる場合、 `train` コマンドは自動的にそれらを学習します。

推定時は、デフォルトではタグは推定されないため、必要に応じで `predict` コマンドに `--predict-tags` 引数を指定してください。
予測時は、デフォルトではタグは予測されないため、必要に応じて `predict` コマンドに `--predict-tags` 引数を指定してください。

`--tag-scores` 引数を指定すると、タグ予測の際に計算された各候補のスコアを表示できます。
タグの候補が1つしかない場合は、スコアが0と表示されます。

```
% echo "花が咲く" | cargo run --release -p predict -- --model path/to/bccwj-suw+unidic_pos+pron.model.zst --predict-tags --tag-scores
花/名詞-普通名詞-一般/ハナ が/助詞-格助詞/ガ 咲く/動詞-一般/サク
花 名詞-普通名詞-一般:18613,接尾辞-名詞的-一般:-18613 ハナ:19973,バナ:-20377,カ:-20480,ゲ:-20410
が 助詞-接続助詞:-20408,助詞-格助詞:23543,接続詞:-25332 ガ:0
咲く 動詞-一般:0 サク:0
```

## 各種トークナイザの速度比較

Expand Down
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Now `外国人参政権` is split into correct tokens.
9:交代 -5794
```

### Tagging
### Tag prediction

Vaporetto experimentally supports tagging (e.g., part-of-speech and pronunciation tags).

Expand All @@ -234,6 +234,17 @@ If the dataset contains tags, the `train` command automatically trains them.

In prediction, tags are not predicted by default, so you have to specify the `--predict-tags` argument to the `predict` command if necessary.

If you specify the `--tag-scores` argument, the score of each candidate calculated during tag prediction is displayed.
If there is only one candidate, the score becomes 0.

```
% echo "花が咲く" | cargo run --release -p predict -- --model path/to/bccwj-suw+unidic_pos+pron.model.zst --predict-tags --tag-scores
花/名詞-普通名詞-一般/ハナ が/助詞-格助詞/ガ 咲く/動詞-一般/サク
花 名詞-普通名詞-一般:18613,接尾辞-名詞的-一般:-18613 ハナ:19973,バナ:-20377,カ:-20480,ゲ:-20410
が 助詞-接続助詞:-20408,助詞-格助詞:23543,接続詞:-25332 ガ:0
咲く 動詞-一般:0 サク:0
```

## Speed Comparison of Various Tokenizers

Vaporetto is 8.7 times faster than KyTea.
Expand Down
35 changes: 33 additions & 2 deletions predict/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ struct Args {
#[arg(long)]
wsconst: Vec<WsConst>,

/// Prints scores.
/// Prints boundary scores.
#[arg(long)]
scores: bool,

/// Prints tag scores.
#[arg(long)]
tag_scores: bool,

/// Do not normalize input strings before prediction.
#[arg(long)]
no_norm: bool,
Expand All @@ -70,6 +74,24 @@ fn print_scores(s: &Sentence, mut out: impl Write) -> Result<(), Box<dyn std::er
Ok(())
}

fn print_tag_scores(s: &Sentence, mut out: impl Write) -> Result<(), Box<dyn std::error::Error>> {
for token in s.iter_tokens() {
out.write_all(token.surface().as_bytes())?;
for cands in token.tag_candidates() {
out.write_all(b"\t")?;
for (i, (tag, score)) in cands.iter().enumerate() {
if i != 0 {
out.write_all(b",")?;
}
write!(out, "{tag}:{score}")?;
}
}
out.write_all(b"\n")?;
}
out.write_all(b"\n")?;
Ok(())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();

Expand All @@ -87,7 +109,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Loading model file...");
let mut f = zstd::Decoder::new(File::open(args.model)?)?;
let model = Model::read(&mut f)?;
let predictor = Predictor::new(model, args.predict_tags)?;
let mut predictor = Predictor::new(model, args.predict_tags)?;
if args.tag_scores {
predictor.store_tag_scores(true);
}

let is_tty = atty::is(atty::Stream::Stdout);

Expand All @@ -114,6 +139,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
out.write_all(b"\n")?;
if args.tag_scores {
print_tag_scores(&s, &mut out)?;
}
if is_tty {
out.flush()?;
}
Expand Down Expand Up @@ -143,6 +171,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} else {
out.write_all(b"\n")?;
}
if args.tag_scores {
print_tag_scores(&s, &mut out)?;
}
if is_tty {
out.flush()?;
}
Expand Down
86 changes: 59 additions & 27 deletions vaporetto/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,10 @@ assert_eq!(
```
"
)]
pub struct Predictor(PredictorData);
pub struct Predictor {
data: PredictorData,
tag_scores: bool,
}

impl Predictor {
/// Creates a new predictor from the model.
Expand Down Expand Up @@ -487,30 +490,40 @@ impl Predictor {
#[cfg(feature = "tag-prediction")]
tag_type_ngram_model,
)?;
Ok(Self(PredictorData {
char_scorer,
type_scorer,
bias: model.0.bias,
Ok(Self {
data: PredictorData {
char_scorer,
type_scorer,
bias: model.0.bias,

#[cfg(feature = "tag-prediction")]
tag_predictor,
#[cfg(feature = "tag-prediction")]
n_tags,
},
tag_scores: false,
})
}

#[cfg(feature = "tag-prediction")]
tag_predictor,
#[cfg(feature = "tag-prediction")]
n_tags,
}))
/// Stores tag scores if the given `flag` is `true`.
#[cfg(feature = "tag-prediction")]
pub fn store_tag_scores(&mut self, flag: bool) {
self.tag_scores = flag;
}

/// Predicts word boundaries of the given sentence.
/// If necessary, this function also prepares for predicting tags.
pub fn predict<'a>(&'a self, sentence: &mut Sentence<'_, 'a>) {
sentence.score_padding = WEIGHT_FIXED_LEN - 1;
sentence.boundary_scores.clear();
sentence
.boundary_scores
.resize(sentence.score_padding * 2 + sentence.len() - 1, self.0.bias);
if let Some(scorer) = self.0.char_scorer.as_ref() {
sentence.boundary_scores.resize(
sentence.score_padding * 2 + sentence.len() - 1,
self.data.bias,
);
if let Some(scorer) = self.data.char_scorer.as_ref() {
scorer.add_scores(sentence);
}
if let Some(scorer) = self.0.type_scorer.as_ref() {
if let Some(scorer) = self.data.type_scorer.as_ref() {
scorer.add_scores(sentence);
}
for (b, s) in sentence
Expand All @@ -530,19 +543,25 @@ impl Predictor {
#[cfg(feature = "tag-prediction")]
pub(crate) fn predict_tags<'a>(&'a self, sentence: &mut Sentence<'_, 'a>) {
let tag_predictor = self
.0
.data
.tag_predictor
.as_ref()
.expect("this predictor is created with predict_tags = false");

if self.0.n_tags == 0 {
if self.data.n_tags == 0 {
return;
}
let mut scores = vec![];
let mut range_start = Some(0);
sentence.n_tags = self.0.n_tags;
sentence.n_tags = self.data.n_tags;
sentence.tags.clear();
sentence.tags.resize(sentence.len() * self.0.n_tags, None);
sentence
.tags
.resize(sentence.len() * self.data.n_tags, None);
sentence.tag_scores.clear();
if self.tag_scores {
sentence.tag_scores.resize(sentence.len(), None);
}
for (i, &b) in sentence.boundaries.iter().enumerate() {
if b == CharacterBoundary::Unknown {
range_start.take();
Expand All @@ -553,7 +572,7 @@ impl Predictor {
scores.clear();
scores.resize(tag_predictor.bias().len(), 0);
tag_predictor.bias().add_scores(&mut scores);
if let Some(scorer) = self.0.char_scorer.as_ref() {
if let Some(scorer) = self.data.char_scorer.as_ref() {
debug_assert!(i < sentence.char_pma_states.len());
// token_id is always smaller than tag_weight.len() because
// tag_predictor is created to contain such values in the new()
Expand All @@ -562,7 +581,7 @@ impl Predictor {
scorer.add_tag_scores(*token_id, i, sentence, &mut scores);
}
}
if let Some(scorer) = self.0.type_scorer.as_ref() {
if let Some(scorer) = self.data.type_scorer.as_ref() {
debug_assert!(i < sentence.type_pma_states.len());
// token_id is always smaller than tag_weight.len() because
// tag_predictor is created to contain such values in the new()
Expand All @@ -573,8 +592,12 @@ impl Predictor {
}
tag_predictor.predict(
&scores,
&mut sentence.tags[i * self.0.n_tags..(i + 1) * self.0.n_tags],
&mut sentence.tags[i * self.data.n_tags..(i + 1) * self.data.n_tags],
);
if !sentence.tag_scores.is_empty() {
sentence.tag_scores[i].replace((&tag_predictor.tags, scores));
scores = vec![];
}
}
}
range_start.replace(i + 1);
Expand All @@ -586,15 +609,15 @@ impl Predictor {
scores.clear();
scores.resize(tag_predictor.bias().len(), 0);
tag_predictor.bias().add_scores(&mut scores);
if let Some(scorer) = self.0.char_scorer.as_ref() {
if let Some(scorer) = self.data.char_scorer.as_ref() {
debug_assert!(sentence.len() <= sentence.char_pma_states.len());
// token_id is always smaller than tag_weight.len() because tag_predictor is
// created to contain such values in the new() function.
unsafe {
scorer.add_tag_scores(*token_id, sentence.len() - 1, sentence, &mut scores);
}
}
if let Some(scorer) = self.0.type_scorer.as_ref() {
if let Some(scorer) = self.data.type_scorer.as_ref() {
debug_assert!(sentence.len() <= sentence.type_pma_states.len());
// token_id is always smaller than tag_weight.len() because tag_predictor is
// created to contain such values in the new() function.
Expand All @@ -603,15 +626,18 @@ impl Predictor {
}
}
let i = sentence.len() - 1;
tag_predictor.predict(&scores, &mut sentence.tags[i * self.0.n_tags..]);
tag_predictor.predict(&scores, &mut sentence.tags[i * self.data.n_tags..]);
if !sentence.tag_scores.is_empty() {
sentence.tag_scores[i].replace((&tag_predictor.tags, scores));
}
}
}
}

/// Serializes the predictor into a Vec.
pub fn serialize_to_vec(&self) -> Result<Vec<u8>> {
let config = bincode::config::standard();
let result = bincode::encode_to_vec(&self.0, config)?;
let result = bincode::encode_to_vec(&self.data, config)?;
Ok(result)
}

Expand All @@ -625,7 +651,13 @@ impl Predictor {
let config = bincode::config::standard();
// Deserialization is unsafe because the automaton will not be verified.
let (predictor_data, size) = bincode::borrow_decode_from_slice(data, config)?;
Ok((Self(predictor_data), &data[size..]))
Ok((
Self {
data: predictor_data,
tag_scores: false,
},
&data[size..],
))
}
}

Expand Down
41 changes: 41 additions & 0 deletions vaporetto/src/sentence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ pub struct Sentence<'a, 'b> {
pub(crate) char_pma_states: Vec<u32>,
pub(crate) type_pma_states: Vec<u32>,
pub(crate) tags: Vec<Option<Cow<'b, str>>>,
#[allow(clippy::type_complexity)]
pub(crate) tag_scores: Vec<Option<(&'b [Vec<String>], Vec<i32>)>>,
pub(crate) n_tags: usize,
predictor: Option<&'b Predictor>,
str_to_char_pos: Vec<usize>,
Expand Down Expand Up @@ -120,6 +122,7 @@ impl<'a, 'b> Default for Sentence<'a, 'b> {
char_pma_states: vec![],
type_pma_states: vec![],
tags: vec![],
tag_scores: vec![],
n_tags: 0,
predictor: None,
str_to_char_pos: vec![],
Expand Down Expand Up @@ -232,6 +235,7 @@ impl<'a, 'b> Sentence<'a, 'b> {
type_pma_states: vec![],
predictor: None,
tags: vec![],
tag_scores: vec![],
n_tags: 0,
str_to_char_pos,
char_to_str_pos,
Expand Down Expand Up @@ -451,6 +455,7 @@ impl<'a, 'b> Sentence<'a, 'b> {
type_pma_states: vec![],
predictor: None,
tags,
tag_scores: vec![],
n_tags,
str_to_char_pos,
char_to_str_pos,
Expand Down Expand Up @@ -688,6 +693,7 @@ impl<'a, 'b> Sentence<'a, 'b> {
type_pma_states: vec![],
predictor: None,
tags,
tag_scores: vec![],
n_tags,
str_to_char_pos,
char_to_str_pos,
Expand Down Expand Up @@ -1203,6 +1209,41 @@ impl<'a, 'b> Token<'a, 'b> {
&self.sentence.tags[start..end]
}

/// Returns tag candidates with scores.
///
/// The return value is a two-dimensional array. The outer array index corresponding to the
/// return value of [`Token::tags()`]. The inner array is a candidate set, where each element
/// is a tuple of the tag name and its score.
///
/// # Panics
///
/// This function panics if [`Predictor::store_tag_scores()`] is set to false.
#[cfg(feature = "tag-prediction")]
#[cfg_attr(docsrs, doc(cfg(feature = "tag-prediction")))]
pub fn tag_candidates(&self) -> Vec<Vec<(&'b str, i32)>> {
let mut results = vec![];
assert!(
!self.sentence.tag_scores.is_empty(),
"Predictor::store_tag_scores() must be set to true to use this function.",
);
if let Some((tags, scores)) = self.sentence.tag_scores[self.end - 1].as_ref() {
let mut i = 0;
for cands in *tags {
let mut inner = vec![];
if cands.len() == 1 {
inner.push((cands[0].as_str(), 0));
} else {
for cand in cands {
inner.push((cand.as_str(), scores[i]));
i += 1;
}
}
results.push(inner);
}
}
results
}

/// Returns the start position of this token in characters.
#[inline]
pub const fn start(&self) -> usize {
Expand Down

0 comments on commit dbbe794

Please sign in to comment.