Skip to content

Commit

Permalink
Add support for adding intervals to dates (#2031)
Browse files Browse the repository at this point in the history
  • Loading branch information
avantgardnerio authored Jul 15, 2022
1 parent 9d8f0c9 commit cb7e5b0
Show file tree
Hide file tree
Showing 4 changed files with 578 additions and 9 deletions.
164 changes: 155 additions & 9 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ use crate::compute::kernels::arity::unary;
use crate::compute::unary_dyn;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
use crate::datatypes::{ArrowNumericType, DataType};
use crate::datatypes::{
ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
};
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
Expand All @@ -55,14 +58,15 @@ use std::sync::Arc;
/// # Errors
///
/// This function errors if the arrays have different lengths
pub fn math_op<T, F>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
pub fn math_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<T>>
) -> Result<PrimitiveArray<LT>>
where
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> T::Native,
LT: ArrowNumericType,
RT: ArrowNumericType,
F: Fn(LT::Native, RT::Native) -> LT::Native,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
Expand All @@ -87,7 +91,7 @@ where

let data = unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
LT::DATA_TYPE,
left.len(),
None,
null_bit_buffer,
Expand All @@ -96,7 +100,7 @@ where
vec![],
)
};
Ok(PrimitiveArray::<T>::from(data))
Ok(PrimitiveArray::<LT>::from(data))
}

/// Helper function for operations where a valid `0` on the right array should
Expand Down Expand Up @@ -774,6 +778,54 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
DataType::Dictionary(_, _) => {
typed_dict_math_op!(left, right, |a, b| a + b, math_op_dict)
}
DataType::Date32 => {
let l = as_primitive_array::<Date32Type>(left);
match right.data_type() {
DataType::Interval(IntervalUnit::YearMonth) => {
let r = as_primitive_array::<IntervalYearMonthType>(right);
let res = math_op(l, r, Date32Type::add_year_months)?;
Ok(Arc::new(res))
}
DataType::Interval(IntervalUnit::DayTime) => {
let r = as_primitive_array::<IntervalDayTimeType>(right);
let res = math_op(l, r, Date32Type::add_day_time)?;
Ok(Arc::new(res))
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
let r = as_primitive_array::<IntervalMonthDayNanoType>(right);
let res = math_op(l, r, Date32Type::add_month_day_nano)?;
Ok(Arc::new(res))
}
_ => Err(ArrowError::CastError(format!(
"Cannot perform arithmetic operation between array of type {} and array of type {}",
left.data_type(), right.data_type()
))),
}
}
DataType::Date64 => {
let l = as_primitive_array::<Date64Type>(left);
match right.data_type() {
DataType::Interval(IntervalUnit::YearMonth) => {
let r = as_primitive_array::<IntervalYearMonthType>(right);
let res = math_op(l, r, Date64Type::add_year_months)?;
Ok(Arc::new(res))
}
DataType::Interval(IntervalUnit::DayTime) => {
let r = as_primitive_array::<IntervalDayTimeType>(right);
let res = math_op(l, r, Date64Type::add_day_time)?;
Ok(Arc::new(res))
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
let r = as_primitive_array::<IntervalMonthDayNanoType>(right);
let res = math_op(l, r, Date64Type::add_month_day_nano)?;
Ok(Arc::new(res))
}
_ => Err(ArrowError::CastError(format!(
"Cannot perform arithmetic operation between array of type {} and array of type {}",
left.data_type(), right.data_type()
))),
}
}
_ => typed_math_op!(left, right, |a, b| a + b, math_op),
}
}
Expand Down Expand Up @@ -1055,6 +1107,8 @@ where
mod tests {
use super::*;
use crate::array::Int32Array;
use crate::datatypes::Date64Type;
use chrono::NaiveDate;

#[test]
fn test_primitive_array_add() {
Expand All @@ -1068,6 +1122,98 @@ mod tests {
assert_eq!(17, c.value(4));
}

#[test]
fn test_date32_month_add() {
let a = Date32Array::from(vec![Date32Type::from_naive_date(
NaiveDate::from_ymd(2000, 1, 1),
)]);
let b =
IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
c.value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd(2001, 3, 1))
);
}

#[test]
fn test_date32_day_time_add() {
let a = Date32Array::from(vec![Date32Type::from_naive_date(
NaiveDate::from_ymd(2000, 1, 1),
)]);
let b = IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
c.value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd(2000, 1, 2))
);
}

