Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): list zip with #9367

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions polars/polars-core/src/chunked_array/arithmetic/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl DecimalChunked {
.zip(rhs.downcast_iter())
.map(|(lhs, rhs)| kernel(lhs, rhs).map(|a| Box::new(a) as ArrayRef))
.collect::<PolarsResult<_>>()?;
lhs.copy_with_chunks(chunks, false, false)
unsafe { lhs.copy_with_chunks(chunks, false, false) }
}
// broadcast right path
(_, 1) => {
Expand All @@ -70,7 +70,7 @@ impl DecimalChunked {
.downcast_iter()
.map(|lhs| operation_lhs(lhs, rhs_val).map(|a| Box::new(a) as ArrayRef))
.collect::<PolarsResult<_>>()?;
lhs.copy_with_chunks(chunks, false, false)
unsafe { lhs.copy_with_chunks(chunks, false, false) }
}
}
}
Expand All @@ -83,7 +83,7 @@ impl DecimalChunked {
.downcast_iter()
.map(|rhs| operation_rhs(lhs_val, rhs).map(|a| Box::new(a) as ArrayRef))
.collect::<PolarsResult<_>>()?;
lhs.copy_with_chunks(chunks, false, false)
unsafe { lhs.copy_with_chunks(chunks, false, false) }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
.zip(rhs.downcast_iter())
.map(|(lhs, rhs)| Box::new(kernel(lhs, rhs)) as ArrayRef)
.collect();
lhs.copy_with_chunks(chunks, false, false)
unsafe { lhs.copy_with_chunks(chunks, false, false) }
}
// broadcast right path
(_, 1) => {
Expand Down
5 changes: 4 additions & 1 deletion polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,10 @@ impl<T: PolarsDataType> ChunkedArray<T> {
}

/// Create a new ChunkedArray from self, where the chunks are replaced.
fn copy_with_chunks(
///
/// # Safety
/// The caller must ensure the dtypes of the chunks are correct
unsafe fn copy_with_chunks(
&self,
chunks: Vec<ArrayRef>,
keep_sorted: bool,
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/ops/chunkops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.clone()
} else {
let chunks = inner_rechunk(&self.chunks);
self.copy_with_chunks(chunks, true, true)
unsafe { self.copy_with_chunks(chunks, true, true) }
}
}
}
Expand All @@ -114,7 +114,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
#[inline]
pub fn slice(&self, offset: i64, length: usize) -> Self {
let (chunks, len) = slice(&self.chunks, offset, length, self.len());
let mut out = self.copy_with_chunks(chunks, true, true);
let mut out = unsafe { self.copy_with_chunks(chunks, true, true) };
out.length = len as IdxSize;
out
}
Expand Down
3 changes: 1 addition & 2 deletions polars/polars-core/src/chunked_array/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,7 @@ impl ExplodeByOffsets for ListChunked {
}
process_range(start, last, &mut builder);
let arr = builder.finish(Some(&inner_type.to_arrow())).unwrap();
self.copy_with_chunks(vec![Box::new(arr)], true, true)
.into_series()
unsafe { self.copy_with_chunks(vec![Box::new(arr)], true, true) }.into_series()
}
}

Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/src/chunked_array/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ where
.zip(filter.downcast_iter())
.map(|(left, mask)| filter_fn(left, mask).unwrap())
.collect::<Vec<_>>();
Ok(self.copy_with_chunks(chunks, true, true))
unsafe { Ok(self.copy_with_chunks(chunks, true, true)) }
}
}

Expand All @@ -58,7 +58,7 @@ impl ChunkFilter<BooleanType> for BooleanChunked {
.zip(filter.downcast_iter())
.map(|(left, mask)| filter_fn(left, mask).unwrap())
.collect::<Vec<_>>();
Ok(self.copy_with_chunks(chunks, true, true))
unsafe { Ok(self.copy_with_chunks(chunks, true, true)) }
}
}

Expand Down Expand Up @@ -87,7 +87,7 @@ impl ChunkFilter<BinaryType> for BinaryChunked {
.map(|(left, mask)| filter_fn(left, mask).unwrap())
.collect::<Vec<_>>();

Ok(self.copy_with_chunks(chunks, true, true))
unsafe { Ok(self.copy_with_chunks(chunks, true, true)) }
}
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {

pub(crate) fn coalesce_nulls(&self, other: &[ArrayRef]) -> Self {
let chunks = coalesce_nulls(&self.chunks, other);
self.copy_with_chunks(chunks, true, false)
unsafe { self.copy_with_chunks(chunks, true, false) }
}
}
pub fn is_not_null(name: &str, chunks: &[ArrayRef]) -> BooleanChunked {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ where
{
fn finish_from_array(&self, array: Box<dyn Array>) -> Self {
let keep_fast_explode = array.null_count() == 0;
self.copy_with_chunks(vec![array], false, keep_fast_explode)
unsafe { self.copy_with_chunks(vec![array], false, keep_fast_explode) }
}
}

Expand Down
104 changes: 29 additions & 75 deletions polars/polars-core/src/chunked_array/ops/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ macro_rules! impl_ternary_broadcast {
}};
}

