Skip to content

Commit

Permalink
fix: Make output dtype known for list.to_struct when fields are p…
Browse files Browse the repository at this point in the history
…assed (#19439)
  • Loading branch information
nameexhaustion authored Oct 27, 2024
1 parent e26a229 commit 98fcc3f
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 115 deletions.
3 changes: 0 additions & 3 deletions crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,6 @@ impl serde::Serialize for PlCredentialProvider {
{
use serde::ser::Error;

// TODO:
// * Add magic bytes here to indicate a python function
// * Check the Python version on deserialize
#[cfg(feature = "python")]
if let PlCredentialProvider::Python(v) = self {
return v.serialize(serializer);
Expand Down
238 changes: 188 additions & 50 deletions crates/polars-ops/src/chunked_array/list/to_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,82 +5,220 @@ use polars_utils::pl_str::PlSmallStr;

use super::*;

#[derive(Copy, Clone, Debug)]
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ListToStructArgs {
FixedWidth(Arc<[PlSmallStr]>),
InferWidth {
infer_field_strategy: ListToStructWidthStrategy,
get_index_name: Option<NameGenerator>,
/// If this is 0, it means unbounded.
max_fields: usize,
},
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ListToStructWidthStrategy {
FirstNonNull,
MaxWidth,
}

fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize {
match n_fields {
ListToStructWidthStrategy::MaxWidth => {
let mut max = 0;

ca.downcast_iter().for_each(|arr| {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
let len = (*o - last) as usize;
max = std::cmp::max(max, len);
last = *o;
impl ListToStructArgs {
pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult<DataType> {
let DataType::List(inner_dtype) = input_dtype else {
polars_bail!(
InvalidOperation:
"attempted list to_struct on non-list dtype: {}",
input_dtype
);
};
let inner_dtype = inner_dtype.as_ref();

match self {
Self::FixedWidth(names) => Ok(DataType::Struct(
names
.iter()
.map(|x| Field::new(x.clone(), inner_dtype.clone()))
.collect::<Vec<_>>(),
)),
Self::InferWidth {
get_index_name,
max_fields,
..
} if *max_fields > 0 => {
let get_index_name_func = get_index_name.as_ref().map_or(
&_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr,
|x| x.0.as_ref(),
);
Ok(DataType::Struct(
(0..*max_fields)
.map(|i| Field::new(get_index_name_func(i), inner_dtype.clone()))
.collect::<Vec<_>>(),
))
},
Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)),
}
}

fn det_n_fields(&self, ca: &ListChunked) -> usize {
match self {
Self::FixedWidth(v) => v.len(),
Self::InferWidth {
infer_field_strategy,
max_fields,
..
} => {
let inferred = match infer_field_strategy {
ListToStructWidthStrategy::MaxWidth => {
let mut max = 0;

ca.downcast_iter().for_each(|arr| {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
let len = (*o - last) as usize;
max = std::cmp::max(max, len);
last = *o;
}
});
max
},
ListToStructWidthStrategy::FirstNonNull => {
let mut len = 0;
for arr in ca.downcast_iter() {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
len = (*o - last) as usize;
if len > 0 {
break;
}
last = *o;
}
if len > 0 {
break;
}
}
len
},
};

if *max_fields > 0 {
inferred.min(*max_fields)
} else {
inferred
}
});
max
},
ListToStructWidthStrategy::FirstNonNull => {
let mut len = 0;
for arr in ca.downcast_iter() {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
len = (*o - last) as usize;
if len > 0 {
break;
}
last = *o;
},
}
}

fn set_output_names(&self, columns: &mut [Series]) {
match self {
Self::FixedWidth(v) => {
assert_eq!(columns.len(), v.len());

for (c, name) in columns.iter_mut().zip(v.iter()) {
c.rename(name.clone());
}
if len > 0 {
break;
},
Self::InferWidth { get_index_name, .. } => {
let get_index_name_func = get_index_name.as_ref().map_or(
&_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr,
|x| x.0.as_ref(),
);

for (i, c) in columns.iter_mut().enumerate() {
c.rename(get_index_name_func(i));
}
}
len
},
},
}
}
}

#[derive(Clone)]
pub struct NameGenerator(pub Arc<dyn Fn(usize) -> PlSmallStr + Send + Sync>);

impl NameGenerator {
pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self {
Self(Arc::new(func))
}
}

impl std::fmt::Debug for NameGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"list::to_struct::NameGenerator function at 0x{:016x}",
self.0.as_ref() as *const _ as *const () as usize
)
}
}

impl Eq for NameGenerator {}

impl PartialEq for NameGenerator {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

pub type NameGenerator = Arc<dyn Fn(usize) -> PlSmallStr + Send + Sync>;
impl std::hash::Hash for NameGenerator {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
}
}

pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr {
format_pl_smallstr!("field_{idx}")
}

pub trait ToStruct: AsList {
fn to_struct(
&self,
n_fields: ListToStructWidthStrategy,
name_generator: Option<NameGenerator>,
) -> PolarsResult<StructChunked> {
fn to_struct(&self, args: &ListToStructArgs) -> PolarsResult<StructChunked> {
let ca = self.as_list();
let n_fields = det_n_fields(ca, n_fields);
let n_fields = args.det_n_fields(ca);

let name_generator = name_generator
.as_deref()
.unwrap_or(&_default_struct_name_gen);

let fields = POOL.install(|| {
let mut fields = POOL.install(|| {
(0..n_fields)
.into_par_iter()
.map(|i| {
ca.lst_get(i as i64, true).map(|mut s| {
s.rename(name_generator(i));
s
})
})
.map(|i| ca.lst_get(i as i64, true))
.collect::<PolarsResult<Vec<_>>>()
})?;

args.set_output_names(&mut fields);

StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())
}
}

impl ToStruct for ListChunked {}

#[cfg(feature = "serde")]
mod _serde_impl {
use super::*;

impl serde::Serialize for NameGenerator {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
Err(S::Error::custom(
"cannot serialize name generator function for to_struct, \
consider passing a list of field names instead.",
))
}
}

impl<'de> serde::Deserialize<'de> for NameGenerator {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
Err(D::Error::custom(
"invalid data: attempted to deserialize list::to_struct::NameGenerator",
))
}
}
}
15 changes: 14 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use polars_ops::chunked_array::list::*;
use super::*;
use crate::{map, map_as_slice, wrap};

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ListFunction {
Concat,
Expand Down Expand Up @@ -56,6 +56,8 @@ pub enum ListFunction {
Join(bool),
#[cfg(feature = "dtype-array")]
ToArray(usize),
#[cfg(feature = "list_to_struct")]
ToStruct(ListToStructArgs),
}

impl ListFunction {
Expand Down Expand Up @@ -103,6 +105,8 @@ impl ListFunction {
#[cfg(feature = "dtype-array")]
ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
NUnique => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)),
}
}
}
Expand Down Expand Up @@ -174,6 +178,8 @@ impl Display for ListFunction {
Join(_) => "join",
#[cfg(feature = "dtype-array")]
ToArray(_) => "to_array",
#[cfg(feature = "list_to_struct")]
ToStruct(_) => "to_struct",
};
write!(f, "list.{name}")
}
Expand Down Expand Up @@ -235,6 +241,8 @@ impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
#[cfg(feature = "dtype-array")]
ToArray(width) => map!(to_array, width),
NUnique => map!(n_unique),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => map!(to_struct, &args),
}
}
}
Expand Down Expand Up @@ -650,6 +658,11 @@ pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult<Column> {
s.cast(&array_dtype)
}

