Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: add optimized BinaryViewArray comparison kernels #13839

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/polars-compute/src/comparisons/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl NotSimdPrimitive for u128 {}
impl NotSimdPrimitive for i128 {}

mod scalar;
mod view;

#[cfg(feature = "simd")]
mod simd;
Expand Down
109 changes: 1 addition & 108 deletions crates/polars-compute/src/comparisons/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use arrow::array::{
BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray,
};
use arrow::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array};
use arrow::bitmap::{self, Bitmap};
use arrow::types::NativeType;
use polars_utils::total_ord::{TotalEq, TotalOrd};
Expand Down Expand Up @@ -71,111 +69,6 @@ impl<T: NativeType + NotSimdPrimitive + TotalOrd> TotalOrdKernel for PrimitiveAr
}
}

impl TotalOrdKernel for BinaryViewArray {
type Scalar = [u8];

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
// TODO! speed-up by first comparing views
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_eq(&r))
.collect()
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_ne(&r))
.collect()
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_lt(&r))
.collect()
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_le(&r))
.collect()
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_eq(&other)).collect()
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_ne(&other)).collect()
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_lt(&other)).collect()
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_le(&other)).collect()
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_gt(&other)).collect()
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_ge(&other)).collect()
}
}

impl TotalOrdKernel for Utf8ViewArray {
type Scalar = str;

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_eq_kernel(&other.to_binview())
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_ne_kernel(&other.to_binview())
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_lt_kernel(&other.to_binview())
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_le_kernel(&other.to_binview())
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_eq_kernel_broadcast(other.as_bytes())
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ne_kernel_broadcast(other.as_bytes())
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_lt_kernel_broadcast(other.as_bytes())
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_le_kernel_broadcast(other.as_bytes())
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_gt_kernel_broadcast(other.as_bytes())
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ge_kernel_broadcast(other.as_bytes())
}
}

impl TotalOrdKernel for BinaryArray<i64> {
type Scalar = [u8];

Expand Down
244 changes: 244 additions & 0 deletions crates/polars-compute/src/comparisons/view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use arrow::array::{BinaryViewArray, Utf8ViewArray};
use arrow::bitmap::Bitmap;

use crate::comparisons::TotalOrdKernel;

// If s fits in 12 bytes, returns the view encoding it would have in a
// BinaryViewArray.
fn small_view_encoding(s: &[u8]) -> Option<u128> {
if s.len() > 12 {
return None;
}

let mut tmp = [0u8; 16];
tmp[0] = s.len() as u8;
tmp[4..4 + s.len()].copy_from_slice(s);
Some(u128::from_le_bytes(tmp))
}

// Loads (up to) the first 4 bytes of s as little-endian, padded with zeros.
fn load_prefix(s: &[u8]) -> u32 {
let start = &s[..s.len().min(4)];
let mut tmp = [0u8; 4];
tmp[..start.len()].copy_from_slice(start);
u32::from_le_bytes(tmp)
}

fn broadcast_inequality(
arr: &BinaryViewArray,
scalar: &[u8],
cmp_prefix: impl Fn(u32, u32) -> bool,
cmp_str: impl Fn(&[u8], &[u8]) -> bool,
) -> Bitmap {
let views = arr.views().as_slice();
let prefix = load_prefix(scalar);
let be_prefix = prefix.to_be();
Bitmap::from_trusted_len_iter((0..arr.len()).map(|i| unsafe {
let v_prefix = (*views.get_unchecked(i) >> 32) as u32;
if v_prefix != prefix {
cmp_prefix(v_prefix.to_be(), be_prefix)
} else {
cmp_str(arr.value_unchecked(i), scalar)
}
}))
}

impl TotalOrdKernel for BinaryViewArray {
type Scalar = [u8];

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
let a_len_prefix = av as u64;
let b_len_prefix = bv as u64;
if a_len_prefix != b_len_prefix {
return false;
}

let alen = av as u32;
if alen <= 12 {
// String is fully inlined, compare top 64 bits. Bottom bits were
// tested equal before, which also ensures the lengths are equal.
(av >> 64) as u64 == (bv >> 64) as u64
} else {
self.value_unchecked(i) == other.value_unchecked(i)
}
}))
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
let a_len_prefix = av as u64;
let b_len_prefix = bv as u64;
if a_len_prefix != b_len_prefix {
return true;
}

let alen = av as u32;
if alen <= 12 {
// String is fully inlined, compare top 64 bits. Bottom bits were
// tested equal before, which also ensures the lengths are equal.
(av >> 64) as u64 != (bv >> 64) as u64
} else {
self.value_unchecked(i) != other.value_unchecked(i)
}
}))
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
// Only check prefix.
let a_prefix = (av >> 32) as u32;
let b_prefix = (bv >> 32) as u32;
if a_prefix != b_prefix {
a_prefix.to_be() < b_prefix.to_be()
} else {
self.value_unchecked(i) < other.value_unchecked(i)
}
}))
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
// Only check prefix.
let a_prefix = (av >> 32) as u32;
let b_prefix = (bv >> 32) as u32;
if a_prefix != b_prefix {
a_prefix.to_be() < b_prefix.to_be()
} else {
self.value_unchecked(i) <= other.value_unchecked(i)
}
}))
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
if let Some(val) = small_view_encoding(other) {
Bitmap::from_trusted_len_iter(self.views().iter().map(|v| *v == val))
} else {
let slf_views = self.views().as_slice();
let prefix = u32::from_le_bytes(other[..4].try_into().unwrap());
let prefix_len = ((prefix as u64) << 32) | other.len() as u64;
Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let v_prefix_len = *slf_views.get_unchecked(i) as u64;
if v_prefix_len != prefix_len {
false
} else {
self.value_unchecked(i) == other
}
}))
}
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
if let Some(val) = small_view_encoding(other) {
Bitmap::from_trusted_len_iter(self.views().iter().map(|v| *v != val))
} else {
let slf_views = self.views().as_slice();
let prefix = u32::from_le_bytes(other[..4].try_into().unwrap());
let prefix_len = ((prefix as u64) << 32) | other.len() as u64;
Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let v_prefix_len = *slf_views.get_unchecked(i) as u64;
if v_prefix_len != prefix_len {
true
} else {
self.value_unchecked(i) != other
}
}))
}
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a < b, |a, b| a < b)
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a <= b, |a, b| a <= b)
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a > b, |a, b| a > b)
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a >= b, |a, b| a >= b)
}
}

impl TotalOrdKernel for Utf8ViewArray {
type Scalar = str;

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_eq_kernel(&other.to_binview())
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_ne_kernel(&other.to_binview())
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_lt_kernel(&other.to_binview())
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_le_kernel(&other.to_binview())
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_eq_kernel_broadcast(other.as_bytes())
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ne_kernel_broadcast(other.as_bytes())
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_lt_kernel_broadcast(other.as_bytes())
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_le_kernel_broadcast(other.as_bytes())
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_gt_kernel_broadcast(other.as_bytes())
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ge_kernel_broadcast(other.as_bytes())
}
}
3 changes: 3 additions & 0 deletions py-polars/tests/unit/operations/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> No
"",
"foo",
"bar",
"fooo",
"fooooooooooo",
"foooooooooooo",
"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom",
"foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo",
"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop",
Expand Down