Skip to content

Commit

Permalink
Reuse Sentence
Browse files Browse the repository at this point in the history
  • Loading branch information
vbkaisetsu committed Jun 11, 2022
1 parent fe68a24 commit 157d187
Showing 1 changed file with 48 additions and 26 deletions.
74 changes: 48 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ struct Vaporetto {
word_cache: HashMap<String, Py<PyUnicode>>,
tag_cache: RefCell<HashMap<String, Py<PyUnicode>>>,
string_buf: RefCell<String>,
sentence_buf1: RefCell<Sentence<'static, 'static>>,
sentence_buf2: RefCell<Sentence<'static, 'static>>,
}

impl Vaporetto {
Expand Down Expand Up @@ -238,35 +240,43 @@ impl Vaporetto {
word_cache,
tag_cache: RefCell::new(HashMap::new()),
string_buf: RefCell::new(String::new()),
sentence_buf1: RefCell::new(Sentence::default()),
sentence_buf2: RefCell::new(Sentence::default()),
})
}

fn tokenize_internal<'a>(&'a self, py: Python, s: &mut Sentence<'_, 'a>) {
fn tokenize_internal<'a>(&'a self, s: &mut Sentence<'_, 'a>) {
let predictor = &self.predictor;
let normalizer = &self.normalizer;
let post_filters = &self.post_filters;
let predict_tags = self.predict_tags;
py.allow_threads(|| {
if let Some(normalizer) = normalizer {
let mut norm_s = Sentence::from_raw(normalizer.filter(s.as_raw_text())).unwrap();
predictor.predict(&mut norm_s);
post_filters
.iter()
.for_each(|filter| filter.filter(&mut norm_s));
s.boundaries_mut().copy_from_slice(norm_s.boundaries());
if predict_tags {
norm_s.fill_tags();
s.reset_tags(norm_s.n_tags());
s.tags_mut().clone_from_slice(norm_s.tags());
}
} else {
predictor.predict(s);
post_filters.iter().for_each(|filter| filter.filter(s));
if predict_tags {
s.fill_tags();
}
if let Some(normalizer) = normalizer {
// Sentence buffer requires lifetimes of text and predictor, but the Vaporetto struct
// cannot have such a Sentence, so we use transmute() to disguise lifetimes.
let norm_s = &mut self.sentence_buf2.borrow_mut();
let norm_s = unsafe {
std::mem::transmute::<&mut Sentence<'static, 'static>, &mut Sentence<'_, '_>>(
norm_s,
)
};
norm_s
.update_raw(normalizer.filter(s.as_raw_text()))
.unwrap();
predictor.predict(norm_s);
post_filters.iter().for_each(|filter| filter.filter(norm_s));
s.boundaries_mut().copy_from_slice(norm_s.boundaries());
if predict_tags {
norm_s.fill_tags();
s.reset_tags(norm_s.n_tags());
s.tags_mut().clone_from_slice(norm_s.tags());
}
} else {
predictor.predict(s);
post_filters.iter().for_each(|filter| filter.filter(s));
if predict_tags {
s.fill_tags();
}
});
}
}
}

Expand All @@ -285,7 +295,7 @@ impl Vaporetto {
let (model, _) = py.allow_threads(|| {
let mut f = Cursor::new(model);
let mut decoder = ruzstd::StreamingDecoder::new(&mut f)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
.map_err(PyValueError::new_err)?;
decoder
.read_to_end(&mut buff)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Expand Down Expand Up @@ -331,8 +341,14 @@ impl Vaporetto {
/// :type out: vaporetto.TokenList
#[pyo3(text_signature = "($self, text, /)")]
fn tokenize(&self, py: Python, text: &str) -> TokenList {
if let Ok(mut s) = Sentence::from_raw(text) {
self.tokenize_internal(py, &mut s);
// Sentence buffer requires lifetimes of text and predictor, but the Vaporetto struct
// cannot have such a Sentence, so we use transmute() to disguise lifetimes.
let s = &mut self.sentence_buf1.borrow_mut();
let s = unsafe {
std::mem::transmute::<&mut Sentence<'static, 'static>, &mut Sentence<'_, '_>>(s)
};
if s.update_raw(text).is_ok() {
self.tokenize_internal(s);

// Creates TokenIterator
let surfaces = s
Expand Down Expand Up @@ -385,8 +401,14 @@ impl Vaporetto {
#[pyo3(text_signature = "($self, text, /)")]
fn tokenize_to_string(&self, py: Python, text: &str) -> Py<PyUnicode> {
let buf = &mut self.string_buf.borrow_mut();
if let Ok(mut s) = Sentence::from_raw(text) {
self.tokenize_internal(py, &mut s);
// Sentence buffer requires lifetimes of text and predictor, but the Vaporetto struct
// cannot have such a Sentence, so we use transmute() to disguise lifetimes.
let s = &mut self.sentence_buf1.borrow_mut();
let s = unsafe {
std::mem::transmute::<&mut Sentence<'static, 'static>, &mut Sentence<'_, '_>>(s)
};
if s.update_raw(text).is_ok() {
self.tokenize_internal(s);
s.write_tokenized_text(buf);
}
PyUnicode::new(py, buf).into()
Expand Down

0 comments on commit 157d187

Please sign in to comment.