Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added dynamic version of negation #685

Merged
merged 1 commit into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub mod decimal;
pub mod time;

use crate::{
array::{Array, PrimitiveArray},
array::{Array, DictionaryArray, PrimitiveArray},
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::{PrimitiveScalar, Scalar},
Expand Down Expand Up @@ -400,6 +400,72 @@ pub fn can_rem(lhs: &DataType, rhs: &DataType) -> bool {
)
}

macro_rules! with_match_negatable {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 | UInt16 | UInt32 | UInt64=> todo!(),
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
})}

/// Negates an [`Array`].
/// # Panic
/// This function panics iff either
/// * the opertion is not supported for the logical type (use [`can_neg`] to check)
/// * the operation overflows
pub fn neg(array: &dyn Array) -> Box<dyn Array> {
use crate::datatypes::PhysicalType::*;
match array.data_type().to_physical_type() {
Primitive(primitive) => with_match_negatable!(primitive, |$T| {
let array = array.as_any().downcast_ref().unwrap();

let result = basic::negate::<$T>(array);
Box::new(result) as Box<dyn Array>
}),
Dictionary(key) => match_integer_type!(key, |$T| {
let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();

let values = neg(array.values().as_ref()).into();

Box::new(DictionaryArray::<$T>::from_data(array.keys().clone(), values)) as Box<dyn Array>
}),
_ => todo!(),
}
}

/// Whether [`neg`] is supported for a given [`DataType`]
pub fn can_neg(data_type: &DataType) -> bool {
if let DataType::Dictionary(_, values) = data_type.to_logical_type() {
return can_neg(values.as_ref());
}

use crate::datatypes::PhysicalType::*;
use crate::datatypes::PrimitiveType::*;
matches!(
data_type.to_physical_type(),
Primitive(Int8)
| Primitive(Int16)
| Primitive(Int32)
| Primitive(Int64)
| Primitive(Float64)
| Primitive(Float32)
| Primitive(DaysMs)
| Primitive(MonthDayNano)
)
}

/// Defines basic addition operation for primitive arrays
pub trait ArrayAdd<Rhs>: Sized {
/// Adds itself to `rhs`
Expand Down
20 changes: 19 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! represent chunks of bits (e.g. `u8`, `u16`), and [`BitChunkIter`], that can be used to
//! iterate over bitmaps in [`BitChunk`]s.
//! Finally, this module also contains traits used to compile code optimized for SIMD instructions at [`mod@simd`].
use std::convert::TryFrom;
use std::{convert::TryFrom, ops::Neg};

mod bit_chunk;
pub use bit_chunk::{BitChunk, BitChunkIter};
Expand Down Expand Up @@ -399,3 +399,21 @@ impl months_days_ns {
self.2
}
}

impl Neg for days_ms {
type Output = Self;

#[inline(always)]
fn neg(self) -> Self::Output {
Self([-self.0[0], -self.0[0]])
}
}

impl Neg for months_days_ns {
type Output = Self;

#[inline(always)]
fn neg(self) -> Self::Output {
Self(-self.0, -self.1, -self.2)
}
}
24 changes: 23 additions & 1 deletion tests/it/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod basic;
mod decimal;
mod time;

use arrow2::array::{new_empty_array, Int32Array};
use arrow2::array::*;
use arrow2::compute::arithmetics::*;
use arrow2::datatypes::DataType::*;
use arrow2::datatypes::{IntervalUnit, TimeUnit};
Expand Down Expand Up @@ -84,3 +84,25 @@ fn consistency() {
}
});
}

#[test]
fn test_neg() {
let a = Int32Array::from(&[None, Some(6), None, Some(6)]);
let result = neg(&a);
let expected = Int32Array::from(&[None, Some(-6), None, Some(-6)]);
assert_eq!(expected, result.as_ref());
}

#[test]
fn test_neg_dict() {
let a = DictionaryArray::<u8>::from_data(
UInt8Array::from_slice(&[0, 0, 1]),
std::sync::Arc::new(Int8Array::from_slice(&[1, 2])),
);
let result = neg(&a);
let expected = DictionaryArray::<u8>::from_data(
UInt8Array::from_slice(&[0, 0, 1]),
std::sync::Arc::new(Int8Array::from_slice(&[-1, -2])),
);
assert_eq!(expected, result.as_ref());
}