fn zip_with<T: PolarsDataType>(
left: &ChunkedArray<T>,
right: &ChunkedArray<T>,
mask: &BooleanChunked,
) -> PolarsResult<ChunkedArray<T>> {
let (left, right, mask) = align_chunks_ternary(left, right, mask);
let chunks = left
.chunks()
.iter()
.zip(right.chunks())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c.as_ref(), right_c.as_ref())?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(left.copy_with_chunks(chunks, false, false)) }
}

impl<T> ChunkZip<T> for ChunkedArray<T>
where
T: PolarsNumericType,
Expand All @@ -79,18 +99,7 @@ where
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, T)
} else {
let (left, right, mask) = align_chunks_ternary(self, other, mask);
let chunks = left
.downcast_iter()
.zip(right.downcast_iter())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c, right_c)?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(ChunkedArray::from_chunks(self.name(), chunks)) }
zip_with(self, other, mask)
}
}
}
Expand All @@ -105,39 +114,17 @@ impl ChunkZip<BooleanType> for BooleanChunked {
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, BooleanType)
} else {
let (left, right, mask) = align_chunks_ternary(self, other, mask);
let chunks = left
.downcast_iter()
.zip(right.downcast_iter())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c, right_c)?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(ChunkedArray::from_chunks(self.name(), chunks)) }
zip_with(self, other, mask)
}
}
}

impl ChunkZip<Utf8Type> for Utf8Chunked {
fn zip_with(&self, mask: &BooleanChunked, other: &Utf8Chunked) -> PolarsResult<Utf8Chunked> {
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, Utf8Type)
} else {
let (left, right, mask) = align_chunks_ternary(self, other, mask);
let chunks = left
.downcast_iter()
.zip(right.downcast_iter())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c, right_c)?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(ChunkedArray::from_chunks(self.name(), chunks)) }
unsafe {
self.as_binary()
.zip_with(mask, &other.as_binary())
.map(|ca| ca.to_utf8())
}
}
}
Expand All @@ -151,54 +138,21 @@ impl ChunkZip<BinaryType> for BinaryChunked {
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, BinaryType)
} else {
let (left, right, mask) = align_chunks_ternary(self, other, mask);
let chunks = left
.downcast_iter()
.zip(right.downcast_iter())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c, right_c)?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(ChunkedArray::from_chunks(self.name(), chunks)) }
zip_with(self, other, mask)
}
}
}

impl ChunkZip<ListType> for ListChunked {
fn zip_with(&self, mask: &BooleanChunked, other: &ListChunked) -> PolarsResult<ListChunked> {
let (left, right, mask) = align_chunks_ternary(self, other, mask);
let chunks = left
.downcast_iter()
.zip(right.downcast_iter())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c, right_c)?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(ChunkedArray::from_chunks(self.name(), chunks)) }
zip_with(self, other, mask)
}
}

#[cfg(feature = "dtype-array")]
impl ChunkZip<FixedSizeListType> for ArrayChunked {
fn zip_with(&self, mask: &BooleanChunked, other: &ArrayChunked) -> PolarsResult<ArrayChunked> {
let (left, right, mask) = align_chunks_ternary(self, other, mask);
let chunks = left
.downcast_iter()
.zip(right.downcast_iter())
.zip(mask.downcast_iter())
.map(|((left_c, right_c), mask_c)| {
let mask_c = prepare_mask(mask_c);
let arr = if_then_else(&mask_c, left_c, right_c)?;
Ok(arr)
})
.collect::<PolarsResult<Vec<_>>>()?;
unsafe { Ok(ChunkedArray::from_chunks(self.name(), chunks)) }
zip_with(self, other, mask)
}
}

Expand Down
29 changes: 29 additions & 0 deletions py-polars/tests/unit/test_arity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import polars as pl
from polars.testing import assert_frame_equal

Expand Down Expand Up @@ -40,3 +42,30 @@ def test_expression_literal_series_order() -> None:

assert df.select(pl.col("a") + s).to_dict(False) == {"a": [2, 4, 6]}
assert df.select(pl.lit(s) + pl.col("a")).to_dict(False) == {"": [2, 4, 6]}


def test_list_zip_with_logical_type() -> None:
df = pl.DataFrame(
{
"start": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 1, 1, 1, 1, 1)],
"stop": [datetime(2023, 1, 1, 1, 3, 1), datetime(2023, 1, 1, 1, 4, 1)],
"use": [1, 0],
}
)

df = df.with_columns(
pl.date_range(
pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"
).alias("interval_1"),
pl.date_range(
pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"
).alias("interval_2"),
)

out = df.select(
pl.when(pl.col("use") == 1)
.then(pl.col("interval_2"))
.otherwise(pl.col("interval_1"))
.alias("interval_new")
)
assert out.dtypes == [pl.List(pl.Datetime(time_unit="us", time_zone=None))]