Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
vbkaisetsu committed Jun 10, 2022
1 parent c810560 commit 3604e2f
Showing 1 changed file with 42 additions and 31 deletions.
73 changes: 42 additions & 31 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ impl Token {

fn __repr__(&self, py: Python) -> PyResult<String> {
let list = self.list.borrow(py);
let surface: String = list.surfaces[self.index].0.as_ref(py).extract()?;
let surface = list.surfaces[self.index].0.as_ref(py).to_str()?;
let mut result = format!("Token {{ surface: {:?}, tags: [", surface);
let pos = list.surfaces[self.index].2 - 1;
for i in 0..list.n_tags {
if i != 0 {
result += ", ";
}
if let Some(tag) = list.tags[pos * list.n_tags + i].as_ref() {
let tag: String = tag.as_ref(py).extract()?;
let tag = tag.as_ref(py).to_str()?;
write!(&mut result, "{:?}", tag).unwrap();
} else {
result += "None";
Expand Down Expand Up @@ -192,6 +192,7 @@ struct Vaporetto {
post_filters: Vec<Box<dyn SentenceFilter>>,
word_cache: HashMap<String, Py<PyUnicode>>,
tag_cache: RefCell<HashMap<String, Py<PyUnicode>>>,
string_buf: RefCell<String>,
}

impl Vaporetto {
Expand Down Expand Up @@ -236,8 +237,10 @@ impl Vaporetto {
post_filters,
word_cache,
tag_cache: RefCell::new(HashMap::new()),
string_buf: RefCell::new(String::new()),
})
}

fn tokenize_internal<'a>(&'a self, py: Python, s: &mut Sentence<'_, 'a>) {
let predictor = &self.predictor;
let normalizer = &self.normalizer;
Expand Down Expand Up @@ -278,11 +281,14 @@ impl Vaporetto {
wsconst: &str,
norm: bool,
) -> PyResult<Self> {
let mut f = Cursor::new(model);
let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap();
let mut buff = vec![];
decoder.read_to_end(&mut buff).unwrap();
let (model, _) = py.allow_threads(|| {
let mut f = Cursor::new(model);
let mut decoder = ruzstd::StreamingDecoder::new(&mut f)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
decoder
.read_to_end(&mut buff)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Model::read_slice(&buff).map_err(|e| PyValueError::new_err(e.to_string()))
})?;
Self::create_internal(py, model, predict_tags, wsconst, norm)
Expand Down Expand Up @@ -329,29 +335,34 @@ impl Vaporetto {
self.tokenize_internal(py, &mut s);

// Creates TokenIterator
let mut surfaces = vec![];
for token in s.iter_tokens() {
let surface = self
.word_cache
.get(token.surface())
.map(|surf| surf.clone_ref(py))
.unwrap_or_else(|| PyUnicode::new(py, token.surface()).into());
surfaces.push((surface, token.start(), token.end()));
}
let mut tags = vec![];
let surfaces = s
.iter_tokens()
.map(|token| {
let surface = self
.word_cache
.get(token.surface())
.map(|surf| surf.clone_ref(py))
.unwrap_or_else(|| PyUnicode::new(py, token.surface()).into());
(surface, token.start(), token.end())
})
.collect();
let tag_cache = &mut self.tag_cache.borrow_mut();
for tag in s.tags() {
tags.push(tag.as_ref().map(|tag| {
tag_cache
.raw_entry_mut()
.from_key(tag.as_ref())
.or_insert_with(|| {
(tag.to_string(), PyUnicode::new(py, tag.as_ref()).into())
})
.1
.clone_ref(py)
}));
}
let tags = s
.tags()
.iter()
.map(|tag| {
tag.as_ref().map(|tag| {
tag_cache
.raw_entry_mut()
.from_key(tag.as_ref())
.or_insert_with(|| {
(tag.to_string(), PyUnicode::new(py, tag.as_ref()).into())
})
.1
.clone_ref(py)
})
})
.collect();
TokenList {
surfaces,
tags,
Expand All @@ -372,13 +383,13 @@ impl Vaporetto {
/// :type text: str
/// :type out: str
#[pyo3(text_signature = "($self, text, /)")]
fn tokenize_to_string(&self, py: Python, text: &str) -> String {
let mut buf = String::new();
fn tokenize_to_string(&self, py: Python, text: &str) -> Py<PyUnicode> {
let buf = &mut self.string_buf.borrow_mut();
if let Ok(mut s) = Sentence::from_raw(text) {
self.tokenize_internal(py, &mut s);
s.write_tokenized_text(&mut buf);
s.write_tokenized_text(buf);
}
buf
PyUnicode::new(py, buf).into()
}
}

Expand Down

0 comments on commit 3604e2f

Please sign in to comment.