#[cfg(feature = "list_to_struct")]
pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult<Column> {
Ok(s.list()?.to_struct(args)?.into_series().into())
}

pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
Ok(s.list()?.lst_n_unique()?.into_column())
}
48 changes: 2 additions & 46 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(feature = "list_to_struct")]
use std::sync::RwLock;

use polars_core::prelude::*;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
Expand Down Expand Up @@ -281,50 +278,9 @@ impl ListNameSpace {
/// an `upper_bound` of struct fields that will be set.
/// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression
/// will look in the current schema to determine which columns to select.
pub fn to_struct(
self,
n_fields: ListToStructWidthStrategy,
name_generator: Option<NameGenerator>,
upper_bound: usize,
) -> Expr {
// heap allocate the output type and fill it later
let out_dtype = Arc::new(RwLock::new(None::<DataType>));

pub fn to_struct(self, args: ListToStructArgs) -> Expr {
self.0
.map(
move |s| {
s.list()?
.to_struct(n_fields, name_generator.clone())
.map(|s| Some(s.into_column()))
},
// we don't yet know the fields
GetOutput::map_dtype(move |dt: &DataType| {
polars_ensure!(matches!(dt, DataType::List(_)), SchemaMismatch: "expected 'List' as input to 'list.to_struct' got {}", dt);
let out = out_dtype.read().unwrap();
match out.as_ref() {
// dtype already set
Some(dt) => Ok(dt.clone()),
// dtype still unknown, set it
None => {
drop(out);
let mut lock = out_dtype.write().unwrap();

let inner = dt.inner_dtype().unwrap();
let fields = (0..upper_bound)
.map(|i| {
let name = _default_struct_name_gen(i);
Field::new(name, inner.clone())
})
.collect();
let dt = DataType::Struct(fields);

*lock = Some(dt.clone());
Ok(dt)
},
}
}),
)
.with_fmt("list.to_struct")
.map_private(FunctionExpr::ListExpr(ListFunction::ToStruct(args)))
}

#[cfg(feature = "is_in")]
Expand Down
Loading

0 comments on commit 98fcc3f

Please sign in to comment.