diff --git a/polars/polars-core/src/schema.rs b/polars/polars-core/src/schema.rs index e4580b173a9c..5a3ad80f385c 100644 --- a/polars/polars-core/src/schema.rs +++ b/polars/polars-core/src/schema.rs @@ -276,6 +276,16 @@ impl Schema { self.inner.shift_remove(name) } + /// Remove a field by name, preserving order, and, if the field existed, return its dtype + /// + /// If the field does not exist, the schema is not modified and `None` is returned. + /// + /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a + /// faster, but not order-preserving, method, use [`remove`][Self::remove]. + pub fn shift_remove_index(&mut self, index: usize) -> Option<(SmartString, DataType)> { + self.inner.shift_remove_index(index) + } + /// Whether the schema contains a field named `name` pub fn contains(&self, name: &str) -> bool { self.get(name).is_some() diff --git a/polars/polars-lazy/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/polars/polars-lazy/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index b06bb93040b9..19133bb93ccd 100644 --- a/polars/polars-lazy/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/polars/polars-lazy/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,10 +1,12 @@ use std::any::Any; +use polars_arrow::export::arrow::array::BinaryArray; use polars_core::prelude::sort::_broadcast_descending; use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_plan::prelude::*; +use polars_row::decode::decode_rows_from_binary; use polars_row::SortField; use super::*; @@ -25,11 +27,79 @@ fn get_sort_fields(sort_idx: &[usize], sort_args: &SortArguments) -> Vec bool { + !sort_idx + .iter() + .any(|i| matches!(schema.get_at_index(*i).unwrap().1, DataType::Categorical(_))) +} +#[cfg(not(feature = "dtype-categorical"))] +fn sort_column_can_be_decoded(_schema: &Schema, _sort_idx: &[usize]) -> bool { + true +} + +fn sort_by_idx(values: &[V], idx: &[usize]) -> Vec { + assert_eq!(values.len(), idx.len()); + + let mut tmp = values + .iter() + .cloned() + .zip(idx.iter().copied()) + .collect::>(); + tmp.sort_unstable_by_key(|k| k.1); + tmp.into_iter().map(|k| k.0).collect() +} + +#[allow(clippy::too_many_arguments)] +fn finalize_dataframe( + df: &mut DataFrame, + sort_idx: &[usize], + sort_args: &SortArguments, + can_decode: bool, + sort_dtypes: Option<&[ArrowDataType]>, + rows: &mut Vec<&'static [u8]>, + sort_fields: &[SortField], + schema: &Schema, +) { unsafe { let cols = df.get_columns_mut(); // pop the encoded sort column - let _ = cols.pop(); + let encoded = cols.pop().unwrap(); + + // we decode the row-encoded binary column + // this will be decoded into multiple columns + // this are the columns we sorted by + // those need to be inserted at the `sort_idx` position + // in the `DataFrame`. + if can_decode { + let sort_dtypes = sort_dtypes.expect("should be set"); + let sort_dtypes = sort_by_idx(sort_dtypes, sort_idx); + + let encoded = encoded.binary().unwrap(); + assert_eq!(encoded.chunks().len(), 1); + let arr = encoded.downcast_iter().next().unwrap(); + + // safety + // temporary extend lifetime + // this is safe as the lifetime in rows stays bound to this scope + let arrays = { + let arr = + std::mem::transmute::<&'_ BinaryArray, &'static BinaryArray>(arr); + decode_rows_from_binary(arr, sort_fields, &sort_dtypes, rows) + }; + rows.clear(); + + let arrays = sort_by_idx(&arrays, sort_idx); + let mut sort_idx = sort_idx.to_vec(); + sort_idx.sort_unstable(); + + for (sort_idx, arr) in sort_idx.into_iter().zip(arrays) { + let (name, logical_dtype) = schema.get_at_index(sort_idx).unwrap(); + assert_eq!(logical_dtype.to_physical(), DataType::from(arr.data_type())); + let col = Series::from_chunks_and_dtype_unchecked(name, vec![arr], logical_dtype); + cols.insert(sort_idx, col); + } + } let first_sort_col = &mut cols[sort_idx[0]]; let flag = if sort_args.descending[0] { @@ -49,18 +119,45 @@ fn finalize_dataframe(df: &mut DataFrame, sort_idx: &[usize], sort_args: &SortAr /// Once the sorting is finished it adapts the result so that /// the encoded column is removed pub struct SortSinkMultiple { + output_schema: SchemaRef, sort_idx: Arc<[usize]>, sort_sink: Box, sort_args: SortArguments, // Needed for encoding sort_fields: Arc<[SortField]>, + sort_dtypes: Option>, // amortize allocs sort_column: Vec, + // if we can decode the sort columns, we will remove those + // columns and decode the binary row-format to restore the + // original columns. This ensures we don't need to keep + // redundant data around in memory or on disk + can_decode: bool, } impl SortSinkMultiple { - pub(crate) fn new(sort_args: SortArguments, schema: &Schema, sort_idx: Vec) -> Self { - let mut schema = schema.clone(); + pub(crate) fn new( + sort_args: SortArguments, + output_schema: SchemaRef, + sort_idx: Vec, + ) -> Self { + let can_decode = sort_column_can_be_decoded(&output_schema, &sort_idx); + let mut schema = (*output_schema).clone(); + + let mut sort_dtypes = None; + if can_decode { + let mut dtypes = Vec::with_capacity(sort_idx.len()); + + // we remove columns by index, but then the indices aren't correct anymore + // so we do it in the proper order and keep track of the indices removed + let mut sorted_sort_idx = sort_idx.to_vec(); + sorted_sort_idx.sort_unstable(); + // remove the sort indices as we will encode them into the sort binary + for (i, sort_i) in sorted_sort_idx.iter().enumerate() { + dtypes.push(schema.shift_remove_index(*sort_i - i).unwrap().1); + } + sort_dtypes = Some(dtypes.into()); + } schema.with_column(POLARS_SORT_COLUMN.into(), DataType::Binary); let sort_fields = get_sort_fields(&sort_idx, &sort_args); @@ -82,13 +179,16 @@ impl SortSinkMultiple { sort_args, sort_idx: Arc::from(sort_idx), sort_fields: Arc::from(sort_fields), + sort_dtypes, sort_column: vec![], + can_decode, + output_schema, } } fn encode(&mut self, chunk: &mut DataChunk) -> PolarsResult<()> { - let df = &chunk.data; - let cols = df.get_columns(); + let df = &mut chunk.data; + let cols = unsafe { df.get_columns_mut() }; self.sort_column.clear(); @@ -97,6 +197,23 @@ impl SortSinkMultiple { let arr = _get_rows_encoded_compat_array(s)?; self.sort_column.push(arr); } + + if self.can_decode { + // we remove columns by index, but then the aren't correct anymore + // so we do it in the proper order and keep track of the indices removed + let mut sorted_sort_idx = self.sort_idx.to_vec(); + sorted_sort_idx.sort_unstable(); + + sorted_sort_idx + .into_iter() + .enumerate() + .for_each(|(i, sort_idx)| { + // shifts all columns right from removed one to the left so + // therefore we subtract `i` as the shifted count + let _ = cols.remove(sort_idx - i); + }) + } + let rows_encoded = polars_row::convert_columns(&self.sort_column, &self.sort_fields); let column = unsafe { Series::from_chunks_and_dtype_unchecked( @@ -135,22 +252,44 @@ impl Sink for SortSinkMultiple { sort_fields: self.sort_fields.clone(), sort_args: self.sort_args.clone(), sort_column: vec![], + can_decode: self.can_decode, + sort_dtypes: self.sort_dtypes.clone(), + output_schema: self.output_schema.clone(), }) } fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { let out = self.sort_sink.finalize(context)?; + let sort_dtypes = self + .sort_dtypes + .take() + .map(|arr| arr.iter().map(|dt| dt.to_arrow()).collect::>()); + // we must adapt the finalized sink result so that the sort encoded column is dropped match out { FinalizedSink::Finished(mut df) => { - finalize_dataframe(&mut df, self.sort_idx.as_ref(), &self.sort_args); + finalize_dataframe( + &mut df, + self.sort_idx.as_ref(), + &self.sort_args, + self.can_decode, + sort_dtypes.as_deref(), + &mut vec![], + self.sort_fields.as_ref(), + &self.output_schema, + ); Ok(FinalizedSink::Finished(df)) } FinalizedSink::Source(source) => Ok(FinalizedSink::Source(Box::new(DropEncoded { source, sort_idx: self.sort_idx.clone(), sort_args: std::mem::take(&mut self.sort_args), + can_decode: self.can_decode, + sort_dtypes, + rows: vec![], + sort_fields: self.sort_fields.clone(), + output_schema: self.output_schema.clone(), }))), // SortSink should not produce this branch FinalizedSink::Operator(_) => unreachable!(), @@ -170,6 +309,11 @@ struct DropEncoded { source: Box, sort_idx: Arc<[usize]>, sort_args: SortArguments, + can_decode: bool, + sort_dtypes: Option>, + rows: Vec<&'static [u8]>, + sort_fields: Arc<[SortField]>, + output_schema: SchemaRef, } impl Source for DropEncoded { @@ -177,7 +321,16 @@ impl Source for DropEncoded { let mut result = self.source.get_batches(context); if let Ok(SourceResult::GotMoreData(data)) = &mut result { for chunk in data { - finalize_dataframe(&mut chunk.data, self.sort_idx.as_ref(), &self.sort_args) + finalize_dataframe( + &mut chunk.data, + self.sort_idx.as_ref(), + &self.sort_args, + self.can_decode, + self.sort_dtypes.as_deref(), + &mut self.rows, + self.sort_fields.as_ref(), + &self.output_schema, + ) } }; result diff --git a/polars/polars-lazy/polars-pipe/src/pipeline/convert.rs b/polars/polars-lazy/polars-pipe/src/pipeline/convert.rs index e0b1f45b2098..3927b216d5e4 100644 --- a/polars/polars-lazy/polars-pipe/src/pipeline/convert.rs +++ b/polars/polars-lazy/polars-pipe/src/pipeline/convert.rs @@ -220,7 +220,7 @@ where }) .collect::>>()?; - let sort_sink = SortSinkMultiple::new(args.clone(), &input_schema, sort_idx); + let sort_sink = SortSinkMultiple::new(args.clone(), input_schema, sort_idx); Box::new(sort_sink) as Box } } diff --git a/polars/polars-row/src/decode.rs b/polars/polars-row/src/decode.rs index 52dd6f1b1969..c968e6ae4b66 100644 --- a/polars/polars-row/src/decode.rs +++ b/polars/polars-row/src/decode.rs @@ -4,6 +4,22 @@ use super::*; use crate::fixed::{decode_bool, decode_primitive}; use crate::variable::decode_binary; +/// Decode `rows` into a arrow format +/// # Safety +/// This will not do any bound checks. Caller must ensure the `rows` are valid +/// encodings. +pub unsafe fn decode_rows_from_binary<'a>( + arr: &'a BinaryArray, + fields: &[SortField], + data_types: &[DataType], + rows: &mut Vec<&'a [u8]>, +) -> Vec { + assert_eq!(arr.null_count(), 0); + rows.clear(); + rows.extend(arr.values_iter()); + decode_rows(rows, fields, data_types) +} + /// Decode `rows` into a arrow format /// # Safety /// This will not do any bound checks. Caller must ensure the `rows` are valid diff --git a/polars/polars-row/src/fixed.rs b/polars/polars-row/src/fixed.rs index c72c40966fca..383b126f29ae 100644 --- a/polars/polars-row/src/fixed.rs +++ b/polars/polars-row/src/fixed.rs @@ -216,11 +216,11 @@ where let values = rows .iter() .map(|row| { - has_nulls |= *row.get_unchecked(0) == null_sentinel; + has_nulls |= *row.get_unchecked_release(0) == null_sentinel; // skip null sentinel let start = 1; let end = start + T::ENCODED_LEN - 1; - let slice = row.get_unchecked(start..end); + let slice = row.get_unchecked_release(start..end); let bytes = T::Encoded::from_slice(slice); T::decode(bytes) }) @@ -247,11 +247,11 @@ pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &SortField) -> Boole let values = rows .iter() .map(|row| { - has_nulls |= *row.get_unchecked(0) == null_sentinel; + has_nulls |= *row.get_unchecked_release(0) == null_sentinel; // skip null sentinel let start = 1; let end = start + bool::ENCODED_LEN - 1; - let slice = row.get_unchecked(start..end); + let slice = row.get_unchecked_release(start..end); let bytes = ::Encoded::from_slice(slice); bool::decode(bytes) }) @@ -271,12 +271,12 @@ pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &SortField) -> Boole } unsafe fn increment_row_counter(rows: &mut [&[u8]], fixed_size: usize) { for row in rows { - *row = row.get_unchecked(fixed_size..); + *row = row.get_unchecked_release(fixed_size..); } } pub(super) unsafe fn decode_nulls(rows: &[&[u8]], null_sentinel: u8) -> Bitmap { rows.iter() - .map(|row| *row.get_unchecked(0) != null_sentinel) + .map(|row| *row.get_unchecked_release(0) != null_sentinel) .collect() }