#[test]
fn test_date32_month_day_nano_add() {
let a = Date32Array::from(vec![Date32Type::from_naive_date(
NaiveDate::from_ymd(2000, 1, 1),
)]);
let b =
IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNanoType::make_value(
1, 2, 3,
)]);
let c = add_dyn(&a, &b).unwrap();
let c = c.as_any().downcast_ref::<Date32Array>().unwrap();
assert_eq!(
c.value(0),
Date32Type::from_naive_date(NaiveDate::from_ymd(2000, 2, 3))
);
}

#[test]
fn test_date64_month_add() {
let a = Date64Array::from(vec![Date64Type::from_naive_date(
NaiveDate::from_ymd(2000, 1, 1),
)]);
let b =
IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
c.value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd(2001, 3, 1))
);
}

#[test]
fn test_date64_day_time_add() {
let a = Date64Array::from(vec![Date64Type::from_naive_date(
NaiveDate::from_ymd(2000, 1, 1),
)]);
let b = IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]);
let c = add_dyn(&a, &b).unwrap();
let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
c.value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd(2000, 1, 2))
);
}

#[test]
fn test_date64_month_day_nano_add() {
let a = Date64Array::from(vec![Date64Type::from_naive_date(
NaiveDate::from_ymd(2000, 1, 1),
)]);
let b =
IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNanoType::make_value(
1, 2, 3,
)]);
let c = add_dyn(&a, &b).unwrap();
let c = c.as_any().downcast_ref::<Date64Array>().unwrap();
assert_eq!(
c.value(0),
Date64Type::from_naive_date(NaiveDate::from_ymd(2000, 2, 3))
);
}

#[test]
fn test_primitive_array_add_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]);
Expand Down
182 changes: 182 additions & 0 deletions arrow/src/datatypes/delta.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// MIT License
//
// Copyright (c) 2020-2022 Oliver Margetts
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// Copied from chronoutil crate

//! Contains utility functions for shifting Date objects.
use chrono::Datelike;

/// Returns true if the year is a leap-year, as naively defined in the Gregorian calendar.
#[inline]
pub(crate) fn is_leap_year(year: i32) -> bool {
year % 4 == 0 && (year % 100 != 0 || year % 400 == 0)
}

// If the day lies within the month, this function has no effect. Otherwise, it shifts
// day backwards to the final day of the month.
// XXX: No attempt is made to handle days outside the 1-31 range.
#[inline]
fn normalise_day(year: i32, month: u32, day: u32) -> u32 {
if day <= 28 {
day
} else if month == 2 {
28 + is_leap_year(year) as u32
} else if day == 31 && (month == 4 || month == 6 || month == 9 || month == 11) {
30
} else {
day
}
}

