diff --git a/polars/polars-core/src/chunked_array/logical/struct_/mod.rs b/polars/polars-core/src/chunked_array/logical/struct_/mod.rs index 568c5d005f52..ff85426def6e 100644 --- a/polars/polars-core/src/chunked_array/logical/struct_/mod.rs +++ b/polars/polars-core/src/chunked_array/logical/struct_/mod.rs @@ -182,23 +182,47 @@ impl StructChunked { let mut null_count = 0; let chunks_lens = self.fields()[0].chunks().len(); + // fast path + // we early return if a column doesn't have nulls + for i in 0..chunks_lens { + for s in self.fields() { + let arr = &s.chunks()[i]; + let has_nulls = arr.null_count() > 0 || matches!(s.dtype(), DataType::Null); + if !has_nulls { + self.null_count = 0; + return; + } + } + } + + // slow path + // we bitand every null validity bitmask to determine + // in which rows all values are null for i in 0..chunks_lens { - // If all fields are null we count it as null - // so we bitand every chunk let mut validity_agg = None; + let mut all_null_array = true; for s in self.fields() { let arr = &s.chunks()[i]; - match (&validity_agg, arr.validity()) { - (Some(agg), Some(validity)) => validity_agg = Some(validity.bitand(agg)), - (None, Some(validity)) => validity_agg = Some(validity.clone()), - _ => {} + if !matches!(s.dtype(), DataType::Null) { + all_null_array = false; + match (&validity_agg, arr.validity()) { + (Some(agg), Some(validity)) => validity_agg = Some(validity.bitand(agg)), + (None, Some(validity)) => validity_agg = Some(validity.clone()), + _ => {} + } } } + // we add the null count if let Some(validity) = &validity_agg { null_count += validity.unset_bits() } + // all arrays are null arrays + // we add the length of the chunk to the null_count + else if all_null_array { + null_count += self.fields()[0].chunks()[i].len() + } } self.null_count = null_count } diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index a8cb14515510..442fcacfec88 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -904,3 +904,9 @@ def test_struct_name_passed_in_agg_apply() -> None: ] ], } + + +def test_struct_null_count_strict_cast() -> None: + s = pl.Series([{"a": None}]).cast(pl.Struct({"a": pl.Categorical})) + assert s.dtype == pl.Struct([pl.Field("a", pl.Categorical)]) + assert s.to_list() == [{"a": None}]