From 5185091e15749833e259d07cd7c266ce0eb95782 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 19 Jan 2024 13:27:11 +0100 Subject: [PATCH] perf: add optimized BinaryViewArray comparison kernels --- crates/polars-compute/src/comparisons/mod.rs | 1 + .../polars-compute/src/comparisons/scalar.rs | 109 +------- crates/polars-compute/src/comparisons/view.rs | 244 ++++++++++++++++++ .../tests/unit/operations/test_comparison.py | 3 + 4 files changed, 249 insertions(+), 108 deletions(-) create mode 100644 crates/polars-compute/src/comparisons/view.rs diff --git a/crates/polars-compute/src/comparisons/mod.rs b/crates/polars-compute/src/comparisons/mod.rs index 9cac2713713d..a0baebad6b7d 100644 --- a/crates/polars-compute/src/comparisons/mod.rs +++ b/crates/polars-compute/src/comparisons/mod.rs @@ -84,6 +84,7 @@ impl NotSimdPrimitive for u128 {} impl NotSimdPrimitive for i128 {} mod scalar; +mod view; #[cfg(feature = "simd")] mod simd; diff --git a/crates/polars-compute/src/comparisons/scalar.rs b/crates/polars-compute/src/comparisons/scalar.rs index 6ad3c4cea011..bc338f3816a8 100644 --- a/crates/polars-compute/src/comparisons/scalar.rs +++ b/crates/polars-compute/src/comparisons/scalar.rs @@ -1,6 +1,4 @@ -use arrow::array::{ - BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray, -}; +use arrow::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; use arrow::bitmap::{self, Bitmap}; use arrow::types::NativeType; use polars_utils::total_ord::{TotalEq, TotalOrd}; @@ -71,111 +69,6 @@ impl TotalOrdKernel for PrimitiveAr } } -impl TotalOrdKernel for BinaryViewArray { - type Scalar = [u8]; - - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - debug_assert!(self.len() == other.len()); - // TODO! speed-up by first comparing views - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_eq(&r)) - .collect() - } - - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - debug_assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_ne(&r)) - .collect() - } - - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - debug_assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_lt(&r)) - .collect() - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - debug_assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_le(&r)) - .collect() - } - - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_eq(&other)).collect() - } - - fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_ne(&other)).collect() - } - - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_lt(&other)).collect() - } - - fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_le(&other)).collect() - } - - fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_gt(&other)).collect() - } - - fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_ge(&other)).collect() - } -} - -impl TotalOrdKernel for Utf8ViewArray { - type Scalar = str; - - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - self.to_binview().tot_eq_kernel(&other.to_binview()) - } - - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - self.to_binview().tot_ne_kernel(&other.to_binview()) - } - - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - self.to_binview().tot_lt_kernel(&other.to_binview()) - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - self.to_binview().tot_le_kernel(&other.to_binview()) - } - - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binview().tot_eq_kernel_broadcast(other.as_bytes()) - } - - fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binview().tot_ne_kernel_broadcast(other.as_bytes()) - } - - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binview().tot_lt_kernel_broadcast(other.as_bytes()) - } - - fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binview().tot_le_kernel_broadcast(other.as_bytes()) - } - - fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binview().tot_gt_kernel_broadcast(other.as_bytes()) - } - - fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binview().tot_ge_kernel_broadcast(other.as_bytes()) - } -} - impl TotalOrdKernel for BinaryArray { type Scalar = [u8]; diff --git a/crates/polars-compute/src/comparisons/view.rs b/crates/polars-compute/src/comparisons/view.rs new file mode 100644 index 000000000000..01ecf816011a --- /dev/null +++ b/crates/polars-compute/src/comparisons/view.rs @@ -0,0 +1,244 @@ +use arrow::array::{BinaryViewArray, Utf8ViewArray}; +use arrow::bitmap::Bitmap; + +use crate::comparisons::TotalOrdKernel; + +// If s fits in 12 bytes, returns the view encoding it would have in a +// BinaryViewArray. +fn small_view_encoding(s: &[u8]) -> Option { + if s.len() > 12 { + return None; + } + + let mut tmp = [0u8; 16]; + tmp[0] = s.len() as u8; + tmp[4..4 + s.len()].copy_from_slice(s); + Some(u128::from_le_bytes(tmp)) +} + +// Loads (up to) the first 4 bytes of s as little-endian, padded with zeros. +fn load_prefix(s: &[u8]) -> u32 { + let start = &s[..s.len().min(4)]; + let mut tmp = [0u8; 4]; + tmp[..start.len()].copy_from_slice(start); + u32::from_le_bytes(tmp) +} + +fn broadcast_inequality( + arr: &BinaryViewArray, + scalar: &[u8], + cmp_prefix: impl Fn(u32, u32) -> bool, + cmp_str: impl Fn(&[u8], &[u8]) -> bool, +) -> Bitmap { + let views = arr.views().as_slice(); + let prefix = load_prefix(scalar); + let be_prefix = prefix.to_be(); + Bitmap::from_trusted_len_iter((0..arr.len()).map(|i| unsafe { + let v_prefix = (*views.get_unchecked(i) >> 32) as u32; + if v_prefix != prefix { + cmp_prefix(v_prefix.to_be(), be_prefix) + } else { + cmp_str(arr.value_unchecked(i), scalar) + } + })) +} + +impl TotalOrdKernel for BinaryViewArray { + type Scalar = [u8]; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = *slf_views.get_unchecked(i); + let bv = *other_views.get_unchecked(i); + + // First 64 bits contain length and prefix. + let a_len_prefix = av as u64; + let b_len_prefix = bv as u64; + if a_len_prefix != b_len_prefix { + return false; + } + + let alen = av as u32; + if alen <= 12 { + // String is fully inlined, compare top 64 bits. Bottom bits were + // tested equal before, which also ensures the lengths are equal. + (av >> 64) as u64 == (bv >> 64) as u64 + } else { + self.value_unchecked(i) == other.value_unchecked(i) + } + })) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = *slf_views.get_unchecked(i); + let bv = *other_views.get_unchecked(i); + + // First 64 bits contain length and prefix. + let a_len_prefix = av as u64; + let b_len_prefix = bv as u64; + if a_len_prefix != b_len_prefix { + return true; + } + + let alen = av as u32; + if alen <= 12 { + // String is fully inlined, compare top 64 bits. Bottom bits were + // tested equal before, which also ensures the lengths are equal. + (av >> 64) as u64 != (bv >> 64) as u64 + } else { + self.value_unchecked(i) != other.value_unchecked(i) + } + })) + } + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = *slf_views.get_unchecked(i); + let bv = *other_views.get_unchecked(i); + + // First 64 bits contain length and prefix. + // Only check prefix. + let a_prefix = (av >> 32) as u32; + let b_prefix = (bv >> 32) as u32; + if a_prefix != b_prefix { + a_prefix.to_be() < b_prefix.to_be() + } else { + self.value_unchecked(i) < other.value_unchecked(i) + } + })) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = *slf_views.get_unchecked(i); + let bv = *other_views.get_unchecked(i); + + // First 64 bits contain length and prefix. + // Only check prefix. + let a_prefix = (av >> 32) as u32; + let b_prefix = (bv >> 32) as u32; + if a_prefix != b_prefix { + a_prefix.to_be() < b_prefix.to_be() + } else { + self.value_unchecked(i) <= other.value_unchecked(i) + } + })) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if let Some(val) = small_view_encoding(other) { + Bitmap::from_trusted_len_iter(self.views().iter().map(|v| *v == val)) + } else { + let slf_views = self.views().as_slice(); + let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); + let prefix_len = ((prefix as u64) << 32) | other.len() as u64; + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let v_prefix_len = *slf_views.get_unchecked(i) as u64; + if v_prefix_len != prefix_len { + false + } else { + self.value_unchecked(i) == other + } + })) + } + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if let Some(val) = small_view_encoding(other) { + Bitmap::from_trusted_len_iter(self.views().iter().map(|v| *v != val)) + } else { + let slf_views = self.views().as_slice(); + let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); + let prefix_len = ((prefix as u64) << 32) | other.len() as u64; + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let v_prefix_len = *slf_views.get_unchecked(i) as u64; + if v_prefix_len != prefix_len { + true + } else { + self.value_unchecked(i) != other + } + })) + } + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a < b, |a, b| a < b) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a <= b, |a, b| a <= b) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a > b, |a, b| a > b) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a >= b, |a, b| a >= b) + } +} + +impl TotalOrdKernel for Utf8ViewArray { + type Scalar = str; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_eq_kernel(&other.to_binview()) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_ne_kernel(&other.to_binview()) + } + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_lt_kernel(&other.to_binview()) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_le_kernel(&other.to_binview()) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_eq_kernel_broadcast(other.as_bytes()) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_ne_kernel_broadcast(other.as_bytes()) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_lt_kernel_broadcast(other.as_bytes()) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_le_kernel_broadcast(other.as_bytes()) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_gt_kernel_broadcast(other.as_bytes()) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_ge_kernel_broadcast(other.as_bytes()) + } +} diff --git a/py-polars/tests/unit/operations/test_comparison.py b/py-polars/tests/unit/operations/test_comparison.py index 3980048d30f0..4f08bc31f795 100644 --- a/py-polars/tests/unit/operations/test_comparison.py +++ b/py-polars/tests/unit/operations/test_comparison.py @@ -320,6 +320,9 @@ def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> No "", "foo", "bar", + "fooo", + "fooooooooooo", + "foooooooooooo", "fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom", "foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo", "fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop",