Skip to content

Commit

Permalink
fix(rust, python): fix struct null_count if fields are null arrays (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and c-peters committed Jul 14, 2023
1 parent 5b76f98 commit bcd4082
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
36 changes: 30 additions & 6 deletions polars/polars-core/src/chunked_array/logical/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,23 +182,47 @@ impl StructChunked {
let mut null_count = 0;
let chunks_lens = self.fields()[0].chunks().len();

// fast path
// we early return if a column doesn't have nulls
for i in 0..chunks_lens {
for s in self.fields() {
let arr = &s.chunks()[i];
let has_nulls = arr.null_count() > 0 || matches!(s.dtype(), DataType::Null);
if !has_nulls {
self.null_count = 0;
return;
}
}
}

// slow path
// we bitand every null validity bitmask to determine
// in which rows all values are null
for i in 0..chunks_lens {
// If all fields are null we count it as null
// so we bitand every chunk
let mut validity_agg = None;

let mut all_null_array = true;
for s in self.fields() {
let arr = &s.chunks()[i];

match (&validity_agg, arr.validity()) {
(Some(agg), Some(validity)) => validity_agg = Some(validity.bitand(agg)),
(None, Some(validity)) => validity_agg = Some(validity.clone()),
_ => {}
if !matches!(s.dtype(), DataType::Null) {
all_null_array = false;
match (&validity_agg, arr.validity()) {
(Some(agg), Some(validity)) => validity_agg = Some(validity.bitand(agg)),
(None, Some(validity)) => validity_agg = Some(validity.clone()),
_ => {}
}
}
}
// we add the null count
if let Some(validity) = &validity_agg {
null_count += validity.unset_bits()
}
// all arrays are null arrays
// we add the length of the chunk to the null_count
else if all_null_array {
null_count += self.fields()[0].chunks()[i].len()
}
}
self.null_count = null_count
}
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,3 +904,9 @@ def test_struct_name_passed_in_agg_apply() -> None:
]
],
}


def test_struct_null_count_strict_cast() -> None:
s = pl.Series([{"a": None}]).cast(pl.Struct({"a": pl.Categorical}))
assert s.dtype == pl.Struct([pl.Field("a", pl.Categorical)])
assert s.to_list() == [{"a": None}]

0 comments on commit bcd4082

Please sign in to comment.