diff --git a/crates/open_jtalk/Cargo.toml b/crates/open_jtalk/Cargo.toml index 25342e0..d85b8d3 100644 --- a/crates/open_jtalk/Cargo.toml +++ b/crates/open_jtalk/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.25" edition = "2021" [dependencies] +camino = "1.1.6" open_jtalk-sys = { path = "../open_jtalk-sys", version = "0.16.111" } thiserror = "1.0.31" diff --git a/crates/open_jtalk/src/mecab/mod.rs b/crates/open_jtalk/src/mecab/mod.rs index 8847d91..42fb31c 100644 --- a/crates/open_jtalk/src/mecab/mod.rs +++ b/crates/open_jtalk/src/mecab/mod.rs @@ -3,7 +3,16 @@ mod mecab_dict_index; pub use mecab_dict_index::*; use super::*; -use std::{ffi::CString, mem::MaybeUninit, path::Path}; +use camino::{Utf8Path, Utf8PathBuf}; +use std::{ffi::CString, mem::MaybeUninit}; + +#[derive(thiserror::Error, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub enum MecabLoadError { + #[error("`{function}` failed")] + Unsuccessful { function: &'static str }, + #[error("file name contained a NUL byte: {filename:?}")] + Nul { filename: Utf8PathBuf }, +} #[derive(Default)] pub struct Mecab(Option); @@ -38,35 +47,42 @@ impl Mecab { self.0.as_ref().unwrap() as *const open_jtalk_sys::Mecab as *mut open_jtalk_sys::Mecab } - pub fn load(&mut self, dic_dir: impl AsRef) -> bool { - let dic_dir = CString::new(dic_dir.as_ref().to_str().unwrap()).unwrap(); - unsafe { - bool_number_to_bool(open_jtalk_sys::Mecab_load( - self.as_raw_ptr(), - dic_dir.as_ptr(), - )) + pub fn load(&mut self, dic_dir: impl AsRef) -> Result<(), MecabLoadError> { + let dic_dir = c_filename(dic_dir.as_ref())?; + let success = bool_number_to_bool(unsafe { + open_jtalk_sys::Mecab_load(self.as_raw_ptr(), dic_dir.as_ptr()) + }); + if !success { + return Err(MecabLoadError::Unsuccessful { + function: "Mecab_load", + }); } - } - - /// # Panics - /// - /// 次の場合にパニックする。 - /// - /// - `dic_dir`または`userdic`が`\0`を含む。 - /// - `dic_dir`または`userdic`がUTF-8の文字列ではない。 - pub fn load_with_userdic(&mut self, dic_dir: &Path, userdic: Option<&Path>) -> bool { - let dic_dir = CString::new(dic_dir.to_str().unwrap()).unwrap(); - let userdic = &userdic.map(|userdic| CString::new(userdic.to_str().unwrap()).unwrap()); - unsafe { - bool_number_to_bool(open_jtalk_sys::Mecab_load_with_userdic( + Ok(()) + } + + pub fn load_with_userdic( + &mut self, + dic_dir: &Utf8Path, + userdic: Option<&Utf8Path>, + ) -> Result<(), MecabLoadError> { + let dic_dir = c_filename(dic_dir)?; + let userdic = &userdic.map(c_filename).transpose()?; + let success = bool_number_to_bool(unsafe { + open_jtalk_sys::Mecab_load_with_userdic( self.as_raw_ptr(), dic_dir.as_ptr(), match userdic { Some(userdic) => userdic.as_ptr(), None => std::ptr::null(), }, - )) + ) + }); + if !success { + return Err(MecabLoadError::Unsuccessful { + function: "Mecab_load_with_userdic", + }); } + Ok(()) } pub fn get_feature(&self) -> Option<&MecabFeature> { unsafe { @@ -113,11 +129,16 @@ impl Mecab { } } +fn c_filename(path: &Utf8Path) -> Result { + CString::new(path.as_str()).map_err(|_| MecabLoadError::Nul { + filename: path.to_owned(), + }) +} + #[cfg(test)] mod tests { - use std::{path::PathBuf, str::FromStr}; - use super::*; + use camino::Utf8Path; use pretty_assertions::{assert_eq, assert_ne}; use resources::Resource as _; @@ -139,11 +160,12 @@ mod tests { #[rstest] fn mecab_load_works() { let mut mecab = ManagedResource::::initialize(); - assert!(mecab.load( - PathBuf::from_str(std::env!("CARGO_MANIFEST_DIR")) - .unwrap() - .join("src/mecab/testdata/mecab_load"), - )); + mecab + .load( + Utf8Path::new(std::env!("CARGO_MANIFEST_DIR")) + .join("src/mecab/testdata/mecab_load"), + ) + .unwrap(); } #[rstest] @@ -156,11 +178,12 @@ mod tests { #[case("h^o-d+e=s/A:2+3+2/B:22-xx_xx/C:10_7+2/D:xx+xx_xx/E:5_5!0_xx-0/F:4_1#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:1_5/I:1-4@2+1&2-1|6+4/J:xx_xx/K:2+2-9",true)] fn mecab_analysis_works(#[case] input: &str, #[case] expected: bool) { let mut mecab = ManagedResource::::initialize(); - assert!(mecab.load( - PathBuf::from_str(std::env!("CARGO_MANIFEST_DIR")) - .unwrap() - .join("src/mecab/testdata/mecab_load"), - )); + mecab + .load( + Utf8Path::new(std::env!("CARGO_MANIFEST_DIR")) + .join("src/mecab/testdata/mecab_load"), + ) + .unwrap(); let s = text2mecab(input).unwrap(); assert_eq!(expected, mecab.analysis(s)); assert_ne!(0, mecab.get_size()); diff --git a/crates/open_jtalk/src/njd.rs b/crates/open_jtalk/src/njd.rs index 970f2f4..8928eb1 100644 --- a/crates/open_jtalk/src/njd.rs +++ b/crates/open_jtalk/src/njd.rs @@ -74,9 +74,8 @@ impl Njd { #[cfg(test)] mod tests { use super::*; + use camino::Utf8Path; use resources::Resource as _; - use std::path::PathBuf; - use std::str::FromStr; #[rstest] fn njd_initialize_and_clear_works() { let mut njd = Njd::default(); @@ -131,11 +130,12 @@ mod tests { let mut njd = ManagedResource::::initialize(); let mut mecab = ManagedResource::::initialize(); - assert!(mecab.load( - PathBuf::from_str(std::env!("CARGO_MANIFEST_DIR")) - .unwrap() - .join("src/mecab/testdata/mecab_load"), - )); + mecab + .load( + Utf8Path::new(std::env!("CARGO_MANIFEST_DIR")) + .join("src/mecab/testdata/mecab_load"), + ) + .unwrap(); let s = text2mecab("h^o-d+e=s/A:2+3+2/B:22-xx_xx/C:10_7+2/D:xx+xx_xx/E:5_5!0_xx-0/F:4_1#0_xx@1_1|1_4/G:xx_xx%xx_xx_xx/H:1_5/I:1-4@2+1&2-1|6+4/J:xx_xx/K:2+2-9").unwrap(); assert!(mecab.analysis(s)); njd.mecab2njd(mecab.get_feature().unwrap(), mecab.get_size());