From ee9c5896bc30904f8a9ec557eb79df26365f921a Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Wed, 5 Jul 2023 09:46:38 +0100 Subject: [PATCH] feat(python, rust): clearer message when stringcache-related errors occur (#9715) --- .../logical/categorical/ops/append.rs | 5 +--- polars/polars-core/src/frame/hash_join/mod.rs | 6 +---- polars/polars-error/src/lib.rs | 23 +++++++++++++++++++ py-polars/polars/exceptions.py | 5 ++++ py-polars/src/error.rs | 4 ++++ py-polars/src/lib.rs | 5 ++++ .../tests/unit/datatypes/test_categorical.py | 5 ++-- py-polars/tests/unit/test_cfg.py | 3 ++- 8 files changed, 44 insertions(+), 12 deletions(-) diff --git a/polars/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/polars/polars-core/src/chunked_array/logical/categorical/ops/append.rs index 7ee05c3660c5..44587e704983 100644 --- a/polars/polars-core/src/chunked_array/logical/categorical/ops/append.rs +++ b/polars/polars-core/src/chunked_array/logical/categorical/ops/append.rs @@ -17,10 +17,7 @@ impl CategoricalChunked { }; if is_local_different_source { - polars_bail!( - ComputeError: - "cannot concat categoricals coming from a different source; consider setting a global StringCache" - ); + polars_bail!(string_cache_mismatch); } else { let len = self.len(); let new_rev_map = self.merge_categorical_map(other)?; diff --git a/polars/polars-core/src/frame/hash_join/mod.rs b/polars/polars-core/src/frame/hash_join/mod.rs index cd724a179c85..ec31166685c4 100644 --- a/polars/polars-core/src/frame/hash_join/mod.rs +++ b/polars/polars-core/src/frame/hash_join/mod.rs @@ -93,11 +93,7 @@ use crate::series::IsSorted; #[cfg(feature = "dtype-categorical")] pub fn _check_categorical_src(l: &DataType, r: &DataType) -> PolarsResult<()> { if let (DataType::Categorical(Some(l)), DataType::Categorical(Some(r))) = (l, r) { - polars_ensure!( - l.same_src(r), - ComputeError: "joins/or comparisons on categoricals can only happen if they were \ - created under the same global string cache" - ); + polars_ensure!(l.same_src(r), string_cache_mismatch); } Ok(()) } diff --git a/polars/polars-error/src/lib.rs b/polars/polars-error/src/lib.rs index 46e688bf6fd2..572922b61e64 100644 --- a/polars/polars-error/src/lib.rs +++ b/polars/polars-error/src/lib.rs @@ -56,6 +56,8 @@ pub enum PolarsError { SchemaMismatch(ErrString), #[error("lengths don't match: {0}")] ShapeMismatch(ErrString), + #[error("string caches don't match: {0}")] + StringCacheMismatch(ErrString), #[error("field not found: {0}")] StructFieldNotFound(ErrString), } @@ -91,6 +93,7 @@ impl PolarsError { SchemaFieldNotFound(msg) => SchemaFieldNotFound(func(msg).into()), SchemaMismatch(msg) => SchemaMismatch(func(msg).into()), ShapeMismatch(msg) => ShapeMismatch(func(msg).into()), + StringCacheMismatch(msg) => StringCacheMismatch(func(msg).into()), StructFieldNotFound(msg) => StructFieldNotFound(func(msg).into()), } } @@ -158,6 +161,26 @@ macro_rules! polars_err { (unpack) => { polars_err!(SchemaMismatch: "cannot unpack series, data types don't match") }; + (string_cache_mismatch) => { + polars_err!(StringCacheMismatch: r#" +cannot compare categoricals coming from different sources, consider setting a global StringCache. + +Help: if you're using Python, this may look something like: + + with pl.StringCache(): + # Initialize Categoricals. + df1 = pl.DataFrame({'a': ['1', '2']}, schema={'a': pl.Categorical}) + df2 = pl.DataFrame({'a': ['1', '3']}, schema={'a': pl.Categorical}) + # Your operations go here. + pl.concat([df1, df2]) + +Alternatively, if the performance cost is acceptable, you could just set: + + import polars as pl + pl.enable_string_cache(True) + +on startup."#.trim_start()) + }; (duplicate = $name:expr) => { polars_err!(Duplicate: "column with name '{}' has more than one occurrences", $name) }; diff --git a/py-polars/polars/exceptions.py b/py-polars/polars/exceptions.py index c8be22c11b75..7f83f92ebafd 100644 --- a/py-polars/polars/exceptions.py +++ b/py-polars/polars/exceptions.py @@ -10,6 +10,7 @@ SchemaError, SchemaFieldNotFoundError, ShapeError, + StringCacheMismatchError, StructFieldNotFoundError, ) except ImportError: @@ -43,6 +44,9 @@ class SchemaFieldNotFoundError(Exception): # type: ignore[no-redef] class ShapeError(Exception): # type: ignore[no-redef] """Exception raised when trying to combine data structures with incompatible shapes.""" # noqa: W505 + class StringCacheMismatchError(Exception): # type: ignore[no-redef] + """Exception raised when string caches come from different sources.""" + class StructFieldNotFoundError(Exception): # type: ignore[no-redef] """Exception raised when a specified schema field is not found.""" @@ -96,6 +100,7 @@ class ChronoFormatWarning(Warning): "SchemaError", "SchemaFieldNotFoundError", "ShapeError", + "StringCacheMismatchError", "StructFieldNotFoundError", "TooManyRowsReturnedError", ] diff --git a/py-polars/src/error.rs b/py-polars/src/error.rs index 6aa5c6e995c1..3d8c81d7aecf 100644 --- a/py-polars/src/error.rs +++ b/py-polars/src/error.rs @@ -45,6 +45,9 @@ impl std::convert::From for PyErr { } PolarsError::SchemaMismatch(err) => SchemaError::new_err(err.to_string()), PolarsError::ShapeMismatch(err) => ShapeError::new_err(err.to_string()), + PolarsError::StringCacheMismatch(err) => { + StringCacheMismatchError::new_err(err.to_string()) + } PolarsError::StructFieldNotFound(name) => { StructFieldNotFoundError::new_err(name.to_string()) } @@ -75,6 +78,7 @@ create_exception!(exceptions, NoDataError, PyException); create_exception!(exceptions, SchemaError, PyException); create_exception!(exceptions, SchemaFieldNotFoundError, PyException); create_exception!(exceptions, ShapeError, PyException); +create_exception!(exceptions, StringCacheMismatchError, PyException); create_exception!(exceptions, StructFieldNotFoundError, PyException); #[macro_export] diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 04aa5d64f653..8fa7914bcb8b 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -232,6 +232,11 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { .unwrap(); m.add("ShapeError", py.get_type::()) .unwrap(); + m.add( + "StringCacheMismatchError", + py.get_type::(), + ) + .unwrap(); m.add( "StructFieldNotFoundError", py.get_type::(), diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index bf9cf7d84654..0348026924bf 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -7,6 +7,7 @@ import polars as pl from polars import StringCache +from polars.exceptions import StringCacheMismatchError from polars.testing import assert_frame_equal @@ -301,8 +302,8 @@ def test_err_on_categorical_asof_join_by_arg() -> None: ] ) with pytest.raises( - pl.ComputeError, - match=r"joins/or comparisons on categoricals can only happen if they were created under the same global string cache", + StringCacheMismatchError, + match="cannot compare categoricals coming from different sources", ): df1.join_asof(df2, on=pl.col("time").set_sorted(), by="cat") diff --git a/py-polars/tests/unit/test_cfg.py b/py-polars/tests/unit/test_cfg.py index a48f7e2685cc..7bebef9f6b5c 100644 --- a/py-polars/tests/unit/test_cfg.py +++ b/py-polars/tests/unit/test_cfg.py @@ -7,6 +7,7 @@ import polars as pl from polars.config import _get_float_fmt +from polars.exceptions import StringCacheMismatchError from polars.testing import assert_frame_equal if TYPE_CHECKING: @@ -481,7 +482,7 @@ def test_string_cache() -> None: df1a = df1.with_columns(pl.col("a").cast(pl.Categorical)) df2a = df2.with_columns(pl.col("a").cast(pl.Categorical)) - with pytest.raises(pl.ComputeError): + with pytest.raises(StringCacheMismatchError): _ = df1a.join(df2a, on="a", how="inner") # now turn on the cache