Skip to content

Commit

Permalink
perf(rust, python): remove sort columns on multiple-key OOC sort (#9545)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 26, 2023
1 parent a73d4bb commit ce465de
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 15 deletions.
10 changes: 10 additions & 0 deletions polars/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ impl Schema {
self.inner.shift_remove(name)
}

/// Remove a field by name, preserving order, and, if the field existed, return its dtype
///
/// If the field does not exist, the schema is not modified and `None` is returned.
///
/// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a
/// faster, but not order-preserving, method, use [`remove`][Self::remove].
pub fn shift_remove_index(&mut self, index: usize) -> Option<(SmartString, DataType)> {
self.inner.shift_remove_index(index)
}

/// Whether the schema contains a field named `name`
pub fn contains(&self, name: &str) -> bool {
self.get(name).is_some()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::any::Any;

use polars_arrow::export::arrow::array::BinaryArray;
use polars_core::prelude::sort::_broadcast_descending;
use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_plan::prelude::*;
use polars_row::decode::decode_rows_from_binary;
use polars_row::SortField;

use super::*;
Expand All @@ -25,11 +27,79 @@ fn get_sort_fields(sort_idx: &[usize], sort_args: &SortArguments) -> Vec<SortFie
.collect()
}

fn finalize_dataframe(df: &mut DataFrame, sort_idx: &[usize], sort_args: &SortArguments) {
#[cfg(feature = "dtype-categorical")]
fn sort_column_can_be_decoded(schema: &Schema, sort_idx: &[usize]) -> bool {
!sort_idx
.iter()
.any(|i| matches!(schema.get_at_index(*i).unwrap().1, DataType::Categorical(_)))
}
#[cfg(not(feature = "dtype-categorical"))]
fn sort_column_can_be_decoded(_schema: &Schema, _sort_idx: &[usize]) -> bool {
true
}

fn sort_by_idx<V: Clone>(values: &[V], idx: &[usize]) -> Vec<V> {
assert_eq!(values.len(), idx.len());

let mut tmp = values
.iter()
.cloned()
.zip(idx.iter().copied())
.collect::<Vec<_>>();
tmp.sort_unstable_by_key(|k| k.1);
tmp.into_iter().map(|k| k.0).collect()
}

#[allow(clippy::too_many_arguments)]
fn finalize_dataframe(
df: &mut DataFrame,
sort_idx: &[usize],
sort_args: &SortArguments,
can_decode: bool,
sort_dtypes: Option<&[ArrowDataType]>,
rows: &mut Vec<&'static [u8]>,
sort_fields: &[SortField],
schema: &Schema,
) {
unsafe {
let cols = df.get_columns_mut();
// pop the encoded sort column
let _ = cols.pop();
let encoded = cols.pop().unwrap();

// we decode the row-encoded binary column
// this will be decoded into multiple columns
// this are the columns we sorted by
// those need to be inserted at the `sort_idx` position
// in the `DataFrame`.
if can_decode {
let sort_dtypes = sort_dtypes.expect("should be set");
let sort_dtypes = sort_by_idx(sort_dtypes, sort_idx);

let encoded = encoded.binary().unwrap();
assert_eq!(encoded.chunks().len(), 1);
let arr = encoded.downcast_iter().next().unwrap();

// safety
// temporary extend lifetime
// this is safe as the lifetime in rows stays bound to this scope
let arrays = {
let arr =
std::mem::transmute::<&'_ BinaryArray<i64>, &'static BinaryArray<i64>>(arr);
decode_rows_from_binary(arr, sort_fields, &sort_dtypes, rows)
};
rows.clear();

let arrays = sort_by_idx(&arrays, sort_idx);
let mut sort_idx = sort_idx.to_vec();
sort_idx.sort_unstable();

for (sort_idx, arr) in sort_idx.into_iter().zip(arrays) {
let (name, logical_dtype) = schema.get_at_index(sort_idx).unwrap();
assert_eq!(logical_dtype.to_physical(), DataType::from(arr.data_type()));
let col = Series::from_chunks_and_dtype_unchecked(name, vec![arr], logical_dtype);
cols.insert(sort_idx, col);
}
}

let first_sort_col = &mut cols[sort_idx[0]];
let flag = if sort_args.descending[0] {
Expand All @@ -49,18 +119,45 @@ fn finalize_dataframe(df: &mut DataFrame, sort_idx: &[usize], sort_args: &SortAr
/// Once the sorting is finished it adapts the result so that
/// the encoded column is removed
pub struct SortSinkMultiple {
output_schema: SchemaRef,
sort_idx: Arc<[usize]>,
sort_sink: Box<dyn Sink>,
sort_args: SortArguments,
// Needed for encoding
sort_fields: Arc<[SortField]>,
sort_dtypes: Option<Arc<[DataType]>>,
// amortize allocs
sort_column: Vec<ArrayRef>,
// if we can decode the sort columns, we will remove those
// columns and decode the binary row-format to restore the
// original columns. This ensures we don't need to keep
// redundant data around in memory or on disk
can_decode: bool,
}

impl SortSinkMultiple {
pub(crate) fn new(sort_args: SortArguments, schema: &Schema, sort_idx: Vec<usize>) -> Self {
let mut schema = schema.clone();
pub(crate) fn new(
sort_args: SortArguments,
output_schema: SchemaRef,
sort_idx: Vec<usize>,
) -> Self {
let can_decode = sort_column_can_be_decoded(&output_schema, &sort_idx);
let mut schema = (*output_schema).clone();

let mut sort_dtypes = None;
if can_decode {
let mut dtypes = Vec::with_capacity(sort_idx.len());

// we remove columns by index, but then the indices aren't correct anymore
// so we do it in the proper order and keep track of the indices removed
let mut sorted_sort_idx = sort_idx.to_vec();
sorted_sort_idx.sort_unstable();
// remove the sort indices as we will encode them into the sort binary
for (i, sort_i) in sorted_sort_idx.iter().enumerate() {
dtypes.push(schema.shift_remove_index(*sort_i - i).unwrap().1);
}
sort_dtypes = Some(dtypes.into());
}
schema.with_column(POLARS_SORT_COLUMN.into(), DataType::Binary);
let sort_fields = get_sort_fields(&sort_idx, &sort_args);

Expand All @@ -82,13 +179,16 @@ impl SortSinkMultiple {
sort_args,
sort_idx: Arc::from(sort_idx),
sort_fields: Arc::from(sort_fields),
sort_dtypes,
sort_column: vec![],
can_decode,
output_schema,
}
}

fn encode(&mut self, chunk: &mut DataChunk) -> PolarsResult<()> {
let df = &chunk.data;
let cols = df.get_columns();
let df = &mut chunk.data;
let cols = unsafe { df.get_columns_mut() };

self.sort_column.clear();

Expand All @@ -97,6 +197,23 @@ impl SortSinkMultiple {
let arr = _get_rows_encoded_compat_array(s)?;
self.sort_column.push(arr);
}

if self.can_decode {
// we remove columns by index, but then the aren't correct anymore
// so we do it in the proper order and keep track of the indices removed
let mut sorted_sort_idx = self.sort_idx.to_vec();
sorted_sort_idx.sort_unstable();

sorted_sort_idx
.into_iter()
.enumerate()
.for_each(|(i, sort_idx)| {
// shifts all columns right from removed one to the left so
// therefore we subtract `i` as the shifted count
let _ = cols.remove(sort_idx - i);
})
}

let rows_encoded = polars_row::convert_columns(&self.sort_column, &self.sort_fields);
let column = unsafe {
Series::from_chunks_and_dtype_unchecked(
Expand Down Expand Up @@ -135,22 +252,44 @@ impl Sink for SortSinkMultiple {
sort_fields: self.sort_fields.clone(),
sort_args: self.sort_args.clone(),
sort_column: vec![],
can_decode: self.can_decode,
sort_dtypes: self.sort_dtypes.clone(),
output_schema: self.output_schema.clone(),
})
}

fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult<FinalizedSink> {
let out = self.sort_sink.finalize(context)?;

let sort_dtypes = self
.sort_dtypes
.take()
.map(|arr| arr.iter().map(|dt| dt.to_arrow()).collect::<Vec<_>>());

// we must adapt the finalized sink result so that the sort encoded column is dropped
match out {
FinalizedSink::Finished(mut df) => {
finalize_dataframe(&mut df, self.sort_idx.as_ref(), &self.sort_args);
finalize_dataframe(
&mut df,
self.sort_idx.as_ref(),
&self.sort_args,
self.can_decode,
sort_dtypes.as_deref(),
&mut vec![],
self.sort_fields.as_ref(),
&self.output_schema,
);
Ok(FinalizedSink::Finished(df))
}
FinalizedSink::Source(source) => Ok(FinalizedSink::Source(Box::new(DropEncoded {
source,
sort_idx: self.sort_idx.clone(),
sort_args: std::mem::take(&mut self.sort_args),
can_decode: self.can_decode,
sort_dtypes,
rows: vec![],
sort_fields: self.sort_fields.clone(),
output_schema: self.output_schema.clone(),
}))),
// SortSink should not produce this branch
FinalizedSink::Operator(_) => unreachable!(),
Expand All @@ -170,14 +309,28 @@ struct DropEncoded {
source: Box<dyn Source>,
sort_idx: Arc<[usize]>,
sort_args: SortArguments,
can_decode: bool,
sort_dtypes: Option<Vec<ArrowDataType>>,
rows: Vec<&'static [u8]>,
sort_fields: Arc<[SortField]>,
output_schema: SchemaRef,
}

impl Source for DropEncoded {
fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult<SourceResult> {
let mut result = self.source.get_batches(context);
if let Ok(SourceResult::GotMoreData(data)) = &mut result {
for chunk in data {
finalize_dataframe(&mut chunk.data, self.sort_idx.as_ref(), &self.sort_args)
finalize_dataframe(
&mut chunk.data,
self.sort_idx.as_ref(),
&self.sort_args,
self.can_decode,
self.sort_dtypes.as_deref(),
&mut self.rows,
self.sort_fields.as_ref(),
&self.output_schema,
)
}
};
result
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-pipe/src/pipeline/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ where
})
.collect::<PolarsResult<Vec<_>>>()?;

let sort_sink = SortSinkMultiple::new(args.clone(), &input_schema, sort_idx);
let sort_sink = SortSinkMultiple::new(args.clone(), input_schema, sort_idx);
Box::new(sort_sink) as Box<dyn Sink>
}
}
Expand Down
16 changes: 16 additions & 0 deletions polars/polars-row/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@ use super::*;
use crate::fixed::{decode_bool, decode_primitive};
use crate::variable::decode_binary;

/// Decode `rows` into a arrow format
/// # Safety
/// This will not do any bound checks. Caller must ensure the `rows` are valid
/// encodings.
pub unsafe fn decode_rows_from_binary<'a>(
arr: &'a BinaryArray<i64>,
fields: &[SortField],
data_types: &[DataType],
rows: &mut Vec<&'a [u8]>,
) -> Vec<ArrayRef> {
assert_eq!(arr.null_count(), 0);
rows.clear();
rows.extend(arr.values_iter());
decode_rows(rows, fields, data_types)
}

/// Decode `rows` into a arrow format
/// # Safety
/// This will not do any bound checks. Caller must ensure the `rows` are valid
Expand Down
12 changes: 6 additions & 6 deletions polars/polars-row/src/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ where
let values = rows
.iter()
.map(|row| {
has_nulls |= *row.get_unchecked(0) == null_sentinel;
has_nulls |= *row.get_unchecked_release(0) == null_sentinel;
// skip null sentinel
let start = 1;
let end = start + T::ENCODED_LEN - 1;
let slice = row.get_unchecked(start..end);
let slice = row.get_unchecked_release(start..end);
let bytes = T::Encoded::from_slice(slice);
T::decode(bytes)
})
Expand All @@ -247,11 +247,11 @@ pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &SortField) -> Boole
let values = rows
.iter()
.map(|row| {
has_nulls |= *row.get_unchecked(0) == null_sentinel;
has_nulls |= *row.get_unchecked_release(0) == null_sentinel;
// skip null sentinel
let start = 1;
let end = start + bool::ENCODED_LEN - 1;
let slice = row.get_unchecked(start..end);
let slice = row.get_unchecked_release(start..end);
let bytes = <bool as FixedLengthEncoding>::Encoded::from_slice(slice);
bool::decode(bytes)
})
Expand All @@ -271,12 +271,12 @@ pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &SortField) -> Boole
}
unsafe fn increment_row_counter(rows: &mut [&[u8]], fixed_size: usize) {
for row in rows {
*row = row.get_unchecked(fixed_size..);
*row = row.get_unchecked_release(fixed_size..);
}
}

pub(super) unsafe fn decode_nulls(rows: &[&[u8]], null_sentinel: u8) -> Bitmap {
rows.iter()
.map(|row| *row.get_unchecked(0) != null_sentinel)
.map(|row| *row.get_unchecked_release(0) != null_sentinel)
.collect()
}

0 comments on commit ce465de

Please sign in to comment.