/// Shift a date by the given number of months.
/// Ambiguous month-ends are shifted backwards as necessary.
pub(crate) fn shift_months<D: Datelike>(date: D, months: i32) -> D {
let mut year = date.year() + (date.month() as i32 + months) / 12;
let mut month = (date.month() as i32 + months) % 12;
let mut day = date.day();

if month < 1 {
year -= 1;
month += 12;
}

day = normalise_day(year, month as u32, day);

// This is slow but guaranteed to succeed (short of interger overflow)
if day <= 28 {
date.with_day(day)
.unwrap()
.with_month(month as u32)
.unwrap()
.with_year(year)
.unwrap()
} else {
date.with_day(1)
.unwrap()
.with_month(month as u32)
.unwrap()
.with_year(year)
.unwrap()
.with_day(day)
.unwrap()
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime};

use super::*;

#[test]
fn test_leap_year_cases() {
let _leap_years: Vec<i32> = vec![
1904, 1908, 1912, 1916, 1920, 1924, 1928, 1932, 1936, 1940, 1944, 1948, 1952,
1956, 1960, 1964, 1968, 1972, 1976, 1980, 1984, 1988, 1992, 1996, 2000, 2004,
2008, 2012, 2016, 2020,
];
let leap_years_1900_to_2020: HashSet<i32> = _leap_years.into_iter().collect();

for year in 1900..2021 {
assert_eq!(is_leap_year(year), leap_years_1900_to_2020.contains(&year))
}
}

#[test]
fn test_shift_months() {
let base = NaiveDate::from_ymd(2020, 1, 31);

assert_eq!(shift_months(base, 0), NaiveDate::from_ymd(2020, 1, 31));
assert_eq!(shift_months(base, 1), NaiveDate::from_ymd(2020, 2, 29));
assert_eq!(shift_months(base, 2), NaiveDate::from_ymd(2020, 3, 31));
assert_eq!(shift_months(base, 3), NaiveDate::from_ymd(2020, 4, 30));
assert_eq!(shift_months(base, 4), NaiveDate::from_ymd(2020, 5, 31));
assert_eq!(shift_months(base, 5), NaiveDate::from_ymd(2020, 6, 30));
assert_eq!(shift_months(base, 6), NaiveDate::from_ymd(2020, 7, 31));
assert_eq!(shift_months(base, 7), NaiveDate::from_ymd(2020, 8, 31));
assert_eq!(shift_months(base, 8), NaiveDate::from_ymd(2020, 9, 30));
assert_eq!(shift_months(base, 9), NaiveDate::from_ymd(2020, 10, 31));
assert_eq!(shift_months(base, 10), NaiveDate::from_ymd(2020, 11, 30));
assert_eq!(shift_months(base, 11), NaiveDate::from_ymd(2020, 12, 31));
assert_eq!(shift_months(base, 12), NaiveDate::from_ymd(2021, 1, 31));
assert_eq!(shift_months(base, 13), NaiveDate::from_ymd(2021, 2, 28));

assert_eq!(shift_months(base, -1), NaiveDate::from_ymd(2019, 12, 31));
assert_eq!(shift_months(base, -2), NaiveDate::from_ymd(2019, 11, 30));
assert_eq!(shift_months(base, -3), NaiveDate::from_ymd(2019, 10, 31));
assert_eq!(shift_months(base, -4), NaiveDate::from_ymd(2019, 9, 30));
assert_eq!(shift_months(base, -5), NaiveDate::from_ymd(2019, 8, 31));
assert_eq!(shift_months(base, -6), NaiveDate::from_ymd(2019, 7, 31));
assert_eq!(shift_months(base, -7), NaiveDate::from_ymd(2019, 6, 30));
assert_eq!(shift_months(base, -8), NaiveDate::from_ymd(2019, 5, 31));
assert_eq!(shift_months(base, -9), NaiveDate::from_ymd(2019, 4, 30));
assert_eq!(shift_months(base, -10), NaiveDate::from_ymd(2019, 3, 31));
assert_eq!(shift_months(base, -11), NaiveDate::from_ymd(2019, 2, 28));
assert_eq!(shift_months(base, -12), NaiveDate::from_ymd(2019, 1, 31));
assert_eq!(shift_months(base, -13), NaiveDate::from_ymd(2018, 12, 31));

assert_eq!(shift_months(base, 1265), NaiveDate::from_ymd(2125, 6, 30));
}

#[test]
fn test_shift_months_with_overflow() {
let base = NaiveDate::from_ymd(2020, 12, 31);

assert_eq!(shift_months(base, 0), base);
assert_eq!(shift_months(base, 1), NaiveDate::from_ymd(2021, 1, 31));
assert_eq!(shift_months(base, 2), NaiveDate::from_ymd(2021, 2, 28));
assert_eq!(shift_months(base, 12), NaiveDate::from_ymd(2021, 12, 31));
assert_eq!(shift_months(base, 18), NaiveDate::from_ymd(2022, 6, 30));

assert_eq!(shift_months(base, -1), NaiveDate::from_ymd(2020, 11, 30));
assert_eq!(shift_months(base, -2), NaiveDate::from_ymd(2020, 10, 31));
assert_eq!(shift_months(base, -10), NaiveDate::from_ymd(2020, 2, 29));
assert_eq!(shift_months(base, -12), NaiveDate::from_ymd(2019, 12, 31));
assert_eq!(shift_months(base, -18), NaiveDate::from_ymd(2019, 6, 30));
}

#[test]
fn test_shift_months_datetime() {
let date = NaiveDate::from_ymd(2020, 1, 31);
let o_clock = NaiveTime::from_hms(1, 2, 3);

let base = NaiveDateTime::new(date, o_clock);

assert_eq!(
shift_months(base, 0).date(),
NaiveDate::from_ymd(2020, 1, 31)
);
assert_eq!(
shift_months(base, 1).date(),
NaiveDate::from_ymd(2020, 2, 29)
);
assert_eq!(
shift_months(base, 2).date(),
NaiveDate::from_ymd(2020, 3, 31)
);
assert_eq!(shift_months(base, 0).time(), o_clock);
assert_eq!(shift_months(base, 1).time(), o_clock);
assert_eq!(shift_months(base, 2).time(), o_clock);
}
}
Loading

0 comments on commit cb7e5b0

Please sign in to comment.