Skip to content

Commit

Permalink
Use PyArray_Check instead of downcasting to PyArray1<u8>
Browse files Browse the repository at this point in the history
  • Loading branch information
messense committed Mar 22, 2022
1 parent 7551794 commit 0ae80bd
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};

use numpy::PyArray1;
use numpy::{npyffi, PyArray1};
use pyo3::class::basic::CompareOp;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::AsPyPointer;
use tk::models::bpe::BPE;
use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
Expand Down Expand Up @@ -258,8 +259,24 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
struct PyArrayUnicode(Vec<String>);
impl FromPyObject<'_> for PyArrayUnicode {
fn extract(ob: &PyAny) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<u8>>()?;
let arr = array.as_array_ptr();
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
}
let arr = ob.as_ptr() as *mut npyffi::PyArrayObject;
if unsafe { (*arr).nd } != 1 {
return Err(exceptions::PyTypeError::new_err(
"Expected a 1 dimensional np.array",
));
}
if unsafe { (*arr).flags }
& (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
== 0
{
return Err(exceptions::PyTypeError::new_err(
"Expected a contiguous np.array",
));
}
let n_elem = unsafe { *(*arr).dimensions } as usize;
let (type_num, elsize, alignment, data) = unsafe {
let desc = (*arr).descr;
(
Expand All @@ -269,10 +286,8 @@ impl FromPyObject<'_> for PyArrayUnicode {
(*arr).data,
)
};
let n_elem = array.shape()[0];

// type_num == 19 => Unicode
if type_num != 19 {
if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 {
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='U']",
));
Expand All @@ -289,8 +304,7 @@ impl FromPyObject<'_> for PyArrayUnicode {
bytes.as_ptr() as *const _,
elsize as isize / alignment as isize,
);
let gil = Python::acquire_gil();
let py = gil.python();
let py = ob.py();
let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
Expand All @@ -310,32 +324,18 @@ impl From<PyArrayUnicode> for tk::InputSequence<'_> {
struct PyArrayStr(Vec<String>);
impl FromPyObject<'_> for PyArrayStr {
fn extract(ob: &PyAny) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<u8>>()?;
let arr = array.as_array_ptr();
let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) };
let n_elem = array.shape()[0];

if type_num != 17 {
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='O']",
));
}

unsafe {
let objects = std::slice::from_raw_parts(data as *const PyObject, n_elem);

let seq = objects
.iter()
.map(|obj| {
let gil = Python::acquire_gil();
let py = gil.python();
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string_lossy().into_owned())
})
.collect::<PyResult<Vec<_>>>()?;
let array = ob.downcast::<PyArray1<PyObject>>()?;
let seq = array
.readonly()
.as_array()
.iter()
.map(|obj| {
let s = obj.cast_as::<PyString>(ob.py())?;
Ok(s.to_string_lossy().into_owned())
})
.collect::<PyResult<Vec<_>>>()?;

Ok(Self(seq))
}
Ok(Self(seq))
}
}
impl From<PyArrayStr> for tk::InputSequence<'_> {
Expand Down

0 comments on commit 0ae80bd

Please sign in to comment.