Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow null dtypes in UDFs if they match the schema #15699

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,36 @@ impl DataType {
_ => false,
}
}

// Answers if this type matches the given type of a schema.
//
// Allows (nested) Null types in this type to match any type in the schema,
// but not vice versa. In such a case Ok(true) is returned, because a cast
// is necessary. If no cast is necessary Ok(false) is returned, and an
// error is returned if the types are incompatible.
pub fn matches_schema_type(&self, schema_type: &DataType) -> PolarsResult<bool> {
match (self, schema_type) {
(DataType::List(l), DataType::List(r)) => l.matches_schema_type(r),
#[cfg(feature = "dtype-struct")]
(DataType::Struct(l), DataType::Struct(r)) => {
let mut must_cast = false;
for (l, r) in l.iter().zip(r.iter()) {
must_cast |= l.dtype.matches_schema_type(&r.dtype)?;
}
Ok(must_cast)
},
(DataType::Null, DataType::Null) => Ok(false),
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, s1), DataType::Decimal(_, s2)) => Ok(s1 != s2),
// We don't allow the other way around, only if our current type is
// null and the schema isn't we allow it.
(DataType::Null, _) => Ok(true),
(l, r) if l == r => Ok(false),
(l, r) => {
polars_bail!(SchemaMismatch: "type {:?} is incompatible with expected type {:?}", l, r)
},
}
}
}

impl PartialEq<ArrowDataType> for DataType {
Expand Down Expand Up @@ -580,34 +610,6 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult<DataType>
})
}

// if returns
// `Ok(true)`: can extend, but must cast
// `Ok(false)`: can extend as is
// Error: cannot extend.
pub(crate) fn can_extend_dtype(left: &DataType, right: &DataType) -> PolarsResult<bool> {
match (left, right) {
(DataType::List(l), DataType::List(r)) => can_extend_dtype(l, r),
#[cfg(feature = "dtype-struct")]
(DataType::Struct(l), DataType::Struct(r)) => {
let mut must_cast = false;
for (l, r) in l.iter().zip(r.iter()) {
must_cast |= can_extend_dtype(&l.dtype, &r.dtype)?;
}
Ok(must_cast)
},
(DataType::Null, DataType::Null) => Ok(false),
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, s1), DataType::Decimal(_, s2)) => Ok(s1 != s2),
// Other way around we don't allow because we keep left dtype as is.
// We don't go to supertype, and we certainly don't want to cast self to null type.
(_, DataType::Null) => Ok(true),
(l, r) => {
polars_ensure!(l == r, SchemaMismatch: "cannot extend/append {:?} with {:?}", left, right);
Ok(false)
},
}
}

#[cfg(feature = "dtype-categorical")]
pub fn create_enum_data_type(categories: Utf8ViewArray) -> DataType {
let rev_map = RevMapping::build_local(categories);
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ impl Series {
///
/// See [`ChunkedArray::append`] and [`ChunkedArray::extend`].
pub fn append(&mut self, other: &Series) -> PolarsResult<&mut Self> {
let must_cast = can_extend_dtype(self.dtype(), other.dtype())?;

let must_cast = other.dtype().matches_schema_type(self.dtype())?;
if must_cast {
let other = other.cast(self.dtype())?;
self._get_inner_mut().append(&other)?;
Expand All @@ -274,8 +273,7 @@ impl Series {
///
/// See [`ChunkedArray::extend`] and [`ChunkedArray::append`].
pub fn extend(&mut self, other: &Series) -> PolarsResult<&mut Self> {
let must_cast = can_extend_dtype(self.dtype(), other.dtype())?;

let must_cast = other.dtype().matches_schema_type(self.dtype())?;
if must_cast {
let other = other.cast(self.dtype())?;
self._get_inner_mut().extend(&other)?;
Expand Down
21 changes: 13 additions & 8 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,19 @@ impl SeriesUdf for PythonUdfExpression {
let func = unsafe { CALL_SERIES_UDF_PYTHON.unwrap() };

let output_type = self.output_type.clone().unwrap_or(DataType::Unknown);
let out = func(s[0].clone(), &self.python_function)?;

polars_ensure!(
matches!(output_type, DataType::Unknown) || out.dtype() == &output_type,
SchemaMismatch:
"expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
output_type, out.dtype(),
);
let mut out = func(s[0].clone(), &self.python_function)?;
if output_type != DataType::Unknown {
let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
polars_err!(
SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
output_type, out.dtype(),
)
})?;
if must_cast {
out = out.cast(&output_type)?;
}
}

Ok(Some(out))
}

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_append_to_an_enum() -> None:
def test_append_to_an_enum_with_new_category() -> None:
with pytest.raises(
pl.SchemaError,
match=("cannot extend/append Enum"),
match=("type Enum.*is incompatible with expected type Enum.*"),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])).append(
pl.Series(["d", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"]))
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def test_map_elements_infer_list() -> None:
assert df.select([pl.all().map_elements(lambda x: [x])]).dtypes == [pl.List] * 3


def test_map_elements_upcast_null_dtype_empty_list() -> None:
df = pl.DataFrame({"a": [1, 2]})
out = df.select(
pl.col("a").map_elements(lambda _: [], return_dtype=pl.List(pl.Int64))
)
assert_frame_equal(
out, pl.DataFrame({"a": [[], []]}, schema={"a": pl.List(pl.Int64)})
)


def test_map_elements_arithmetic_consistency() -> None:
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):
Expand Down
Loading