From 23f651ede23a70640ed8f8871e127864aa3274d0 Mon Sep 17 00:00:00 2001 From: Mike Schmidt Date: Mon, 12 Feb 2024 10:09:29 -0700 Subject: [PATCH] feat!(pylace): Return `PyResults` for several methods vs unwrap Removed several unwraps and replaced them with `PyResult` returns. Some could not be removed due to PyO3's limitation to non-failable type conversions https://github.com/PyO3/pyo3/issues/1813. --- pylace/Cargo.lock | 18 ++++---- pylace/src/component.rs | 5 ++- pylace/src/df.rs | 4 ++ pylace/src/lib.rs | 76 ++++++++++++++++++++++------------ pylace/src/update_handler.rs | 1 - pylace/src/utils.rs | 80 +++++++++++++++++++----------------- 6 files changed, 110 insertions(+), 74 deletions(-) diff --git a/pylace/Cargo.lock b/pylace/Cargo.lock index 98b41b7a..2f8e8ed9 100644 --- a/pylace/Cargo.lock +++ b/pylace/Cargo.lock @@ -521,7 +521,7 @@ dependencies = [ [[package]] name = "lace" -version = "0.6.0" +version = "0.7.0" dependencies = [ "dirs", "indexmap", @@ -549,7 +549,7 @@ dependencies = [ [[package]] name = "lace_cc" -version = "0.5.0" +version = "0.6.0" dependencies = [ "enum_dispatch", "itertools", @@ -569,7 +569,7 @@ dependencies = [ [[package]] name = "lace_codebook" -version = "0.5.0" +version = "0.6.0" dependencies = [ "lace_consts", "lace_data", @@ -590,7 +590,7 @@ dependencies = [ [[package]] name = "lace_data" -version = "0.2.0" +version = "0.3.0" dependencies = [ "lace_utils", "serde", @@ -599,7 +599,7 @@ dependencies = [ [[package]] name = "lace_geweke" -version = "0.2.1" +version = "0.3.0" dependencies = [ "indicatif", "lace_stats", @@ -611,7 +611,7 @@ dependencies = [ [[package]] name = "lace_metadata" -version = "0.5.0" +version = "0.6.0" dependencies = [ "bincode", "hex", @@ -630,7 +630,7 @@ dependencies = [ [[package]] name = "lace_stats" -version = "0.2.1" +version = "0.3.0" dependencies = [ "itertools", "lace_consts", @@ -644,7 +644,7 @@ dependencies = [ [[package]] name = "lace_utils" -version = "0.2.0" +version = "0.3.0" dependencies = [ "rand", "rayon", @@ -1345,7 +1345,7 @@ checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" [[package]] name = "pylace" -version = "0.6.0" +version = "0.7.0" dependencies = [ "lace", "lace_utils", diff --git a/pylace/src/component.rs b/pylace/src/component.rs index dde2b944..67916a57 100644 --- a/pylace/src/component.rs +++ b/pylace/src/component.rs @@ -73,7 +73,10 @@ impl CategoricalParams { _ => format!( "[{}, ..., {}]", self.weights[0], - self.weights.last().unwrap() + self.weights + .last() + .map(|x| x.to_string()) + .unwrap_or_else(|| "-".to_string()) ), }; diff --git a/pylace/src/df.rs b/pylace/src/df.rs index 512f0678..b7d05c00 100644 --- a/pylace/src/df.rs +++ b/pylace/src/df.rs @@ -166,6 +166,8 @@ pub(crate) fn to_py_array( Ok(array.to_object(py)) } +// TODO: When https://github.com/PyO3/pyo3/issues/1813 is solved, implement a +// failable version. impl IntoPy for PySeries { fn into_py(self, py: Python<'_>) -> PyObject { let s = self.0.rechunk(); @@ -181,6 +183,8 @@ impl IntoPy for PySeries { } } +// TODO: When https://github.com/PyO3/pyo3/issues/1813 is solved, implement a +// failable version. impl IntoPy for PyDataFrame { fn into_py(self, py: Python<'_>) -> PyObject { let pyseries = self diff --git a/pylace/src/lib.rs b/pylace/src/lib.rs index 5374f6d9..fe65a227 100644 --- a/pylace/src/lib.rs +++ b/pylace/src/lib.rs @@ -16,8 +16,8 @@ use lace::prelude::ColMetadataList; use lace::{EngineUpdateConfig, FType, HasStates, OracleT}; use polars::prelude::{DataFrame, NamedFrom, Series}; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; -use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyType}; +use pyo3::{create_exception, prelude::*}; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; @@ -112,18 +112,20 @@ impl CoreEngine { /// Load a Engine from metadata #[classmethod] - fn load(_cls: &PyType, path: PathBuf) -> CoreEngine { + fn load(_cls: &PyType, path: PathBuf) -> PyResult { let (engine, rng) = { - let mut engine = lace::Engine::load(path).unwrap(); - let rng = Xoshiro256Plus::from_rng(&mut engine.rng).unwrap(); + let mut engine = lace::Engine::load(path) + .map_err(|e| EngineLoadError::new_err(e.to_string()))?; + let rng = Xoshiro256Plus::from_rng(&mut engine.rng) + .map_err(|e| EngineLoadError::new_err(e.to_string()))?; (engine, rng) }; - Self { + Ok(Self { col_indexer: Indexer::columns(&engine.codebook), row_indexer: Indexer::rows(&engine.codebook), rng, engine, - } + }) } /// Save the engine to `path` @@ -491,13 +493,17 @@ impl CoreEngine { let mut a = Vec::with_capacity(pairs.len()); let mut b = Vec::with_capacity(pairs.len()); - utils::pairs_list_iter(pairs, indexer).for_each(|res| { - let (ix_a, ix_b) = res.unwrap(); - let name_a = indexer.to_name[&ix_a].clone(); - let name_b = indexer.to_name[&ix_b].clone(); - a.push(name_a); - b.push(name_b); - }); + utils::pairs_list_iter(pairs, indexer) + .map(|res| { + let (ix_a, ix_b) = res?; + let name_a = indexer.to_name[&ix_a].clone(); + let name_b = indexer.to_name[&ix_b].clone(); + a.push(name_a); + b.push(name_b); + + Ok::<(), PyErr>(()) + }) + .collect::>()?; let a = Series::new("A", a); let b = Series::new("B", b); @@ -1047,7 +1053,7 @@ impl CoreEngine { transitions: Option>, save_path: Option, update_handler: Option, - ) { + ) -> PyResult<()> { use lace::update_handler::Timeout; use std::time::Duration; @@ -1084,11 +1090,13 @@ impl CoreEngine { config, (timeout, PyUpdateHandler::new(update_handler)), ) - .unwrap(); + .map_err(|e| EngineUpdateError::new_err(e.to_string())) } else { - self.engine.update(config, timeout).unwrap(); + self.engine + .update(config, timeout) + .map_err(|e| EngineUpdateError::new_err(e.to_string())) } - }); + }) } /// Append new rows to the table. @@ -1132,15 +1140,22 @@ impl CoreEngine { })?; // must add new row names to indexer - let row_names = df_vals.row_names.unwrap(); - (self.engine.n_rows()..).zip(row_names.iter()).for_each( + let row_names = df_vals.row_names.ok_or_else(|| { + PyValueError::new_err("Provided dataframe has no index (row names)") + })?; + (self.engine.n_rows()..).zip(row_names.iter()).map( |(ix, name)| { - // row names passed to 'append' should not exist - assert!(!self.row_indexer.to_ix.contains_key(name)); - self.row_indexer.to_ix.insert(name.to_owned(), ix); - self.row_indexer.to_name.insert(ix, name.to_owned()); + if self.row_indexer.to_ix.contains_key(name) { + Err(PyValueError::new_err( + format!("Duplicate ids/indices cannot be inserted. Duplicate `{name}`") + )) + } else { + self.row_indexer.to_ix.insert(name.to_owned(), ix); + self.row_indexer.to_name.insert(ix, name.to_owned()); + Ok(()) + } }, - ); + ).collect::>()?; let data = parts_to_insert_values( df_vals.col_names, @@ -1217,7 +1232,11 @@ impl CoreEngine { let data = parts_to_insert_values( col_names, - df_vals.row_names.unwrap(), + df_vals.row_names.ok_or_else(|| { + PyValueError::new_err( + "Provided dataframe has no index (row names)", + ) + })?, df_vals.values, ); @@ -1314,9 +1333,12 @@ pub fn infer_srs_metadata( .map(metadata::ColumnMetadata) } +create_exception!(lace, EngineLoadError, pyo3::exceptions::PyException); +create_exception!(lace, EngineUpdateError, pyo3::exceptions::PyException); + /// A Python module implemented in Rust. #[pymodule] -fn core(_py: Python, m: &PyModule) -> PyResult<()> { +fn core(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -1334,5 +1356,7 @@ fn core(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(infer_srs_metadata, m)?)?; m.add_function(wrap_pyfunction!(metadata::codebook_from_df, m)?)?; + m.add("EngineLoadError", py.get_type::())?; + m.add("EngineUpdateError", py.get_type::())?; Ok(()) } diff --git a/pylace/src/update_handler.rs b/pylace/src/update_handler.rs index c04791bc..023e18e3 100644 --- a/pylace/src/update_handler.rs +++ b/pylace/src/update_handler.rs @@ -1,4 +1,3 @@ -use std::io::Write; /// Update Handler and associated tooling for `CoreEngine.update` in Python. use std::sync::{Arc, Mutex}; diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index 0e45a1f0..55e3bc1f 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -418,7 +418,9 @@ impl Indexer { let name = self.to_name.remove(&ix).ok_or_else(|| { PyIndexError::new_err(format!("Index {ix} not found")) })?; - self.to_ix.remove(&name).unwrap(); + self.to_ix + .remove(&name) + .expect("Should exist as a consequence of the check above"); Ok(name) } } @@ -441,7 +443,7 @@ pub(crate) fn pairs_list_iter<'a>( } }) .unwrap_or_else(|_| { - let ixs: &PyTuple = item.downcast().unwrap(); + let ixs: &PyTuple = item.downcast()?; if ixs.len() != 2 { Err(PyErr::new::( "A pair consists of two items", @@ -606,8 +608,13 @@ pub(crate) fn dict_to_given( .iter() .map(|(key, value)| { value_to_index(key, indexer).and_then(|ix| { - value_to_datum(value, engine.ftype(ix).unwrap()) - .map(|x| (ix, x)) + value_to_datum( + value, + engine.ftype(ix).expect( + "Index from indexer ought to be valid.", + ), + ) + .map(|x| (ix, x)) }) }) .collect::>>()?; @@ -618,7 +625,7 @@ pub(crate) fn dict_to_given( } pub(crate) fn srs_to_strings(srs: &PyAny) -> PyResult> { - let list: &PyList = srs.call_method0("to_list").unwrap().extract().unwrap(); + let list: &PyList = srs.call_method0("to_list")?.extract()?; list.iter() .map(|x| x.extract::()) @@ -666,11 +673,11 @@ fn process_row_dict( _col_indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, -) -> Result, PyErr> { +) -> PyResult> { let mut row_data: Vec = Vec::with_capacity(row_dict.len()); for (name_any, value_any) in row_dict { - let col_name: &PyString = name_any.downcast().unwrap(); - let col_name = col_name.to_str().unwrap(); + let col_name: &PyString = name_any.downcast()?; + let col_name = col_name.to_str()?; let ftype = engine .codebook .col_metadata(col_name.to_owned()) @@ -698,7 +705,7 @@ fn values_to_data( ) -> PyResult>> { data.iter() .map(|row_any| { - let row_dict: &PyDict = row_any.downcast().unwrap(); + let row_dict: &PyDict = row_any.downcast()?; process_row_dict(row_dict, col_indexer, engine, suppl_types) }) .collect() @@ -722,35 +729,38 @@ fn df_to_values( ) -> PyResult { Python::with_gil(|py| { let (columns, data, row_names) = { - let columns = df.getattr("columns").unwrap(); - if columns.get_type().name().unwrap().contains("Index") { + let columns = df.getattr("columns")?; + if columns.get_type().name()?.contains("Index") { // Is a Pandas dataframe let index = df.getattr("index")?; let row_names = srs_to_strings(index).ok(); - let cols = - columns.call_method0("tolist").unwrap().to_object(py); + let cols = columns.call_method0("tolist")?.to_object(py); let kwargs = PyDict::new(py); - kwargs.set_item("orient", "records").unwrap(); - let data = df.call_method("to_dict", (), Some(kwargs)).unwrap(); + kwargs.set_item("orient", "records")?; + let data = df.call_method("to_dict", (), Some(kwargs))?; (cols, data, row_names) } else { // Is a Polars dataframe - let list = columns.downcast::().unwrap(); + let list = columns.downcast::()?; let index_col = { // Find all the index columns let mut index_col_names = list .iter() - .map(|s| s.extract::<&str>().unwrap()) - .filter_map(|s| { - if is_index_col(s) { - Some(String::from(s)) - } else { - None - } + .map(|s| s.extract::<&str>()) + .map(|s| { + s.map(|s| { + if is_index_col(s) { + Some(String::from(s)) + } else { + None + } + }) + .transpose() }) - .collect::>(); + .flatten() + .collect::>>()?; if index_col_names.is_empty() { Ok(None) @@ -769,7 +779,7 @@ fn df_to_values( let (df, row_names) = if let Some(ref index_name) = index_col { // remove the index column label - list.call_method1("remove", (index_name,)).unwrap(); + list.call_method1("remove", (index_name,))?; // Get the indices from the index if it exists let row_names = df.get_item(index_name) @@ -779,20 +789,20 @@ fn df_to_values( "Indices in index '{index_name}' are not strings: {err}")) })?; // remove the index column from the data - let df = df.call_method1("drop", (index_name,)).unwrap(); + let df = df.call_method1("drop", (index_name,))?; (df, Some(row_names)) } else { (df, None) }; - let data = df.call_method0("to_dicts").unwrap(); + let data = df.call_method0("to_dicts")?; (list.to_object(py), data, row_names) } }; - let data: &PyList = data.extract().unwrap(); - let columns: &PyList = columns.extract(py).unwrap(); + let data: &PyList = data.extract()?; + let columns: &PyList = columns.extract(py)?; // will return nothing if there are unknown column names let col_ixs = columns .iter() @@ -820,7 +830,7 @@ fn srs_to_column_values( engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult { - let data = srs.call_method0("to_frame").unwrap(); + let data = srs.call_method0("to_frame")?; df_to_values(data, indexer, engine, suppl_types) } @@ -831,11 +841,7 @@ fn srs_to_row_values( engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult { - let data = srs - .call_method0("to_frame") - .unwrap() - .call_method0("transpose") - .unwrap(); + let data = srs.call_method0("to_frame")?.call_method0("transpose")?; df_to_values(data, indexer, engine, suppl_types) } @@ -845,7 +851,7 @@ pub(crate) fn pandas_to_logp_values( indexer: &Indexer, engine: &lace::Engine, ) -> PyResult { - let type_name = xs.get_type().name().unwrap(); + let type_name = xs.get_type().name()?; match type_name { "DataFrame" => df_to_values(xs, indexer, engine, None), @@ -862,7 +868,7 @@ pub(crate) fn pandas_to_insert_values( engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult { - let type_name = xs.get_type().name().unwrap(); + let type_name = xs.get_type().name()?; match type_name { "DataFrame" => df_to_values(xs, col_indexer, engine, suppl_types),