From abf967dcaff34f0a7663dec2cad67a25b6bf04ee Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 21 Feb 2021 05:35:40 -0500 Subject: [PATCH 01/54] ARROW-11651: [Rust][DataFusion] Implement Postgres String Functions: Length Functions Splitting up https://github.com/apache/arrow/pull/9243 This implements the following functions: - String functions - [x] bit_Length - [x] char_length - [x] character_length - [x] length - [x] octet_length Closes #9509 from seddonm1/length-functions Lead-authored-by: Mike Seddon Co-authored-by: Jorge C. Leitao Signed-off-by: Andrew Lamb --- rust/arrow/Cargo.toml | 4 + rust/arrow/benches/bit_length_kernel.rs | 46 +++ rust/arrow/src/compute/kernels/length.rs | 268 +++++++++++++++--- rust/datafusion/Cargo.toml | 1 + rust/datafusion/README.md | 6 +- rust/datafusion/src/logical_plan/expr.rs | 22 +- rust/datafusion/src/logical_plan/mod.rs | 11 +- .../datafusion/src/physical_plan/functions.rs | 263 +++++++++++++---- .../src/physical_plan/string_expressions.rs | 29 +- rust/datafusion/src/prelude.rs | 6 +- 10 files changed, 528 insertions(+), 128 deletions(-) create mode 100644 rust/arrow/benches/bit_length_kernel.rs diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 0b14b5bfae8e1..5ab1f8cc02b30 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -114,6 +114,10 @@ harness = false name = "length_kernel" harness = false +[[bench]] +name = "bit_length_kernel" +harness = false + [[bench]] name = "sort_kernel" harness = false diff --git a/rust/arrow/benches/bit_length_kernel.rs b/rust/arrow/benches/bit_length_kernel.rs new file mode 100644 index 0000000000000..51d3134571260 --- /dev/null +++ b/rust/arrow/benches/bit_length_kernel.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::{array::*, compute::kernels::length::bit_length}; + +fn bench_bit_length(array: &StringArray) { + criterion::black_box(bit_length(array).unwrap()); +} + +fn add_benchmark(c: &mut Criterion) { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // double ["hello", " ", "world", "!"] 10 times + let mut values = vec!["one", "on", "o", ""]; + for _ in 0..10 { + values = double_vec(values); + } + let array = StringArray::from(values); + + c.bench_function("bit_length", |b| b.iter(|| bench_bit_length(&array))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 740bb2b68c8a0..ed1fda4a06203 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -17,26 +17,33 @@ //! Defines kernel for length of a string array -use crate::{array::*, buffer::Buffer}; use crate::{ - datatypes::DataType, + array::*, + buffer::Buffer, + datatypes::{ArrowNativeType, ArrowPrimitiveType}, +}; +use crate::{ + datatypes::{DataType, Int32Type, Int64Type}, error::{ArrowError, Result}, }; use std::sync::Arc; -#[allow(clippy::unnecessary_wraps)] -fn length_string(array: &Array, data_type: DataType) -> Result +fn unary_offsets_string( + array: &GenericStringArray, + data_type: DataType, + op: F, +) -> ArrayRef where - OffsetSize: OffsetSizeTrait, + O: StringOffsetSizeTrait + ArrowNativeType, + F: Fn(O) -> O, { // note: offsets are stored as u8, but they can be interpreted as OffsetSize let offsets = &array.data_ref().buffers()[0]; // this is a 30% improvement over iterating over u8s and building OffsetSize, which // justifies the usage of `unsafe`. - let slice: &[OffsetSize] = - &unsafe { offsets.typed_data::() }[array.offset()..]; + let slice: &[O] = &unsafe { offsets.typed_data::() }[array.offset()..]; - let lengths = slice.windows(2).map(|offset| offset[1] - offset[0]); + let lengths = slice.windows(2).map(|offset| op(offset[1] - offset[0])); // JUSTIFICATION // Benefit @@ -60,18 +67,45 @@ where vec![buffer], vec![], ); - Ok(make_array(Arc::new(data))) + make_array(Arc::new(data)) } -/// Returns an array of Int32/Int64 denoting the number of characters in each string in the array. +fn octet_length( + array: &dyn Array, +) -> ArrayRef +where + T::Native: StringOffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + unary_offsets_string::(array, T::DATA_TYPE, |x| x) +} + +fn bit_length_impl( + array: &dyn Array, +) -> ArrayRef +where + T::Native: StringOffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let bits_in_bytes = O::from_usize(8).unwrap(); + unary_offsets_string::(array, T::DATA_TYPE, |x| x * bits_in_bytes) +} + +/// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. /// /// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 /// * length of null is null. /// * length is in number of bytes pub fn length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => length_string::(array, DataType::Int32), - DataType::LargeUtf8 => length_string::(array, DataType::Int64), + DataType::Utf8 => Ok(octet_length::(array)), + DataType::LargeUtf8 => Ok(octet_length::(array)), _ => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", array.data_type() @@ -79,11 +113,27 @@ pub fn length(array: &Array) -> Result { } } +/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. +/// +/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * bit_length of null is null. +/// * bit_length is in number of bits +pub fn bit_length(array: &Array) -> Result { + match array.data_type() { + DataType::Utf8 => Ok(bit_length_impl::(array)), + DataType::LargeUtf8 => Ok(bit_length_impl::(array)), + _ => Err(ArrowError::ComputeError(format!( + "bit_length not supported for {:?}", + array.data_type() + ))), + } +} + #[cfg(test)] mod tests { use super::*; - fn cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { fn double_vec(v: Vec) -> Vec { [&v[..], &v[..]].concat() } @@ -105,34 +155,38 @@ mod tests { } #[test] - fn test_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = StringArray::from(input); - let result = length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value, result.value(i)); - }); - Ok(()) - }) + fn length_test_string() -> Result<()> { + length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) } #[test] - fn test_large_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = LargeStringArray::from(input); - let result = length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value as i64, result.value(i)); - }); - Ok(()) - }) - } - - fn null_cases() -> Vec<(Vec>, usize, Vec>)> { + fn length_test_large_string() -> Result<()> { + length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn length_null_cases() -> Vec<(Vec>, usize, Vec>)> { vec![( vec![Some("one"), None, Some("three"), Some("four")], 4, @@ -141,8 +195,8 @@ mod tests { } #[test] - fn null_string() -> Result<()> { - null_cases() + fn length_null_string() -> Result<()> { + length_null_cases() .into_iter() .try_for_each(|(input, len, expected)| { let array = StringArray::from(input); @@ -157,8 +211,8 @@ mod tests { } #[test] - fn null_large_string() -> Result<()> { - null_cases() + fn length_null_large_string() -> Result<()> { + length_null_cases() .into_iter() .try_for_each(|(input, len, expected)| { let array = LargeStringArray::from(input); @@ -179,7 +233,7 @@ mod tests { /// Tests that length is not valid for u64. #[test] - fn wrong_type() { + fn length_wrong_type() { let array: UInt64Array = vec![1u64].into(); assert!(length(&array).is_err()); @@ -187,7 +241,7 @@ mod tests { /// Tests with an offset #[test] - fn offsets() -> Result<()> { + fn length_offsets() -> Result<()> { let a = StringArray::from(vec!["hello", " ", "world"]); let b = make_array( ArrayData::builder(DataType::Utf8) @@ -203,4 +257,130 @@ mod tests { Ok(()) } + + fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // a large array + let mut values = vec!["one", "on", "o", ""]; + let mut expected = vec![24, 16, 8, 0]; + for _ in 0..10 { + values = double_vec(values); + expected = double_vec(expected); + } + + vec![ + (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), + (vec!["💖"], 1, vec![32]), + (vec!["josé"], 1, vec![40]), + (values, 4096, expected), + ] + } + + #[test] + fn bit_length_test_string() -> Result<()> { + bit_length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) + } + + #[test] + fn bit_length_test_large_string() -> Result<()> { + bit_length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn bit_length_null_cases() -> Vec<(Vec>, usize, Vec>)> + { + vec![( + vec![Some("one"), None, Some("three"), Some("four")], + 4, + vec![Some(24), None, Some(40), Some(32)], + )] + } + + #[test] + fn bit_length_null_string() -> Result<()> { + bit_length_null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + let expected: Int32Array = expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + #[test] + fn bit_length_null_large_string() -> Result<()> { + bit_length_null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + // convert to i64 + let expected: Int64Array = expected + .iter() + .map(|e| e.map(|e| e as i64)) + .collect::>() + .into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + /// Tests that bit_length is not valid for u64. + #[test] + fn bit_length_wrong_type() { + let array: UInt64Array = vec![1u64].into(); + + assert!(bit_length(&array).is_err()); + } + + /// Tests with an offset + #[test] + fn bit_length_offsets() -> Result<()> { + let a = StringArray::from(vec!["hello", " ", "world"]); + let b = make_array( + ArrayData::builder(DataType::Utf8) + .len(2) + .offset(1) + .buffers(a.data_ref().buffers().to_vec()) + .build(), + ); + let result = bit_length(b.as_ref())?; + + let expected = Int32Array::from(vec![8, 40]); + assert_eq!(expected.data(), result.data()); + + Ok(()) + } } diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index ea556662e3283..11cc63bbdc308 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -64,6 +64,7 @@ log = "^0.4" md-5 = "^0.9.1" sha2 = "^0.9.1" ordered-float = "2.0" +unicode-segmentation = "^1.7.1" [dev-dependencies] rand = "0.8" diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 7a122506e6791..b4cb04321e7b1 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -57,7 +57,11 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] UDAFs (user-defined aggregate functions) - [x] Common math functions - String functions - - [x] Length + - [x] bit_Length + - [x] char_length + - [x] character_length + - [x] length + - [x] octet_length - [x] Concatenate - Miscellaneous/Boolean functions - [x] nullif diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index ffed843d2ca37..16c01edb0668a 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -850,6 +850,8 @@ macro_rules! unary_scalar_expr { } // generate methods for creating the supported unary expressions + +// math functions unary_scalar_expr!(Sqrt, sqrt); unary_scalar_expr!(Sin, sin); unary_scalar_expr!(Cos, cos); @@ -867,24 +869,22 @@ unary_scalar_expr!(Exp, exp); unary_scalar_expr!(Log, ln); unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); + +// string functions +unary_scalar_expr!(BitLength, bit_length); +unary_scalar_expr!(CharacterLength, character_length); +unary_scalar_expr!(CharacterLength, length); unary_scalar_expr!(Lower, lower); -unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Ltrim, ltrim); -unary_scalar_expr!(Rtrim, rtrim); -unary_scalar_expr!(Upper, upper); unary_scalar_expr!(MD5, md5); +unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(Rtrim, rtrim); unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); - -/// returns the length of a string in bytes -pub fn length(e: Expr) -> Expr { - Expr::ScalarFunction { - fun: functions::BuiltinScalarFunction::Length, - args: vec![e], - } -} +unary_scalar_expr!(Trim, trim); +unary_scalar_expr!(Upper, upper); /// returns the concatenation of string expressions pub fn concat(args: Vec) -> Expr { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index fbad5e2660662..6244387e180ea 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -34,11 +34,12 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, - combine_filters, concat, cos, count, count_distinct, create_udaf, create_udf, exp, - exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, - md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, - tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, + abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil, + character_length, col, combine_filters, concat, cos, count, count_distinct, + create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, + log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, + sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, + ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index c5cd01f93c596..baacf9492708e 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -45,9 +45,9 @@ use crate::{ }; use arrow::{ array::ArrayRef, - compute::kernels::length::length, + compute::kernels::length::{bit_length, length}, datatypes::TimeUnit, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Int32Type, Int64Type, Schema}, record_batch::RecordBatch, }; use fmt::{Debug, Formatter}; @@ -118,8 +118,6 @@ pub enum BuiltinScalarFunction { Abs, /// signum Signum, - /// length - Length, /// concat Concat, /// lower @@ -150,6 +148,12 @@ pub enum BuiltinScalarFunction { SHA384, /// SHA512, SHA512, + /// bit_length + BitLength, + /// character_length + CharacterLength, + /// octet_length + OctetLength, } impl fmt::Display for BuiltinScalarFunction { @@ -180,9 +184,6 @@ impl FromStr for BuiltinScalarFunction { "truc" => BuiltinScalarFunction::Trunc, "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, - "length" => BuiltinScalarFunction::Length, - "char_length" => BuiltinScalarFunction::Length, - "character_length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, "lower" => BuiltinScalarFunction::Lower, "trim" => BuiltinScalarFunction::Trim, @@ -198,6 +199,11 @@ impl FromStr for BuiltinScalarFunction { "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, + "bit_length" => BuiltinScalarFunction::BitLength, + "octet_length" => BuiltinScalarFunction::OctetLength, + "length" => BuiltinScalarFunction::CharacterLength, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -231,16 +237,6 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Length => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The length function can only accept strings.".to_string(), - )); - } - }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, @@ -357,6 +353,36 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The bit_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The character_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The octet_length function can only accept strings.".to_string(), + )); + } + }), _ => Ok(DataType::Float64), } } @@ -392,7 +418,41 @@ pub fn create_physical_expr( BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, - BuiltinScalarFunction::Length => |args| match &args[0] { + BuiltinScalarFunction::Concat => string_expressions::concatenate, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Trim => string_expressions::trim, + BuiltinScalarFunction::Ltrim => string_expressions::ltrim, + BuiltinScalarFunction::Rtrim => string_expressions::rtrim, + BuiltinScalarFunction::Upper => string_expressions::upper, + BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Array => array_expressions::array, + BuiltinScalarFunction::BitLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), + )), + _ => unreachable!(), + }, + }, + BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function( + string_expressions::character_length::, + )(args), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::character_length::, + )(args), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), + }, + BuiltinScalarFunction::OctetLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), @@ -402,17 +462,7 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), }, - BuiltinScalarFunction::Concat => string_expressions::concatenate, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Ltrim => string_expressions::ltrim, - BuiltinScalarFunction::Rtrim => string_expressions::rtrim, - BuiltinScalarFunction::Upper => string_expressions::upper, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, - BuiltinScalarFunction::Array => array_expressions::array, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -439,7 +489,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), BuiltinScalarFunction::Upper | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::Length + | BuiltinScalarFunction::BitLength + | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::Trim | BuiltinScalarFunction::Ltrim | BuiltinScalarFunction::Rtrim @@ -617,48 +669,135 @@ mod tests { }; use arrow::{ array::{ - ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, + Array, ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, }; - fn generic_test_math(value: ScalarValue, expected: &str) -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - let arg = lit(value); - - let expr = create_physical_expr(&BuiltinScalarFunction::Exp, &[arg], &schema)?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Float64); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) + /// $FUNC function to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result> where Result allows testing errors and Option allows testing Null + /// $EXPECTED_TYPE is the expected value type + /// $DATA_TYPE is the function to test result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { + // used to provide type annotation + let expected: Result> = $EXPECTED; + + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + let expr = + create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema)?; + + // type is correct + assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + + match expected { + Ok(expected) => { + let result = expr.evaluate(&batch)?; + let result = result.into_array(batch.num_rows()); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + // evaluate is expected error - cannot use .expect_err() due to Debug not being implemented + match expr.evaluate(&batch) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert_eq!(error.to_string(), expected_error.to_string()); + } + } + } + }; + }; } #[test] - fn test_math_function() -> Result<()> { - // 2.71828182845904523536... : https://oeis.org/A001113 - let exp_f64 = "2.718281828459045"; - let exp_f32 = "2.7182817459106445"; - generic_test_math(ScalarValue::from(1i32), exp_f64)?; - generic_test_math(ScalarValue::from(1u32), exp_f64)?; - generic_test_math(ScalarValue::from(1u64), exp_f64)?; - generic_test_math(ScalarValue::from(1f64), exp_f64)?; - generic_test_math(ScalarValue::from(1f32), exp_f32)?; + fn test_functions() -> Result<()> { + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Int32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::UInt32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::UInt64(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Float64(Some(1.0)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Float32(Some(1.0)))], + Ok(Some((1.0_f32).exp() as f64)), + f64, + Float64, + Float64Array + ); Ok(()) } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index a4ccef08681e6..81d2c67eec63b 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -24,9 +24,13 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ - array::{Array, GenericStringArray, StringArray, StringOffsetSizeTrait}, - datatypes::DataType, + array::{ + Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray, + StringOffsetSizeTrait, + }, + datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use unicode_segmentation::UnicodeSegmentation; use super::ColumnarValue; @@ -115,6 +119,27 @@ where } } +/// Returns number of characters in the string. +/// character_length('josé') = 4 +pub fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| { + x.map(|x: &str| T::Native::from_usize(x.graphemes(true).count()).unwrap()) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + /// concatenate string columns together. pub fn concatenate(args: &[ColumnarValue]) -> Result { // downcast all arguments to strings diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 4575de19c6607..26e03c7453e98 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,8 +28,8 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max, - md5, min, rtrim, sha224, sha256, sha384, sha512, sum, trim, upper, JoinType, - Partitioning, + array, avg, bit_length, character_length, col, concat, count, create_udf, in_list, + length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256, + sha384, sha512, sum, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; From def1965d84812d5a37671c25c8dee57b4316b3c2 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 21 Feb 2021 05:39:41 -0500 Subject: [PATCH 02/54] ARROW-11687: [Rust][DataFusion] RepartitionExec Hanging @andygrove I found an interesting defect where the final partition of the `RepartitionExec::execute` thread spawner was consistently not being spawned via `tokio::spawn`. This meant that `RepartitionStream::poll_next` was sitting waiting forever for data that never arrived. I am unable to reproduce via DataFusion tests. It looks like a race condition where the `JoinHandle` was not being `await`ed and something strange going on with the internals of tokio like lazy evaluation? This PR fixes the problem. Closes #9523 from seddonm1/tokio-race-condition Authored-by: Mike Seddon Signed-off-by: Andrew Lamb --- rust/datafusion/src/physical_plan/repartition.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/physical_plan/repartition.rs b/rust/datafusion/src/physical_plan/repartition.rs index 20e7122de1268..edabfde27c4b5 100644 --- a/rust/datafusion/src/physical_plan/repartition.rs +++ b/rust/datafusion/src/physical_plan/repartition.rs @@ -125,7 +125,7 @@ impl ExecutionPlan for RepartitionExec { let input = self.input.clone(); let mut channels = channels.clone(); let partitioning = self.partitioning.clone(); - let _: JoinHandle> = tokio::spawn(async move { + let join_handle: JoinHandle> = tokio::spawn(async move { let mut stream = input.execute(i).await?; let mut counter = 0; while let Some(result) = stream.next().await { @@ -157,6 +157,10 @@ impl ExecutionPlan for RepartitionExec { } Ok(()) }); + join_handle + .await + .map(|_| ()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; } } From 4718c070b12069ece1ec7a818ff6abc3991329b4 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 21 Feb 2021 05:41:34 -0500 Subject: [PATCH 03/54] ARROW-11690: [Rust][DataFusion] Avoid expr copies while using builder methods This is part of a larger body of work I would like to do to DataFusion to make it more efficient and idomatic Rust. See https://issues.apache.org/jira/browse/ARROW-11689 for more context. The theme is to make the plan and expression rewriting phases of DataFusion more efficient by avoiding copies This particular PR avoids deep cloning `Expr`s when building up new exprs. While this is technically a backwards incompatible change, given there was only a single place in the datafusion codebase that needs to be updated, I think the impact will be minimal. The basic principle is if the function needs to `clone` one of its arguments, the caller should be given the choice of when to do that. Often, the caller has no more need of the object and thus can give up ownership without issue Closes #9527 from alamb/alamb/less_expr_clone Authored-by: Andrew Lamb Signed-off-by: Andrew Lamb --- rust/datafusion/src/logical_plan/expr.rs | 99 ++++++++++---------- rust/datafusion/src/physical_plan/parquet.rs | 8 +- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 16c01edb0668a..2aa4f2e45fca0 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -348,13 +348,13 @@ impl Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - pub fn cast_to(&self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { - Ok(self.clone()) + Ok(self) } else if can_cast_types(&this_type, cast_to_type) { Ok(Expr::Cast { - expr: Box::new(self.clone()), + expr: Box::new(self), data_type: cast_to_type.clone(), }) } else { @@ -365,75 +365,78 @@ impl Expr { } } - /// Equal - pub fn eq(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::Eq, other) + /// Return `self == other` + pub fn eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::Eq, other) } - /// Not equal - pub fn not_eq(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::NotEq, other) + /// Return `self != other` + pub fn not_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotEq, other) } - /// Greater than - pub fn gt(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::Gt, other) + /// Return `self > other` + pub fn gt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Gt, other) } - /// Greater than or equal to - pub fn gt_eq(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::GtEq, other) + /// Return `self >= other` + pub fn gt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::GtEq, other) } - /// Less than - pub fn lt(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::Lt, other) + /// Return `self < other` + pub fn lt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Lt, other) } - /// Less than or equal to - pub fn lt_eq(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::LtEq, other) + /// Return `self <= other` + pub fn lt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::LtEq, other) } - /// And - pub fn and(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::And, other) + /// Return `self && other` + pub fn and(self, other: Expr) -> Expr { + binary_expr(self, Operator::And, other) } - /// Or - pub fn or(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::Or, other) + /// Return `self || other` + pub fn or(self, other: Expr) -> Expr { + binary_expr(self, Operator::Or, other) } - /// Not - pub fn not(&self) -> Expr { - Expr::Not(Box::new(self.clone())) + /// Return `!self` + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Expr { + Expr::Not(Box::new(self)) } - /// Calculate the modulus of two expressions - pub fn modulus(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::Modulus, other) + /// Calculate the modulus of two expressions. + /// Return `self % other` + pub fn modulus(self, other: Expr) -> Expr { + binary_expr(self, Operator::Modulus, other) } - /// like (string) another expression - pub fn like(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::Like, other) + /// Return `self LIKE other` + pub fn like(self, other: Expr) -> Expr { + binary_expr(self, Operator::Like, other) } - /// not like another expression - pub fn not_like(&self, other: Expr) -> Expr { - binary_expr(self.clone(), Operator::NotLike, other) + /// Return `self NOT LIKE other` + pub fn not_like(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotLike, other) } - /// Alias - pub fn alias(&self, name: &str) -> Expr { - Expr::Alias(Box::new(self.clone()), name.to_owned()) + /// Return `self AS name` alias expression + pub fn alias(self, name: &str) -> Expr { + Expr::Alias(Box::new(self), name.to_owned()) } - /// InList - pub fn in_list(&self, list: Vec, negated: bool) -> Expr { + /// Return `self IN ` if `negated` is false, otherwise + /// return `self NOT IN `.a + pub fn in_list(self, list: Vec, negated: bool) -> Expr { Expr::InList { - expr: Box::new(self.clone()), + expr: Box::new(self), list, negated, } @@ -445,9 +448,9 @@ impl Expr { /// # use datafusion::logical_plan::col; /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST /// ``` - pub fn sort(&self, asc: bool, nulls_first: bool) -> Expr { + pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { Expr::Sort { - expr: Box::new(self.clone()), + expr: Box::new(self), asc, nulls_first, } @@ -784,7 +787,7 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { } } -/// Whether it can be represented as a literal expression +/// Trait for converting a type to a [`Literal`] literal expression. pub trait Literal { /// convert the value to a Literal expression fn lit(&self) -> Expr; diff --git a/rust/datafusion/src/physical_plan/parquet.rs b/rust/datafusion/src/physical_plan/parquet.rs index f224d28895902..6ab26c2c9e094 100644 --- a/rust/datafusion/src/physical_plan/parquet.rs +++ b/rust/datafusion/src/physical_plan/parquet.rs @@ -298,7 +298,7 @@ pub struct RowGroupPredicateBuilder { } impl RowGroupPredicateBuilder { - /// Try to create a new instance of PredicateExpressionBuilder. + /// Try to create a new instance of PredicateExpressionBuilder. /// This will translate the filter expression into a statistics predicate expression /// (for example (column / 2) = 4 becomes (column_min / 2) <= 4 && 4 <= (column_max / 2)), /// then convert it to a DataFusion PhysicalExpression and cache it for later use by build_row_group_predicate. @@ -340,11 +340,11 @@ impl RowGroupPredicateBuilder { }) } - /// Generate a predicate function used to filter row group metadata. + /// Generate a predicate function used to filter row group metadata. /// This function takes a list of all row groups as parameter, /// so that DataFusion's physical expressions can be re-used by /// generating a RecordBatch, containing statistics arrays, - /// on which the physical predicate expression is executed to generate a row group filter array. + /// on which the physical predicate expression is executed to generate a row group filter array. /// The generated filter array is then used in the returned closure to filter row groups. pub fn build_row_group_predicate( &self, @@ -611,7 +611,7 @@ fn build_predicate_expression( let max_column_expr = expr_builder.max_column_expr()?; min_column_expr .lt_eq(expr_builder.scalar_expr().clone()) - .and(expr_builder.scalar_expr().lt_eq(max_column_expr)) + .and(expr_builder.scalar_expr().clone().lt_eq(max_column_expr)) } Operator::Gt => { // column > literal => (min, max) > literal => max > literal From aebabca047a8adeb9f6d4fc81e29d12cc42e629b Mon Sep 17 00:00:00 2001 From: Andre Braga Reis Date: Sun, 21 Feb 2021 05:59:32 -0500 Subject: [PATCH 04/54] ARROW-11572: [Rust] Add a kernel for division by single scalar This PR proposes a `divide_scalar` kernel that divides numeric arrays by a single scalar. Benchmarks show ~40-50% gains: ``` # features = [] divide 512 time: [2.3210 us 2.3345 us 2.3490 us] divide_scalar 512 time: [1.4374 us 1.4425 us 1.4485 us] (-38%) divide_nulls 512 time: [2.1718 us 2.1799 us 2.1894 us] divide_scalar_nulls 512 time: [1.3888 us 1.3959 us 1.4036 us] (-36%) # features = ["simd"] divide 512 time: [1.0221 us 1.0348 us 1.0481 us] divide_scalar 512 time: [468.04 ns 471.36 ns 475.19 ns] (-54%) divide_nulls 512 time: [960.20 ns 964.30 ns 969.15 ns] divide_scalar_nulls 512 time: [471.33 ns 476.41 ns 482.09 ns] (-51%) ``` The speedups are due to: - checking for `DivideByZero` only once; - not having to combine two null bitmaps; - using `Simd::splat()` to fill the divisor chunks. Tests are pretty bare right now, if you think this is worth merging I'll write a few more. Closes #9454 from abreis/divide-scalar Authored-by: Andre Braga Reis Signed-off-by: Andrew Lamb --- rust/arrow/benches/arithmetic_kernels.rs | 15 +- rust/arrow/src/buffer/immutable.rs | 2 +- rust/arrow/src/compute/kernels/arithmetic.rs | 150 ++++++++++++++++++- 3 files changed, 162 insertions(+), 5 deletions(-) diff --git a/rust/arrow/benches/arithmetic_kernels.rs b/rust/arrow/benches/arithmetic_kernels.rs index a1e6ad97664ec..721157e2846a6 100644 --- a/rust/arrow/benches/arithmetic_kernels.rs +++ b/rust/arrow/benches/arithmetic_kernels.rs @@ -18,15 +18,16 @@ #[macro_use] extern crate criterion; use criterion::Criterion; +use rand::Rng; use std::sync::Arc; extern crate arrow; -use arrow::compute::kernels::arithmetic::*; use arrow::compute::kernels::limit::*; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type}; +use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng}; fn create_array(size: usize, with_nulls: bool) -> ArrayRef { let null_density = if with_nulls { 0.5 } else { 0.0 }; @@ -58,6 +59,11 @@ fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) { criterion::black_box(divide(&arr_a, &arr_b).unwrap()); } +fn bench_divide_scalar(array: &ArrayRef, divisor: f32) { + let array = array.as_any().downcast_ref::().unwrap(); + criterion::black_box(divide_scalar(&array, divisor).unwrap()); +} + fn bench_limit(arr_a: &ArrayRef, max: usize) { criterion::black_box(limit(arr_a, max)); } @@ -65,6 +71,7 @@ fn bench_limit(arr_a: &ArrayRef, max: usize) { fn add_benchmark(c: &mut Criterion) { let arr_a = create_array(512, false); let arr_b = create_array(512, false); + let scalar = seedable_rng().gen(); c.bench_function("add 512", |b| b.iter(|| bench_add(&arr_a, &arr_b))); c.bench_function("subtract 512", |b| { @@ -74,6 +81,9 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_multiply(&arr_a, &arr_b)) }); c.bench_function("divide 512", |b| b.iter(|| bench_divide(&arr_a, &arr_b))); + c.bench_function("divide_scalar 512", |b| { + b.iter(|| bench_divide_scalar(&arr_a, scalar)) + }); c.bench_function("limit 512, 512", |b| b.iter(|| bench_limit(&arr_a, 512))); let arr_a_nulls = create_array(512, false); @@ -84,6 +94,9 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("divide_nulls_512", |b| { b.iter(|| bench_divide(&arr_a_nulls, &arr_b_nulls)) }); + c.bench_function("divide_scalar_nulls_512", |b| { + b.iter(|| bench_divide_scalar(&arr_a_nulls, scalar)) + }); } criterion_group!(benches, add_benchmark); diff --git a/rust/arrow/src/buffer/immutable.rs b/rust/arrow/src/buffer/immutable.rs index df5690c06bf74..e96bc003c8b5e 100644 --- a/rust/arrow/src/buffer/immutable.rs +++ b/rust/arrow/src/buffer/immutable.rs @@ -293,7 +293,7 @@ impl Buffer { /// Creates a [`Buffer`] from an [`Iterator`] with a trusted (upper) length or errors /// if any of the items of the iterator is an error. - /// Prefer this to `collect` whenever possible, as it is faster ~60% faster. + /// Prefer this to `collect` whenever possible, as it is ~60% faster. /// # Safety /// This method assumes that the iterator's size is correct and is undefined behavior /// to use it on an iterator that reports an incorrect length. diff --git a/rust/arrow/src/compute/kernels/arithmetic.rs b/rust/arrow/src/compute/kernels/arithmetic.rs index 067756662cf06..a40e5ea430817 100644 --- a/rust/arrow/src/compute/kernels/arithmetic.rs +++ b/rust/arrow/src/compute/kernels/arithmetic.rs @@ -256,6 +256,34 @@ where Ok(PrimitiveArray::::from(Arc::new(data))) } +/// Scalar-divisor version of `math_divide`. +fn math_divide_scalar( + array: &PrimitiveArray, + divisor: T::Native, +) -> Result> +where + T: ArrowNumericType, + T::Native: Div + Zero, +{ + if divisor.is_zero() { + return Err(ArrowError::DivideByZero); + } + + let values = array.values().iter().map(|value| *value / divisor); + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + + let data = ArrayData::new( + T::DATA_TYPE, + array.len(), + None, + array.data_ref().null_buffer().cloned(), + 0, + vec![buffer], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + /// SIMD vectorized version of `math_op` above. #[cfg(simd)] fn simd_math_op( @@ -387,9 +415,38 @@ where Ok(()) } -/// SIMD vectorized version of `divide`, the divide kernel needs it's own implementation as there -/// is a need to handle situations where a divide by `0` occurs. This is complicated by `NULL` -/// slots and padding. +/// Scalar-divisor version of `simd_checked_divide_remainder`. +#[cfg(simd)] +#[inline] +fn simd_checked_divide_scalar_remainder( + array_chunks: ChunksExact, + divisor: T::Native, + result_chunks: ChunksExactMut, +) -> Result<()> +where + T::Native: Zero + Div, +{ + if divisor.is_zero() { + return Err(ArrowError::DivideByZero); + } + + let result_remainder = result_chunks.into_remainder(); + let array_remainder = array_chunks.remainder(); + + result_remainder + .iter_mut() + .zip(array_remainder.iter()) + .for_each(|(result_scalar, array_scalar)| { + *result_scalar = *array_scalar / divisor; + }); + + Ok(()) +} + +/// SIMD vectorized version of `divide`. +/// +/// The divide kernels need their own implementation as there is a need to handle situations +/// where a divide by `0` occurs. This is complicated by `NULL` slots and padding. #[cfg(simd)] fn simd_divide( left: &PrimitiveArray, @@ -506,6 +563,52 @@ where Ok(PrimitiveArray::::from(Arc::new(data))) } +/// SIMD vectorized version of `divide_scalar`. +#[cfg(simd)] +fn simd_divide_scalar( + array: &PrimitiveArray, + divisor: T::Native, +) -> Result> +where + T: ArrowNumericType, + T::Native: One + Zero + Div, +{ + if divisor.is_zero() { + return Err(ArrowError::DivideByZero); + } + + let lanes = T::lanes(); + let buffer_size = array.len() * std::mem::size_of::(); + let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); + + let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + let mut array_chunks = array.values().chunks_exact(lanes); + + result_chunks + .borrow_mut() + .zip(array_chunks.borrow_mut()) + .for_each(|(result_slice, array_slice)| { + let simd_left = T::load(array_slice); + let simd_right = T::init(divisor); + + let simd_result = T::bin_op(simd_left, simd_right, |a, b| a / b); + T::write(simd_result, result_slice); + }); + + simd_checked_divide_scalar_remainder::(array_chunks, divisor, result_chunks)?; + + let data = ArrayData::new( + T::DATA_TYPE, + array.len(), + None, + array.data_ref().null_buffer().cloned(), + 0, + vec![result.into()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn add( @@ -622,6 +725,28 @@ where return math_divide(&left, &right); } +/// Divide every value in an array by a scalar. If any value in the array is null then the +/// result is also null. If the scalar is zero then the result of this operation will be +/// `Err(ArrowError::DivideByZero)`. +pub fn divide_scalar( + array: &PrimitiveArray, + divisor: T::Native, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: Add + + Sub + + Mul + + Div + + Zero + + One, +{ + #[cfg(simd)] + return simd_divide_scalar(&array, divisor); + #[cfg(not(simd))] + return math_divide_scalar(&array, divisor); +} + #[cfg(test)] mod tests { use super::*; @@ -709,6 +834,15 @@ mod tests { assert_eq!(9, c.value(4)); } + #[test] + fn test_primitive_array_divide_scalar() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = divide_scalar(&a, b).unwrap(); + let expected = Int32Array::from(vec![5, 4, 3, 2, 0]); + assert_eq!(c, expected); + } + #[test] fn test_primitive_array_divide_sliced() { let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]); @@ -740,6 +874,16 @@ mod tests { assert_eq!(true, c.is_null(5)); } + #[test] + fn test_primitive_array_divide_scalar_with_nulls() { + let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); + let b = 3; + let c = divide_scalar(&a, b).unwrap(); + let expected = + Int32Array::from(vec![Some(5), None, Some(2), Some(0), Some(3), None]); + assert_eq!(c, expected); + } + #[test] fn test_primitive_array_divide_with_nulls_sliced() { let a = Int32Array::from(vec![ From 924449eba36acda22ccb319e8de8921c090a4cd2 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Sun, 21 Feb 2021 07:06:20 -0500 Subject: [PATCH 05/54] ARROW-11426: [Rust][DataFusion] EXTRACT support This PR starts implementing support for the `EXTRACT` syntax / execution, to retrieve date parts (hours, minutes, days, etc.) from temporal data types, with the following syntax: `EXTRACT (HOUR FROM dt)` See https://www.postgresql.org/docs/13/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT for reference This is just a first implementation, in following PRs we can extend the support to different date parts, time zones, etc. Closes #9359 from Dandandan/temporal_sql Authored-by: Heres, Daniel Signed-off-by: Andrew Lamb --- rust/datafusion/src/logical_plan/expr.rs | 4 +- rust/datafusion/src/logical_plan/mod.rs | 1 - .../src/physical_plan/datetime_expressions.rs | 109 ++++++++++++++++-- .../src/physical_plan/expressions/mod.rs | 1 - .../datafusion/src/physical_plan/functions.rs | 27 +++++ .../src/physical_plan/type_coercion.rs | 51 +++++--- rust/datafusion/src/sql/planner.rs | 7 ++ rust/datafusion/tests/sql.rs | 20 +++- 8 files changed, 192 insertions(+), 28 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 2aa4f2e45fca0..245ca3aaaa895 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1144,9 +1144,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { let expr = create_name(expr, input_schema)?; let list = list.iter().map(|expr| create_name(expr, input_schema)); if *negated { - Ok(format!("{:?} NOT IN ({:?})", expr, list)) + Ok(format!("{} NOT IN ({:?})", expr, list)) } else { - Ok(format!("{:?} IN ({:?})", expr, list)) + Ok(format!("{} IN ({:?})", expr, list)) } } other => Err(DataFusionError::NotImplemented(format!( diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 6244387e180ea..0de0a032520bc 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -29,7 +29,6 @@ mod extension; mod operators; mod plan; mod registry; - pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; diff --git a/rust/datafusion/src/physical_plan/datetime_expressions.rs b/rust/datafusion/src/physical_plan/datetime_expressions.rs index 8642e3b40e3fd..3d363ce97d216 100644 --- a/rust/datafusion/src/physical_plan/datetime_expressions.rs +++ b/rust/datafusion/src/physical_plan/datetime_expressions.rs @@ -16,27 +16,30 @@ // under the License. //! DateTime expressions - use std::sync::Arc; +use super::ColumnarValue; use crate::{ error::{DataFusionError, Result}, scalar::{ScalarType, ScalarValue}, }; -use arrow::temporal_conversions::timestamp_ns_to_datetime; +use arrow::{ + array::{Array, ArrayRef, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait}, + datatypes::{ArrowPrimitiveType, DataType, TimestampNanosecondType}, +}; use arrow::{ array::{ - Array, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait, - TimestampNanosecondArray, + Date32Array, Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, }, - datatypes::{ArrowPrimitiveType, DataType, TimestampNanosecondType}, + compute::kernels::temporal, + datatypes::TimeUnit, + temporal_conversions::timestamp_ns_to_datetime, }; use chrono::prelude::*; use chrono::Duration; use chrono::LocalResult; -use super::ColumnarValue; - #[inline] /// Accepts a string in RFC3339 / ISO8601 standard format and some /// variants and converts it to a nanosecond precision timestamp. @@ -344,6 +347,98 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { }) } +macro_rules! extract_date_part { + ($ARRAY: expr, $FN:expr) => { + match $ARRAY.data_type() { + DataType::Date32 => { + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + Ok($FN(array)?) + } + DataType::Date64 => { + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + Ok($FN(array)?) + } + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + TimeUnit::Millisecond => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + TimeUnit::Microsecond => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + TimeUnit::Nanosecond => { + let array = $ARRAY + .as_any() + .downcast_ref::() + .unwrap(); + Ok($FN(array)?) + } + }, + datatype => Err(DataFusionError::Internal(format!( + "Extract does not support datatype {:?}", + datatype + ))), + } + }; +} + +/// DATE_PART SQL function +pub fn date_part(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Execution( + "Expected two arguments in DATE_PART".to_string(), + )); + } + let (date_part, array) = (&args[0], &args[1]); + + let date_part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = date_part { + v + } else { + return Err(DataFusionError::Execution( + "First argument of `DATE_PART` must be non-null scalar Utf8".to_string(), + )); + }; + + let is_scalar = matches!(array, ColumnarValue::Scalar(_)); + + let array = match array { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array(), + }; + + let arr = match date_part.to_lowercase().as_str() { + "hour" => extract_date_part!(array, temporal::hour), + "year" => extract_date_part!(array, temporal::year), + _ => Err(DataFusionError::Execution(format!( + "Date part '{}' not supported", + date_part + ))), + }?; + + Ok(if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array( + &(Arc::new(arr) as ArrayRef), + 0, + )?) + } else { + ColumnarValue::Array(Arc::new(arr)) + }) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/rust/datafusion/src/physical_plan/expressions/mod.rs b/rust/datafusion/src/physical_plan/expressions/mod.rs index bf47aa1cfe838..fe5fea1e2e4da 100644 --- a/rust/datafusion/src/physical_plan/expressions/mod.rs +++ b/rust/datafusion/src/physical_plan/expressions/mod.rs @@ -58,7 +58,6 @@ pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use sum::{sum_return_type, Sum}; - /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index baacf9492708e..51941188bb440 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -71,6 +71,8 @@ pub enum Signature { Exact(Vec), /// fixed number of arguments of arbitrary types Any(usize), + /// One of a list of signatures + OneOf(Vec), } /// Scalar function @@ -138,6 +140,8 @@ pub enum BuiltinScalarFunction { NullIf, /// Date truncate DateTrunc, + /// Date part + DatePart, /// MD5 MD5, /// SHA224 @@ -192,6 +196,7 @@ impl FromStr for BuiltinScalarFunction { "upper" => BuiltinScalarFunction::Upper, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "date_trunc" => BuiltinScalarFunction::DateTrunc, + "date_part" => BuiltinScalarFunction::DatePart, "array" => BuiltinScalarFunction::Array, "nullif" => BuiltinScalarFunction::NullIf, "md5" => BuiltinScalarFunction::MD5, @@ -294,6 +299,7 @@ pub fn return_type( BuiltinScalarFunction::DateTrunc => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } + BuiltinScalarFunction::DatePart => Ok(DataType::Int32), BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( Box::new(Field::new("item", arg_types[0].clone(), true)), arg_types.len() as i32, @@ -463,6 +469,7 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::DatePart => datetime_expressions::date_part, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -507,6 +514,26 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { DataType::Utf8, DataType::Timestamp(TimeUnit::Nanosecond, None), ]), + BuiltinScalarFunction::DatePart => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Date32]), + Signature::Exact(vec![DataType::Utf8, DataType::Date64]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Second, None), + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Microsecond, None), + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Millisecond, None), + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ]), + ]), BuiltinScalarFunction::Array => { Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) } diff --git a/rust/datafusion/src/physical_plan/type_coercion.rs b/rust/datafusion/src/physical_plan/type_coercion.rs index a84707a48dfa5..ae920cb870f78 100644 --- a/rust/datafusion/src/physical_plan/type_coercion.rs +++ b/rust/datafusion/src/physical_plan/type_coercion.rs @@ -29,7 +29,7 @@ //! i64. However, i64 -> i32 is never performed as there are i64 //! values which can not be represented by i32 values. -use std::sync::Arc; +use std::{sync::Arc, vec}; use arrow::datatypes::{DataType, Schema, TimeUnit}; @@ -68,6 +68,32 @@ pub fn data_types( current_types: &[DataType], signature: &Signature, ) -> Result> { + let valid_types = get_valid_types(signature, current_types)?; + + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, ¤t_types) { + return Ok(types); + } + } + + // none possible -> Error + Err(DataFusionError::Plan(format!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, signature + ))) +} + +fn get_valid_types( + signature: &Signature, + current_types: &[DataType], +) -> Result>> { let valid_types = match signature { Signature::Variadic(valid_types) => valid_types .iter() @@ -95,23 +121,16 @@ pub fn data_types( } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } - }; - - if valid_types.contains(¤t_types.to_owned()) { - return Ok(current_types.to_vec()); - } - - for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, ¤t_types) { - return Ok(types); + Signature::OneOf(types) => { + let mut r = vec![]; + for s in types { + r.extend(get_valid_types(s, current_types)?); + } + r } - } + }; - // none possible -> Error - Err(DataFusionError::Plan(format!( - "Coercion from {:?} to the signature {:?} failed.", - current_types, signature - ))) + Ok(valid_types) } /// Try to coerce current_types into valid_types. diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index fc56052b29f01..f985b50653688 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -726,6 +726,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Boolean(n)) => Ok(lit(*n)), SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::DatePart, + args: vec![ + Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))), + self.sql_expr_to_logical_expr(expr)?, + ], + }), SQLExpr::Value(Value::Interval { value, diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index d5a278d9301eb..2f780b662b86c 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1717,7 +1717,7 @@ fn make_timestamp_nano_table() -> Result> { } #[tokio::test] -async fn to_timstamp() -> Result<()> { +async fn to_timestamp() -> Result<()> { let mut ctx = ExecutionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?); @@ -2134,6 +2134,24 @@ async fn crypto_expressions() -> Result<()> { Ok(()) } +#[tokio::test] +async fn extract_date_part() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT + date_part('hour', CAST('2020-01-01' AS DATE)) AS hr1, + EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE)) AS hr2, + EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS hr3, + date_part('YEAR', CAST('2000-01-01' AS DATE)) AS year1, + EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS year2 + "; + + let actual = execute(&mut ctx, sql).await; + + let expected = vec![vec!["0", "0", "12", "2000", "2020"]]; + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn in_list_array() -> Result<()> { let mut ctx = ExecutionContext::new(); From 0b838ccd668116f7b77bdb3fd388d98192fea195 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 22 Feb 2021 14:02:48 +0100 Subject: [PATCH 06/54] ARROW-11595: [C++][NIGHTLY:test-conda-cpp-valgrind] Avoid branching on potentially indeterminate values in GenerateBitsUnrolled Comparison kernels generate an output bitmap for all array values, including those masked by a null bit. This should be fine since the indeterminate bits are also masked in the output but valgrind still triggers on the branching in GenerateBitsUnrolled. Fix: replace branching with equivalent multiplication. Closes #9471 from bkietz/11595-GenerateBitsUnrolled-trig Authored-by: Benjamin Kietzman Signed-off-by: Antoine Pitrou --- cpp/src/arrow/util/bitmap_generate.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/util/bitmap_generate.h b/cpp/src/arrow/util/bitmap_generate.h index 5a146f64db3e4..129fa91323141 100644 --- a/cpp/src/arrow/util/bitmap_generate.h +++ b/cpp/src/arrow/util/bitmap_generate.h @@ -77,7 +77,7 @@ void GenerateBitsUnrolled(uint8_t* bitmap, int64_t start_offset, int64_t length, if (bit_mask != 0x01) { current_byte = *cur & BitUtil::kPrecedingBitmask[start_bit_offset]; while (bit_mask != 0 && remaining > 0) { - current_byte = g() ? (current_byte | bit_mask) : current_byte; + current_byte |= g() * bit_mask; bit_mask = static_cast(bit_mask << 1); --remaining; } @@ -100,7 +100,7 @@ void GenerateBitsUnrolled(uint8_t* bitmap, int64_t start_offset, int64_t length, current_byte = 0; bit_mask = 0x01; while (remaining_bits-- > 0) { - current_byte = g() ? (current_byte | bit_mask) : current_byte; + current_byte |= g() * bit_mask; bit_mask = static_cast(bit_mask << 1); } *cur++ = current_byte; From 879e32d5f561df46e8f646fea8a3b3dca97d2500 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 22 Feb 2021 08:27:38 -0500 Subject: [PATCH 07/54] ARROW-11722: [Rust] Improve error message in FFI cast. While trying to cast from pyarrow to rust I encountered the following error message: `CDataInterface("The datatype \"{}\" is still not supported in Rust implementation")`, which is not very informative. This PR is very small and prints the datatype format passed to the function in the error message. Closes #9540 from ritchie46/improve_ffi_error Authored-by: Ritchie Vink Signed-off-by: Andrew Lamb --- rust/arrow/src/ffi.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rust/arrow/src/ffi.rs b/rust/arrow/src/ffi.rs index 528a5188670de..ffaf7423b6196 100644 --- a/rust/arrow/src/ffi.rs +++ b/rust/arrow/src/ffi.rs @@ -193,11 +193,11 @@ fn to_datatype(format: &str) -> Result { "ttm" => DataType::Time32(TimeUnit::Millisecond), "ttu" => DataType::Time64(TimeUnit::Microsecond), "ttn" => DataType::Time64(TimeUnit::Nanosecond), - _ => { - return Err(ArrowError::CDataInterface( - "The datatype \"{}\" is still not supported in Rust implementation" - .to_string(), - )) + dt => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{}\" is not supported in the Rust implementation", + dt + ))) } }) } From 39b23b7c4e75c4c1a151e254c3d27d68f2d129bb Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Mon, 22 Feb 2021 09:58:16 -0500 Subject: [PATCH 08/54] ARROW-11721: [Rust] json schema inference to return Schema instead of SchemaRef Looks like there is no particular reason to return inferred schema as `SchemaRef`, so I think it's more flexible and performant to return `Schema` instead, then let caller decide whether it should be wrapped into a `SchemaRef` or not. Closes #9538 from houqp/qp_json Authored-by: Qingping Hou Signed-off-by: Andrew Lamb --- rust/arrow/src/json/reader.rs | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/rust/arrow/src/json/reader.rs b/rust/arrow/src/json/reader.rs index d2cc80cc1f9e6..fd1a83751b587 100644 --- a/rust/arrow/src/json/reader.rs +++ b/rust/arrow/src/json/reader.rs @@ -161,10 +161,8 @@ fn generate_fields(spec: &HashMap) -> Result> { } /// Generate schema from JSON field names and inferred data types -fn generate_schema(spec: HashMap) -> Result { - let fields = generate_fields(&spec)?; - let schema = Schema::new(fields); - Ok(Arc::new(schema)) +fn generate_schema(spec: HashMap) -> Result { + Ok(Schema::new(generate_fields(&spec)?)) } /// JSON file reader that produces a serde_json::Value iterator from a Read trait @@ -266,7 +264,7 @@ impl<'a, R: Read> Iterator for ValueIter<'a, R> { pub fn infer_json_schema_from_seekable( reader: &mut BufReader, max_read_records: Option, -) -> Result { +) -> Result { let schema = infer_json_schema(reader, max_read_records); // return the reader seek back to the start reader.seek(SeekFrom::Start(0))?; @@ -303,7 +301,7 @@ pub fn infer_json_schema_from_seekable( pub fn infer_json_schema( reader: &mut BufReader, max_read_records: Option, -) -> Result { +) -> Result { infer_json_schema_from_iterator(ValueIter::new(reader, max_read_records)) } @@ -528,7 +526,7 @@ fn collect_field_types_from_object( /// The reason we diverge here is because we don't have utilities to deal with JSON data once it's /// interpreted as Strings. We should match Spark's behavior once we added more JSON parsing /// kernels in the future. -pub fn infer_json_schema_from_iterator(value_iter: I) -> Result +pub fn infer_json_schema_from_iterator(value_iter: I) -> Result where I: Iterator>, { @@ -559,12 +557,13 @@ where /// use arrow::json::reader::{Decoder, ValueIter, infer_json_schema}; /// use std::fs::File; /// use std::io::{BufReader, Seek, SeekFrom}; +/// use std::sync::Arc; /// /// let mut reader = /// BufReader::new(File::open("test/data/mixed_arrays.json").unwrap()); /// let inferred_schema = infer_json_schema(&mut reader, None).unwrap(); /// let batch_size = 1024; -/// let decoder = Decoder::new(inferred_schema, batch_size, None); +/// let decoder = Decoder::new(Arc::new(inferred_schema), batch_size, None); /// /// // seek back to start so that the original file is usable again /// reader.seek(SeekFrom::Start(0)).unwrap(); @@ -1551,7 +1550,10 @@ impl ReaderBuilder { // check if schema should be inferred let schema = match self.schema { Some(schema) => schema, - None => infer_json_schema_from_seekable(&mut buf_reader, self.max_records)?, + None => Arc::new(infer_json_schema_from_seekable( + &mut buf_reader, + self.max_records, + )?), }; Ok(Reader::from_buf_reader( @@ -1923,7 +1925,7 @@ mod tests { file.seek(SeekFrom::Start(0)).unwrap(); let reader = BufReader::new(GzDecoder::new(&file)); - let mut reader = Reader::from_buf_reader(reader, schema, 64, None); + let mut reader = Reader::from_buf_reader(reader, Arc::new(schema), 64, None); let batch_gz = reader.next().unwrap().unwrap(); for batch in vec![batch, batch_gz] { @@ -2591,13 +2593,13 @@ mod tests { BufReader::new(File::open("test/data/mixed_arrays.json").unwrap()); let inferred_schema = infer_json_schema_from_seekable(&mut reader, None).unwrap(); - assert_eq!(inferred_schema, Arc::new(schema.clone())); + assert_eq!(inferred_schema, schema); let file = File::open("test/data/mixed_arrays.json.gz").unwrap(); let mut reader = BufReader::new(GzDecoder::new(&file)); let inferred_schema = infer_json_schema(&mut reader, None).unwrap(); - assert_eq!(inferred_schema, Arc::new(schema)); + assert_eq!(inferred_schema, schema); } #[test] @@ -2629,7 +2631,7 @@ mod tests { ) .unwrap(); - assert_eq!(inferred_schema, Arc::new(schema)); + assert_eq!(inferred_schema, schema); } #[test] @@ -2671,7 +2673,7 @@ mod tests { ) .unwrap(); - assert_eq!(inferred_schema, Arc::new(schema)); + assert_eq!(inferred_schema, schema); } #[test] @@ -2707,7 +2709,7 @@ mod tests { ) .unwrap(); - assert_eq!(inferred_schema, Arc::new(schema)); + assert_eq!(inferred_schema, schema); } #[test] From 1f129a17645bf8f695be1b401f6dd18831bc0c5d Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 22 Feb 2021 13:41:17 -0500 Subject: [PATCH 09/54] ARROW-11694: [C++] Fix Take() with no validity bitmap but unknown null count Also ensure that SortIndices sets the result null_count to 0. Closes #9546 from pitrou/ARROW-9006-take-empty-null-bitmap Authored-by: Antoine Pitrou Signed-off-by: Benjamin Kietzman --- cpp/src/arrow/compute/exec.cc | 3 + .../arrow/compute/kernels/util_internal.cc | 2 +- .../compute/kernels/vector_selection_test.cc | 82 +++++++++++++++---- .../arrow/compute/kernels/vector_sort_test.cc | 2 + 4 files changed, 71 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index ecf3d6962f552..6443c96e91868 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -473,6 +473,9 @@ class KernelExecutorImpl : public KernelExecutor { if (validity_preallocated_) { ARROW_ASSIGN_OR_RAISE(out->buffers[0], kernel_ctx_->AllocateBitmap(length)); } + if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { + out->null_count = 0; + } for (size_t i = 0; i < data_preallocated_.size(); ++i) { const auto& prealloc = data_preallocated_[i]; if (prealloc.bit_width >= 0) { diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 93badbd3b253e..1656ed2aaf34a 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -53,7 +53,7 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { arg.data += arr.offset * arg.bit_width / 8; } // This may be kUnknownNullCount - arg.null_count = arr.null_count.load(); + arg.null_count = (arg.is_valid != nullptr) ? arr.null_count.load() : 0; return arg; } diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index 0785a1d602e38..cf52870ed89d0 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -947,10 +947,37 @@ uint64_t GetMaxIndex(int64_t values_length) { return static_cast(values_length - 1); } +class TestTakeKernel : public ::testing::Test { + public: + void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& values, + const std::shared_ptr& indices) { + ASSERT_EQ(values->null_count(), 0); + ASSERT_EQ(indices->null_count(), 0); + auto expected = (*Take(values, indices)).make_array(); + + auto new_values = MakeArray(values->data()->Copy()); + new_values->data()->buffers[0].reset(); + new_values->data()->null_count = kUnknownNullCount; + auto new_indices = MakeArray(indices->data()->Copy()); + new_indices->data()->buffers[0].reset(); + new_indices->data()->null_count = kUnknownNullCount; + auto result = (*Take(new_values, new_indices)).make_array(); + + AssertArraysEqual(*expected, *result); + } + + void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& type, + const std::string& values, + const std::string& indices) { + TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values), + ArrayFromJSON(int16(), indices)); + } +}; + template -class TestTakeKernel : public ::testing::Test {}; +class TestTakeKernelTyped : public TestTakeKernel {}; -TEST(TestTakeKernel, TakeNull) { +TEST_F(TestTakeKernel, TakeNull) { AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]"); AssertTakeNull("[null, null, null]", "[0, 2]", "[null, null]"); @@ -961,13 +988,13 @@ TEST(TestTakeKernel, TakeNull) { TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr)); } -TEST(TestTakeKernel, InvalidIndexType) { +TEST_F(TestTakeKernel, InvalidIndexType) { std::shared_ptr arr; ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]", float32(), "[0.0, 1.0, 0.1]", &arr)); } -TEST(TestTakeKernel, DefaultOptions) { +TEST_F(TestTakeKernel, DefaultOptions) { auto indices = ArrayFromJSON(int8(), "[null, 2, 0, 3]"); auto values = ArrayFromJSON(int8(), "[7, 8, 9, null]"); ASSERT_OK_AND_ASSIGN(auto no_options_provided, CallFunction("take", {values, indices})); @@ -979,12 +1006,14 @@ TEST(TestTakeKernel, DefaultOptions) { AssertDatumsEqual(explicit_defaults, no_options_provided); } -TEST(TestTakeKernel, TakeBoolean) { +TEST_F(TestTakeKernel, TakeBoolean) { AssertTakeBoolean("[7, 8, 9]", "[]", "[]"); AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]"); AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]"); AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false, true]"); + TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]", "[1, 0, 0]"); + std::shared_ptr arr; ASSERT_RAISES(IndexError, TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr)); @@ -993,7 +1022,7 @@ TEST(TestTakeKernel, TakeBoolean) { } template -class TestTakeKernelWithNumeric : public TestTakeKernel { +class TestTakeKernelWithNumeric : public TestTakeKernelTyped { protected: void AssertTake(const std::string& values, const std::string& indices, const std::string& expected) { @@ -1022,7 +1051,7 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { } template -class TestTakeKernelWithString : public TestTakeKernel { +class TestTakeKernelWithString : public TestTakeKernelTyped { public: std::shared_ptr value_type() { return TypeTraits::type_singleton(); @@ -1057,6 +1086,9 @@ TYPED_TEST(TestTakeKernelWithString, TakeString) { this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]"); this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])"); + this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a", "b", "c"])", + "[0, 1, 0]"); + std::shared_ptr type = this->value_type(); std::shared_ptr arr; ASSERT_RAISES(IndexError, @@ -1072,7 +1104,7 @@ TYPED_TEST(TestTakeKernelWithString, TakeDictionary) { this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]"); } -class TestTakeKernelFSB : public TestTakeKernel { +class TestTakeKernelFSB : public TestTakeKernelTyped { public: std::shared_ptr value_type() { return fixed_size_binary(3); } @@ -1087,6 +1119,9 @@ TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) { this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]"); this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null, "bbb", "aaa"])"); + this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), + R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]"); + std::shared_ptr type = this->value_type(); std::shared_ptr arr; ASSERT_RAISES(IndexError, @@ -1095,7 +1130,7 @@ TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) { int64(), "[2, 5]", &arr)); } -class TestTakeKernelWithList : public TestTakeKernel {}; +class TestTakeKernelWithList : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithList, TakeListInt32) { std::string list_json = "[[], [1,2], null, [3]]"; @@ -1107,6 +1142,9 @@ TEST_F(TestTakeKernelWithList, TakeListInt32) { CheckTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json); CheckTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]", "[[], [], [], [], [], [], [1, 2]]"); + + this->TestNoValidityBitmapButUnknownNullCount(list(int32()), "[[], [1,2], [3]]", + "[0, 1, 0]"); } TEST_F(TestTakeKernelWithList, TakeListListInt32) { @@ -1134,9 +1172,12 @@ TEST_F(TestTakeKernelWithList, TakeListListInt32) { CheckTake(type, list_json, "[0, 1, 2, 3]", list_json); CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]", "[[], [], [], [], [], [], [[1], [2, null, 2], []]]"); + + this->TestNoValidityBitmapButUnknownNullCount( + type, "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]"); } -class TestTakeKernelWithLargeList : public TestTakeKernel {}; +class TestTakeKernelWithLargeList : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) { std::string list_json = "[[], [1,2], null, [3]]"; @@ -1144,7 +1185,7 @@ TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) { CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]"); } -class TestTakeKernelWithFixedSizeList : public TestTakeKernel {}; +class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; @@ -1160,9 +1201,13 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { CheckTake( fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]", "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]"); + + this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3), + "[[1, null, 3], [4, 5, 6], [7, 8, null]]", + "[0, 1, 0]"); } -class TestTakeKernelWithMap : public TestTakeKernel {}; +class TestTakeKernelWithMap : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { std::string map_json = R"([ @@ -1196,7 +1241,7 @@ TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { ])"); } -class TestTakeKernelWithStruct : public TestTakeKernel {}; +class TestTakeKernelWithStruct : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithStruct, TakeStruct) { auto struct_type = struct_({field("a", int32()), field("b", utf8())}); @@ -1229,9 +1274,12 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) { {"a": 2, "b": "hello"}, {"a": 2, "b": "hello"} ])"); + + this->TestNoValidityBitmapButUnknownNullCount( + struct_type, R"([{"a": 1}, {"a": 2, "b": "hello"}])", "[0, 1, 0]"); } -class TestTakeKernelWithUnion : public TestTakeKernel {}; +class TestTakeKernelWithUnion : public TestTakeKernelTyped {}; // TODO: Restore Union take functionality TEST_F(TestTakeKernelWithUnion, DISABLED_TakeUnion) { @@ -1385,7 +1433,7 @@ TEST_F(TestPermutationsWithTake, InvertPermutation) { } } -class TestTakeKernelWithRecordBatch : public TestTakeKernel { +class TestTakeKernelWithRecordBatch : public TestTakeKernelTyped { public: void AssertTake(const std::shared_ptr& schm, const std::string& batch_json, const std::string& indices, const std::string& expected_batch) { @@ -1444,7 +1492,7 @@ TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) { ])"); } -class TestTakeKernelWithChunkedArray : public TestTakeKernel { +class TestTakeKernelWithChunkedArray : public TestTakeKernelTyped { public: void AssertTake(const std::shared_ptr& type, const std::vector& values, const std::string& indices, @@ -1501,7 +1549,7 @@ TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) { {"[0, 1, 0]", "[5, 1]"}, &arr)); } -class TestTakeKernelWithTable : public TestTakeKernel { +class TestTakeKernelWithTable : public TestTakeKernelTyped
{ public: void AssertTake(const std::shared_ptr& schm, const std::vector& table_json, const std::string& filter, diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 0c9cad508efc7..cbeaacf39a89c 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -132,6 +132,8 @@ class TestNthToIndices : public TestBase { protected: void AssertNthToIndicesArray(const std::shared_ptr values, int n) { ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, NthToIndices(*values, n)); + // null_count field should have been initialized to 0, for convenience + ASSERT_EQ(offsets->data()->null_count, 0); ASSERT_OK(offsets->ValidateFull()); Validate(*checked_pointer_cast(values), n, *checked_pointer_cast(offsets)); From 4abf8a6784350816b6a44e9a13388c45f5f5cdda Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 22 Feb 2021 15:51:48 -0800 Subject: [PATCH 10/54] ARROW-11223: [Java] Fix: BaseVariableWidthVector/BaseLargeVariableWidthVector setNull() and getBufferSizeFor() trigger offset buffer overflow Fix: BaseVariableWidthVector/BaseLargeVariableWidthVector setNull() and getBufferSizeFor() trigger offset buffer overflow: the issue is caused by the different allocated size of a validity buffer which is N/8 from the offset buffer (N+1)*OFFSET_WIDTH. When the ration of null versus non-null values is large and the offset buffer gets filled up before the validity buffer, the getBufferSizeFor call will trigger the overflow error. Reproduce it by: ~~~java BaseVariableWidthVector v1 = new VarCharVector("var1", allocator)) v1.setInitialCapacity(512); v1.allocateNew(); int numNullValues1 = v1.getValueCapacity() + 1; for (int i = 0; i < numNullValues1; i++) { v1.setNull(i); } v1.getBufferSizeFor(numNullValues1) // Raise error of buffer overflow ~~~ Closes #9187 from WeichenXu123/fix_bug1 Authored-by: Weichen Xu Signed-off-by: Bryan Cutler --- .../vector/BaseLargeVariableWidthVector.java | 9 ++++--- .../arrow/vector/BaseVariableWidthVector.java | 9 ++++--- .../arrow/vector/TestVectorReAlloc.java | 24 +++++++++++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java index 583d3bb7fbbe4..64c79483288e9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java @@ -187,7 +187,8 @@ public double getDensity() { } /** - * Get the current value capacity for the vector. + * Get the current capacity which does not exceed either validity buffer or offset buffer. + * Note: Here the `getValueCapacity` has no relationship with the value buffer. * @return number of elements that vector can hold. */ @Override @@ -903,7 +904,8 @@ public int getLastSet() { */ @Override public void setIndexDefined(int index) { - while (index >= getValidityBufferValueCapacity()) { + // We need to check and realloc both validity and offset buffer + while (index >= getValueCapacity()) { reallocValidityAndOffsetBuffers(); } BitVectorHelper.setBit(validityBuffer, index); @@ -1056,7 +1058,8 @@ public void setSafe(int index, ByteBuffer value, int start, int length) { * @param index position of element */ public void setNull(int index) { - while (index >= getValidityBufferValueCapacity()) { + // We need to check and realloc both validity and offset buffer + while (index >= getValueCapacity()) { reallocValidityAndOffsetBuffers(); } BitVectorHelper.unsetBit(validityBuffer, index); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index ebfe99554f3b7..7fd191967334c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -205,7 +205,8 @@ public double getDensity() { } /** - * Get the current value capacity for the vector. + * Get the current capacity which does not exceed either validity buffer or offset buffer. + * Note: Here the `getValueCapacity` has no relationship with the value buffer. * @return number of elements that vector can hold. */ @Override @@ -941,7 +942,8 @@ public long getStartEnd(int index) { */ @Override public void setIndexDefined(int index) { - while (index >= getValidityBufferValueCapacity()) { + // We need to check and realloc both validity and offset buffer + while (index >= getValueCapacity()) { reallocValidityAndOffsetBuffers(); } BitVectorHelper.setBit(validityBuffer, index); @@ -1094,7 +1096,8 @@ public void setSafe(int index, ByteBuffer value, int start, int length) { * @param index position of element */ public void setNull(int index) { - while (index >= getValidityBufferValueCapacity()) { + // We need to check and realloc both validity and offset buffer + while (index >= getValueCapacity()) { reallocValidityAndOffsetBuffers(); } BitVectorHelper.unsetBit(validityBuffer, index); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java index 2d3f82d9956ab..fae50c9dffcf9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java @@ -151,6 +151,30 @@ public void testStructType() { } } + @Test + public void testVariableWidthTypeSetNullValues() { + // Test ARROW-11223 bug is fixed + try (final BaseVariableWidthVector v1 = new VarCharVector("var1", allocator)) { + v1.setInitialCapacity(512); + v1.allocateNew(); + int numNullValues1 = v1.getValueCapacity() + 1; + for (int i = 0; i < numNullValues1; i++) { + v1.setNull(i); + } + Assert.assertTrue(v1.getBufferSizeFor(numNullValues1) > 0); + } + + try (final BaseLargeVariableWidthVector v2 = new LargeVarCharVector("var2", allocator)) { + v2.setInitialCapacity(512); + v2.allocateNew(); + int numNullValues2 = v2.getValueCapacity() + 1; + for (int i = 0; i < numNullValues2; i++) { + v2.setNull(i); + } + Assert.assertTrue(v2.getBufferSizeFor(numNullValues2) > 0); + } + } + @Test public void testFixedAllocateAfterReAlloc() throws Exception { try (final IntVector vector = new IntVector("", allocator)) { From 06c795c948b594c16d3a48289519ce036a285aad Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 23 Feb 2021 10:44:39 +0100 Subject: [PATCH 11/54] ARROW-11724: [C++] Resolve namespace collisions with protobuf 3.15 Renames our alias to `afpb` to not clash with `protobuf`'s definition. Original error: ``` ../src/arrow/flight/server.cc:63:11: error: redefinition of 'pb' as an alias for a different namespace namespace pb = arrow::flight::protocol; ^ /Users/uwe/mambaforge/conda-bld/arrow-cpp-ext_1613940663852/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/include/google/protobuf/port.h:44:11: note: previously defined as an alias for 'protobuf_future_namespace_placeholder' namespace pb = ::protobuf_future_namespace_placeholder; ^ 1 error generated. ``` Closes #9542 from xhochy/ARROW-11724 Authored-by: Uwe L. Korn Signed-off-by: Antoine Pitrou --- cpp/src/arrow/flight/client.cc | 4 ++-- cpp/src/arrow/flight/flight_test.cc | 4 ++-- cpp/src/arrow/flight/serialization_internal.cc | 4 ++-- cpp/src/arrow/flight/server.cc | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f42fbdaa9cf17..724f999fabedc 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -58,12 +58,12 @@ #include "arrow/flight/serialization_internal.h" #include "arrow/flight/types.h" -namespace pb = arrow::flight::protocol; - namespace arrow { namespace flight { +namespace pb = arrow::flight::protocol; + const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail"; FlightCallOptions::FlightCallOptions() diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 663fead62ce7d..2cea4c490bb49 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -51,11 +51,11 @@ #include "arrow/flight/middleware_internal.h" #include "arrow/flight/test_util.h" -namespace pb = arrow::flight::protocol; - namespace arrow { namespace flight { +namespace pb = arrow::flight::protocol; + const char kValidUsername[] = "flight_username"; const char kValidPassword[] = "flight_password"; const char kInvalidUsername[] = "invalid_flight_username"; diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index cf5f4140c3709..8c6b737c7e5d0 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -53,14 +53,14 @@ #include "arrow/util/bit_util.h" #include "arrow/util/logging.h" -namespace pb = arrow::flight::protocol; - static constexpr int64_t kInt32Max = std::numeric_limits::max(); namespace arrow { namespace flight { namespace internal { +namespace pb = arrow::flight::protocol; + using arrow::ipc::IpcPayload; using google::protobuf::internal::WireFormatLite; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 4e35950c8717e..d1b4e0fa81636 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -60,11 +60,11 @@ using ServerContext = grpc::ServerContext; template using ServerWriter = grpc::ServerWriter; -namespace pb = arrow::flight::protocol; - namespace arrow { namespace flight { +namespace pb = arrow::flight::protocol; + // Macro that runs interceptors before returning the given status #define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \ do { \ From f60c0b8e30532ca8a607e1e330087b8f9e0a0673 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 23 Feb 2021 10:51:27 +0100 Subject: [PATCH 12/54] ARROW-11501: [C++] endianness check does not work on Solaris This includes the change that @kiszk suggested, also applied to the vendored fast_float header as well. That patch has been upstreamed (https://github.com/fastfloat/fast_float/pull/59). I have confirmed that this now compiles on Solaris. Closes #9549 from nealrichardson/sun-endian Authored-by: Neal Richardson Signed-off-by: Antoine Pitrou --- cpp/src/arrow/util/endian.h | 2 ++ cpp/src/arrow/vendored/fast_float/README.md | 1 + cpp/src/arrow/vendored/fast_float/float_common.h | 4 ++++ 3 files changed, 7 insertions(+) diff --git a/cpp/src/arrow/util/endian.h b/cpp/src/arrow/util/endian.h index 81577e9091f25..0cb2e44d275fa 100644 --- a/cpp/src/arrow/util/endian.h +++ b/cpp/src/arrow/util/endian.h @@ -22,6 +22,8 @@ #else #if defined(__APPLE__) || defined(__FreeBSD__) #include // IWYU pragma: keep +#elif defined(sun) || defined(__sun) +#include // IWYU pragma: keep #else #include // IWYU pragma: keep #endif diff --git a/cpp/src/arrow/vendored/fast_float/README.md b/cpp/src/arrow/vendored/fast_float/README.md index d0b249e4973a7..876d6b08a3326 100644 --- a/cpp/src/arrow/vendored/fast_float/README.md +++ b/cpp/src/arrow/vendored/fast_float/README.md @@ -5,3 +5,4 @@ See https://github.com/lemire/fast_float Changes: - enclosed in `arrow_vendored` namespace. +- changeset e0bd5735300e761d8553a24b0525dd1e856fa4ca has been applied (Solaris endian header) \ No newline at end of file diff --git a/cpp/src/arrow/vendored/fast_float/float_common.h b/cpp/src/arrow/vendored/fast_float/float_common.h index 6127fe69492da..0147468c955db 100644 --- a/cpp/src/arrow/vendored/fast_float/float_common.h +++ b/cpp/src/arrow/vendored/fast_float/float_common.h @@ -32,6 +32,10 @@ #else #if defined(__APPLE__) || defined(__FreeBSD__) #include +// Start: addition to Arrow for Solaris support +#elif defined(sun) || defined(__sun) +#include // IWYU pragma: keep +// End #else #include #endif From 9caca11aecb9e9083b47bf2bdda2647afc48cee3 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 23 Feb 2021 14:21:42 +0100 Subject: [PATCH 13/54] ARROW-11737: [C++] Patch vendored xxhash for Solaris cf. https://github.com/Cyan4973/xxHash/pull/498 and https://github.com/Cyan4973/xxHash/pull/502 Closes #9552 from nealrichardson/sun-xxhash Authored-by: Neal Richardson Signed-off-by: Antoine Pitrou --- cpp/src/arrow/vendored/xxhash/README.md | 1 + cpp/src/arrow/vendored/xxhash/xxhash.h | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/vendored/xxhash/README.md b/cpp/src/arrow/vendored/xxhash/README.md index a24fa68d868c3..6f942ede1e154 100644 --- a/cpp/src/arrow/vendored/xxhash/README.md +++ b/cpp/src/arrow/vendored/xxhash/README.md @@ -19,3 +19,4 @@ The files in this directory are vendored from xxHash git tag v0.8.0 (https://github.com/Cyan4973/xxHash). +Includes https://github.com/Cyan4973/xxHash/pull/502 for Solaris compatibility \ No newline at end of file diff --git a/cpp/src/arrow/vendored/xxhash/xxhash.h b/cpp/src/arrow/vendored/xxhash/xxhash.h index 2d56d23c5d0be..99b2b4b380037 100644 --- a/cpp/src/arrow/vendored/xxhash/xxhash.h +++ b/cpp/src/arrow/vendored/xxhash/xxhash.h @@ -2091,7 +2091,10 @@ XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src /* === Compiler specifics === */ -#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */ +/* Patch from https://github.com/Cyan4973/xxHash/pull/498 */ +#if ((defined(sun) || defined(__sun)) && __cplusplus) /* Solaris includes __STDC_VERSION__ with C++. Tested with GCC 5.5 */ +# define XXH_RESTRICT /* disable */ +#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */ # define XXH_RESTRICT restrict #else /* Note: it might be useful to define __restrict or __restrict__ for some C++ compilers */ From 153577da120d6e11b339d8378041248b66e135be Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 23 Feb 2021 09:36:03 -0500 Subject: [PATCH 14/54] ARROW-11730: [C++] Add implicit convenience constructors for constructing Future from Status/Result This enables use of macros like ARROW_RETURN_NOT_OK and ARROW_ASSIGN_OR_RAISE. Closes #9547 from lidavidm/arrow-11730-future-ctor Lead-authored-by: David Li Co-authored-by: David Li Signed-off-by: Benjamin Kietzman --- cpp/src/arrow/util/future.h | 22 ++++++++++++++++++++++ cpp/src/arrow/util/future_test.cc | 28 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index ee053cf3096ce..f7960b9064bc6 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -493,6 +493,28 @@ class ARROW_MUST_USE_TYPE Future { }); } + /// \brief Implicit constructor to create a finished future from a value + Future(ValueType val) : Future() { // NOLINT runtime/explicit + impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS); + SetResult(std::move(val)); + } + + /// \brief Implicit constructor to create a future from a Result, enabling use + /// of macros like ARROW_ASSIGN_OR_RAISE. + Future(Result res) : Future() { // NOLINT runtime/explicit + if (ARROW_PREDICT_TRUE(res.ok())) { + impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS); + } else { + impl_ = FutureImpl::MakeFinished(FutureState::FAILURE); + } + SetResult(std::move(res)); + } + + /// \brief Implicit constructor to create a future from a Status, enabling use + /// of macros like ARROW_RETURN_NOT_OK. + Future(Status s) // NOLINT runtime/explicit + : Future(Result(std::move(s))) {} + protected: template struct Callback { diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index 97b643316a7f7..0436007a88be6 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -392,6 +392,34 @@ TEST(FutureSyncTest, GetStatusFuture) { } } +// Ensure the implicit convenience constructors behave as desired. +TEST(FutureSyncTest, ImplicitConstructors) { + { + auto fut = ([]() -> Future { + return arrow::Status::Invalid("Invalid"); + })(); + AssertFailed(fut); + ASSERT_RAISES(Invalid, fut.result()); + } + { + auto fut = ([]() -> Future { + return arrow::Result(arrow::Status::Invalid("Invalid")); + })(); + AssertFailed(fut); + ASSERT_RAISES(Invalid, fut.result()); + } + { + auto fut = ([]() -> Future { return MoveOnlyDataType(42); })(); + AssertSuccessful(fut); + } + { + auto fut = ([]() -> Future { + return arrow::Result(MoveOnlyDataType(42)); + })(); + AssertSuccessful(fut); + } +} + TEST(FutureRefTest, ChainRemoved) { // Creating a future chain should not prevent the futures from being deleted if the // entire chain is deleted From 922fb984d20080c2dc3b6b7da72c0122e0f24bae Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Tue, 23 Feb 2021 16:22:30 +0100 Subject: [PATCH 15/54] ARROW-11541: [C++][Compute] Implement tdigest kernel Implement approximate quantile kernel using t-digest util. Closes #9435 from cyb70289/quantile-approx Lead-authored-by: Yibo Cai Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_aggregate.cc | 5 + cpp/src/arrow/compute/api_aggregate.h | 36 +++++ .../compute/kernels/aggregate_benchmark.cc | 113 ++++++++++--- .../compute/kernels/aggregate_tdigest.cc | 153 ++++++++++++++++++ .../arrow/compute/kernels/aggregate_test.cc | 121 ++++++++++---- cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + cpp/src/arrow/util/tdigest.cc | 14 +- cpp/src/arrow/util/tdigest.h | 19 ++- cpp/src/arrow/util/tdigest_test.cc | 19 ++- docs/source/cpp/compute.rst | 2 + python/pyarrow/_compute.pyx | 19 +++ python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 8 + python/pyarrow/tests/test_compute.py | 18 +++ 16 files changed, 470 insertions(+), 61 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/aggregate_tdigest.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 4403def994932..382a851c159a5 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -366,6 +366,7 @@ if(ARROW_COMPUTE) compute/kernels/aggregate_basic.cc compute/kernels/aggregate_mode.cc compute/kernels/aggregate_quantile.cc + compute/kernels/aggregate_tdigest.cc compute/kernels/aggregate_var_std.cc compute/kernels/codegen_internal.cc compute/kernels/scalar_arithmetic.cc diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc index 586eac2eeaec3..5afa104896085 100644 --- a/cpp/src/arrow/compute/api_aggregate.cc +++ b/cpp/src/arrow/compute/api_aggregate.cc @@ -68,5 +68,10 @@ Result Quantile(const Datum& value, const QuantileOptions& options, return CallFunction("quantile", {value}, &options, ctx); } +Result TDigest(const Datum& value, const TDigestOptions& options, + ExecContext* ctx) { + return CallFunction("tdigest", {value}, &options, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 335186122fdc8..eef1587bb732b 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -127,6 +127,28 @@ struct ARROW_EXPORT QuantileOptions : public FunctionOptions { enum Interpolation interpolation; }; +/// \brief Control TDigest approximate quantile kernel behavior +/// +/// By default, returns the median value. +struct ARROW_EXPORT TDigestOptions : public FunctionOptions { + explicit TDigestOptions(double q = 0.5, uint32_t delta = 100, + uint32_t buffer_size = 500) + : q{q}, delta{delta}, buffer_size{buffer_size} {} + + explicit TDigestOptions(std::vector q, uint32_t delta = 100, + uint32_t buffer_size = 500) + : q{std::move(q)}, delta{delta}, buffer_size{buffer_size} {} + + static TDigestOptions Defaults() { return TDigestOptions{}; } + + /// quantile must be between 0 and 1 inclusive + std::vector q; + /// compression parameter, default 100 + uint32_t delta; + /// input buffer size, default 500 + uint32_t buffer_size; +}; + /// @} /// \brief Count non-null (or null) values in an array. @@ -270,5 +292,19 @@ Result Quantile(const Datum& value, const QuantileOptions& options = QuantileOptions::Defaults(), ExecContext* ctx = NULLPTR); +/// \brief Calculate the approximate quantiles of a numeric array with T-Digest algorithm +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see TDigestOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as an array +/// +/// \since 4.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result TDigest(const Datum& value, + const TDigestOptions& options = TDigestOptions::Defaults(), + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc b/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc index db5db543013d8..c90dd03c06ed7 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc @@ -461,11 +461,23 @@ VARIANCE_KERNEL_BENCHMARK(VarianceKernelDouble, DoubleType); // Quantile // +static std::vector deciles() { + return {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}; +} + +static std::vector centiles() { + std::vector q(101); + for (int i = 0; i <= 100; ++i) { + q[i] = i / 100.0; + } + return q; +} + template -void QuantileKernelBench(benchmark::State& state, int min, int max) { +void QuantileKernel(benchmark::State& state, int min, int max, std::vector q) { using CType = typename TypeTraits::CType; - QuantileOptions options; + QuantileOptions options(std::move(q)); RegressionArgs args(state); const int64_t array_size = args.size / sizeof(CType); auto rand = random::RandomArrayGenerator(1926); @@ -474,29 +486,90 @@ void QuantileKernelBench(benchmark::State& state, int min, int max) { for (auto _ : state) { ABORT_NOT_OK(Quantile(array, options).status()); } + state.SetItemsProcessed(state.iterations() * array_size); +} + +template +void QuantileKernelMedian(benchmark::State& state, int min, int max) { + QuantileKernel(state, min, max, {0.5}); +} + +template +void QuantileKernelMedianWide(benchmark::State& state) { + QuantileKernel(state, 0, 1 << 24, {0.5}); +} + +template +void QuantileKernelMedianNarrow(benchmark::State& state) { + QuantileKernel(state, -30000, 30000, {0.5}); +} + +template +void QuantileKernelDecilesWide(benchmark::State& state) { + QuantileKernel(state, 0, 1 << 24, deciles()); +} + +template +void QuantileKernelDecilesNarrow(benchmark::State& state) { + QuantileKernel(state, -30000, 30000, deciles()); +} + +template +void QuantileKernelCentilesWide(benchmark::State& state) { + QuantileKernel(state, 0, 1 << 24, centiles()); +} + +template +void QuantileKernelCentilesNarrow(benchmark::State& state) { + QuantileKernel(state, -30000, 30000, centiles()); } -static void QuantileKernelBenchArgs(benchmark::internal::Benchmark* bench) { +static void QuantileKernelArgs(benchmark::internal::Benchmark* bench) { BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024}); } -#define QUANTILE_KERNEL_BENCHMARK_WIDE(FuncName, Type) \ - static void FuncName(benchmark::State& state) { \ - QuantileKernelBench(state, 0, 1 << 24); \ - } \ - BENCHMARK(FuncName)->Apply(QuantileKernelBenchArgs) - -#define QUANTILE_KERNEL_BENCHMARK_NARROW(FuncName, Type) \ - static void FuncName(benchmark::State& state) { \ - QuantileKernelBench(state, -30000, 30000); \ - } \ - BENCHMARK(FuncName)->Apply(QuantileKernelBenchArgs) - -QUANTILE_KERNEL_BENCHMARK_WIDE(QuantileKernelInt32Wide, Int32Type); -QUANTILE_KERNEL_BENCHMARK_NARROW(QuantileKernelInt32Narrow, Int32Type); -QUANTILE_KERNEL_BENCHMARK_WIDE(QuantileKernelInt64Wide, Int64Type); -QUANTILE_KERNEL_BENCHMARK_NARROW(QuantileKernelInt64Narrow, Int64Type); -QUANTILE_KERNEL_BENCHMARK_WIDE(QuantileKernelDouble, DoubleType); +BENCHMARK_TEMPLATE(QuantileKernelMedianNarrow, Int32Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelMedianWide, Int32Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelMedianNarrow, Int64Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelMedianWide, Int64Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelMedianWide, DoubleType)->Apply(QuantileKernelArgs); + +BENCHMARK_TEMPLATE(QuantileKernelDecilesNarrow, Int32Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelDecilesWide, Int32Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelDecilesWide, DoubleType)->Apply(QuantileKernelArgs); + +BENCHMARK_TEMPLATE(QuantileKernelCentilesNarrow, Int32Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelCentilesWide, Int32Type)->Apply(QuantileKernelArgs); +BENCHMARK_TEMPLATE(QuantileKernelCentilesWide, DoubleType)->Apply(QuantileKernelArgs); + +static void TDigestKernelDouble(benchmark::State& state, std::vector q) { + TDigestOptions options{std::move(q)}; + RegressionArgs args(state); + const int64_t array_size = args.size / sizeof(double); + auto rand = random::RandomArrayGenerator(1926); + auto array = rand.Numeric(array_size, 0, 1 << 24, args.null_proportion); + + for (auto _ : state) { + ABORT_NOT_OK(TDigest(array, options).status()); + } + state.SetItemsProcessed(state.iterations() * array_size); +} + +static void TDigestKernelDoubleMedian(benchmark::State& state) { + TDigestKernelDouble(state, {0.5}); +} + +static void TDigestKernelDoubleDeciles(benchmark::State& state) { + TDigestKernelDouble(state, deciles()); +} + +static void TDigestKernelDoubleCentiles(benchmark::State& state) { + TDigestKernelDouble(state, centiles()); +} + +BENCHMARK(TDigestKernelDoubleMedian)->Apply(QuantileKernelArgs); +BENCHMARK(TDigestKernelDoubleDeciles)->Apply(QuantileKernelArgs); +BENCHMARK(TDigestKernelDoubleCentiles)->Apply(QuantileKernelArgs); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc new file mode 100644 index 0000000000000..fc8f43b0ae2d5 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/tdigest.h" + +namespace arrow { +namespace compute { +namespace internal { + +namespace { + +using arrow::internal::TDigest; +using arrow::internal::VisitSetBitRunsVoid; + +template +struct TDigestImpl : public ScalarAggregator { + using ThisType = TDigestImpl; + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename ArrowType::c_type; + + explicit TDigestImpl(const TDigestOptions& options) + : q{options.q}, tdigest{options.delta, options.buffer_size} {} + + void Consume(KernelContext*, const ExecBatch& batch) override { + const ArrayData& data = *batch[0].array(); + const CType* values = data.GetValues(1); + + if (data.length > data.GetNullCount()) { + VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length, + [&](int64_t pos, int64_t len) { + for (int64_t i = 0; i < len; ++i) { + this->tdigest.NanAdd(values[pos + i]); + } + }); + } + } + + void MergeFrom(KernelContext*, KernelState&& src) override { + auto& other = checked_cast(src); + std::vector other_tdigest; + other_tdigest.push_back(std::move(other.tdigest)); + this->tdigest.Merge(&other_tdigest); + } + + void Finalize(KernelContext* ctx, Datum* out) override { + const int64_t out_length = this->tdigest.is_empty() ? 0 : this->q.size(); + auto out_data = ArrayData::Make(float64(), out_length, 0); + out_data->buffers.resize(2, nullptr); + + if (out_length > 0) { + KERNEL_ASSIGN_OR_RAISE(out_data->buffers[1], ctx, + ctx->Allocate(out_length * sizeof(double))); + double* out_buffer = out_data->template GetMutableValues(1); + for (int64_t i = 0; i < out_length; ++i) { + out_buffer[i] = this->tdigest.Quantile(this->q[i]); + } + } + + *out = Datum(std::move(out_data)); + } + + const std::vector& q; + TDigest tdigest; +}; + +struct TDigestInitState { + std::unique_ptr state; + KernelContext* ctx; + const DataType& in_type; + const TDigestOptions& options; + + TDigestInitState(KernelContext* ctx, const DataType& in_type, + const TDigestOptions& options) + : ctx(ctx), in_type(in_type), options(options) {} + + Status Visit(const DataType&) { + return Status::NotImplemented("No tdigest implemented"); + } + + Status Visit(const HalfFloatType&) { + return Status::NotImplemented("No tdigest implemented"); + } + + template + enable_if_t::value, Status> Visit(const Type&) { + state.reset(new TDigestImpl(options)); + return Status::OK(); + } + + std::unique_ptr Create() { + ctx->SetStatus(VisitTypeInline(in_type, this)); + return std::move(state); + } +}; + +std::unique_ptr TDigestInit(KernelContext* ctx, const KernelInitArgs& args) { + TDigestInitState visitor(ctx, *args.inputs[0].type, + static_cast(*args.options)); + return visitor.Create(); +} + +void AddTDigestKernels(KernelInit init, + const std::vector>& types, + ScalarAggregateFunction* func) { + for (const auto& ty : types) { + auto sig = KernelSignature::Make({InputType::Array(ty)}, float64()); + AddAggKernel(std::move(sig), init, func); + } +} + +const FunctionDoc tdigest_doc{ + "Approximate quantiles of a numeric array with T-Digest algorithm", + ("By default, 0.5 quantile (median) is returned.\n" + "Nulls and NaNs are ignored.\n" + "An empty array is returned if there is no valid data point."), + {"array"}, + "TDigestOptions"}; + +std::shared_ptr AddTDigestAggKernels() { + static auto default_tdigest_options = TDigestOptions::Defaults(); + auto func = std::make_shared( + "tdigest", Arity::Unary(), &tdigest_doc, &default_tdigest_options); + AddTDigestKernels(TDigestInit, NumericTypes(), func.get()); + return func; +} + +} // namespace + +void RegisterScalarAggregateTDigest(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(AddTDigestAggKernels())); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 47872565ab66b..e772d474909df 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -1292,7 +1292,7 @@ TYPED_TEST(TestVarStdKernelRandom, Basics) { double var_population, var_sample; using ArrayType = typename TypeTraits::ArrayType; - auto typed_array = std::static_pointer_cast(array->Slice(0, total_size)); + auto typed_array = checked_pointer_cast(array->Slice(0, total_size)); std::tie(var_population, var_sample) = WelfordVar(*typed_array); this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population); @@ -1313,7 +1313,7 @@ TEST_F(TestVarStdKernelIntegerLength, Basics) { // auto array = rand.Numeric(4000000000, min, min + 100000, 0.1); double var_population, var_sample; - auto int32_array = std::static_pointer_cast(array); + auto int32_array = checked_pointer_cast(array); std::tie(var_population, var_sample) = WelfordVar(*int32_array); this->AssertVarStdIs(*array, VarianceOptions{0}, var_population); @@ -1343,22 +1343,22 @@ class TestPrimitiveQuantileKernel : public ::testing::Test { ASSERT_OK(out_array->ValidateFull()); ASSERT_EQ(out_array->length(), options.q.size()); ASSERT_EQ(out_array->null_count(), 0); - ASSERT_EQ(out_array->type(), expected[0][i].type()); + AssertTypeEqual(out_array->type(), expected[0][i].type()); - if (out_array->type() == float64()) { + if (out_array->type()->Equals(float64())) { const double* quantiles = out_array->data()->GetValues(1); for (int64_t j = 0; j < out_array->length(); ++j) { const auto& numeric_scalar = - std::static_pointer_cast(expected[j][i].scalar()); + checked_pointer_cast(expected[j][i].scalar()); ASSERT_TRUE((quantiles[j] == numeric_scalar->value) || (std::isnan(quantiles[j]) && std::isnan(numeric_scalar->value))); } } else { - ASSERT_EQ(out_array->type(), type_singleton()); + AssertTypeEqual(out_array->type(), type_singleton()); const CType* quantiles = out_array->data()->GetValues(1); for (int64_t j = 0; j < out_array->length(); ++j) { const auto& numeric_scalar = - std::static_pointer_cast>(expected[j][i].scalar()); + checked_pointer_cast>(expected[j][i].scalar()); ASSERT_EQ(quantiles[j], numeric_scalar->value); } } @@ -1530,42 +1530,86 @@ TEST_F(TestInt64QuantileKernel, Int64) { #undef O #ifndef __MINGW32__ -class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel { +class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel { public: void CheckQuantiles(int64_t array_size, int64_t num_quantiles) { - auto rand = random::RandomArrayGenerator(0x5487658); + std::shared_ptr array; + std::vector quantiles; // small value range to exercise input array with equal values and histogram approach - const auto array = rand.Numeric(array_size, -100, 200, 0.1); + GenerateTestData(array_size, num_quantiles, -100, 200, &array, &quantiles); + this->AssertQuantilesAre(array, QuantileOptions{quantiles}, + NaiveQuantile(*array, quantiles, interpolations_)); + } + + void CheckTDigests(const std::vector& chunk_sizes, int64_t num_quantiles) { + int total_size = 0; + for (int size : chunk_sizes) { + total_size += size; + } + std::shared_ptr array; std::vector quantiles; - random_real(num_quantiles, 0x5487658, 0.0, 1.0, &quantiles); - // make sure to exercise 0 and 1 quantiles - *std::min_element(quantiles.begin(), quantiles.end()) = 0; - *std::max_element(quantiles.begin(), quantiles.end()) = 1; + GenerateTestData(total_size, num_quantiles, 100, 123456789, &array, &quantiles); - this->AssertQuantilesAre(array, QuantileOptions{quantiles}, - NaiveQuantile(*array, quantiles)); + total_size = 0; + ArrayVector array_vector; + for (int size : chunk_sizes) { + array_vector.emplace_back(array->Slice(total_size, size)); + total_size += size; + } + auto chunked = *ChunkedArray::Make(array_vector); + + TDigestOptions options(quantiles); + ASSERT_OK_AND_ASSIGN(Datum out, TDigest(chunked, options)); + const auto& out_array = out.make_array(); + ASSERT_OK(out_array->ValidateFull()); + ASSERT_EQ(out_array->length(), quantiles.size()); + ASSERT_EQ(out_array->null_count(), 0); + AssertTypeEqual(out_array->type(), float64()); + + // linear interpolated exact quantile as reference + std::vector> exact = + NaiveQuantile(*array, quantiles, {QuantileOptions::LINEAR}); + const double* approx = out_array->data()->GetValues(1); + for (size_t i = 0; i < quantiles.size(); ++i) { + const auto& exact_scalar = checked_pointer_cast(exact[i][0].scalar()); + const double tolerance = std::fabs(exact_scalar->value) * 0.05; + EXPECT_NEAR(approx[i], exact_scalar->value, tolerance) << quantiles[i]; + } } private: - std::vector> NaiveQuantile(const Array& array, - const std::vector& quantiles) { + void GenerateTestData(int64_t array_size, int64_t num_quantiles, int min, int max, + std::shared_ptr* array, std::vector* quantiles) { + auto rand = random::RandomArrayGenerator(0x5487658); + *array = rand.Float64(array_size, min, max, /*null_prob=*/0.1, /*nan_prob=*/0.2); + + random_real(num_quantiles, 0x5487658, 0.0, 1.0, quantiles); + // make sure to exercise 0 and 1 quantiles + *std::min_element(quantiles->begin(), quantiles->end()) = 0; + *std::max_element(quantiles->begin(), quantiles->end()) = 1; + } + + std::vector> NaiveQuantile( + const Array& array, const std::vector& quantiles, + const std::vector& interpolations) { // copy and sort input array - std::vector input(array.length() - array.null_count()); - const int32_t* values = array.data()->GetValues(1); + std::vector input(array.length() - array.null_count()); + const double* values = array.data()->GetValues(1); const auto bitmap = array.null_bitmap_data(); int64_t index = 0; for (int64_t i = 0; i < array.length(); ++i) { - if (BitUtil::GetBit(bitmap, i)) { + if (BitUtil::GetBit(bitmap, i) && !std::isnan(values[i])) { input[index++] = values[i]; } } + input.resize(index); std::sort(input.begin(), input.end()); std::vector> output(quantiles.size(), - std::vector(interpolations_.size())); - for (uint64_t i = 0; i < interpolations_.size(); ++i) { - const auto interp = interpolations_[i]; + std::vector(interpolations.size())); + for (uint64_t i = 0; i < interpolations.size(); ++i) { + const auto interp = interpolations[i]; for (uint64_t j = 0; j < quantiles.size(); ++j) { output[j][i] = GetQuantile(input, quantiles[j], interp); } @@ -1573,7 +1617,7 @@ class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel { return output; } - Datum GetQuantile(const std::vector& input, double q, + Datum GetQuantile(const std::vector& input, double q, enum QuantileOptions::Interpolation interp) { const double index = (input.size() - 1) * q; const uint64_t lower_index = static_cast(index); @@ -1594,14 +1638,14 @@ class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel { } case QuantileOptions::LINEAR: if (fraction == 0) { - return Datum(static_cast(input[lower_index])); + return Datum(input[lower_index]); } else { return Datum(fraction * input[lower_index + 1] + (1 - fraction) * input[lower_index]); } case QuantileOptions::MIDPOINT: if (fraction == 0) { - return Datum(static_cast(input[lower_index])); + return Datum(input[lower_index]); } else { return Datum(input[lower_index] / 2.0 + input[lower_index + 1] / 2.0); } @@ -1625,7 +1669,30 @@ TEST_F(TestRandomQuantileKernel, Histogram) { // exercise histogram approach: size >= 65536, range <= 65536 this->CheckQuantiles(/*array_size=*/80000, /*num_quantiles=*/100); } + +TEST_F(TestRandomQuantileKernel, TDigest) { + this->CheckTDigests(/*chunk_sizes=*/{12345, 6789, 8765, 4321}, /*num_quantiles=*/100); +} #endif +class TestTDigestKernel : public ::testing::Test {}; + +TEST_F(TestTDigestKernel, AllNullsOrNaNs) { + const std::vector> tests = { + {"[]"}, + {"[null, null]", "[]", "[null]"}, + {"[NaN]", "[NaN, NaN]", "[]"}, + {"[null, NaN, null]"}, + {"[NaN, NaN]", "[]", "[null]"}, + }; + + for (const auto& json : tests) { + auto chunked = ChunkedArrayFromJSON(float64(), json); + ASSERT_OK_AND_ASSIGN(Datum out, TDigest(chunked, TDigestOptions())); + ASSERT_OK(out.make_array()->ValidateFull()); + ASSERT_EQ(out.array()->length, 0); + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index b1e0d48ccdc2d..9385c5c2a16ee 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -130,6 +130,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarAggregateBasic(registry.get()); RegisterScalarAggregateMode(registry.get()); RegisterScalarAggregateQuantile(registry.get()); + RegisterScalarAggregateTDigest(registry.get()); RegisterScalarAggregateVariance(registry.get()); // Vector functions diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 4e39eeb820405..3b0f4475328b4 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -45,6 +45,7 @@ void RegisterVectorSort(FunctionRegistry* registry); void RegisterScalarAggregateBasic(FunctionRegistry* registry); void RegisterScalarAggregateMode(FunctionRegistry* registry); void RegisterScalarAggregateQuantile(FunctionRegistry* registry); +void RegisterScalarAggregateTDigest(FunctionRegistry* registry); void RegisterScalarAggregateVariance(FunctionRegistry* registry); } // namespace internal diff --git a/cpp/src/arrow/util/tdigest.cc b/cpp/src/arrow/util/tdigest.cc index 68385a2a578e1..5550c98c9729a 100644 --- a/cpp/src/arrow/util/tdigest.cc +++ b/cpp/src/arrow/util/tdigest.cc @@ -332,6 +332,8 @@ class TDigest::TDigestImpl { return Lerp(td[ci_left].mean, td[ci_right].mean, diff); } + double total_weight() const { return total_weight_; } + private: // must be delcared before merger_, see constructor initialization list const uint32_t delta_; @@ -352,6 +354,8 @@ TDigest::TDigest(uint32_t delta, uint32_t buffer_size) : impl_(new TDigestImpl(d } TDigest::~TDigest() = default; +TDigest::TDigest(TDigest&&) = default; +TDigest& TDigest::operator=(TDigest&&) = default; void TDigest::Reset() { input_.resize(0); @@ -368,14 +372,14 @@ void TDigest::Dump() { impl_->Dump(); } -void TDigest::Merge(std::vector>* tdigests) { +void TDigest::Merge(std::vector* tdigests) { MergeInput(); std::vector tdigest_impls; tdigest_impls.reserve(tdigests->size()); for (auto& td : *tdigests) { - td->MergeInput(); - tdigest_impls.push_back(td->impl_.get()); + td.MergeInput(); + tdigest_impls.push_back(td.impl_.get()); } impl_->Merge(tdigest_impls); } @@ -385,6 +389,10 @@ double TDigest::Quantile(double q) { return impl_->Quantile(q); } +bool TDigest::is_empty() const { + return input_.size() == 0 && impl_->total_weight() == 0; +} + void TDigest::MergeInput() { if (input_.size() > 0) { impl_->MergeInput(input_); // will mutate input_ diff --git a/cpp/src/arrow/util/tdigest.h b/cpp/src/arrow/util/tdigest.h index 329926a153697..88ca6bae2c567 100644 --- a/cpp/src/arrow/util/tdigest.h +++ b/cpp/src/arrow/util/tdigest.h @@ -40,6 +40,8 @@ class ARROW_EXPORT TDigest { public: explicit TDigest(uint32_t delta = 100, uint32_t buffer_size = 500); ~TDigest(); + TDigest(TDigest&&); + TDigest& operator=(TDigest&&); // reset and re-use this tdigest void Reset(); @@ -52,6 +54,7 @@ class ARROW_EXPORT TDigest { // buffer a single data point, consume internal buffer if full // this function is intensively called and performance critical + // call it only if you are sure no NAN exists in input data void Add(double value) { DCHECK(!std::isnan(value)) << "cannot add NAN"; if (ARROW_PREDICT_FALSE(input_.size() == input_.capacity())) { @@ -60,12 +63,26 @@ class ARROW_EXPORT TDigest { input_.push_back(value); } + // skip NAN on adding + template + typename std::enable_if::value>::type NanAdd(T value) { + if (!std::isnan(value)) Add(value); + } + + template + typename std::enable_if::value>::type NanAdd(T value) { + Add(static_cast(value)); + } + // merge with other t-digests, called infrequently - void Merge(std::vector>* tdigests); + void Merge(std::vector* tdigests); // calculate quantile double Quantile(double q); + // check if this tdigest contains no valid data points + bool is_empty() const; + private: // merge input data with current tdigest void MergeInput(); diff --git a/cpp/src/arrow/util/tdigest_test.cc b/cpp/src/arrow/util/tdigest_test.cc index f17024eaac5be..e9a3924f812e7 100644 --- a/cpp/src/arrow/util/tdigest_test.cc +++ b/cpp/src/arrow/util/tdigest_test.cc @@ -32,7 +32,6 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/testing/util.h" -#include "arrow/util/make_unique.h" #include "arrow/util/tdigest.h" namespace arrow { @@ -51,7 +50,7 @@ TEST(TDigestTest, SingleValue) { } TEST(TDigestTest, FewValues) { - // exact quantile at 0.1 intervanl, test sorted and unsorted input + // exact quantile at 0.1 interval, test sorted and unsorted input std::vector> values_vector = { {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {4, 1, 9, 0, 3, 2, 5, 6, 8, 7, 10}, @@ -152,13 +151,13 @@ void TestMerge(const std::vector>& values_vector, uint32_t d const std::vector quantiles = {0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 1}; - std::vector> tds; + std::vector tds; for (const auto& values : values_vector) { - auto td = make_unique(delta); + TDigest td(delta); for (double value : values) { - td->Add(value); + td.Add(value); } - ASSERT_OK(td->Validate()); + ASSERT_OK(td.Validate()); tds.push_back(std::move(td)); } @@ -181,13 +180,13 @@ void TestMerge(const std::vector>& values_vector, uint32_t d // merge into a non empty tdigest { - std::unique_ptr td = std::move(tds[0]); + TDigest td = std::move(tds[0]); tds.erase(tds.begin(), tds.begin() + 1); - td->Merge(&tds); - ASSERT_OK(td->Validate()); + td.Merge(&tds); + ASSERT_OK(td.Validate()); for (size_t i = 0; i < quantiles.size(); ++i) { const double tolerance = std::max(std::fabs(expected[i]) * error_ratio, 0.1); - EXPECT_NEAR(td->Quantile(quantiles[i]), expected[i], tolerance) << quantiles[i]; + EXPECT_NEAR(td.Quantile(quantiles[i]), expected[i], tolerance) << quantiles[i]; } } } diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index bb96dce799319..7c2eae1e63f8d 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -207,6 +207,8 @@ Aggregations +--------------------------+------------+--------------------+-----------------------+--------------------------------------------+ | sum | Unary | Numeric | Scalar Numeric (4) | | +--------------------------+------------+--------------------+-----------------------+--------------------------------------------+ +| tdigest | Unary | Numeric | Scalar Float64 | :struct:`TDigestOptions` | ++--------------------------+------------+--------------------+-----------------------+--------------------------------------------+ | variance | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | +--------------------------+------------+--------------------+-----------------------+--------------------------------------------+ diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0e5c6779d7c56..e5a19288b876c 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -965,3 +965,22 @@ class QuantileOptions(_QuantileOptions): if not isinstance(q, (list, tuple, np.ndarray)): q = [q] self._set_options(q, interpolation) + + +cdef class _TDigestOptions(FunctionOptions): + cdef: + unique_ptr[CTDigestOptions] tdigest_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.tdigest_options.get() + + def _set_options(self, quantiles, delta, buffer_size): + self.tdigest_options.reset( + new CTDigestOptions(quantiles, delta, buffer_size)) + + +class TDigestOptions(_TDigestOptions): + def __init__(self, *, q=0.5, delta=100, buffer_size=500): + if not isinstance(q, (list, tuple, np.ndarray)): + q = [q] + self._set_options(q, delta, buffer_size) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index e1e64a6b7447f..616b2de89ec24 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -43,6 +43,7 @@ SortOptions, StrptimeOptions, TakeOptions, + TDigestOptions, TrimOptions, VarianceOptions, # Functions diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 6c1c7f671c769..e10ef1e3a5e7c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1890,6 +1890,14 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: vector[double] q CQuantileInterp interpolation + cdef cppclass CTDigestOptions \ + "arrow::compute::TDigestOptions"(CFunctionOptions): + CTDigestOptions(vector[double] q, + unsigned int delta, unsigned int buffer_size) + vector[double] q + unsigned int delta + unsigned int buffer_size + enum DatumType" arrow::Datum::type": DatumType_NONE" arrow::Datum::NONE" DatumType_SCALAR" arrow::Datum::SCALAR" diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 06a0269b54d87..673c1387c4749 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -1199,3 +1199,21 @@ def test_quantile(): pc.quantile(arr, q=1.1) with pytest.raises(ValueError, match="'zzz' is not a valid interpolation"): pc.quantile(arr, interpolation='zzz') + + +def test_tdigest(): + arr = pa.array([1, 2, 3, 4]) + result = pc.tdigest(arr) + assert result.to_pylist() == [2.5] + + arr = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) + result = pc.tdigest(arr) + assert result.to_pylist() == [2.5] + + arr = pa.array([1, 2, 3, 4]) + result = pc.tdigest(arr, q=[0, 0.5, 1]) + assert result.to_pylist() == [1, 2.5, 4] + + arr = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) + result = pc.tdigest(arr, q=[0, 0.5, 1]) + assert result.to_pylist() == [1, 2.5, 4] From 6a5ed0a91c36177905854a27a63d73c79bfa2ba6 Mon Sep 17 00:00:00 2001 From: Andre Braga Reis Date: Tue, 23 Feb 2021 14:53:40 -0500 Subject: [PATCH 16/54] ARROW-11725: [Rust][DataFusion] Make use of the new divide_scalar kernel in arrow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a small PR to make DataFusion use the just-merged `divide_scalar` arrow kernel (#9454). Performance-wise: * on the `arrow` side, this specialized kernel is ~40-50% faster than the standard `divide`, mostly due to not having to check for divide-by-zero on every row; * on the `datafusion` side, it can now skip the `scalar.to_array_of_size(num_rows)` allocation, which should be a decent win for operations on large arrays. The eventual goal is to have `op_scalar` variants for every arithmetic operation — `divide` will show the biggest performance gains but all variants should save DataFusion a (possibly expensive) allocation. Closes #9543 from abreis/datafusion-divide-scalar Authored-by: Andre Braga Reis Signed-off-by: Andrew Lamb --- .../src/physical_plan/expressions/binary.rs | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/rust/datafusion/src/physical_plan/expressions/binary.rs b/rust/datafusion/src/physical_plan/expressions/binary.rs index 0d503508d63db..9e048c9d4fd82 100644 --- a/rust/datafusion/src/physical_plan/expressions/binary.rs +++ b/rust/datafusion/src/physical_plan/expressions/binary.rs @@ -18,7 +18,9 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; -use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; +use arrow::compute::kernels::arithmetic::{ + add, divide, divide_scalar, multiply, subtract, +}; use arrow::compute::kernels::boolean::{and, or}; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ @@ -162,10 +164,10 @@ macro_rules! compute_op { macro_rules! binary_string_array_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result = match $LEFT.data_type() { + let result: Result> = match $LEFT.data_type() { DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for scalar operation on string array", other ))), }; @@ -178,7 +180,7 @@ macro_rules! binary_string_array_op { match $LEFT.data_type() { DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for binary operation on string arrays", other ))), } @@ -202,19 +204,44 @@ macro_rules! binary_primitive_array_op { DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array), DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for binary operation on primitive arrays", other ))), } }}; } +/// Invoke a compute kernel on an array and a scalar +/// The binary_primitive_array_op_scalar macro only evaluates for primitive +/// types like integers and floats. +macro_rules! binary_primitive_array_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let result: Result> = match $LEFT.data_type() { + DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for scalar operation on primitive array", + other + ))), + }; + Some(result) + }}; +} + /// The binary_array_op_scalar macro includes types that extend beyond the primitive, /// such as Utf8 strings. #[macro_export] macro_rules! binary_array_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result = match $LEFT.data_type() { + let result: Result> = match $LEFT.data_type() { DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), @@ -233,7 +260,7 @@ macro_rules! binary_array_op_scalar { compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for scalar operation on dyn array", other ))), }; @@ -268,7 +295,7 @@ macro_rules! binary_array_op { compute_op!($LEFT, $RIGHT, $OP, Date64Array) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for binary operation on dyn arrays", other ))), } @@ -424,6 +451,9 @@ impl PhysicalExpr for BinaryExpr { Operator::NotLike => { binary_string_array_op_scalar!(array, scalar.clone(), nlike) } + Operator::Divide => { + binary_primitive_array_op_scalar!(array, scalar.clone(), divide) + } // if scalar operation is not supported - fallback to array implementation _ => None, } From b0a2e1b5f35aae383a1ce75bea6bea142f8fa510 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Tue, 23 Feb 2021 12:45:36 -0800 Subject: [PATCH 17/54] ARROW-11743: [R] Use pkgdown's new found ability to autolink Jiras Closes #9555 from jonkeane/new-pkgdown Lead-authored-by: Jonathan Keane Co-authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/NEWS.md | 18 +++++++++--------- r/_pkgdown.yml | 2 ++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/r/NEWS.md b/r/NEWS.md index 5ec1b14106eff..65c4e2205cca3 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -43,7 +43,7 @@ * Option `arrow.skip_nul` (default `FALSE`, as in `base::scan()`) allows conversion of Arrow string (`utf8()`) type data containing embedded nul `\0` characters to R. If set to `TRUE`, nuls will be stripped and a warning is emitted if any are found. * `arrow_info()` for an overview of various run-time and build-time Arrow configurations, useful for debugging * Set environment variable `ARROW_DEFAULT_MEMORY_POOL` before loading the Arrow package to change memory allocators. Windows packages are built with `mimalloc`; most others are built with both `jemalloc` (used by default) and `mimalloc`. These alternative memory allocators are generally much faster than the system memory allocator, so they are used by default when available, but sometimes it is useful to turn them off for debugging purposes. To disable them, set `ARROW_DEFAULT_MEMORY_POOL=system`. -* List columns that have attributes on each element are now also included with the metadata that is saved when creating Arrow tables. This allows `sf` tibbles to faithfully preserved and roundtripped ([ARROW-10386](https://issues.apache.org/jira/browse/ARROW-10386)). +* List columns that have attributes on each element are now also included with the metadata that is saved when creating Arrow tables. This allows `sf` tibbles to faithfully preserved and roundtripped (ARROW-10386). * R metadata that exceeds 100Kb is now compressed before being written to a table; see `schema()` for more details. ## Bug fixes @@ -52,8 +52,8 @@ * C++ functions now trigger garbage collection when needed * `write_parquet()` can now write RecordBatches * Reading a Table from a RecordBatchStreamReader containing 0 batches no longer crashes -* `readr`'s `problems` attribute is removed when converting to Arrow RecordBatch and table to prevent large amounts of metadata from accumulating inadvertently ([ARROW-10624](https://issues.apache.org/jira/browse/ARROW-10624)) -* Fixed reading of compressed Feather files written with Arrow 0.17 ([ARROW-10850](https://issues.apache.org/jira/browse/ARROW-10850)) +* `readr`'s `problems` attribute is removed when converting to Arrow RecordBatch and table to prevent large amounts of metadata from accumulating inadvertently (ARROW-10624) +* Fixed reading of compressed Feather files written with Arrow 0.17 (ARROW-10850) * `SubTreeFileSystem` gains a useful print method and no longer errors when printing ## Packaging and installation @@ -284,22 +284,22 @@ See `vignette("install", package = "arrow")` for details. * The R6 classes that wrap the C++ classes are now documented and exported and have been renamed to be more R-friendly. Users of the high-level R interface in this package are not affected. Those who want to interact with the Arrow C++ API more directly should work with these objects and methods. As part of this change, many functions that instantiated these R6 objects have been removed in favor of `Class$create()` methods. Notably, `arrow::array()` and `arrow::table()` have been removed in favor of `Array$create()` and `Table$create()`, eliminating the package startup message about masking `base` functions. For more information, see the new `vignette("arrow")`. * Due to a subtle change in the Arrow message format, data written by the 0.15 version libraries may not be readable by older versions. If you need to send data to a process that uses an older version of Arrow (for example, an Apache Spark server that hasn't yet updated to Arrow 0.15), you can set the environment variable `ARROW_PRE_0_15_IPC_FORMAT=1`. -* The `as_tibble` argument in the `read_*()` functions has been renamed to `as_data_frame` ([ARROW-6337](https://issues.apache.org/jira/browse/ARROW-6337), @jameslamb) +* The `as_tibble` argument in the `read_*()` functions has been renamed to `as_data_frame` (ARROW-6337, @jameslamb) * The `arrow::Column` class has been removed, as it was removed from the C++ library ## New features * `Table` and `RecordBatch` objects have S3 methods that enable you to work with them more like `data.frame`s. Extract columns, subset, and so on. See `?Table` and `?RecordBatch` for examples. -* Initial implementation of bindings for the C++ File System API. ([ARROW-6348](https://issues.apache.org/jira/browse/ARROW-6348)) -* Compressed streams are now supported on Windows ([ARROW-6360](https://issues.apache.org/jira/browse/ARROW-6360)), and you can also specify a compression level ([ARROW-6533](https://issues.apache.org/jira/browse/ARROW-6533)) +* Initial implementation of bindings for the C++ File System API. (ARROW-6348) +* Compressed streams are now supported on Windows (ARROW-6360), and you can also specify a compression level (ARROW-6533) ## Other upgrades * Parquet file reading is much, much faster, thanks to improvements in the Arrow C++ library. * `read_csv_arrow()` supports more parsing options, including `col_names`, `na`, `quoted_na`, and `skip` -* `read_parquet()` and `read_feather()` can ingest data from a `raw` vector ([ARROW-6278](https://issues.apache.org/jira/browse/ARROW-6278)) -* File readers now properly handle paths that need expanding, such as `~/file.parquet` ([ARROW-6323](https://issues.apache.org/jira/browse/ARROW-6323)) -* Improved support for creating types in a schema: the types' printed names (e.g. "double") are guaranteed to be valid to use in instantiating a schema (e.g. `double()`), and time types can be created with human-friendly resolution strings ("ms", "s", etc.). ([ARROW-6338](https://issues.apache.org/jira/browse/ARROW-6338), [ARROW-6364](https://issues.apache.org/jira/browse/ARROW-6364)) +* `read_parquet()` and `read_feather()` can ingest data from a `raw` vector (ARROW-6278) +* File readers now properly handle paths that need expanding, such as `~/file.parquet` (ARROW-6323) +* Improved support for creating types in a schema: the types' printed names (e.g. "double") are guaranteed to be valid to use in instantiating a schema (e.g. `double()`), and time types can be created with human-friendly resolution strings ("ms", "s", etc.). (ARROW-6338, ARROW-6364) # arrow 0.14.1 diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index ba71aea1bf610..af10006ec0036 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -155,5 +155,7 @@ reference: - install_pyarrow repo: + jira_projects: [ARROW] url: source: https://github.com/apache/arrow/blob/master/r/ + issue: https://issues.apache.org/jira/browse/ From 3c7fb4ca18d2dc1cb592a129bc1754a4031e9ad2 Mon Sep 17 00:00:00 2001 From: Diana Clarke Date: Tue, 23 Feb 2021 17:00:22 -0500 Subject: [PATCH 18/54] ARROW-11573: [Developer][Archery] Google benchmark now reports run type See: https://issues.apache.org/jira/browse/ARROW-11573 Google Benchmark now reports run type [1], so the following code and comment can be updated. ``` Observations are found when running with `--benchmark_repetitions`. Sadly, the format mixes values and aggregates, e.g. RegressionSumKernel/32768/0 1 us 1 us 25.8077GB/s RegressionSumKernel/32768/0 1 us 1 us 25.7066GB/s RegressionSumKernel/32768/0 1 us 1 us 25.1481GB/s RegressionSumKernel/32768/0 1 us 1 us 25.846GB/s RegressionSumKernel/32768/0 1 us 1 us 25.6453GB/s RegressionSumKernel/32768/0_mean 1 us 1 us 25.6307GB/s RegressionSumKernel/32768/0_median 1 us 1 us 25.7066GB/s RegressionSumKernel/32768/0_stddev 0 us 0 us 288.046MB/s As from benchmark v1.4.1 (2019-04-24), the only way to differentiate an actual run from the aggregates, is to match on the benchmark name. The aggregates will be appended with `_$agg_name`. This class encapsulate the logic to separate runs from aggregate . This is hopefully avoided in benchmark's master version with a separate json attribute. ``` ``` @property def is_agg(self): """ Indicate if the observation is a run or an aggregate. """ suffixes = ["_mean", "_median", "_stddev"] return any(map(lambda x: self._name.endswith(x), suffixes)) ``` Here's example output (note the aggregate vs the actual observation): ``` {'aggregate_name': 'mean', 'cpu_time': 9818703.124999983, 'items_per_second': 26700744.55186333, 'iterations': 3, 'name': 'TakeStringRandomIndicesWithNulls/262144/0_mean', 'null_percent': 0.0, 'real_time': 10138621.349445505, 'repetitions': 0, 'run_name': 'TakeStringRandomIndicesWithNulls/262144/0', 'run_type': 'aggregate', 'size': 262144.0, 'threads': 1, 'time_unit': 'ns'}, {'cpu_time': 9718937.499999996, 'items_per_second': 26972495.707478322, 'iterations': 64, 'name': 'TakeStringRandomIndicesWithNulls/262144/0', 'null_percent': 0.0, 'real_time': 10297947.859726265, 'repetition_index': 2, 'repetitions': 0, 'run_name': 'TakeStringRandomIndicesWithNulls/262144/0', 'run_type': 'iteration', 'size': 262144.0, 'threads': 1, 'time_unit': 'ns'}, ``` [1] https://github.com/google/benchmark/commit/8688c5c4cfa1527ceca2136b2a738d9712a01890 Closes #9457 from dianaclarke/ARROW-11573 Authored-by: Diana Clarke Signed-off-by: Benjamin Kietzman --- dev/archery/archery/benchmark/google.py | 30 +++++++--------- dev/archery/tests/test_benchmarks.py | 47 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/dev/archery/archery/benchmark/google.py b/dev/archery/archery/benchmark/google.py index 81429ef82c9da..20d0d8fa9a6aa 100644 --- a/dev/archery/archery/benchmark/google.py +++ b/dev/archery/archery/benchmark/google.py @@ -67,8 +67,11 @@ def results(self, repetitions=1): class GoogleBenchmarkObservation: """ Represents one run of a single (google c++) benchmark. - Observations are found when running with `--benchmark_repetitions`. Sadly, - the format mixes values and aggregates, e.g. + Aggregates are reported by Google Benchmark executables alongside + other observations whenever repetitions are specified (with + `--benchmark_repetitions` on the bare benchmark, or with the + archery option `--repetitions`). Aggregate observations are not + included in `GoogleBenchmark.runs`. RegressionSumKernel/32768/0 1 us 1 us 25.8077GB/s RegressionSumKernel/32768/0 1 us 1 us 25.7066GB/s @@ -78,32 +81,25 @@ class GoogleBenchmarkObservation: RegressionSumKernel/32768/0_mean 1 us 1 us 25.6307GB/s RegressionSumKernel/32768/0_median 1 us 1 us 25.7066GB/s RegressionSumKernel/32768/0_stddev 0 us 0 us 288.046MB/s - - As from benchmark v1.4.1 (2019-04-24), the only way to differentiate an - actual run from the aggregates, is to match on the benchmark name. The - aggregates will be appended with `_$agg_name`. - - This class encapsulate the logic to separate runs from aggregate . This is - hopefully avoided in benchmark's master version with a separate json - attribute. """ - def __init__(self, name, real_time, cpu_time, time_unit, size=None, - bytes_per_second=None, items_per_second=None, **counters): + def __init__(self, name, real_time, cpu_time, time_unit, run_type, + size=None, bytes_per_second=None, items_per_second=None, + **counters): self._name = name self.real_time = real_time self.cpu_time = cpu_time self.time_unit = time_unit + self.run_type = run_type self.size = size self.bytes_per_second = bytes_per_second self.items_per_second = items_per_second self.counters = counters @property - def is_agg(self): + def is_aggregate(self): """ Indicate if the observation is a run or an aggregate. """ - suffixes = ["_mean", "_median", "_stddev"] - return any(map(lambda x: self._name.endswith(x), suffixes)) + return self.run_type == "aggregate" @property def is_realtime(self): @@ -113,7 +109,7 @@ def is_realtime(self): @property def name(self): name = self._name - return name.rsplit("_", maxsplit=1)[0] if self.is_agg else name + return name.rsplit("_", maxsplit=1)[0] if self.is_aggregate else name @property def time(self): @@ -153,7 +149,7 @@ def __init__(self, name, runs): """ self.name = name # exclude google benchmark aggregate artifacts - _, runs = partition(lambda b: b.is_agg, runs) + _, runs = partition(lambda b: b.is_aggregate, runs) self.runs = sorted(runs, key=lambda b: b.value) unit = self.runs[0].unit less_is_better = not unit.endswith("per_second") diff --git a/dev/archery/tests/test_benchmarks.py b/dev/archery/tests/test_benchmarks.py index e6d10121b9fd9..fefdf1eb3da00 100644 --- a/dev/archery/tests/test_benchmarks.py +++ b/dev/archery/tests/test_benchmarks.py @@ -15,8 +15,14 @@ # specific language governing permissions and limitations # under the License. +import json + from archery.benchmark.core import Benchmark, median from archery.benchmark.compare import BenchmarkComparator +from archery.benchmark.google import ( + GoogleBenchmark, GoogleBenchmarkObservation +) +from archery.utils.codec import JsonEncoder def test_benchmark_comparator(): @@ -50,3 +56,44 @@ def test_benchmark_median(): assert False except ValueError: pass + + +def test_omits_aggregates(): + name = "AllocateDeallocate/size:1048576/real_time" + google_aggregate = { + "aggregate_name": "mean", + "cpu_time": 1757.428694267678, + "iterations": 3, + "name": "AllocateDeallocate/size:1048576/real_time_mean", + "real_time": 1849.3869337041162, + "repetitions": 0, + "run_name": "AllocateDeallocate/size:1048576/real_time", + "run_type": "aggregate", + "threads": 1, + "time_unit": "ns", + } + google_result = { + "cpu_time": 1778.6004847419827, + "iterations": 352765, + "name": name, + "real_time": 1835.3137357788837, + "repetition_index": 0, + "repetitions": 0, + "run_name": "AllocateDeallocate/size:1048576/real_time", + "run_type": "iteration", + "threads": 1, + "time_unit": "ns", + } + archery_result = { + "name": name, + "unit": "ns", + "less_is_better": True, + "values": [1778.6004847419827], + } + assert google_aggregate["run_type"] == "aggregate" + assert google_result["run_type"] == "iteration" + observation1 = GoogleBenchmarkObservation(**google_aggregate) + observation2 = GoogleBenchmarkObservation(**google_result) + benchmark = GoogleBenchmark(name, [observation1, observation2]) + result = json.dumps(benchmark, cls=JsonEncoder) + assert json.loads(result) == archery_result From 73cd0bc4102b75ceff9845d62f4125e2948181f7 Mon Sep 17 00:00:00 2001 From: Ryan Jennings Date: Tue, 23 Feb 2021 17:42:35 -0500 Subject: [PATCH 19/54] ARROW-11753: [Rust][DataFusion] Add tests for when Datafusion qualified field names resolved Adds tests that prove ARROW-11432 is resolved. Does not resolve the issue. The added test is currently ignored, however [#ignore] can be removed on the resolution of bug or testing. Closes #9544 from TurnOfACard/arrow-11432-tests Authored-by: Ryan Jennings Signed-off-by: Andrew Lamb --- rust/datafusion/tests/sql.rs | 69 ++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 2f780b662b86c..7a0666635a2e4 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1320,6 +1320,44 @@ fn create_join_context( Ok(ctx) } +fn create_join_context_qualified() -> Result { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, true), + Field::new("b", DataType::UInt32, true), + Field::new("c", DataType::UInt32, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), + Arc::new(UInt32Array::from(vec![10, 20, 30, 40])), + Arc::new(UInt32Array::from(vec![50, 60, 70, 80])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table)); + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, true), + Field::new("b", DataType::UInt32, true), + Field::new("c", DataType::UInt32, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 9, 4])), + Arc::new(UInt32Array::from(vec![100, 200, 300, 400])), + Arc::new(UInt32Array::from(vec![500, 600, 700, 800])), + ], + )?; + let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(t2_table)); + + Ok(ctx) +} + #[tokio::test] async fn csv_explain() { let mut ctx = ExecutionContext::new(); @@ -2237,3 +2275,34 @@ async fn in_list_scalar() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +// TODO Tests to prove correct implementation of INNER JOIN's with qualified names. +// https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11432. +#[tokio::test] +#[ignore] +async fn inner_join_qualified_names() -> Result<()> { + // Setup the statements that test qualified names function correctly. + let equivalent_sql = [ + "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c + FROM t1 + INNER JOIN t2 ON t1.a = t2.a + ORDER BY t1.a", + "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c + FROM t1 + INNER JOIN t2 ON t2.a = t1.a + ORDER BY t1.a", + ]; + + let expected = vec![ + vec!["1", "10", "50", "1", "100", "500"], + vec!["2", "20", "60", "2", "20", "600"], + vec!["4", "40", "80", "4", "400", "800"], + ]; + + for sql in equivalent_sql.iter() { + let mut ctx = create_join_context_qualified()?; + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } + Ok(()) +} From 5bea62493d919dcb97ca0f22bcb7ebfc239cee25 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 23 Feb 2021 17:52:24 -0500 Subject: [PATCH 20/54] ARROW-11688: [Rust] Casts between Utf8 and LargeUtf8 This PR makes it possible to cast between `utf8` and `large-utf8` arrays. Closes #9526 from ritchie46/cast_utf8_to_largeutf8 Authored-by: Ritchie Vink Signed-off-by: Andrew Lamb --- rust/arrow/src/compute/kernels/cast.rs | 91 ++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index b5fc09f999ced..25592c657ae88 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -82,6 +82,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (_, Boolean) => DataType::is_numeric(from_type), (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8, + (Utf8, LargeUtf8) => true, + (LargeUtf8, Utf8) => true, (Utf8, Date32) => true, (Utf8, Date64) => true, (Utf8, Timestamp(TimeUnit::Nanosecond, None)) => true, @@ -361,6 +363,7 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { ))), }, (Utf8, _) => match to_type { + LargeUtf8 => cast_str_container::(&**array), UInt8 => cast_string_to_numeric::(array), UInt16 => cast_string_to_numeric::(array), UInt32 => cast_string_to_numeric::(array), @@ -428,6 +431,7 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { ))), }, (_, Utf8) => match from_type { + LargeUtf8 => cast_str_container::(&**array), UInt8 => cast_numeric_to_string::(array), UInt16 => cast_numeric_to_string::(array), UInt32 => cast_numeric_to_string::(array), @@ -1297,6 +1301,53 @@ fn cast_list_inner( Ok(Arc::new(list) as ArrayRef) } +/// Helper function to cast from `Utf8` to `LargeUtf8` and vice versa. If the `LargeUtf8` is too large for +/// a `Utf8` array it will return an Error. +fn cast_str_container(array: &dyn Array) -> Result +where + OffsetSizeFrom: StringOffsetSizeTrait + ToPrimitive, + OffsetSizeTo: StringOffsetSizeTrait + NumCast + ArrowNativeType, +{ + let str_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let list_data = array.data(); + let str_values_buf = str_array.value_data(); + + let offsets = unsafe { list_data.buffers()[0].typed_data::() }; + + let mut offset_builder = BufferBuilder::::new(offsets.len()); + offsets.iter().try_for_each::<_, Result<_>>(|offset| { + let offset = OffsetSizeTo::from(*offset).ok_or_else(|| { + ArrowError::ComputeError( + "large-utf8 array too large to cast to utf8-array".into(), + ) + })?; + offset_builder.append(offset); + Ok(()) + })?; + + let offset_buffer = offset_builder.finish(); + + let dtype = if matches!(std::mem::size_of::(), 8) { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + + let mut builder = ArrayData::builder(dtype) + .len(array.len()) + .add_buffer(offset_buffer) + .add_buffer(str_values_buf); + + if let Some(buf) = list_data.null_buffer() { + builder = builder.null_bit_buffer(buf.clone()) + } + let data = builder.build(); + Ok(Arc::new(GenericStringArray::::from(data))) +} + /// Cast the container type of List/Largelist array but not the inner types. /// This function can leave the value data intact and only has to cast the offset dtypes. fn cast_list_container( @@ -1778,6 +1829,46 @@ mod tests { assert_eq!(out, vec![Some("1"), Some("2"), Some("3")]); } + #[test] + fn test_str_to_str_casts() { + for data in vec![ + vec![Some("foo"), Some("bar"), Some("ham")], + vec![Some("foo"), None, Some("bar")], + ] { + let a = Arc::new(LargeStringArray::from(data.clone())) as ArrayRef; + let to = cast(&a, &DataType::Utf8).unwrap(); + let expect = a + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + let out = to + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(expect, out); + + let a = Arc::new(StringArray::from(data)) as ArrayRef; + let to = cast(&a, &DataType::LargeUtf8).unwrap(); + let expect = a + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + let out = to + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(expect, out); + } + } + #[test] fn test_cast_from_f64() { let f64_values: Vec = vec![ From a83bc1792f53791ea2000a972a3c868d29b6f875 Mon Sep 17 00:00:00 2001 From: Max Burke Date: Wed, 24 Feb 2021 15:30:07 +0200 Subject: [PATCH 21/54] ARROW-11452: [Rust] Fix issue with Parquet Arrow reader not following type path Not sure where the test data file should go, but I've attached it. [structs.parquet.zip](https://github.com/apache/arrow/files/5906689/structs.parquet.zip) Closes #9390 from maxburke/ARROW-11452 Lead-authored-by: Max Burke Co-authored-by: Neville Dipale Signed-off-by: Neville Dipale --- cpp/submodules/parquet-testing | 2 +- rust/parquet/src/arrow/array_reader.rs | 59 ++++++++++++++++++++++---- rust/parquet/src/arrow/arrow_reader.rs | 19 +++++++++ rust/parquet/src/schema/types.rs | 4 ++ 4 files changed, 75 insertions(+), 9 deletions(-) diff --git a/cpp/submodules/parquet-testing b/cpp/submodules/parquet-testing index e31fe1a02c9e9..8e7badc6a3817 160000 --- a/cpp/submodules/parquet-testing +++ b/cpp/submodules/parquet-testing @@ -1 +1 @@ -Subproject commit e31fe1a02c9e9f271e4bfb8002d403c52f1ef8eb +Subproject commit 8e7badc6a3817a02e06d17b5d8ab6b6dc356e890 diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index 70187850e497e..dcdfbcbe7b00d 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -1095,6 +1095,7 @@ where for c in column_indices { let column = parquet_schema.column(c).self_type() as *const Type; + leaves.insert(column, c); let root = parquet_schema.get_column_root_ptr(c); @@ -1395,12 +1396,11 @@ impl<'a> ArrayReaderBuilder { self.file_reader.clone(), )?); - let arrow_type = self - .arrow_schema - .field_with_name(cur_type.name()) - .ok() - .map(|f| f.data_type()) - .cloned(); + let arrow_type: Option = match self.get_arrow_field(&cur_type, context) + { + Some(f) => Some(f.data_type().clone()), + _ => None, + }; match cur_type.get_physical_type() { PhysicalType::BOOLEAN => Ok(Box::new(PrimitiveArrayReader::::new( @@ -1631,9 +1631,13 @@ impl<'a> ArrayReaderBuilder { let mut children_reader = Vec::with_capacity(cur_type.get_fields().len()); for child in cur_type.get_fields() { + let mut struct_context = context.clone(); if let Some(child_reader) = self.dispatch(child.clone(), context)? { - let field = match self.arrow_schema.field_with_name(child.name()) { - Ok(f) => f.to_owned(), + // TODO: this results in calling get_arrow_field twice, it could be reused + // from child_reader above, by making child_reader carry its `Field` + struct_context.path.append(vec![child.name().to_string()]); + let field = match self.get_arrow_field(child, &struct_context) { + Some(f) => f.clone(), _ => Field::new( child.name(), child_reader.get_data_type().clone(), @@ -1657,6 +1661,45 @@ impl<'a> ArrayReaderBuilder { Ok(None) } } + + fn get_arrow_field( + &self, + cur_type: &Type, + context: &'a ArrayReaderBuilderContext, + ) -> Option<&Field> { + let parts: Vec<&str> = context + .path + .parts() + .iter() + .map(|x| -> &str { x }) + .collect::>(); + + // If the parts length is one it'll have the top level "schema" type. If + // it's two then it'll be a top-level type that we can get from the arrow + // schema directly. + if parts.len() <= 2 { + self.arrow_schema.field_with_name(cur_type.name()).ok() + } else { + // If it's greater than two then we need to traverse the type path + // until we find the actual field we're looking for. + let mut field: Option<&Field> = None; + + for (i, part) in parts.iter().enumerate().skip(1) { + if i == 1 { + field = self.arrow_schema.field_with_name(part).ok(); + } else if let Some(f) = field { + if let ArrowType::Struct(fields) = f.data_type() { + field = fields.iter().find(|f| f.name() == part) + } else { + field = None + } + } else { + field = None + } + } + field + } + } } #[cfg(test)] diff --git a/rust/parquet/src/arrow/arrow_reader.rs b/rust/parquet/src/arrow/arrow_reader.rs index 288e043b64291..7bbe8de1d6459 100644 --- a/rust/parquet/src/arrow/arrow_reader.rs +++ b/rust/parquet/src/arrow/arrow_reader.rs @@ -649,4 +649,23 @@ mod tests { } } } + + #[test] + fn test_read_structs() { + // This particular test file has columns of struct types where there is + // a column that has the same name as one of the struct fields + // (see: ARROW-11452) + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/nested_structs.rust.parquet", testdata); + let parquet_file_reader = + SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_file_reader)); + let record_batch_reader = arrow_reader + .get_record_reader(60) + .expect("Failed to read into array!"); + + for batch in record_batch_reader { + batch.unwrap(); + } + } } diff --git a/rust/parquet/src/schema/types.rs b/rust/parquet/src/schema/types.rs index 27768fbb63eff..5c35e1cde2c03 100644 --- a/rust/parquet/src/schema/types.rs +++ b/rust/parquet/src/schema/types.rs @@ -561,6 +561,10 @@ impl ColumnPath { pub fn append(&mut self, mut tail: Vec) { self.parts.append(&mut tail); } + + pub fn parts(&self) -> &[String] { + &self.parts + } } impl fmt::Display for ColumnPath { From 4b0375b295e6e1863c77d84d9244b09d93cc7cd5 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Wed, 24 Feb 2021 15:32:43 +0200 Subject: [PATCH 22/54] ARROW-11718: [Rust] Don't write IPC footers on drop As discussed in #9520, these destructors shouldn't be there at all. Closes #9536 from sfackler/no-ipc-drop Authored-by: Steven Fackler Signed-off-by: Neville Dipale --- rust/arrow/src/ipc/writer.rs | 34 ++++++++----------- .../src/bin/arrow-json-integration-test.rs | 2 ++ 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 7f06fa186f99a..c6b28944d2407 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -434,6 +434,12 @@ impl FileWriter { /// Write footer and closing tag, then mark the writer as done pub fn finish(&mut self) -> Result<()> { + if self.finished { + return Err(ArrowError::IoError( + "Cannot write footer to file writer as it is closed".to_string(), + )); + } + // write EOS write_continuation(&mut self.writer, &self.write_options, 0)?; @@ -463,15 +469,6 @@ impl FileWriter { } } -/// Finish the file if it is not 'finished' when it goes out of scope -impl Drop for FileWriter { - fn drop(&mut self) { - if !self.finished { - let _ = self.finish(); - } - } -} - pub struct StreamWriter { /// The object to write to writer: BufWriter, @@ -537,6 +534,12 @@ impl StreamWriter { /// Write continuation bytes, and mark the stream as done pub fn finish(&mut self) -> Result<()> { + if self.finished { + return Err(ArrowError::IoError( + "Cannot write footer to stream writer as it is closed".to_string(), + )); + } + write_continuation(&mut self.writer, &self.write_options, 0)?; self.finished = true; @@ -545,15 +548,6 @@ impl StreamWriter { } } -/// Finish the stream if it is not 'finished' when it goes out of scope -impl Drop for StreamWriter { - fn drop(&mut self) { - if !self.finished { - let _ = self.finish(); - } - } -} - /// Stores the encoded data, which is an ipc::Message, and optional Arrow data pub struct EncodedData { /// An encoded ipc::Message @@ -776,7 +770,7 @@ mod tests { let mut writer = FileWriter::try_new(file, &schema).unwrap(); writer.write(&batch).unwrap(); - // this is inside a block to test the implicit finishing of the file on `Drop` + writer.finish().unwrap(); } { @@ -826,7 +820,7 @@ mod tests { FileWriter::try_new_with_options(file, &schema, options).unwrap(); writer.write(&batch).unwrap(); - // this is inside a block to test the implicit finishing of the file on `Drop` + writer.finish().unwrap(); } { diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index 52517bc8dc9a1..257802028b207 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -78,6 +78,8 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> writer.write(&b)?; } + writer.finish()?; + Ok(()) } From 13b8db691a6b57f5e9641d785c2870746b4e2a13 Mon Sep 17 00:00:00 2001 From: Diana Clarke Date: Wed, 24 Feb 2021 09:26:49 -0500 Subject: [PATCH 23/54] ARROW-11746: [Developer][Archery] Fix prefer real time check See: https://issues.apache.org/jira/browse/ARROW-11746 Google Benchmark adds `/real_time` to the end of the benchmark name to indicate if the `real_time` observation should be preferred over the `cpu_time` observation. https://github.com/google/benchmark/blob/af72911f2fe6b8114564614d2db17a449f8c4af0/src/benchmark_register.cc#L222 Example: `AllocateDeallocate/size:1048576/real_time` Archery is looking for `"/realtime"`, not `"/real_time"` though. ``` @property def is_realtime(self): """ Indicate if the preferred value is realtime instead of cputime. """ return self.name.find("/realtime") != -1 ``` Closes #9557 from dianaclarke/ARROW-11746 Authored-by: Diana Clarke Signed-off-by: Benjamin Kietzman --- dev/archery/archery/benchmark/google.py | 2 +- dev/archery/tests/test_benchmarks.py | 57 ++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/dev/archery/archery/benchmark/google.py b/dev/archery/archery/benchmark/google.py index 20d0d8fa9a6aa..f5958b17864e2 100644 --- a/dev/archery/archery/benchmark/google.py +++ b/dev/archery/archery/benchmark/google.py @@ -104,7 +104,7 @@ def is_aggregate(self): @property def is_realtime(self): """ Indicate if the preferred value is realtime instead of cputime. """ - return self.name.find("/realtime") != -1 + return self.name.find("/real_time") != -1 @property def name(self): diff --git a/dev/archery/tests/test_benchmarks.py b/dev/archery/tests/test_benchmarks.py index fefdf1eb3da00..0566805842a06 100644 --- a/dev/archery/tests/test_benchmarks.py +++ b/dev/archery/tests/test_benchmarks.py @@ -58,6 +58,61 @@ def test_benchmark_median(): pass +def assert_benchmark(name, google_result, archery_result): + observation = GoogleBenchmarkObservation(**google_result) + benchmark = GoogleBenchmark(name, [observation]) + result = json.dumps(benchmark, cls=JsonEncoder) + assert json.loads(result) == archery_result + + +def test_prefer_real_time(): + name = "AllocateDeallocate/size:1048576/real_time" + google_result = { + "cpu_time": 1778.6004847419827, + "iterations": 352765, + "name": name, + "real_time": 1835.3137357788837, + "repetition_index": 0, + "repetitions": 0, + "run_name": "AllocateDeallocate/size:1048576/real_time", + "run_type": "iteration", + "threads": 1, + "time_unit": "ns", + } + archery_result = { + "name": name, + "unit": "ns", + "less_is_better": True, + "values": [1835.3137357788837], + } + assert name.endswith("/real_time") + assert_benchmark(name, google_result, archery_result) + + +def test_prefer_cpu_time(): + name = "AllocateDeallocate/size:1048576" + google_result = { + "cpu_time": 1778.6004847419827, + "iterations": 352765, + "name": name, + "real_time": 1835.3137357788837, + "repetition_index": 0, + "repetitions": 0, + "run_name": "AllocateDeallocate/size:1048576", + "run_type": "iteration", + "threads": 1, + "time_unit": "ns", + } + archery_result = { + "name": name, + "unit": "ns", + "less_is_better": True, + "values": [1778.6004847419827], + } + assert not name.endswith("/real_time") + assert_benchmark(name, google_result, archery_result) + + def test_omits_aggregates(): name = "AllocateDeallocate/size:1048576/real_time" google_aggregate = { @@ -88,7 +143,7 @@ def test_omits_aggregates(): "name": name, "unit": "ns", "less_is_better": True, - "values": [1778.6004847419827], + "values": [1835.3137357788837], } assert google_aggregate["run_type"] == "aggregate" assert google_result["run_type"] == "iteration" From 732fbf6178c622fe4d90ef34eb28d94f527e5287 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Wed, 24 Feb 2021 09:54:54 -0500 Subject: [PATCH 24/54] ARROW-11727: [C++][FlightRPC] Estimate latency quantiles with TDigest Current code uses P-Square algorithm from boost accumulator library to estimate latency quantiles. P-Square is very bad at estimating skewed quantiles like 0.99. This patch replaces boost accumulator with our own TDigest utility, which gives much more accurate estimations. Evaluate 0.99 latency quantile accuracy of TDigest and Boost. Exact value is obtained by storing and sorting all data points. | Exact | TDigest | Boost-P2 | | ----- | ------- | -------- | | 86 | 93 | 2130 | | 175 | 235 | 1526 | | 151 | 165 | 1926 | | 147 | 153 | 302 | | 251 | 313 | 561 | Closes #9558 from cyb70289/flight-tdigest Authored-by: Yibo Cai Signed-off-by: David Li --- cpp/src/arrow/flight/flight_benchmark.cc | 25 ++++++------------------ cpp/src/arrow/util/tdigest.cc | 13 ++++++++++++ cpp/src/arrow/util/tdigest.h | 4 ++++ 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index 52e71f6c33d40..06068e608addc 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -21,12 +21,6 @@ #include #include -#include -#include -#include -#include -#include - #include #include "arrow/io/memory.h" @@ -34,6 +28,7 @@ #include "arrow/record_batch.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/stopwatch.h" +#include "arrow/util/tdigest.h" #include "arrow/util/thread_pool.h" #include "arrow/flight/api.h" @@ -59,7 +54,6 @@ DEFINE_int32(records_per_batch, 4096, "Total records per batch within stream"); DEFINE_bool(test_put, false, "Test DoPut instead of DoGet"); namespace perf = arrow::flight::perf; -namespace acc = boost::accumulators; namespace arrow { @@ -75,17 +69,12 @@ struct PerformanceResult { }; struct PerformanceStats { - using accumulator_type = acc::accumulator_set< - double, acc::stats>; - - PerformanceStats() : latencies(acc::extended_p_square_probabilities = quantiles) {} std::mutex mutex; int64_t total_batches = 0; int64_t total_records = 0; int64_t total_bytes = 0; const std::array quantiles = {0.5, 0.95, 0.99}; - accumulator_type latencies; + mutable arrow::internal::TDigest latencies; void Update(int64_t total_batches, int64_t total_records, int64_t total_bytes) { std::lock_guard lock(this->mutex); @@ -99,17 +88,15 @@ struct PerformanceStats { // A better approach may be calculate per-thread quantiles and merge. void AddLatency(uint64_t elapsed_nanos) { std::lock_guard lock(this->mutex); - latencies(elapsed_nanos); + latencies.Add(static_cast(elapsed_nanos)); } // ns -> us - uint64_t max_latency() const { return acc::max(latencies) / 1000; } + uint64_t max_latency() const { return latencies.Max() / 1000; } - uint64_t mean_latency() const { return acc::mean(latencies) / 1000; } + uint64_t mean_latency() const { return latencies.Mean() / 1000; } - uint64_t quantile_latency(double q) const { - return acc::quantile(latencies, acc::quantile_probability = q) / 1000; - } + uint64_t quantile_latency(double q) const { return latencies.Quantile(q) / 1000; } }; Status WaitForReady(FlightClient* client) { diff --git a/cpp/src/arrow/util/tdigest.cc b/cpp/src/arrow/util/tdigest.cc index 5550c98c9729a..b23bca397ec94 100644 --- a/cpp/src/arrow/util/tdigest.cc +++ b/cpp/src/arrow/util/tdigest.cc @@ -332,6 +332,14 @@ class TDigest::TDigestImpl { return Lerp(td[ci_left].mean, td[ci_right].mean, diff); } + double Mean() const { + double sum = 0; + for (const auto& centroid : tdigests_[current_]) { + sum += centroid.mean * centroid.weight; + } + return total_weight_ == 0 ? NAN : sum / total_weight_; + } + double total_weight() const { return total_weight_; } private: @@ -389,6 +397,11 @@ double TDigest::Quantile(double q) { return impl_->Quantile(q); } +double TDigest::Mean() { + MergeInput(); + return impl_->Mean(); +} + bool TDigest::is_empty() const { return input_.size() == 0 && impl_->total_weight() == 0; } diff --git a/cpp/src/arrow/util/tdigest.h b/cpp/src/arrow/util/tdigest.h index 88ca6bae2c567..ae42ce48e7d28 100644 --- a/cpp/src/arrow/util/tdigest.h +++ b/cpp/src/arrow/util/tdigest.h @@ -80,6 +80,10 @@ class ARROW_EXPORT TDigest { // calculate quantile double Quantile(double q); + double Min() { return Quantile(0); } + double Max() { return Quantile(1); } + double Mean(); + // check if this tdigest contains no valid data points bool is_empty() const; From 8e8a0009c44b96e1eb34f87e9fd631ac9309479d Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 24 Feb 2021 10:34:31 -0500 Subject: [PATCH 25/54] ARROW-10438: [C++][Dataset] Partitioning::Format on nulls Tested and added support for partitioning with nulls. I had to make some changes to the hash kernels. You can now specify how you want DictionaryEncode to treat nulls. The MASK option will continue the current behavior (null not in dictionary, null value in indices) and the ENCODE option will put `null` in the dictionary and there will be no null values in the indices array. Partitioning on nulls will depend on the partitioning scheme. For directory partitioning null is allowed on inner fields but it is not allowed on an outer field if an inner field is defined. In other words, if the schema is a(int32), b(int32), c(int32) then the following are allowed ``` / (a=null, b=null, c=null) /32 (a=32, b=null, c=null) /32/57 (a=32, b=57, c=null) ``` There is no way to specify `a=null, b=57, c=null`. This does mean that partition directories can contain a mix of files and nested partition directories (e.g. /32 might contain file.parquet and the directory /57). Alternatively we could just forbid nulls in the directory partitioning scheme. For the hive scheme we need to be compatible with other tools that read/write hive. Those tools use a fallback value which defaults to `__HIVE_DEFAULT_PARTITION__`. So by default you would have directories that look like... ``` /a=__HIVE_DEFAULT_PARTITION__/b=__HIVE_DEFAULT_PARTITION__/c=__HIVE_DEFAULT_PARTITION__ ``` The null fallback value is configurable as a string passed to HivePartitioning::HivePartitioning or HivePartitioning::MakeFactory. ARROW-11649 has been created for extending this null fallback configuration to R. Closes #9323 from westonpace/feature/arrow-10438 Lead-authored-by: Weston Pace Co-authored-by: Benjamin Kietzman Signed-off-by: Benjamin Kietzman --- cpp/src/arrow/compute/api_vector.cc | 5 +- cpp/src/arrow/compute/api_vector.h | 35 ++- cpp/src/arrow/compute/kernels/vector_hash.cc | 116 +++++++--- .../arrow/compute/kernels/vector_hash_test.cc | 65 ++++++ cpp/src/arrow/dataset/expression.cc | 26 ++- cpp/src/arrow/dataset/expression.h | 4 + cpp/src/arrow/dataset/expression_test.cc | 42 +++- cpp/src/arrow/dataset/partition.cc | 98 +++++--- cpp/src/arrow/dataset/partition.h | 23 +- cpp/src/arrow/dataset/partition_test.cc | 211 ++++++++++++++++-- cpp/src/arrow/dataset/projector.cc | 16 +- cpp/src/arrow/python/arrow_to_pandas.cc | 8 +- python/pyarrow/_compute.pyx | 26 +++ python/pyarrow/_dataset.pyx | 30 ++- python/pyarrow/array.pxi | 5 +- python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 14 ++ python/pyarrow/includes/libarrow_dataset.pxd | 7 +- python/pyarrow/public-api.pxi | 3 + python/pyarrow/table.pxi | 5 +- python/pyarrow/tests/test_dataset.py | 166 ++++++++++++-- 21 files changed, 783 insertions(+), 123 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index f5ab46ac603c3..0082d48112dc1 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -74,8 +74,9 @@ Result> Unique(const Datum& value, ExecContext* ctx) { return result.make_array(); } -Result DictionaryEncode(const Datum& value, ExecContext* ctx) { - return CallFunction("dictionary_encode", {value}, ctx); +Result DictionaryEncode(const Datum& value, const DictionaryEncodeOptions& options, + ExecContext* ctx) { + return CallFunction("dictionary_encode", {value}, &options, ctx); } const char kValuesFieldName[] = "values"; diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 9e9cad9e5d9bf..d67568e15671a 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -63,6 +63,24 @@ enum class SortOrder { Descending, }; +/// \brief Options for the dictionary encode function +struct DictionaryEncodeOptions : public FunctionOptions { + /// Configure how null values will be encoded + enum NullEncodingBehavior { + /// the null value will be added to the dictionary with a proper index + ENCODE, + /// the null value will be masked in the indices array + MASK + }; + + explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK) + : null_encoding_behavior(null_encoding) {} + + static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); } + + NullEncodingBehavior null_encoding_behavior = MASK; +}; + /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices struct ARROW_EXPORT SortKey { explicit SortKey(std::string name, SortOrder order = SortOrder::Ascending) @@ -289,14 +307,29 @@ Result> ValueCounts(const Datum& value, ExecContext* ctx = NULLPTR); /// \brief Dictionary-encode values in an array-like object +/// +/// Any nulls encountered in the dictionary will be handled according to the +/// specified null encoding behavior. +/// +/// For example, given values ["a", "b", null, "a", null] the output will be +/// (null_encoding == ENCODE) Indices: [0, 1, 2, 0, 2] / Dict: ["a", "b", null] +/// (null_encoding == MASK) Indices: [0, 1, null, 0, null] / Dict: ["a", "b"] +/// +/// If the input is already dictionary encoded this function is a no-op unless +/// it needs to modify the null_encoding (TODO) +/// /// \param[in] data array-like input /// \param[in] ctx the function execution context, optional +/// \param[in] options configures null encoding behavior /// \return result with same shape and type as input /// /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result DictionaryEncode(const Datum& data, ExecContext* ctx = NULLPTR); +Result DictionaryEncode( + const Datum& data, + const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(), + ExecContext* ctx = NULLPTR); // ---------------------------------------------------------------------- // Deprecated functions diff --git a/cpp/src/arrow/compute/kernels/vector_hash.cc b/cpp/src/arrow/compute/kernels/vector_hash.cc index 34d18c24a0c43..de4d3ee302280 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -58,7 +58,10 @@ class UniqueAction final : public ActionBase { using ActionBase::ActionBase; static constexpr bool with_error_status = false; - static constexpr bool with_memo_visit_null = true; + + UniqueAction(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : ActionBase(type, pool) {} Status Reset() { return Status::OK(); } @@ -76,6 +79,8 @@ class UniqueAction final : public ActionBase { template void ObserveNotFound(Index index) {} + bool ShouldEncodeNulls() { return true; } + Status Flush(Datum* out) { return Status::OK(); } Status FlushFinal(Datum* out) { return Status::OK(); } @@ -89,9 +94,9 @@ class ValueCountsAction final : ActionBase { using ActionBase::ActionBase; static constexpr bool with_error_status = true; - static constexpr bool with_memo_visit_null = true; - ValueCountsAction(const std::shared_ptr& type, MemoryPool* pool) + ValueCountsAction(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) : ActionBase(type, pool), count_builder_(pool) {} Status Reserve(const int64_t length) { @@ -147,6 +152,8 @@ class ValueCountsAction final : ActionBase { } } + bool ShouldEncodeNulls() const { return true; } + private: Int64Builder count_builder_; }; @@ -159,10 +166,14 @@ class DictEncodeAction final : public ActionBase { using ActionBase::ActionBase; static constexpr bool with_error_status = false; - static constexpr bool with_memo_visit_null = false; - DictEncodeAction(const std::shared_ptr& type, MemoryPool* pool) - : ActionBase(type, pool), indices_builder_(pool) {} + DictEncodeAction(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : ActionBase(type, pool), indices_builder_(pool) { + if (auto options_ptr = static_cast(options)) { + encode_options_ = *options_ptr; + } + } Status Reset() { indices_builder_.Reset(); @@ -173,12 +184,16 @@ class DictEncodeAction final : public ActionBase { template void ObserveNullFound(Index index) { - indices_builder_.UnsafeAppendNull(); + if (encode_options_.null_encoding_behavior == DictionaryEncodeOptions::MASK) { + indices_builder_.UnsafeAppendNull(); + } else { + indices_builder_.UnsafeAppend(index); + } } template void ObserveNullNotFound(Index index) { - indices_builder_.UnsafeAppendNull(); + ObserveNullFound(index); } template @@ -191,6 +206,10 @@ class DictEncodeAction final : public ActionBase { ObserveFound(index); } + bool ShouldEncodeNulls() { + return encode_options_.null_encoding_behavior == DictionaryEncodeOptions::ENCODE; + } + Status Flush(Datum* out) { std::shared_ptr result; RETURN_NOT_OK(indices_builder_.FinishInternal(&result)); @@ -202,10 +221,14 @@ class DictEncodeAction final : public ActionBase { private: Int32Builder indices_builder_; + DictionaryEncodeOptions encode_options_; }; class HashKernel : public KernelState { public: + HashKernel() : options_(nullptr) {} + explicit HashKernel(const FunctionOptions* options) : options_(options) {} + // Reset for another run. virtual Status Reset() = 0; @@ -229,6 +252,7 @@ class HashKernel : public KernelState { virtual Status Append(const ArrayData& arr) = 0; protected: + const FunctionOptions* options_; std::mutex lock_; }; @@ -237,12 +261,12 @@ class HashKernel : public KernelState { // (NullType has a separate implementation) template + bool with_error_status = Action::with_error_status> class RegularHashKernel : public HashKernel { public: - RegularHashKernel(const std::shared_ptr& type, MemoryPool* pool) - : pool_(pool), type_(type), action_(type, pool) {} + RegularHashKernel(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : HashKernel(options), pool_(pool), type_(type), action_(type, options, pool) {} Status Reset() override { memo_table_.reset(new MemoTable(pool_, 0)); @@ -282,7 +306,7 @@ class RegularHashKernel : public HashKernel { &unused_memo_index); }, [this]() { - if (with_memo_visit_null) { + if (action_.ShouldEncodeNulls()) { auto on_found = [this](int32_t memo_index) { action_.ObserveNullFound(memo_index); }; @@ -318,16 +342,14 @@ class RegularHashKernel : public HashKernel { [this]() { // Null Status s = Status::OK(); - if (with_memo_visit_null) { - auto on_found = [this](int32_t memo_index) { - action_.ObserveNullFound(memo_index); - }; - auto on_not_found = [this, &s](int32_t memo_index) { - action_.ObserveNullNotFound(memo_index, &s); - }; + auto on_found = [this](int32_t memo_index) { + action_.ObserveNullFound(memo_index); + }; + auto on_not_found = [this, &s](int32_t memo_index) { + action_.ObserveNullNotFound(memo_index, &s); + }; + if (action_.ShouldEncodeNulls()) { memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); - } else { - action_.ObserveNullNotFound(-1); } return s; }); @@ -345,18 +367,23 @@ class RegularHashKernel : public HashKernel { // ---------------------------------------------------------------------- // Hash kernel implementation for nulls -template +template class NullHashKernel : public HashKernel { public: - NullHashKernel(const std::shared_ptr& type, MemoryPool* pool) - : pool_(pool), type_(type), action_(type, pool) {} + NullHashKernel(const std::shared_ptr& type, const FunctionOptions* options, + MemoryPool* pool) + : pool_(pool), type_(type), action_(type, options, pool) {} Status Reset() override { return action_.Reset(); } - Status Append(const ArrayData& arr) override { + Status Append(const ArrayData& arr) override { return DoAppend(arr); } + + template + enable_if_t DoAppend(const ArrayData& arr) { RETURN_NOT_OK(action_.Reserve(arr.length)); for (int64_t i = 0; i < arr.length; ++i) { if (i == 0) { + seen_null_ = true; action_.ObserveNullNotFound(0); } else { action_.ObserveNullFound(0); @@ -365,12 +392,31 @@ class NullHashKernel : public HashKernel { return Status::OK(); } + template + enable_if_t DoAppend(const ArrayData& arr) { + Status s = Status::OK(); + RETURN_NOT_OK(action_.Reserve(arr.length)); + for (int64_t i = 0; i < arr.length; ++i) { + if (seen_null_ == false && i == 0) { + seen_null_ = true; + action_.ObserveNullNotFound(0, &s); + } else { + action_.ObserveNullFound(0); + } + } + return s; + } + Status Flush(Datum* out) override { return action_.Flush(out); } Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } Status GetDictionary(std::shared_ptr* out) override { - // TODO(wesm): handle null being a valid dictionary value - auto null_array = std::make_shared(0); + std::shared_ptr null_array; + if (seen_null_) { + null_array = std::make_shared(1); + } else { + null_array = std::make_shared(0); + } *out = null_array->data(); return Status::OK(); } @@ -380,6 +426,7 @@ class NullHashKernel : public HashKernel { protected: MemoryPool* pool_; std::shared_ptr type_; + bool seen_null_ = false; Action action_; }; @@ -451,8 +498,8 @@ struct HashKernelTraits> { template std::unique_ptr HashInitImpl(KernelContext* ctx, const KernelInitArgs& args) { using HashKernelType = typename HashKernelTraits::HashKernel; - auto result = ::arrow::internal::make_unique(args.inputs[0].type, - ctx->memory_pool()); + auto result = ::arrow::internal::make_unique( + args.inputs[0].type, args.options, ctx->memory_pool()); ctx->SetStatus(result->Reset()); return std::move(result); } @@ -507,6 +554,8 @@ KernelInit GetHashInit(Type::type type_id) { } } +using DictionaryEncodeState = OptionsWrapper; + template std::unique_ptr DictionaryHashInit(KernelContext* ctx, const KernelInitArgs& args) { @@ -639,9 +688,11 @@ const FunctionDoc value_counts_doc( "Nulls in the input are ignored."), {"array"}); +const auto kDefaultDictionaryEncodeOptions = DictionaryEncodeOptions::Defaults(); const FunctionDoc dictionary_encode_doc( "Dictionary-encode array", - ("Return a dictionary-encoded version of the input array."), {"array"}); + ("Return a dictionary-encoded version of the input array."), {"array"}, + "DictionaryEncodeOptions"); } // namespace @@ -691,7 +742,8 @@ void RegisterVectorHash(FunctionRegistry* registry) { // Unique and ValueCounts output unchunked arrays base.output_chunked = true; auto dict_encode = std::make_shared("dictionary_encode", Arity::Unary(), - &dictionary_encode_doc); + &dictionary_encode_doc, + &kDefaultDictionaryEncodeOptions); AddHashKernels(dict_encode.get(), base, OutputType(DictEncodeOutput)); // Calling dictionary_encode on dictionary input not supported, but if it diff --git a/cpp/src/arrow/compute/kernels/vector_hash_test.cc b/cpp/src/arrow/compute/kernels/vector_hash_test.cc index e9ae4a64d979e..179792e2141c4 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash_test.cc @@ -305,6 +305,11 @@ TEST_F(TestHashKernel, ValueCountsBoolean) { ArrayFromJSON(boolean(), "[false]"), ArrayFromJSON(int64(), "[2]")); } +TEST_F(TestHashKernel, ValueCountsNull) { + CheckValueCounts(ArrayFromJSON(null(), "[null, null, null]"), + ArrayFromJSON(null(), "[null]"), ArrayFromJSON(int64(), "[3]")); +} + TEST_F(TestHashKernel, DictEncodeBoolean) { CheckDictEncode(boolean(), {true, true, false, true, false}, {true, false, true, true, true}, {true, false}, {}, @@ -542,6 +547,12 @@ TEST_F(TestHashKernel, UniqueDecimal) { {true, false, true, true}, expected, {1, 0, 1}); } +TEST_F(TestHashKernel, UniqueNull) { + CheckUnique(null(), {nullptr, nullptr}, {false, true}, + {nullptr}, {false}); + CheckUnique(null(), {}, {}, {}, {}); +} + TEST_F(TestHashKernel, ValueCountsDecimal) { std::vector values{12, 12, 11, 12}; std::vector expected{12, 0, 11}; @@ -586,6 +597,33 @@ TEST_F(TestHashKernel, DictionaryUniqueAndValueCounts) { auto different_dictionaries = *ChunkedArray::Make({input, input2}); ASSERT_RAISES(Invalid, Unique(different_dictionaries)); ASSERT_RAISES(Invalid, ValueCounts(different_dictionaries)); + + // Dictionary with encoded nulls + auto dict_with_null = ArrayFromJSON(int64(), "[10, null, 30, 40]"); + input = std::make_shared(dict_ty, indices, dict_with_null); + ex_uniques = std::make_shared(dict_ty, ex_indices, dict_with_null); + CheckUnique(input, ex_uniques); + + CheckValueCounts(input, ex_uniques, ex_counts); + + // Dictionary with masked nulls + auto indices_with_null = + ArrayFromJSON(index_ty, "[3, 0, 0, 0, null, null, 3, 0, null, 3, 0, null]"); + auto ex_indices_with_null = ArrayFromJSON(index_ty, "[3, 0, null]"); + ex_uniques = std::make_shared(dict_ty, ex_indices_with_null, dict); + input = std::make_shared(dict_ty, indices_with_null, dict); + CheckUnique(input, ex_uniques); + + CheckValueCounts(input, ex_uniques, ex_counts); + + // Dictionary with encoded AND masked nulls + auto some_indices_with_null = + ArrayFromJSON(index_ty, "[3, 0, 0, 0, 1, 1, 3, 0, null, 3, 0, null]"); + ex_uniques = + std::make_shared(dict_ty, ex_indices_with_null, dict_with_null); + input = std::make_shared(dict_ty, indices_with_null, dict_with_null); + CheckUnique(input, ex_uniques); + CheckValueCounts(input, ex_uniques, ex_counts); } } @@ -656,6 +694,33 @@ TEST_F(TestHashKernel, ZeroLengthDictionaryEncode) { ASSERT_OK(dict_result.ValidateFull()); } +TEST_F(TestHashKernel, NullEncodingSchemes) { + auto values = ArrayFromJSON(uint8(), "[1, 1, null, 2, null]"); + + // Masking should put null in the indices array + auto expected_mask_indices = ArrayFromJSON(int32(), "[0, 0, null, 1, null]"); + auto expected_mask_dictionary = ArrayFromJSON(uint8(), "[1, 2]"); + auto dictionary_type = dictionary(int32(), uint8()); + std::shared_ptr expected = std::make_shared( + dictionary_type, expected_mask_indices, expected_mask_dictionary); + + ASSERT_OK_AND_ASSIGN(Datum datum_result, DictionaryEncode(values)); + std::shared_ptr result = datum_result.make_array(); + AssertArraysEqual(*expected, *result); + + // Encoding should put null in the dictionary + auto expected_encoded_indices = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]"); + auto expected_encoded_dict = ArrayFromJSON(uint8(), "[1, null, 2]"); + expected = std::make_shared(dictionary_type, expected_encoded_indices, + expected_encoded_dict); + + auto options = DictionaryEncodeOptions::Defaults(); + options.null_encoding_behavior = DictionaryEncodeOptions::ENCODE; + ASSERT_OK_AND_ASSIGN(datum_result, DictionaryEncode(values, options)); + result = datum_result.make_array(); + AssertArraysEqual(*expected, *result); +} + TEST_F(TestHashKernel, ChunkedArrayZeroChunk) { // ARROW-6857 auto chunked_array = std::make_shared(ArrayVector{}, utf8()); diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 56339430ee921..5ddb270451aa6 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -95,6 +95,8 @@ namespace { std::string PrintDatum(const Datum& datum) { if (datum.is_scalar()) { + if (!datum.scalar()->is_valid) return "null"; + switch (datum.type()->id()) { case Type::STRING: case Type::LARGE_STRING: @@ -110,6 +112,7 @@ std::string PrintDatum(const Datum& datum) { default: break; } + return datum.scalar()->ToString(); } return datum.ToString(); @@ -698,16 +701,25 @@ Status ExtractKnownFieldValuesImpl( return !(ref && lit); } + if (call->function_name == "is_null") { + auto ref = call->arguments[0].field_ref(); + return !ref; + } + return true; }); for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) { auto call = CallNotNull(*it); - auto ref = call->arguments[0].field_ref(); - auto lit = call->arguments[1].literal(); - - known_values->emplace(*ref, *lit); + if (call->function_name == "equal") { + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + known_values->emplace(*ref, *lit); + } else if (call->function_name == "is_null") { + auto ref = call->arguments[0].field_ref(); + known_values->emplace(*ref, Datum(std::make_shared())); + } } conjunction_members->erase(unconsumed_end, conjunction_members->end()); @@ -756,7 +768,7 @@ Result ReplaceFieldsWithKnownValues( DictionaryScalar::Make(std::move(index), std::move(dictionary))); } } - ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(it->second, expr.type())); + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type())); return literal(std::move(lit)); } } @@ -1222,6 +1234,10 @@ Expression greater_equal(Expression lhs, Expression rhs) { return call("greater_equal", {std::move(lhs), std::move(rhs)}); } +Expression is_null(Expression lhs) { return call("is_null", {std::move(lhs)}); } + +Expression is_valid(Expression lhs) { return call("is_valid", {std::move(lhs)}); } + Expression and_(Expression lhs, Expression rhs) { return call("and_kleene", {std::move(lhs), std::move(rhs)}); } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 13c714b2d72cd..8bdcb4a0ffa6f 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -236,6 +236,10 @@ ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs); ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression is_null(Expression lhs); + +ARROW_DS_EXPORT Expression is_valid(Expression lhs); + ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs); ARROW_DS_EXPORT Expression and_(const std::vector&); ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 2f0110255ec42..c837c5be8930d 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -240,6 +240,10 @@ TEST(Expression, Equality) { call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); } +Expression null_literal(const std::shared_ptr& type) { + return Expression(MakeNullScalar(type)); +} + TEST(Expression, Hash) { std::unordered_set set; @@ -250,6 +254,9 @@ TEST(Expression, Hash) { EXPECT_FALSE(set.emplace(literal(1)).second) << "already inserted"; EXPECT_TRUE(set.emplace(literal(3)).second); + EXPECT_TRUE(set.emplace(null_literal(int32())).second); + EXPECT_FALSE(set.emplace(null_literal(int32())).second) << "already inserted"; + EXPECT_TRUE(set.emplace(null_literal(float32())).second); // NB: no validation on construction; we couldn't execute // add with zero arguments EXPECT_TRUE(set.emplace(call("add", {})).second); @@ -258,7 +265,7 @@ TEST(Expression, Hash) { // NB: unbound expressions don't check for availability in any registry EXPECT_TRUE(set.emplace(call("widgetify", {})).second); - EXPECT_EQ(set.size(), 6); + EXPECT_EQ(set.size(), 8); } TEST(Expression, IsScalarExpression) { @@ -603,6 +610,8 @@ TEST(Expression, FoldConstants) { // call against literals (3 + 2 == 5) ExpectFoldsTo(call("add", {literal(3), literal(2)}), literal(5)); + ExpectFoldsTo(call("equal", {literal(3), literal(3)}), literal(true)); + // call against literal and field_ref ExpectFoldsTo(call("add", {literal(3), field_ref("i32")}), call("add", {literal(3), field_ref("i32")})); @@ -722,7 +731,7 @@ TEST(Expression, ExtractKnownFieldValues) { TEST(Expression, ReplaceFieldsWithKnownValues) { auto ExpectReplacesTo = [](Expression expr, - std::unordered_map known_values, + const std::unordered_map& known_values, Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); @@ -765,6 +774,19 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { }), literal(2), })); + + std::unordered_map i32_valid_str_null{ + {"i32", Datum(3)}, {"str", MakeNullScalar(utf8())}}; + + ExpectReplacesTo(is_null(field_ref("i32")), i32_valid_str_null, is_null(literal(3))); + + ExpectReplacesTo(is_valid(field_ref("i32")), i32_valid_str_null, is_valid(literal(3))); + + ExpectReplacesTo(is_null(field_ref("str")), i32_valid_str_null, + is_null(null_literal(utf8()))); + + ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null, + is_valid(null_literal(utf8()))); } struct { @@ -1013,6 +1035,22 @@ TEST(Expression, SimplifyWithGuarantee) { Simplify{greater(field_ref("dict_i32"), literal(int64_t(1)))} .WithGuarantee(equal(field_ref("dict_i32"), literal(0))) .Expect(false); + + Simplify{equal(field_ref("i32"), literal(7))} + .WithGuarantee(equal(field_ref("i32"), literal(7))) + .Expect(literal(true)); + + Simplify{equal(field_ref("i32"), literal(7))} + .WithGuarantee(not_(equal(field_ref("i32"), literal(7)))) + .Expect(equal(field_ref("i32"), literal(7))); + + Simplify{is_null(field_ref("i32"))} + .WithGuarantee(is_null(field_ref("i32"))) + .Expect(literal(true)); + + Simplify{is_valid(field_ref("i32"))} + .WithGuarantee(is_valid(field_ref("i32"))) + .Expect(is_valid(field_ref("i32"))); } TEST(Expression, SimplifyThenExecute) { diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index d6a3723d055fd..522dbbeb5d22b 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -92,7 +92,11 @@ inline Expression ConjunctionFromGroupingRow(Scalar* row) { std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); - equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); + if (values->at(i)->is_valid) { + equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); + } else { + equality_expressions[i] = is_null(field_ref(name)); + } } return and_(std::move(equality_expressions)); } @@ -147,7 +151,9 @@ Result KeyValuePartitioning::ConvertKey(const Key& key) const { std::shared_ptr converted; - if (field->type()->id() == Type::DICTIONARY) { + if (!key.value.has_value()) { + return is_null(field_ref(field->name())); + } else if (field->type()->id() == Type::DICTIONARY) { if (dictionaries_.empty() || dictionaries_[field_index] == nullptr) { return Status::Invalid("No dictionary provided for dictionary field ", field->ToString()); @@ -164,16 +170,16 @@ Result KeyValuePartitioning::ConvertKey(const Key& key) const { } // look up the partition value in the dictionary - ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(value.dictionary->type(), key.value)); + ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(value.dictionary->type(), *key.value)); ARROW_ASSIGN_OR_RAISE(auto index, compute::IndexIn(converted, value.dictionary)); value.index = index.scalar(); if (!value.index->is_valid) { return Status::Invalid("Dictionary supplied for field ", field->ToString(), - " does not contain '", key.value, "'"); + " does not contain '", *key.value, "'"); } converted = std::make_shared(std::move(value), field->type()); } else { - ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), key.value)); + ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), *key.value)); } return equal(field_ref(field->name()), literal(std::move(converted))); @@ -207,8 +213,18 @@ Result KeyValuePartitioning::Format(const Expression& expr) const { const auto& field = schema_->field(match[0]); if (!value->type->Equals(field->type())) { - return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, - ") is invalid for ", field->ToString()); + if (value->is_valid) { + auto maybe_converted = compute::Cast(value, field->type()); + if (!maybe_converted.ok()) { + return Status::TypeError("Error converting scalar ", value->ToString(), + " (of type ", *value->type, + ") to a partition key for ", field->ToString(), ": ", + maybe_converted.status().message()); + } + value = maybe_converted->scalar(); + } else { + value = MakeNullScalar(field->type()); + } } if (value->type->id() == Type::DICTIONARY) { @@ -252,7 +268,7 @@ Result DirectoryPartitioning::FormatValues( std::vector segments(static_cast(schema_->num_fields())); for (int i = 0; i < schema_->num_fields(); ++i) { - if (values[i] != nullptr) { + if (values[i] != nullptr && values[i]->is_valid) { segments[i] = values[i]->ToString(); continue; } @@ -287,8 +303,13 @@ class KeyValuePartitioningFactory : public PartitioningFactory { return it_inserted.first->second; } - Status InsertRepr(const std::string& name, util::string_view repr) { - return InsertRepr(GetOrInsertField(name), repr); + Status InsertRepr(const std::string& name, util::optional repr) { + auto field_index = GetOrInsertField(name); + if (repr.has_value()) { + return InsertRepr(field_index, *repr); + } else { + return Status::OK(); + } } Status InsertRepr(int index, util::string_view repr) { @@ -309,7 +330,7 @@ class KeyValuePartitioningFactory : public PartitioningFactory { RETURN_NOT_OK(repr_memos_[index]->GetArrayData(0, &reprs)); if (reprs->length == 0) { - return Status::Invalid("No segments were available for field '", name, + return Status::Invalid("No non-null segments were available for field '", name, "'; couldn't infer type"); } @@ -410,13 +431,19 @@ std::shared_ptr DirectoryPartitioning::MakeFactory( } util::optional HivePartitioning::ParseKey( - const std::string& segment) { + const std::string& segment, const std::string& null_fallback) { auto name_end = string_view(segment).find_first_of('='); + // Not round-trippable if (name_end == string_view::npos) { return util::nullopt; } - return Key{segment.substr(0, name_end), segment.substr(name_end + 1)}; + auto name = segment.substr(0, name_end); + auto value = segment.substr(name_end + 1); + if (value == null_fallback) { + return Key{name, util::nullopt}; + } + return Key{name, value}; } std::vector HivePartitioning::ParseKeys( @@ -424,7 +451,7 @@ std::vector HivePartitioning::ParseKeys( std::vector keys; for (const auto& segment : fs::internal::SplitAbstractPath(path)) { - if (auto key = ParseKey(segment)) { + if (auto key = ParseKey(segment, null_fallback_)) { keys.push_back(std::move(*key)); } } @@ -439,11 +466,11 @@ Result HivePartitioning::FormatValues(const ScalarVector& values) c const std::string& name = schema_->field(i)->name(); if (values[i] == nullptr) { - if (!NextValid(values, i)) break; - + segments[i] = ""; + } else if (!values[i]->is_valid) { // If no key is available just provide a placeholder segment to maintain the // field_index <-> path nesting relation - segments[i] = name; + segments[i] = name + "=" + null_fallback_; } else { segments[i] = name + "=" + values[i]->ToString(); } @@ -454,8 +481,8 @@ Result HivePartitioning::FormatValues(const ScalarVector& values) c class HivePartitioningFactory : public KeyValuePartitioningFactory { public: - explicit HivePartitioningFactory(PartitioningFactoryOptions options) - : KeyValuePartitioningFactory(options) {} + explicit HivePartitioningFactory(HivePartitioningFactoryOptions options) + : KeyValuePartitioningFactory(options), null_fallback_(options.null_fallback) {} std::string type_name() const override { return "hive"; } @@ -463,7 +490,7 @@ class HivePartitioningFactory : public KeyValuePartitioningFactory { const std::vector& paths) override { for (auto path : paths) { for (auto&& segment : fs::internal::SplitAbstractPath(path)) { - if (auto key = HivePartitioning::ParseKey(segment)) { + if (auto key = HivePartitioning::ParseKey(segment, null_fallback_)) { RETURN_NOT_OK(InsertRepr(key->name, key->value)); } } @@ -486,16 +513,18 @@ class HivePartitioningFactory : public KeyValuePartitioningFactory { // drop fields which aren't in field_names_ auto out_schema = SchemaFromColumnNames(schema, field_names_); - return std::make_shared(std::move(out_schema), dictionaries_); + return std::make_shared(std::move(out_schema), dictionaries_, + null_fallback_); } } private: + const std::string null_fallback_; std::vector field_names_; }; std::shared_ptr HivePartitioning::MakeFactory( - PartitioningFactoryOptions options) { + HivePartitioningFactoryOptions options) { return std::shared_ptr(new HivePartitioningFactory(options)); } @@ -578,10 +607,6 @@ class StructDictionary { Encoded out{nullptr, std::make_shared()}; for (const auto& column : columns) { - if (column->null_count() != 0) { - return Status::NotImplemented("Grouping on a field with nulls"); - } - RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); } @@ -625,8 +650,27 @@ class StructDictionary { private: Status AddOne(Datum column, std::shared_ptr* fused_indices) { + if (column.type()->id() == Type::DICTIONARY) { + if (column.null_count() != 0) { + // TODO(ARROW-11732) Optimize this by allowign DictionaryEncode to transfer a + // null-masked dictionary to a null-encoded dictionary. At the moment we decode + // and then encode causing one extra copy, and a potentially expansive decoding + // copy at that. + ARROW_ASSIGN_OR_RAISE( + auto decoded_dictionary, + compute::Cast( + column, + std::static_pointer_cast(column.type())->value_type(), + compute::CastOptions())); + column = decoded_dictionary; + } + } if (column.type()->id() != Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(std::move(column))); + compute::DictionaryEncodeOptions options; + options.null_encoding_behavior = + compute::DictionaryEncodeOptions::NullEncodingBehavior::ENCODE; + ARROW_ASSIGN_OR_RAISE(column, + compute::DictionaryEncode(std::move(column), options)); } auto dict_column = column.array_as(); diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 944434e64f772..42e1b4c409764 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -92,6 +92,11 @@ struct PartitioningFactoryOptions { bool infer_dictionary = false; }; +struct HivePartitioningFactoryOptions : PartitioningFactoryOptions { + /// The hive partitioning scheme maps null to a hard coded fallback string. + std::string null_fallback; +}; + /// \brief PartitioningFactory provides creation of a partitioning when the /// specific schema must be inferred from available paths (no explicit schema is known). class ARROW_DS_EXPORT PartitioningFactory { @@ -119,7 +124,8 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { /// An unconverted equality expression consisting of a field name and the representation /// of a scalar value struct Key { - std::string name, value; + std::string name; + util::optional value; }; static Status SetDefaultValuesFromKeys(const Expression& expr, @@ -175,6 +181,8 @@ class ARROW_DS_EXPORT DirectoryPartitioning : public KeyValuePartitioning { Result FormatValues(const ScalarVector& values) const override; }; +static constexpr char kDefaultHiveNullFallback[] = "__HIVE_DEFAULT_PARTITION__"; + /// \brief Multi-level, directory based partitioning /// originating from Apache Hive with all data files stored in the /// leaf directories. Data is partitioned by static values of a @@ -188,17 +196,22 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { public: // If a field in schema is of dictionary type, the corresponding element of dictionaries // must be contain the dictionary of values for that field. - explicit HivePartitioning(std::shared_ptr schema, ArrayVector dictionaries = {}) - : KeyValuePartitioning(std::move(schema), std::move(dictionaries)) {} + explicit HivePartitioning(std::shared_ptr schema, ArrayVector dictionaries = {}, + std::string null_fallback = kDefaultHiveNullFallback) + : KeyValuePartitioning(std::move(schema), std::move(dictionaries)), + null_fallback_(null_fallback) {} std::string type_name() const override { return "hive"; } + std::string null_fallback() const { return null_fallback_; } - static util::optional ParseKey(const std::string& segment); + static util::optional ParseKey(const std::string& segment, + const std::string& null_fallback); static std::shared_ptr MakeFactory( - PartitioningFactoryOptions = {}); + HivePartitioningFactoryOptions = {}); private: + const std::string null_fallback_; std::vector ParseKeys(const std::string& path) const override; Result FormatValues(const ScalarVector& values) const override; diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 286848d9ae920..75e60f994f047 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -27,6 +27,7 @@ #include #include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/path_util.h" @@ -77,6 +78,39 @@ class TestPartitioning : public ::testing::Test { ASSERT_OK_AND_ASSIGN(partitioning_, factory_->Finish(actual)); } + void AssertPartition(const std::shared_ptr partitioning, + const std::shared_ptr full_batch, + const RecordBatchVector& expected_batches, + const std::vector& expected_expressions) { + ASSERT_OK_AND_ASSIGN(auto partition_results, partitioning->Partition(full_batch)); + std::shared_ptr rest = full_batch; + ASSERT_EQ(partition_results.batches.size(), expected_batches.size()); + auto max_index = std::min(partition_results.batches.size(), expected_batches.size()); + for (std::size_t partition_index = 0; partition_index < max_index; + partition_index++) { + std::shared_ptr actual_batch = + partition_results.batches[partition_index]; + AssertBatchesEqual(*expected_batches[partition_index], *actual_batch); + Expression actual_expression = partition_results.expressions[partition_index]; + ASSERT_EQ(expected_expressions[partition_index], actual_expression); + } + } + + void AssertPartition(const std::shared_ptr partitioning, + const std::shared_ptr schema, + const std::string& record_batch_json, + const std::shared_ptr partitioned_schema, + const std::vector& expected_record_batch_strs, + const std::vector& expected_expressions) { + auto record_batch = RecordBatchFromJSON(schema, record_batch_json); + RecordBatchVector expected_batches; + for (const auto& expected_record_batch_str : expected_record_batch_strs) { + expected_batches.push_back( + RecordBatchFromJSON(partitioned_schema, expected_record_batch_str)); + } + AssertPartition(partitioning, record_batch, expected_batches, expected_expressions); + } + void AssertInspectError(const std::vector& paths) { ASSERT_RAISES(Invalid, factory_->Inspect(paths)); } @@ -103,6 +137,30 @@ class TestPartitioning : public ::testing::Test { std::shared_ptr written_schema_; }; +TEST_F(TestPartitioning, Partition) { + auto partition_schema = schema({field("a", int32()), field("b", utf8())}); + auto schema_ = schema({field("a", int32()), field("b", utf8()), field("c", uint32())}); + auto remaining_schema = schema({field("c", uint32())}); + auto partitioning = std::make_shared(partition_schema); + std::string json = R"([{"a": 3, "b": "x", "c": 0}, + {"a": 3, "b": "x", "c": 1}, + {"a": 1, "b": null, "c": 2}, + {"a": null, "b": null, "c": 3}, + {"a": null, "b": "z", "c": 4}, + {"a": null, "b": null, "c": 5} + ])"; + std::vector expected_batches = {R"([{"c": 0}, {"c": 1}])", R"([{"c": 2}])", + R"([{"c": 3}, {"c": 5}])", + R"([{"c": 4}])"}; + std::vector expected_expressions = { + and_(equal(field_ref("a"), literal(3)), equal(field_ref("b"), literal("x"))), + and_(equal(field_ref("a"), literal(1)), is_null(field_ref("b"))), + and_(is_null(field_ref("a")), is_null(field_ref("b"))), + and_(is_null(field_ref("a")), equal(field_ref("b"), literal("z")))}; + AssertPartition(partitioning, schema_, json, remaining_schema, expected_batches, + expected_expressions); +} + TEST_F(TestPartitioning, DirectoryPartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); @@ -136,6 +194,10 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormat) { equal(field_ref("alpha"), literal(0))), "0/hello"); AssertFormat(equal(field_ref("alpha"), literal(0)), "0"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), is_null(field_ref("beta"))), + "0"); + AssertFormatError( + and_(is_null(field_ref("alpha")), equal(field_ref("beta"), literal("hello")))); AssertFormatError(equal(field_ref("beta"), literal("hello"))); AssertFormat(literal(true), ""); @@ -209,6 +271,8 @@ TEST_F(TestPartitioning, DictionaryInference) { // successful dictionary inference AssertInspect({"/a/0"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/a/0", "/a/1"}, {DictStr("alpha"), DictInt("beta")}); + AssertInspect({"/a/0", "/a"}, {DictStr("alpha"), DictInt("beta")}); + AssertInspect({"/0/a", "/1"}, {DictInt("alpha"), DictStr("beta")}); AssertInspect({"/a/0", "/b/0", "/a/1", "/b/1"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/a/-", "/b/-", "/a/_", "/b/_"}, {DictStr("alpha"), DictStr("beta")}); } @@ -246,13 +310,15 @@ TEST_F(TestPartitioning, DiscoverSchemaSegfault) { TEST_F(TestPartitioning, HivePartitioning) { partitioning_ = std::make_shared( - schema({field("alpha", int32()), field("beta", float32())})); + schema({field("alpha", int32()), field("beta", float32())}), ArrayVector(), "xyz"); AssertParse("/alpha=0/beta=3.25", and_(equal(field_ref("alpha"), literal(0)), equal(field_ref("beta"), literal(3.25f)))); AssertParse("/beta=3.25/alpha=0", and_(equal(field_ref("beta"), literal(3.25f)), equal(field_ref("alpha"), literal(0)))); AssertParse("/alpha=0", equal(field_ref("alpha"), literal(0))); + AssertParse("/alpha=xyz/beta=3.25", and_(is_null(field_ref("alpha")), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/beta=3.25", equal(field_ref("beta"), literal(3.25f))); AssertParse("", literal(true)); @@ -271,7 +337,7 @@ TEST_F(TestPartitioning, HivePartitioning) { TEST_F(TestPartitioning, HivePartitioningFormat) { partitioning_ = std::make_shared( - schema({field("alpha", int32()), field("beta", float32())})); + schema({field("alpha", int32()), field("beta", float32())}), ArrayVector(), "xyz"); written_schema_ = partitioning_->schema(); @@ -282,9 +348,16 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { equal(field_ref("alpha"), literal(0))), "alpha=0/beta=3.25"); AssertFormat(equal(field_ref("alpha"), literal(0)), "alpha=0"); - AssertFormat(equal(field_ref("beta"), literal(3.25f)), "alpha/beta=3.25"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), is_null(field_ref("beta"))), + "alpha=0/beta=xyz"); + AssertFormat( + and_(is_null(field_ref("alpha")), equal(field_ref("beta"), literal(3.25f))), + "alpha=xyz/beta=3.25"); AssertFormat(literal(true), ""); + AssertFormat(and_(is_null(field_ref("alpha")), is_null(field_ref("beta"))), + "alpha=xyz/beta=xyz"); + ASSERT_OK_AND_ASSIGN(written_schema_, written_schema_->AddField(0, field("gamma", utf8()))); AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), @@ -300,7 +373,9 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { } TEST_F(TestPartitioning, DiscoverHiveSchema) { - factory_ = HivePartitioning::MakeFactory(); + auto options = HivePartitioningFactoryOptions(); + options.null_fallback = "xyz"; + factory_ = HivePartitioning::MakeFactory(options); // type is int32 if possible AssertInspect({"/alpha=0/beta=1"}, {Int("alpha"), Int("beta")}); @@ -313,6 +388,12 @@ TEST_F(TestPartitioning, DiscoverHiveSchema) { // (...so ensure your partitions are ordered the same for all paths) AssertInspect({"/alpha=0/beta=1", "/beta=2/alpha=3"}, {Int("alpha"), Int("beta")}); + // Null fallback strings shouldn't interfere with type inference + AssertInspect({"/alpha=xyz/beta=x", "/alpha=7/beta=xyz"}, {Int("alpha"), Str("beta")}); + + // Cannot infer if the only values are null + AssertInspectError({"/alpha=xyz"}); + // If there are too many digits fall back to string AssertInspect({"/alpha=3760212050"}, {Str("alpha")}); @@ -322,8 +403,9 @@ TEST_F(TestPartitioning, DiscoverHiveSchema) { } TEST_F(TestPartitioning, HiveDictionaryInference) { - PartitioningFactoryOptions options; + HivePartitioningFactoryOptions options; options.infer_dictionary = true; + options.null_fallback = "xyz"; factory_ = HivePartitioning::MakeFactory(options); // type is still int32 if possible @@ -335,6 +417,8 @@ TEST_F(TestPartitioning, HiveDictionaryInference) { // successful dictionary inference AssertInspect({"/alpha=a/beta=0"}, {DictStr("alpha"), DictInt("beta")}); AssertInspect({"/alpha=a/beta=0", "/alpha=a/1"}, {DictStr("alpha"), DictInt("beta")}); + AssertInspect({"/alpha=a/beta=0", "/alpha=xyz/beta=xyz"}, + {DictStr("alpha"), DictInt("beta")}); AssertInspect( {"/alpha=a/beta=0", "/alpha=b/beta=0", "/alpha=a/beta=1", "/alpha=b/beta=1"}, {DictStr("alpha"), DictInt("beta")}); @@ -343,8 +427,19 @@ TEST_F(TestPartitioning, HiveDictionaryInference) { {DictStr("alpha"), DictStr("beta")}); } +TEST_F(TestPartitioning, HiveNullFallbackPassedOn) { + HivePartitioningFactoryOptions options; + options.null_fallback = "xyz"; + factory_ = HivePartitioning::MakeFactory(options); + + EXPECT_OK_AND_ASSIGN(auto schema, factory_->Inspect({"/alpha=a/beta=0"})); + EXPECT_OK_AND_ASSIGN(auto partitioning, factory_->Finish(schema)); + ASSERT_EQ("xyz", + std::static_pointer_cast(partitioning)->null_fallback()); +} + TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { - PartitioningFactoryOptions options; + HivePartitioningFactoryOptions options; options.infer_dictionary = true; factory_ = HivePartitioning::MakeFactory(options); @@ -369,6 +464,55 @@ TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { AssertParseError("/alpha=yosemite"); // not in inspected dictionary } +TEST_F(TestPartitioning, SetDefaultValuesConcrete) { + auto small_schm = schema({field("c", int32())}); + auto schm = schema({field("a", int32()), field("b", utf8())}); + auto full_schm = schema({field("a", int32()), field("b", utf8()), field("c", int32())}); + RecordBatchProjector record_batch_projector(full_schm); + HivePartitioning part(schm); + ARROW_EXPECT_OK(part.SetDefaultValuesFromKeys( + and_(equal(field_ref("a"), literal(10)), is_valid(field_ref("b"))), + &record_batch_projector)); + + auto in_rb = RecordBatchFromJSON(small_schm, R"([{"c": 0}, + {"c": 1}, + {"c": 2}, + {"c": 3} + ])"); + + EXPECT_OK_AND_ASSIGN(auto out_rb, record_batch_projector.Project(*in_rb)); + auto expected_rb = RecordBatchFromJSON(full_schm, R"([{"a": 10, "b": null, "c": 0}, + {"a": 10, "b": null, "c": 1}, + {"a": 10, "b": null, "c": 2}, + {"a": 10, "b": null, "c": 3} + ])"); + AssertBatchesEqual(*expected_rb, *out_rb); +} + +TEST_F(TestPartitioning, SetDefaultValuesNull) { + auto small_schm = schema({field("c", int32())}); + auto schm = schema({field("a", int32()), field("b", utf8())}); + auto full_schm = schema({field("a", int32()), field("b", utf8()), field("c", int32())}); + RecordBatchProjector record_batch_projector(full_schm); + HivePartitioning part(schm); + ARROW_EXPECT_OK(part.SetDefaultValuesFromKeys( + and_(is_null(field_ref("a")), is_null(field_ref("b"))), &record_batch_projector)); + + auto in_rb = RecordBatchFromJSON(small_schm, R"([{"c": 0}, + {"c": 1}, + {"c": 2}, + {"c": 3} + ])"); + + EXPECT_OK_AND_ASSIGN(auto out_rb, record_batch_projector.Project(*in_rb)); + auto expected_rb = RecordBatchFromJSON(full_schm, R"([{"a": null, "b": null, "c": 0}, + {"a": null, "b": null, "c": 1}, + {"a": null, "b": null, "c": 2}, + {"a": null, "b": null, "c": 3} + ])"); + AssertBatchesEqual(*expected_rb, *out_rb); +} + TEST_F(TestPartitioning, EtlThenHive) { FieldVector etl_fields{field("year", int16()), field("month", int8()), field("day", int8()), field("hour", int8())}; @@ -467,13 +611,13 @@ class RangePartitioning : public Partitioning { std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { - auto key = HivePartitioning::ParseKey(segment); + auto key = HivePartitioning::ParseKey(segment, ""); if (!key) { return Status::Invalid("can't parse '", segment, "' as a range"); } std::smatch matches; - RETURN_NOT_OK(DoRegex(key->value, &matches)); + RETURN_NOT_OK(DoRegex(*key->value, &matches)); auto& min_cmp = matches[1] == "[" ? greater_equal : greater; std::string min_repr = matches[2]; @@ -600,20 +744,45 @@ TEST(GroupTest, Basics) { } TEST(GroupTest, WithNulls) { - auto has_nulls = checked_pointer_cast( - ArrayFromJSON(struct_({field("a", utf8()), field("b", int32())}), R"([ - {"a": "ex", "b": 0}, - {"a": null, "b": 0}, - {"a": "why", "b": 0}, - {"a": "ex", "b": 1}, - {"a": "why", "b": 0}, - {"a": "ex", "b": 1}, - {"a": "ex", "b": 0}, - {"a": "why", "b": null} - ])")); - ASSERT_RAISES(NotImplemented, MakeGroupings(*has_nulls)); + AssertGrouping({field("a", utf8()), field("b", int32())}, + R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": null, "b": 0, "id": 1}, + {"a": null, "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": null, "b": null, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": null, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 6]}, + {"a": null, "b": 0, "ids": [1, 2]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": null, "b": null, "ids": [4]}, + {"a": "why", "b": null, "ids": [7]} + ])"); - has_nulls = checked_pointer_cast( + AssertGrouping({field("a", dictionary(int32(), utf8())), field("b", int32())}, + R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": null, "b": 0, "id": 1}, + {"a": null, "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": null, "b": null, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": null, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 6]}, + {"a": null, "b": 0, "ids": [1, 2]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": null, "b": null, "ids": [4]}, + {"a": "why", "b": null, "ids": [7]} + ])"); + + auto has_nulls = checked_pointer_cast( ArrayFromJSON(struct_({field("a", utf8()), field("b", int32())}), R"([ {"a": "ex", "b": 0}, null, diff --git a/cpp/src/arrow/dataset/projector.cc b/cpp/src/arrow/dataset/projector.cc index 2ba679ce6e7ae..ba0eb2ddff552 100644 --- a/cpp/src/arrow/dataset/projector.cc +++ b/cpp/src/arrow/dataset/projector.cc @@ -23,6 +23,7 @@ #include #include "arrow/array.h" +#include "arrow/compute/cast.h" #include "arrow/dataset/type_fwd.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -88,9 +89,18 @@ Status RecordBatchProjector::SetDefaultValue(FieldRef ref, auto field_type = to_->field(index)->type(); if (!field_type->Equals(scalar->type)) { - return Status::TypeError("field ", to_->field(index)->ToString(), - " cannot be materialized from scalar of type ", - *scalar->type); + if (scalar->is_valid) { + auto maybe_converted = compute::Cast(scalar, field_type); + if (!maybe_converted.ok()) { + return Status::TypeError("Field ", to_->field(index)->ToString(), + " cannot be materialized from scalar of type ", + *scalar->type, + ". Cast error: ", maybe_converted.status().message()); + } + scalar = maybe_converted->scalar(); + } else { + scalar = MakeNullScalar(field_type); + } } scalars_[index] = std::move(scalar); diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index 092452850301f..1c47f9742de31 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -2183,7 +2183,9 @@ Status ConvertCategoricals(const PandasOptions& options, ChunkedArrayVector* arr "only zero-copy conversions allowed"); } compute::ExecContext ctx(options.pool); - ARROW_ASSIGN_OR_RAISE(Datum out, DictionaryEncode((*arrays)[i], &ctx)); + ARROW_ASSIGN_OR_RAISE( + Datum out, DictionaryEncode((*arrays)[i], + compute::DictionaryEncodeOptions::Defaults(), &ctx)); (*arrays)[i] = out.chunked_array(); (*fields)[i] = (*fields)[i]->WithType((*arrays)[i]->type()); return Status::OK(); @@ -2232,7 +2234,9 @@ Status ConvertChunkedArrayToPandas(const PandasOptions& options, "only zero-copy conversions allowed"); } compute::ExecContext ctx(options.pool); - ARROW_ASSIGN_OR_RAISE(Datum out, DictionaryEncode(arr, &ctx)); + ARROW_ASSIGN_OR_RAISE( + Datum out, + DictionaryEncode(arr, compute::DictionaryEncodeOptions::Defaults(), &ctx)); arr = out.chunked_array(); } diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index e5a19288b876c..3cb152aa381eb 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -649,6 +649,32 @@ class FilterOptions(_FilterOptions): self._set_options(null_selection_behavior) +cdef class _DictionaryEncodeOptions(FunctionOptions): + cdef: + unique_ptr[CDictionaryEncodeOptions] dictionary_encode_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.dictionary_encode_options.get() + + def _set_options(self, null_encoding_behavior): + if null_encoding_behavior == 'encode': + self.dictionary_encode_options.reset( + new CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior_ENCODE)) + elif null_encoding_behavior == 'mask': + self.dictionary_encode_options.reset( + new CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior_MASK)) + else: + raise ValueError('"{}" is not a valid null_encoding_behavior' + .format(null_encoding_behavior)) + + +class DictionaryEncodeOptions(_DictionaryEncodeOptions): + def __init__(self, null_encoding_behavior='mask'): + self._set_options(null_encoding_behavior) + + cdef class _TakeOptions(FunctionOptions): cdef: unique_ptr[CTakeOptions] take_options diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index c67dbc99d77bc..1c4e5d302c570 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -206,6 +206,10 @@ cdef class Expression(_Weakrefable): """Checks whether the expression is not-null (valid)""" return Expression._call("is_valid", [self]) + def is_null(self): + """Checks whether the expression is null""" + return Expression._call("is_null", [self]) + def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" cdef shared_ptr[CCastOptions] c_options @@ -1546,7 +1550,7 @@ cdef class DirectoryPartitioning(Partitioning): Returns ------- - DirectoryPartitioningFactory + PartitioningFactory To be used in the FileSystemFactoryOptions. """ cdef: @@ -1590,6 +1594,8 @@ cdef class HivePartitioning(Partitioning): corresponding entry of `dictionaries` must be an array containing every value which may be taken by the corresponding column or an error will be raised in parsing. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + If any field is None then this fallback will be used as a label Returns ------- @@ -1608,13 +1614,19 @@ cdef class HivePartitioning(Partitioning): cdef: CHivePartitioning* hive_partitioning - def __init__(self, Schema schema not None, dictionaries=None): + def __init__(self, + Schema schema not None, + dictionaries=None, + null_fallback="__HIVE_DEFAULT_PARTITION__"): + cdef: shared_ptr[CHivePartitioning] c_partitioning + c_string c_null_fallback = tobytes(null_fallback) c_partitioning = make_shared[CHivePartitioning]( pyarrow_unwrap_schema(schema), - _partitioning_dictionaries(schema, dictionaries) + _partitioning_dictionaries(schema, dictionaries), + c_null_fallback ) self.init( c_partitioning) @@ -1623,7 +1635,9 @@ cdef class HivePartitioning(Partitioning): self.hive_partitioning = sp.get() @staticmethod - def discover(infer_dictionary=False, max_partition_dictionary_size=0): + def discover(infer_dictionary=False, + max_partition_dictionary_size=0, + null_fallback="__HIVE_DEFAULT_PARTITION__"): """ Discover a HivePartitioning. @@ -1639,6 +1653,10 @@ cdef class HivePartitioning(Partitioning): Synonymous with infer_dictionary for backwards compatibility with 1.0: setting this to -1 or None is equivalent to passing infer_dictionary=True. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + When inferring a schema for partition fields this value will be + replaced by null. The default is set to __HIVE_DEFAULT_PARTITION__ + for compatibility with Spark Returns ------- @@ -1646,7 +1664,7 @@ cdef class HivePartitioning(Partitioning): To be used in the FileSystemFactoryOptions. """ cdef: - CPartitioningFactoryOptions c_options + CHivePartitioningFactoryOptions c_options if max_partition_dictionary_size in {-1, None}: infer_dictionary = True @@ -1657,6 +1675,8 @@ cdef class HivePartitioning(Partitioning): if infer_dictionary: c_options.infer_dictionary = True + c_options.null_fallback = tobytes(null_fallback) + return PartitioningFactory.wrap( CHivePartitioning.MakeFactory(c_options)) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index ae9e213b98dd1..a832b00b1ebde 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -842,11 +842,12 @@ cdef class Array(_PandasConvertible): """ return _pc().call_function('unique', [self]) - def dictionary_encode(self): + def dictionary_encode(self, null_encoding='mask'): """ Compute dictionary-encoded representation of array. """ - return _pc().call_function('dictionary_encode', [self]) + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) def value_counts(self): """ diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 616b2de89ec24..3d7f5ecb4c3d1 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -30,6 +30,7 @@ ArraySortOptions, CastOptions, CountOptions, + DictionaryEncodeOptions, FilterOptions, MatchSubstringOptions, MinMaxOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e10ef1e3a5e7c..ba3c3ad7d2b66 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1802,6 +1802,20 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CFilterOptions(CFilterNullSelectionBehavior null_selection) CFilterNullSelectionBehavior null_selection_behavior + enum CDictionaryEncodeNullEncodingBehavior \ + "arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior": + CDictionaryEncodeNullEncodingBehavior_ENCODE \ + "arrow::compute::DictionaryEncodeOptions::ENCODE" + CDictionaryEncodeNullEncodingBehavior_MASK \ + "arrow::compute::DictionaryEncodeOptions::MASK" + + cdef cppclass CDictionaryEncodeOptions \ + "arrow::compute::DictionaryEncodeOptions"(CFunctionOptions): + CDictionaryEncodeOptions() + CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior null_encoding) + CDictionaryEncodeNullEncodingBehavior null_encoding + cdef cppclass CTakeOptions \ " arrow::compute::TakeOptions"(CFunctionOptions): CTakeOptions(c_bool boundscheck) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 29f9738dedc6d..93bc0edddc1fe 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -274,6 +274,11 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: "arrow::dataset::PartitioningFactoryOptions": c_bool infer_dictionary + cdef cppclass CHivePartitioningFactoryOptions \ + "arrow::dataset::HivePartitioningFactoryOptions": + c_bool infer_dictionary, + c_string null_fallback + cdef cppclass CPartitioningFactory "arrow::dataset::PartitioningFactory": pass @@ -293,7 +298,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: @staticmethod shared_ptr[CPartitioningFactory] MakeFactory( - CPartitioningFactoryOptions) + CHivePartitioningFactoryOptions) cdef cppclass CPartitioningOrFactory \ "arrow::dataset::PartitioningOrFactory": diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index aa738f9aaea2d..998af512c55aa 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -251,6 +251,9 @@ cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar): if data_type == NULL: raise ValueError('Scalar data type was NULL') + if data_type.id() == _Type_NA: + return _NULL + if data_type.id() not in _scalar_classes: raise ValueError('Scalar type not supported') diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index c6b0b4180b6b0..3f1fc28ee6046 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -276,7 +276,7 @@ cdef class ChunkedArray(_PandasConvertible): """ return _pc().cast(self, target_type, safe=safe) - def dictionary_encode(self): + def dictionary_encode(self, null_encoding='mask'): """ Compute dictionary-encoded representation of array @@ -285,7 +285,8 @@ cdef class ChunkedArray(_PandasConvertible): pyarrow.ChunkedArray Same chunking as the input, all chunks share a common dictionary. """ - return _pc().call_function('dictionary_encode', [self]) + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) def flatten(self, MemoryPool memory_pool=None): """ diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 796f6d998e8fc..57179f391de9a 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -16,6 +16,8 @@ # under the License. import contextlib +import os +import posixpath import pathlib import pickle import textwrap @@ -381,11 +383,16 @@ def test_partitioning(): with pytest.raises(pa.ArrowInvalid): partitioning.parse('/prefix/3/aaa') + expr = partitioning.parse('/3') + expected = ds.field('group') == 3 + assert expr.equals(expected) + partitioning = ds.HivePartitioning( pa.schema([ pa.field('alpha', pa.int64()), pa.field('beta', pa.int64()) - ]) + ]), + null_fallback='xyz' ) expr = partitioning.parse('/alpha=0/beta=3') expected = ( @@ -394,6 +401,12 @@ def test_partitioning(): ) assert expr.equals(expected) + expr = partitioning.parse('/alpha=xyz/beta=3') + expected = ( + (ds.field('alpha').is_null() & (ds.field('beta') == ds.scalar(3))) + ) + assert expr.equals(expected) + for shouldfail in ['/alpha=one/beta=2', '/alpha=one', '/beta=two']: with pytest.raises(pa.ArrowInvalid): partitioning.parse(shouldfail) @@ -412,7 +425,7 @@ def test_expression_serialization(): d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), ds.field('i64') > 5, ds.field('i64') == 5, - ds.field('i64') == 7] + ds.field('i64') == 7, ds.field('i64').is_null()] for expr in all_exprs: assert isinstance(expr, ds.Expression) restored = pickle.loads(pickle.dumps(expr)) @@ -468,6 +481,9 @@ def test_partition_keys(): assert ds._get_partition_keys(nope) == {} assert ds._get_partition_keys(a & nope) == {'a': 'a'} + null = ds.field('a').is_null() + assert ds._get_partition_keys(null) == {'a': None} + def test_parquet_read_options(): opts1 = ds.ParquetReadOptions() @@ -1239,6 +1255,57 @@ def test_partitioning_factory_dictionary(mockfs, infer_dictionary): assert inferred_schema.field('key').type == pa.string() +def test_dictionary_partitioning_outer_nulls_raises(tempdir): + table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']}) + part = ds.partitioning( + pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())])) + with pytest.raises(pa.ArrowInvalid): + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + + +def _has_subdirs(basedir): + elements = os.listdir(basedir) + return any([os.path.isdir(os.path.join(basedir, el)) for el in elements]) + + +def _do_list_all_dirs(basedir, path_so_far, result): + for f in os.listdir(basedir): + true_nested = os.path.join(basedir, f) + if os.path.isdir(true_nested): + norm_nested = posixpath.join(path_so_far, f) + if _has_subdirs(true_nested): + _do_list_all_dirs(true_nested, norm_nested, result) + else: + result.append(norm_nested) + + +def _list_all_dirs(basedir): + result = [] + _do_list_all_dirs(basedir, '', result) + return result + + +def _check_dataset_directories(tempdir, expected_directories): + actual_directories = set(_list_all_dirs(tempdir)) + assert actual_directories == set(expected_directories) + + +def test_dictionary_partitioning_inner_nulls(tempdir): + table = pa.table({'a': ['x', 'y', 'z'], 'b': ['x', 'y', None]}) + part = ds.partitioning( + pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())])) + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + _check_dataset_directories(tempdir, ['x/x', 'y/y', 'z']) + + +def test_hive_partitioning_nulls(tempdir): + table = pa.table({'a': ['x', None, 'z'], 'b': ['x', 'y', None]}) + part = ds.HivePartitioning(pa.schema( + [pa.field('a', pa.string()), pa.field('b', pa.string())]), None, 'xyz') + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + _check_dataset_directories(tempdir, ['a=x/b=x', 'a=xyz/b=y', 'a=z/b=xyz']) + + def test_partitioning_function(): schema = pa.schema([("year", pa.int16()), ("month", pa.int8())]) names = ["year", "month"] @@ -1600,25 +1667,48 @@ def test_open_dataset_non_existing_file(): @pytest.mark.parquet @pytest.mark.parametrize('partitioning', ["directory", "hive"]) +@pytest.mark.parametrize('null_fallback', ['xyz', None]) +@pytest.mark.parametrize('infer_dictionary', [False, True]) @pytest.mark.parametrize('partition_keys', [ (["A", "B", "C"], [1, 2, 3]), ([1, 2, 3], ["A", "B", "C"]), (["A", "B", "C"], ["D", "E", "F"]), ([1, 2, 3], [4, 5, 6]), + ([1, None, 3], ["A", "B", "C"]), + ([1, 2, 3], ["A", None, "C"]), + ([None, 2, 3], [None, 2, 3]), ]) -def test_open_dataset_partitioned_dictionary_type(tempdir, partitioning, - partition_keys): +def test_partition_discovery( + tempdir, partitioning, null_fallback, infer_dictionary, partition_keys +): # ARROW-9288 / ARROW-9476 import pyarrow.parquet as pq - table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5}) + + table = pa.table({'a': range(9), 'b': [0.0] * 4 + [1.0] * 5}) + + has_null = None in partition_keys[0] or None in partition_keys[1] + if partitioning == "directory" and has_null: + # Directory partitioning can't handle the first part being null + return if partitioning == "directory": partitioning = ds.DirectoryPartitioning.discover( - ["part1", "part2"], infer_dictionary=True) + ["part1", "part2"], infer_dictionary=infer_dictionary) fmt = "{0}/{1}" + null_value = None else: - partitioning = ds.HivePartitioning.discover(infer_dictionary=True) + if null_fallback: + partitioning = ds.HivePartitioning.discover( + infer_dictionary=infer_dictionary, null_fallback=null_fallback + ) + else: + partitioning = ds.HivePartitioning.discover( + infer_dictionary=infer_dictionary) fmt = "part1={0}/part2={1}" + if null_fallback: + null_value = null_fallback + else: + null_value = "__HIVE_DEFAULT_PARTITION__" basepath = tempdir / "dataset" basepath.mkdir() @@ -1626,19 +1716,23 @@ def test_open_dataset_partitioned_dictionary_type(tempdir, partitioning, part_keys1, part_keys2 = partition_keys for part1 in part_keys1: for part2 in part_keys2: - path = basepath / fmt.format(part1, part2) + path = basepath / \ + fmt.format(part1 or null_value, part2 or null_value) path.mkdir(parents=True) pq.write_table(table, path / "test.parquet") dataset = ds.dataset(str(basepath), partitioning=partitioning) - def dict_type(key): - value_type = pa.string() if isinstance(key, str) else pa.int32() - return pa.dictionary(pa.int32(), value_type) + def expected_type(key): + if infer_dictionary: + value_type = pa.string() if isinstance(key, str) else pa.int32() + return pa.dictionary(pa.int32(), value_type) + else: + return pa.string() if isinstance(key, str) else pa.int32() expected_schema = table.schema.append( - pa.field("part1", dict_type(part_keys1[0])) + pa.field("part1", expected_type(part_keys1[0])) ).append( - pa.field("part2", dict_type(part_keys2[0])) + pa.field("part2", expected_type(part_keys2[0])) ) assert dataset.schema.equals(expected_schema) @@ -2304,6 +2398,52 @@ def test_dataset_project_only_partition_columns(tempdir): assert all_cols.column('part').equals(part_only.column('part')) +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_to_dataset_given_null_just_works(tempdir): + import pyarrow.parquet as pq + + schema = pa.schema([ + pa.field('col', pa.int64()), + pa.field('part', pa.dictionary(pa.int32(), pa.string())) + ]) + table = pa.table({'part': [None, None, 'a', 'a'], + 'col': list(range(4))}, schema=schema) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=[ + 'part'], use_legacy_dataset=False) + + actual_table = pq.read_table(tempdir / 'test_dataset') + # column.equals can handle the difference in chunking but not the fact + # that `part` will have different dictionaries for the two chunks + assert actual_table.column('part').to_pylist( + ) == table.column('part').to_pylist() + assert actual_table.column('col').equals(table.column('col')) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_legacy_write_to_dataset_drops_null(tempdir): + import pyarrow.parquet as pq + + schema = pa.schema([ + pa.field('col', pa.int64()), + pa.field('part', pa.dictionary(pa.int32(), pa.string())) + ]) + table = pa.table({'part': ['a', 'a', None, None], + 'col': list(range(4))}, schema=schema) + expected = pa.table( + {'part': ['a', 'a'], 'col': list(range(2))}, schema=schema) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=[ + 'part'], use_legacy_dataset=True) + + actual = pq.read_table(tempdir / 'test_dataset') + assert actual == expected + + @pytest.mark.parquet @pytest.mark.pandas def test_dataset_project_null_column(tempdir): From 66be4c26d1d9746bf29136f9710b7c93c56c0ff7 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 24 Feb 2021 18:34:13 +0100 Subject: [PATCH 26/54] ARROW-11741: [C++] Fix decimal casts on big endian platforms Closes #9554 from pitrou/ARROW-11741-cast-decimal-big-endian Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/kernels/codegen_internal.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 11e03bba2873a..8c49e796623e7 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -663,11 +663,14 @@ struct ScalarUnaryNotNullStateful { static void Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0, Datum* out) { ArrayData* out_arr = out->mutable_array(); - auto out_data = out_arr->GetMutableValues(1); + // Decimal128 data buffers are not safely reinterpret_cast-able on big-endian + using endian_agnostic = std::array; + auto out_data = out_arr->GetMutableValues(1); VisitArrayValuesInline( arg0, [&](Arg0Value v) { - *out_data++ = functor.op.template Call(ctx, v); + functor.op.template Call(ctx, v).ToBytes( + out_data++->data()); }, [&]() { ++out_data; }); } From 2fb458d35ce36ae02488b07cff32ec730d2000c1 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 24 Feb 2021 20:12:40 +0100 Subject: [PATCH 27/54] ARROW-11665: [C++][Python] Improve docstrings for decimal and union types Also add Python factory functions `pa.sparse_union` and `pa.dense_union`. Closes #9560 from pitrou/ARROW-11665-py-decimal-doc Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/type.h | 58 +++++++++- cpp/src/arrow/type_fwd.h | 5 +- python/pyarrow/__init__.py | 6 +- python/pyarrow/public-api.pxi | 4 +- python/pyarrow/tests/test_misc.py | 29 +++-- python/pyarrow/tests/test_types.py | 31 +++-- python/pyarrow/types.pxi | 174 ++++++++++++++++++++++++----- 7 files changed, 252 insertions(+), 55 deletions(-) diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 0672354ab6cd2..fafe333852582 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -626,8 +626,10 @@ class ARROW_EXPORT LargeListType : public BaseListType { /// \brief Concrete type class for map data /// /// Map data is nested data where each value is a variable number of -/// key-item pairs. Maps can be recursively nested, for example -/// map(utf8, map(utf8, int32)). +/// key-item pairs. Its physical representation is the same as +/// a list of `{key, item}` structs. +/// +/// Maps can be recursively nested, for example map(utf8, map(utf8, int32)). class ARROW_EXPORT MapType : public ListType { public: static constexpr Type::type type_id = Type::MAP; @@ -894,6 +896,19 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { }; /// \brief Concrete type class for 128-bit decimal data +/// +/// Arrow decimals are fixed-point decimal numbers encoded as a scaled +/// integer. The precision is the number of significant digits that the +/// decimal type can represent; the scale is the number of digits after +/// the decimal point (note the scale can be negative). +/// +/// As an example, `Decimal128Type(7, 3)` can exactly represent the numbers +/// 1234.567 and -1234.567 (encoded internally as the 128-bit integers +/// 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567. +/// +/// Decimal128Type has a maximum precision of 38 significant digits +/// (also available as Decimal128Type::kMaxPrecision). +/// If higher precision is needed, consider using Decimal256Type. class ARROW_EXPORT Decimal128Type : public DecimalType { public: static constexpr Type::type type_id = Type::DECIMAL128; @@ -915,6 +930,18 @@ class ARROW_EXPORT Decimal128Type : public DecimalType { }; /// \brief Concrete type class for 256-bit decimal data +/// +/// Arrow decimals are fixed-point decimal numbers encoded as a scaled +/// integer. The precision is the number of significant digits that the +/// decimal type can represent; the scale is the number of digits after +/// the decimal point (note the scale can be negative). +/// +/// Decimal256Type has a maximum precision of 76 significant digits. +/// (also available as Decimal256Type::kMaxPrecision). +/// +/// For most use cases, the maximum precision offered by Decimal128Type +/// is sufficient, and it will result in a more compact and more efficient +/// encoding. class ARROW_EXPORT Decimal256Type : public DecimalType { public: static constexpr Type::type type_id = Type::DECIMAL256; @@ -935,7 +962,7 @@ class ARROW_EXPORT Decimal256Type : public DecimalType { static constexpr int32_t kByteWidth = 32; }; -/// \brief Concrete type class for union data +/// \brief Base type class for union data class ARROW_EXPORT UnionType : public NestedType { public: static constexpr int8_t kMaxTypeCode = 127; @@ -983,6 +1010,17 @@ class ARROW_EXPORT UnionType : public NestedType { std::vector child_ids_; }; +/// \brief Concrete type class for sparse union data +/// +/// A sparse union is a nested type where each logical value is taken from +/// a single child. A buffer of 8-bit type ids indicates which child +/// a given logical value is to be taken from. +/// +/// In a sparse union, each child array should have the same length as the +/// union array, regardless of the actual number of union values that +/// refer to it. +/// +/// Note that, unlike most other types, unions don't have a top-level validity bitmap. class ARROW_EXPORT SparseUnionType : public UnionType { public: static constexpr Type::type type_id = Type::SPARSE_UNION; @@ -999,6 +1037,20 @@ class ARROW_EXPORT SparseUnionType : public UnionType { std::string name() const override { return "sparse_union"; } }; +/// \brief Concrete type class for dense union data +/// +/// A dense union is a nested type where each logical value is taken from +/// a single child, at a specific offset. A buffer of 8-bit type ids +/// indicates which child a given logical value is to be taken from, +/// and a buffer of 32-bit offsets indicates at which physical position +/// in the given child array the logical value is to be taken from. +/// +/// Unlike a sparse union, a dense union allows encoding only the child array +/// values which are actually referred to by the union array. This is +/// counterbalanced by the additional footprint of the offsets buffer, and +/// the additional indirection cost when looking up values. +/// +/// Note that, unlike most other types, unions don't have a top-level validity bitmap. class ARROW_EXPORT DenseUnionType : public UnionType { public: static constexpr Type::type type_id = Type::DENSE_UNION; diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 14329675c8f1b..230c1ff6cb61e 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -438,7 +438,10 @@ std::shared_ptr ARROW_EXPORT date64(); ARROW_EXPORT std::shared_ptr fixed_size_binary(int32_t byte_width); -/// \brief Create a Decimal128Type or Decimal256Type instance depending on the precision +/// \brief Create a DecimalType instance depending on the precision +/// +/// If the precision is greater than 38, a Decimal128Type is returned, +/// otherwise a Decimal256Type. ARROW_EXPORT std::shared_ptr decimal(int32_t precision, int32_t scale); diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 98da7f45d265e..995bed9f1950a 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -96,12 +96,14 @@ def show_versions(): binary, string, utf8, large_binary, large_string, large_utf8, decimal128, decimal256, - list_, large_list, map_, struct, union, dictionary, + list_, large_list, map_, struct, + union, sparse_union, dense_union, + dictionary, field, type_for_alias, DataType, DictionaryType, StructType, ListType, LargeListType, MapType, FixedSizeListType, - UnionType, + UnionType, SparseUnionType, DenseUnionType, TimestampType, Time32Type, Time64Type, DurationType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 998af512c55aa..c427fb9f5db0f 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -94,9 +94,9 @@ cdef api object pyarrow_wrap_data_type( elif type.get().id() == _Type_STRUCT: out = StructType.__new__(StructType) elif type.get().id() == _Type_SPARSE_UNION: - out = UnionType.__new__(UnionType) + out = SparseUnionType.__new__(SparseUnionType) elif type.get().id() == _Type_DENSE_UNION: - out = UnionType.__new__(UnionType) + out = DenseUnionType.__new__(DenseUnionType) elif type.get().id() == _Type_TIMESTAMP: out = TimestampType.__new__(TimestampType) elif type.get().id() == _Type_DURATION: diff --git a/python/pyarrow/tests/test_misc.py b/python/pyarrow/tests/test_misc.py index bda5cb17eb0d6..a7a9f25b8c980 100644 --- a/python/pyarrow/tests/test_misc.py +++ b/python/pyarrow/tests/test_misc.py @@ -84,18 +84,21 @@ def test_runtime_info(): pa.Buffer, pa.Array, pa.Tensor, - pa.lib.DataType, - pa.lib.ListType, - pa.lib.LargeListType, - pa.lib.FixedSizeListType, - pa.lib.UnionType, - pa.lib.StructType, - pa.lib.Time32Type, - pa.lib.Time64Type, - pa.lib.TimestampType, - pa.lib.Decimal128Type, - pa.lib.DictionaryType, - pa.lib.FixedSizeBinaryType, + pa.DataType, + pa.ListType, + pa.LargeListType, + pa.FixedSizeListType, + pa.UnionType, + pa.SparseUnionType, + pa.DenseUnionType, + pa.StructType, + pa.Time32Type, + pa.Time64Type, + pa.TimestampType, + pa.Decimal128Type, + pa.Decimal256Type, + pa.DictionaryType, + pa.FixedSizeBinaryType, pa.NullArray, pa.NumericArray, pa.IntegerArray, @@ -125,6 +128,7 @@ def test_runtime_info(): pa.Time64Array, pa.DurationArray, pa.Decimal128Array, + pa.Decimal256Array, pa.StructArray, pa.Scalar, pa.BooleanScalar, @@ -140,6 +144,7 @@ def test_runtime_info(): pa.FloatScalar, pa.DoubleScalar, pa.Decimal128Scalar, + pa.Decimal256Scalar, pa.Date32Scalar, pa.Date64Scalar, pa.Time32Scalar, diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index e387e4b24b080..698ba8df0cc3d 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -17,6 +17,7 @@ from collections import OrderedDict from collections.abc import Iterator +from functools import partial import datetime import sys @@ -561,31 +562,45 @@ def check_fields(ty, fields): pa.field('y', pa.binary())] type_codes = [5, 9] - for mode in ('sparse', pa.lib.UnionMode_SPARSE): - ty = pa.union(fields, mode=mode) + sparse_factories = [ + partial(pa.union, mode='sparse'), + partial(pa.union, mode=pa.lib.UnionMode_SPARSE), + pa.sparse_union, + ] + + dense_factories = [ + partial(pa.union, mode='dense'), + partial(pa.union, mode=pa.lib.UnionMode_DENSE), + pa.dense_union, + ] + + for factory in sparse_factories: + ty = factory(fields) + assert isinstance(ty, pa.SparseUnionType) assert ty.mode == 'sparse' check_fields(ty, fields) assert ty.type_codes == [0, 1] - ty = pa.union(fields, mode=mode, type_codes=type_codes) + ty = factory(fields, type_codes=type_codes) assert ty.mode == 'sparse' check_fields(ty, fields) assert ty.type_codes == type_codes # Invalid number of type codes with pytest.raises(ValueError): - pa.union(fields, mode=mode, type_codes=type_codes[1:]) + factory(fields, type_codes=type_codes[1:]) - for mode in ('dense', pa.lib.UnionMode_DENSE): - ty = pa.union(fields, mode=mode) + for factory in dense_factories: + ty = factory(fields) + assert isinstance(ty, pa.DenseUnionType) assert ty.mode == 'dense' check_fields(ty, fields) assert ty.type_codes == [0, 1] - ty = pa.union(fields, mode=mode, type_codes=type_codes) + ty = factory(fields, type_codes=type_codes) assert ty.mode == 'dense' check_fields(ty, fields) assert ty.type_codes == type_codes # Invalid number of type codes with pytest.raises(ValueError): - pa.union(fields, mode=mode, type_codes=type_codes[1:]) + factory(fields, type_codes=type_codes[1:]) for mode in ('unknown', 2): with pytest.raises(ValueError, match='Invalid union mode'): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 56ea4e950a8c1..184e3dd8a7c77 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -442,7 +442,7 @@ cdef class StructType(DataType): cdef class UnionType(DataType): """ - Concrete class for struct data types. + Base class for union data types. """ cdef void init(self, const shared_ptr[CDataType]& type) except *: @@ -492,6 +492,18 @@ cdef class UnionType(DataType): return union, (list(self), self.mode, self.type_codes) +cdef class SparseUnionType(UnionType): + """ + Concrete class for sparse union types. + """ + + +cdef class DenseUnionType(UnionType): + """ + Concrete class for dense union types. + """ + + cdef class TimestampType(DataType): """ Concrete class for timestamp data types. @@ -2110,11 +2122,28 @@ def float64(): cpdef DataType decimal128(int precision, int scale=0): """ - Create decimal type with precision and scale and 128bit width. + Create decimal type with precision and scale and 128-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + As an example, ``decimal128(7, 3)`` can exactly represent the numbers + 1234.567 and -1234.567 (encoded internally as the 128-bit integers + 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567. + + ``decimal128(5, -3)`` can exactly represent the number 12345000 + (encoded internally as the 128-bit integer 12345), but neither + 123450000 nor 1234500. + + If you need a precision higher than 38 significant digits, consider + using ``decimal256``. Parameters ---------- precision : int + Must be between 1 and 38 scale : int Returns @@ -2130,11 +2159,22 @@ cpdef DataType decimal128(int precision, int scale=0): cpdef DataType decimal256(int precision, int scale=0): """ - Create decimal type with precision and scale and 256bit width. + Create decimal type with precision and scale and 256-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + For most use cases, the maximum precision offered by ``decimal128`` + is sufficient, and it will result in a more compact and more efficient + encoding. ``decimal256`` is useful if you need a precision higher + than 38 significant digits. Parameters ---------- precision : int + Must be between 1 and 76 scale : int Returns @@ -2386,25 +2426,119 @@ def struct(fields): return pyarrow_wrap_data_type(struct_type) -def union(children_fields, mode, type_codes=None): +cdef _extract_union_params(child_fields, type_codes, + vector[shared_ptr[CField]]* c_fields, + vector[int8_t]* c_type_codes): + cdef: + Field child_field + + for child_field in child_fields: + c_fields[0].push_back(child_field.sp_field) + + if type_codes is not None: + if len(type_codes) != (c_fields.size()): + raise ValueError("type_codes should have the same length " + "as fields") + for code in type_codes: + c_type_codes[0].push_back(code) + else: + c_type_codes[0] = range(c_fields.size()) + + +def sparse_union(child_fields, type_codes=None): + """ + Create SparseUnionType from child fields. + + A sparse union is a nested type where each logical value is taken from + a single child. A buffer of 8-bit type ids indicates which child + a given logical value is to be taken from. + + In a sparse union, each child array should have the same length as the + union array, regardless of the actual number of union values that + refer to it. + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + type_codes : list of integers, default None + + Returns + ------- + type : SparseUnionType + """ + cdef: + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + + _extract_union_params(child_fields, type_codes, + &c_fields, &c_type_codes) + + return pyarrow_wrap_data_type( + CMakeSparseUnionType(move(c_fields), move(c_type_codes))) + + +def dense_union(child_fields, type_codes=None): """ - Create UnionType from children fields. + Create DenseUnionType from child fields. - A union is defined by an ordered sequence of types; each slot in the union - can have a value chosen from these types. + A dense union is a nested type where each logical value is taken from + a single child, at a specific offset. A buffer of 8-bit type ids + indicates which child a given logical value is to be taken from, + and a buffer of 32-bit offsets indicates at which physical position + in the given child array the logical value is to be taken from. + + Unlike a sparse union, a dense union allows encoding only the child array + values which are actually referred to by the union array. This is + counterbalanced by the additional footprint of the offsets buffer, and + the additional indirection cost when looking up values. Parameters ---------- - fields : sequence of Field values + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + type_codes : list of integers, default None + + Returns + ------- + type : DenseUnionType + """ + cdef: + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + + _extract_union_params(child_fields, type_codes, + &c_fields, &c_type_codes) + + return pyarrow_wrap_data_type( + CMakeDenseUnionType(move(c_fields), move(c_type_codes))) + + +def union(child_fields, mode, type_codes=None): + """ + Create UnionType from child fields. + + A union is a nested type where each logical value is taken from a + single child. A buffer of 8-bit type ids indicates which child + a given logical value is to be taken from. + + Unions come in two flavors: sparse and dense + (see also `pyarrow.sparse_union` and `pyarrow.dense_union`). + + Parameters + ---------- + child_fields : sequence of Field values Each field must have a UTF8-encoded name, and these field names are part of the type metadata. mode : str - Either 'dense' or 'sparse'. + Must be 'sparse' or 'dense' type_codes : list of integers, default None Returns ------- - type : DataType + type : UnionType """ cdef: Field child_field @@ -2424,24 +2558,10 @@ def union(children_fields, mode, type_codes=None): else: raise ValueError("Invalid union mode {0!r}".format(mode)) - for child_field in children_fields: - c_fields.push_back(child_field.sp_field) - - if type_codes is not None: - if len(type_codes) != (c_fields.size()): - raise ValueError("type_codes should have the same length " - "as fields") - for code in type_codes: - c_type_codes.push_back(code) - else: - c_type_codes = range(c_fields.size()) - - if mode == UnionMode_SPARSE: - union_type = CMakeSparseUnionType(c_fields, c_type_codes) + if mode == _UnionMode_SPARSE: + return sparse_union(child_fields, type_codes) else: - union_type = CMakeDenseUnionType(c_fields, c_type_codes) - - return pyarrow_wrap_data_type(union_type) + return dense_union(child_fields, type_codes) cdef dict _type_aliases = { From 4de992c60ba433ad9b15ca1c41e6ec40bc542c2a Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Wed, 24 Feb 2021 14:38:25 -0500 Subject: [PATCH 28/54] ARROW-11738: [Rust][DataFusion] Fix Concat and Trim Functions This PR is a child of https://github.com/apache/arrow/pull/9243 It does a few things that are hard to separate: - fixes the behavior of `concat` and `trim` functions to be in line with the Postgres implementations - restructures some of the code base (mainly sorting and adding tests) to facilitate easier testing and implementation of the remainder of https://github.com/apache/arrow/pull/9243 @alamb @jorgecarleitao please review but merging will be dependent on https://github.com/apache/arrow/pull/9507 Closes #9551 from seddonm1/concat Authored-by: Mike Seddon Signed-off-by: Andrew Lamb --- rust/datafusion/README.md | 8 +- rust/datafusion/src/logical_plan/expr.rs | 12 +- rust/datafusion/src/logical_plan/mod.rs | 12 +- .../datafusion/src/physical_plan/functions.rs | 952 ++++++++++++++---- .../src/physical_plan/string_expressions.rs | 423 ++++++-- rust/datafusion/src/prelude.rs | 6 +- rust/datafusion/tests/sql.rs | 497 +++++---- 7 files changed, 1441 insertions(+), 469 deletions(-) diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index b4cb04321e7b1..5dcab04399e98 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -58,11 +58,17 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] Common math functions - String functions - [x] bit_Length + - [x] btrim - [x] char_length - [x] character_length + - [x] concat + - [x] concat_ws - [x] length + - [x] ltrim - [x] octet_length - - [x] Concatenate + - [x] rtrim + - [x] substr + - [x] trim - Miscellaneous/Boolean functions - [x] nullif - Common date/time functions diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 245ca3aaaa895..6dadefea54810 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -875,8 +875,11 @@ unary_scalar_expr!(Log10, log10); // string functions unary_scalar_expr!(BitLength, bit_length); +unary_scalar_expr!(Btrim, btrim); unary_scalar_expr!(CharacterLength, character_length); unary_scalar_expr!(CharacterLength, length); +unary_scalar_expr!(Concat, concat); +unary_scalar_expr!(ConcatWithSeparator, concat_ws); unary_scalar_expr!(Lower, lower); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); @@ -886,17 +889,10 @@ unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); +unary_scalar_expr!(Substr, substr); unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Upper, upper); -/// returns the concatenation of string expressions -pub fn concat(args: Vec) -> Expr { - Expr::ScalarFunction { - fun: functions::BuiltinScalarFunction::Concat, - args, - } -} - /// returns an array of fixed size with each argument on it. pub fn array(args: Vec) -> Expr { Expr::ScalarFunction { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0de0a032520bc..99c35fafd547c 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -33,12 +33,12 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil, - character_length, col, combine_filters, concat, cos, count, count_distinct, - create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, - log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, - sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, - ExpressionVisitor, Literal, Recursion, + abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, btrim, case, ceil, + character_length, col, combine_filters, concat, concat_ws, cos, count, + count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, + length, lit, ln, log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, + rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, trim, + trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 51941188bb440..1c82d0fea45a4 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -86,78 +86,87 @@ pub type ReturnTypeFunction = /// Enum of all built-in scalar functions #[derive(Debug, Clone, PartialEq, Eq)] pub enum BuiltinScalarFunction { - /// sqrt - Sqrt, - /// sin - Sin, - /// cos - Cos, - /// tan - Tan, - /// asin - Asin, + // math functions + /// abs + Abs, /// acos Acos, + /// asin + Asin, /// atan Atan, + /// ceil + Ceil, + /// cos + Cos, /// exp Exp, + /// floor + Floor, /// log, also known as ln Log, - /// log2 - Log2, /// log10 Log10, - /// floor - Floor, - /// ceil - Ceil, + /// log2 + Log2, /// round Round, - /// trunc - Trunc, - /// abs - Abs, /// signum Signum, + /// sin + Sin, + /// sqrt + Sqrt, + /// tan + Tan, + /// trunc + Trunc, + + // string functions + /// construct an array from columns + Array, + /// bit_length + BitLength, + /// btrim + Btrim, + /// character_length + CharacterLength, /// concat Concat, + /// concat_ws + ConcatWithSeparator, + /// Date part + DatePart, + /// Date truncate + DateTrunc, /// lower Lower, - /// upper - Upper, - /// trim - Trim, /// trim left Ltrim, - /// trim right - Rtrim, - /// to_timestamp - ToTimestamp, - /// construct an array from columns - Array, - /// SQL NULLIF() - NullIf, - /// Date truncate - DateTrunc, - /// Date part - DatePart, /// MD5 MD5, + /// SQL NULLIF() + NullIf, + /// octet_length + OctetLength, + /// trim right + Rtrim, /// SHA224 SHA224, - /// SHA256, + /// SHA256 SHA256, /// SHA384 SHA384, - /// SHA512, + /// SHA512 SHA512, - /// bit_length - BitLength, - /// character_length - CharacterLength, - /// octet_length - OctetLength, + /// substr + Substr, + /// to_timestamp + ToTimestamp, + /// trim + Trim, + /// upper + Upper, } impl fmt::Display for BuiltinScalarFunction { @@ -171,44 +180,51 @@ impl FromStr for BuiltinScalarFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name { - "sqrt" => BuiltinScalarFunction::Sqrt, - "sin" => BuiltinScalarFunction::Sin, - "cos" => BuiltinScalarFunction::Cos, - "tan" => BuiltinScalarFunction::Tan, - "asin" => BuiltinScalarFunction::Asin, + // math functions + "abs" => BuiltinScalarFunction::Abs, "acos" => BuiltinScalarFunction::Acos, + "asin" => BuiltinScalarFunction::Asin, "atan" => BuiltinScalarFunction::Atan, + "ceil" => BuiltinScalarFunction::Ceil, + "cos" => BuiltinScalarFunction::Cos, "exp" => BuiltinScalarFunction::Exp, + "floor" => BuiltinScalarFunction::Floor, "log" => BuiltinScalarFunction::Log, - "log2" => BuiltinScalarFunction::Log2, "log10" => BuiltinScalarFunction::Log10, - "floor" => BuiltinScalarFunction::Floor, - "ceil" => BuiltinScalarFunction::Ceil, + "log2" => BuiltinScalarFunction::Log2, "round" => BuiltinScalarFunction::Round, - "truc" => BuiltinScalarFunction::Trunc, - "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, + "sin" => BuiltinScalarFunction::Sin, + "sqrt" => BuiltinScalarFunction::Sqrt, + "tan" => BuiltinScalarFunction::Tan, + "trunc" => BuiltinScalarFunction::Trunc, + + // string functions + "array" => BuiltinScalarFunction::Array, + "bit_length" => BuiltinScalarFunction::BitLength, + "btrim" => BuiltinScalarFunction::Btrim, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, "concat" => BuiltinScalarFunction::Concat, + "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, + "date_part" => BuiltinScalarFunction::DatePart, + "date_trunc" => BuiltinScalarFunction::DateTrunc, + "length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, - "trim" => BuiltinScalarFunction::Trim, "ltrim" => BuiltinScalarFunction::Ltrim, - "rtrim" => BuiltinScalarFunction::Rtrim, - "upper" => BuiltinScalarFunction::Upper, - "to_timestamp" => BuiltinScalarFunction::ToTimestamp, - "date_trunc" => BuiltinScalarFunction::DateTrunc, - "date_part" => BuiltinScalarFunction::DatePart, - "array" => BuiltinScalarFunction::Array, - "nullif" => BuiltinScalarFunction::NullIf, "md5" => BuiltinScalarFunction::MD5, + "nullif" => BuiltinScalarFunction::NullIf, + "octet_length" => BuiltinScalarFunction::OctetLength, + "rtrim" => BuiltinScalarFunction::Rtrim, "sha224" => BuiltinScalarFunction::SHA224, "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, - "bit_length" => BuiltinScalarFunction::BitLength, - "octet_length" => BuiltinScalarFunction::OctetLength, - "length" => BuiltinScalarFunction::CharacterLength, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, + "substr" => BuiltinScalarFunction::Substr, + "to_timestamp" => BuiltinScalarFunction::ToTimestamp, + "trim" => BuiltinScalarFunction::Trim, + "upper" => BuiltinScalarFunction::Upper, + _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -242,80 +258,98 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Concat => Ok(DataType::Utf8), - BuiltinScalarFunction::Lower => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::LargeUtf8, - DataType::Utf8 => DataType::Utf8, + BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( + Box::new(Field::new("item", arg_types[0].clone(), true)), + arg_types.len() as i32, + )), + BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The upper function can only accept strings.".to_string(), + "The bit_length function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Ltrim => Ok(match arg_types[0] { + BuiltinScalarFunction::Btrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The ltrim function can only accept strings.".to_string(), + "The btrim function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Rtrim => Ok(match arg_types[0] { + BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The character_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Concat => Ok(DataType::Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8), + BuiltinScalarFunction::DatePart => Ok(DataType::Int32), + BuiltinScalarFunction::DateTrunc => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The rtrim function can only accept strings.".to_string(), + "The upper function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Trim => Ok(match arg_types[0] { + BuiltinScalarFunction::Ltrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The trim function can only accept strings.".to_string(), + "The ltrim function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Upper => Ok(match arg_types[0] { + BuiltinScalarFunction::MD5 => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The upper function can only accept strings.".to_string(), + "The md5 function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::ToTimestamp => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - BuiltinScalarFunction::DateTrunc => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - BuiltinScalarFunction::DatePart => Ok(DataType::Int32), - BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( - Box::new(Field::new("item", arg_types[0].clone(), true)), - arg_types.len() as i32, - )), BuiltinScalarFunction::NullIf => { // NULLIF has two args and they might get coerced, get a preview of this let coerced_types = data_types(arg_types, &signature(fun)); coerced_types.map(|typs| typs[0].clone()) } - BuiltinScalarFunction::MD5 => Ok(match arg_types[0] { + BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The octet_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Rtrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The md5 function can only accept strings.".to_string(), + "The rtrim function can only accept strings.".to_string(), )); } }), @@ -359,37 +393,57 @@ pub fn return_type( )); } }), - BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, + BuiltinScalarFunction::Substr => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The bit_length function can only accept strings.".to_string(), + "The substr function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, + BuiltinScalarFunction::ToTimestamp => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + BuiltinScalarFunction::Trim => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The character_length function can only accept strings.".to_string(), + "The trim function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, + BuiltinScalarFunction::Upper => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The octet_length function can only accept strings.".to_string(), + "The upper function can only accept strings.".to_string(), )); } }), - _ => Ok(DataType::Float64), + + BuiltinScalarFunction::Abs + | BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Log + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Trunc => Ok(DataType::Float64), } } @@ -401,37 +455,26 @@ pub fn create_physical_expr( input_schema: &Schema, ) -> Result> { let fun_expr: ScalarFunctionImplementation = Arc::new(match fun { - BuiltinScalarFunction::Sqrt => math_expressions::sqrt, - BuiltinScalarFunction::Sin => math_expressions::sin, - BuiltinScalarFunction::Cos => math_expressions::cos, - BuiltinScalarFunction::Tan => math_expressions::tan, - BuiltinScalarFunction::Asin => math_expressions::asin, + // math functions + BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Acos => math_expressions::acos, + BuiltinScalarFunction::Asin => math_expressions::asin, BuiltinScalarFunction::Atan => math_expressions::atan, + BuiltinScalarFunction::Ceil => math_expressions::ceil, + BuiltinScalarFunction::Cos => math_expressions::cos, BuiltinScalarFunction::Exp => math_expressions::exp, + BuiltinScalarFunction::Floor => math_expressions::floor, BuiltinScalarFunction::Log => math_expressions::ln, - BuiltinScalarFunction::Log2 => math_expressions::log2, BuiltinScalarFunction::Log10 => math_expressions::log10, - BuiltinScalarFunction::Floor => math_expressions::floor, - BuiltinScalarFunction::Ceil => math_expressions::ceil, + BuiltinScalarFunction::Log2 => math_expressions::log2, BuiltinScalarFunction::Round => math_expressions::round, - BuiltinScalarFunction::Trunc => math_expressions::trunc, - BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Signum => math_expressions::signum, - BuiltinScalarFunction::NullIf => nullif_func, - BuiltinScalarFunction::MD5 => crypto_expressions::md5, - BuiltinScalarFunction::SHA224 => crypto_expressions::sha224, - BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, - BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, - BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, - BuiltinScalarFunction::Concat => string_expressions::concatenate, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Ltrim => string_expressions::ltrim, - BuiltinScalarFunction::Rtrim => string_expressions::rtrim, - BuiltinScalarFunction::Upper => string_expressions::upper, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Sin => math_expressions::sin, + BuiltinScalarFunction::Sqrt => math_expressions::sqrt, + BuiltinScalarFunction::Tan => math_expressions::tan, + BuiltinScalarFunction::Trunc => math_expressions::trunc, + + // string functions BuiltinScalarFunction::Array => array_expressions::array, BuiltinScalarFunction::BitLength => |args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), @@ -445,6 +488,18 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::Btrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function btrim", + other, + ))), + }, BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { DataType::Utf8 => make_scalar_function( string_expressions::character_length::, @@ -457,6 +512,27 @@ pub fn create_physical_expr( other, ))), }, + BuiltinScalarFunction::Concat => string_expressions::concat, + BuiltinScalarFunction::ConcatWithSeparator => { + |args| make_scalar_function(string_expressions::concat_ws)(args) + } + BuiltinScalarFunction::DatePart => datetime_expressions::date_part, + BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ltrim", + other, + ))), + }, + BuiltinScalarFunction::MD5 => crypto_expressions::md5, + BuiltinScalarFunction::NullIf => nullif_func, BuiltinScalarFunction::OctetLength => |args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -469,7 +545,48 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, - BuiltinScalarFunction::DatePart => datetime_expressions::date_part, + BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rtrim", + other, + ))), + }, + BuiltinScalarFunction::SHA224 => crypto_expressions::sha224, + BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, + BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, + BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, + BuiltinScalarFunction::Substr => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::substr::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::substr::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function substr", + other, + ))), + }, + BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function trim", + other, + ))), + }, + BuiltinScalarFunction::Upper => string_expressions::upper, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -493,22 +610,31 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { // for now, the list is small, as we do not have many built-in functions. match fun { - BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), - BuiltinScalarFunction::Upper - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::BitLength + BuiltinScalarFunction::Array => { + Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) + } + BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { + Signature::Variadic(vec![DataType::Utf8]) + } + BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Trim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim + | BuiltinScalarFunction::Lower | BuiltinScalarFunction::MD5 + | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::SHA224 | BuiltinScalarFunction::SHA256 | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 => { + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::Trim + | BuiltinScalarFunction::Upper => { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } + BuiltinScalarFunction::Btrim + | BuiltinScalarFunction::Ltrim + | BuiltinScalarFunction::Rtrim => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8]), + Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), + ]), BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::DateTrunc => Signature::Exact(vec![ DataType::Utf8, @@ -534,9 +660,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { DataType::Timestamp(TimeUnit::Nanosecond, None), ]), ]), - BuiltinScalarFunction::Array => { - Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) - } + BuiltinScalarFunction::Substr => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64, DataType::Int64]), + ]), BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } @@ -753,6 +882,106 @@ mod tests { #[test] fn test_functions() -> Result<()> { + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(40)), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(40)), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some("\n trim \n".to_string())))], + Ok(Some("\n trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("xyxtrimyyx".to_string()))), + lit(ScalarValue::Utf8(Some("xyz".to_string()))), + ], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("\nxyxtrimyyx\n".to_string()))), + lit(ScalarValue::Utf8(Some("xyz\n".to_string()))), + ], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("xyz".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("xyxtrimyyx".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); test_function!( CharacterLength, &[lit(ScalarValue::Utf8(Some("chars".to_string())))], @@ -785,6 +1014,88 @@ mod tests { Int32, Int32Array ); + test_function!( + Concat, + &[ + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aabbcc")), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[ + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[lit(ScalarValue::Utf8(None))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aa|bb|cc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aa|cc")), + &str, + Utf8, + StringArray + ); test_function!( Exp, &[lit(ScalarValue::Int32(Some(1)))], @@ -825,42 +1136,331 @@ mod tests { Float64, Float64Array ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("\n trim ".to_string())))], + Ok(Some("\n trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(30))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Int64(Some(-1))), + ], + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim \n".to_string())))], + Ok(Some(" trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); Ok(()) } - fn test_concat(value: ScalarValue, expected: &str) -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - // concat(value, value) - let expr = create_physical_expr( - &BuiltinScalarFunction::Concat, - &[lit(value.clone()), lit(value)], - &schema, - )?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Utf8); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) - } - - #[test] - fn test_concat_utf8() -> Result<()> { - test_concat(ScalarValue::Utf8(Some("aa".to_string())), "aaaa") - } - #[test] fn test_concat_error() -> Result<()> { let result = return_type(&BuiltinScalarFunction::Concat, &[]); diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 81d2c67eec63b..7ab0f9f215be8 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +// Some of these functions reference the Postgres documentation +// or implementation to ensure compatibility and are subject to +// the Postgres license. + //! String expressions use std::sync::Arc; @@ -25,7 +29,7 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray, + Array, ArrayRef, GenericStringArray, Int64Array, PrimitiveArray, StringArray, StringOffsetSizeTrait, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, @@ -119,6 +123,71 @@ where } } +macro_rules! downcast_vec { + ($ARGS:expr, $ARRAY_TYPE:ident) => {{ + $ARGS + .iter() + .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { + Some(array) => Ok(array), + _ => Err(DataFusionError::Internal("failed to downcast".to_string())), + }) + }}; +} + +/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_start_matches(' ').trim_end_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_start_matches(&chars[..]) + .trim_end_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "btrim was called with {} arguments. It requires at most 2.", + other + ))), + } +} + /// Returns number of characters in the string. /// character_length('josé') = 4 pub fn character_length(args: &[ArrayRef]) -> Result @@ -140,16 +209,15 @@ where Ok(Arc::new(result) as ArrayRef) } -/// concatenate string columns together. -pub fn concatenate(args: &[ColumnarValue]) -> Result { - // downcast all arguments to strings - //let args = downcast_vec!(args, StringArray).collect::>>()?; +/// Concatenates the text representations of all the arguments. NULL arguments are ignored. +/// concat('abcde', 2, NULL, 22) = 'abcde222' +pub fn concat(args: &[ColumnarValue]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal( - "Concatenate was called with 0 arguments. It requires at least one." - .to_string(), - )); + return Err(DataFusionError::Internal(format!( + "concat was called with {} arguments. It requires at least 1.", + args.len() + ))); } // first, decide whether to return a scalar or a vector. @@ -158,42 +226,30 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { _ => None, }); if let Some(size) = return_array.next() { - let iter = (0..size).map(|index| { - let mut owned_string: String = "".to_owned(); - - // if any is null, the result is null - let mut is_null = false; - for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(value) = maybe_value { - owned_string.push_str(value); - } else { - is_null = true; - break; // short-circuit as we already know the result + let result = (0..size) + .map(|index| { + let mut owned_string: String = "".to_owned(); + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(value) = maybe_value { + owned_string.push_str(value); + } } - } - ColumnarValue::Array(v) => { - if v.is_null(index) { - is_null = true; - break; // short-circuit as we already know the result - } else { - let v = v.as_any().downcast_ref::().unwrap(); - owned_string.push_str(&v.value(index)); + ColumnarValue::Array(v) => { + if v.is_valid(index) { + let v = v.as_any().downcast_ref::().unwrap(); + owned_string.push_str(&v.value(index)); + } } + _ => unreachable!(), } - _ => unreachable!(), } - } - if is_null { - None - } else { Some(owned_string) - } - }); - let array = iter.collect::(); + }) + .collect::(); - Ok(ColumnarValue::Array(Arc::new(array))) + Ok(ColumnarValue::Array(Arc::new(result))) } else { // short avenue with only scalars let initial = Some("".to_string()); @@ -203,9 +259,7 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { inner.push_str(v); } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - acc = None; - } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} _ => unreachable!(""), }; }; @@ -215,27 +269,284 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { } } -/// lower +/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. +/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' +pub fn concat_ws(args: &[ArrayRef]) -> Result { + // downcast all arguments to strings + let args = downcast_vec!(args, StringArray).collect::>>()?; + + // do not accept 0 or 1 arguments. + if args.len() < 2 { + return Err(DataFusionError::Internal(format!( + "concat_ws was called with {} arguments. It requires at least 2.", + args.len() + ))); + } + + // first map is the iterator, second is for the `Option<_>` + let result = args[0] + .iter() + .enumerate() + .map(|(index, x)| { + x.map(|sep: &str| { + let mut owned_string: String = "".to_owned(); + for arg_index in 1..args.len() { + let arg = &args[arg_index]; + if !arg.is_null(index) { + owned_string.push_str(&arg.value(index)); + // if not last push separator + if arg_index != args.len() - 1 { + owned_string.push_str(&sep); + } + } + } + owned_string + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Removes the longest string containing only characters in characters (a space by default) from the start of string. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_start_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_start_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "ltrim was called with {} arguments. It requires at most 2.", + other + ))), + } +} + +/// Converts the string to all lower case. +/// lower('TOM') = 'tom' pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |x| x.to_ascii_lowercase(), "lower") } -/// upper -pub fn upper(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.to_ascii_uppercase(), "upper") -} +/// Removes the longest string containing only characters in characters (a space by default) from the end of string. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_end_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); -/// trim -pub fn trim(args: &[ColumnarValue]) -> Result { - handle(args, |x: &str| x.trim(), "trim") + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_end_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "rtrim was called with {} arguments. It requires at most two.", + other + ))), + } } -/// ltrim -pub fn ltrim(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.trim_start(), "ltrim") +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let start_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast start to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if start_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let start: i64 = start_array.value(i); + + if start <= 0 { + x.to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + if graphemes.len() < start_pos { + "".to_string() + } else { + graphemes[start_pos..].concat() + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let start_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast start to Int64Array".to_string(), + ) + })?; + + let count_array: &Int64Array = args[2] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast count to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if start_array.is_null(i) || count_array.is_null(i) { + Ok(None) + } else { + x.map(|x: &str| { + let start: i64 = start_array.value(i); + let count = count_array.value(i); + + if count < 0 { + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )) + } else if start <= 0 { + Ok(x.to_string()) + } else { + let graphemes = x.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + let count_usize = count as usize; + if graphemes.len() < start_pos { + Ok("".to_string()) + } else if graphemes.len() < start_pos + count_usize { + Ok(graphemes[start_pos..].concat()) + } else { + Ok(graphemes[start_pos..start_pos + count_usize] + .concat()) + } + } + }) + .transpose() + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "substr was called with {} arguments. It requires 2 or 3.", + other + ))), + } } -/// rtrim -pub fn rtrim(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.trim_end(), "rtrim") +/// Converts the string to all upper case. +/// upper('tom') = 'TOM' +pub fn upper(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.to_ascii_uppercase(), "upper") } diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 26e03c7453e98..d60f0c32a4dfc 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,8 +28,8 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, bit_length, character_length, col, concat, count, create_udf, in_list, - length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256, - sha384, sha512, sum, trim, upper, JoinType, Partitioning, + array, avg, bit_length, btrim, character_length, col, concat, concat_ws, count, + create_udf, in_list, length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, + sha224, sha256, sha384, sha512, substr, sum, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 7a0666635a2e4..587fe299bd8ec 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1639,7 +1639,7 @@ async fn query_concat() -> Result<()> { let expected = vec![ vec!["-hi-0"], vec!["a-hi-1"], - vec!["NULL"], + vec!["aa-hi-"], vec!["aaa-hi-3"], ]; assert_eq!(expected, actual); @@ -1886,7 +1886,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation let sql = "SELECT concat(d1, '-foo') FROM test"; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]]; + let expected = vec![vec!["one-foo"], vec!["-foo"], vec!["three-foo"]]; assert_eq!(expected, actual); // aggregation @@ -2023,170 +2023,290 @@ async fn csv_group_by_date() -> Result<()> { Ok(()) } -#[tokio::test] -async fn string_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - char_length('tom') AS char_length - ,char_length(NULL) AS char_length_null - ,character_length('tom') AS character_length - ,character_length(NULL) AS character_length_null - ,lower('TOM') AS lower - ,lower(NULL) AS lower_null - ,upper('tom') AS upper - ,upper(NULL) AS upper_null - ,trim(' tom ') AS trim - ,trim(NULL) AS trim_null - ,ltrim(' tom ') AS trim_left - ,rtrim(' tom ') AS trim_right - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "3", "NULL", "3", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", "tom ", - " tom", - ]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn boolean_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - true AS val_1, - false AS val_2 - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec!["true", "false"]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn interval_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - (interval '1') as interval_1, - (interval '1 second') as interval_2, - (interval '500 milliseconds') as interval_3, - (interval '5 second') as interval_4, - (interval '1 minute') as interval_5, - (interval '0.5 minute') as interval_6, - (interval '.5 minute') as interval_7, - (interval '5 minute') as interval_8, - (interval '5 minute 1 second') as interval_9, - (interval '1 hour') as interval_10, - (interval '5 hour') as interval_11, - (interval '1 day') as interval_12, - (interval '1 day 1') as interval_13, - (interval '0.5') as interval_14, - (interval '0.5 day 1') as interval_15, - (interval '0.49 day') as interval_16, - (interval '0.499 day') as interval_17, - (interval '0.4999 day') as interval_18, - (interval '0.49999 day') as interval_19, - (interval '0.49999999999 day') as interval_20, - (interval '5 day') as interval_21, - (interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds') as interval_22, - (interval '0.5 month') as interval_23, - (interval '1 month') as interval_24, - (interval '5 month') as interval_25, - (interval '13 month') as interval_26, - (interval '0.5 year') as interval_27, - (interval '1 year') as interval_28, - (interval '2 year') as interval_29 - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs", - "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs", - "0 years 0 mons 0 days 0 hours 1 mins 0.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs", - "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs", - "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs", - "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs", - "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs", - "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs", - "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs", - "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs", - "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs", - "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs", - "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs", - "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs", - "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs", - "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs", - "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs", - "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs", - "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs", - "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs", - "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs", - "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs", - "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs", - ]]; - assert_eq!(expected, actual); +macro_rules! test_expression { + ($SQL:expr, $EXPECTED:expr) => { + let mut ctx = ExecutionContext::new(); + let sql = format!("SELECT {}", $SQL); + let actual = execute(&mut ctx, sql.as_str()).await; + assert_eq!($EXPECTED, actual[0][0]); + }; +} + +#[tokio::test] +async fn test_string_expressions() -> Result<()> { + test_expression!("bit_length('')", "0"); + test_expression!("bit_length('chars')", "40"); + test_expression!("bit_length('josé')", "40"); + test_expression!("bit_length(NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); + test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); + test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); + test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); + test_expression!("btrim(NULL, 'xyz')", "NULL"); + test_expression!("char_length('')", "0"); + test_expression!("char_length('chars')", "5"); + test_expression!("char_length(NULL)", "NULL"); + test_expression!("character_length('')", "0"); + test_expression!("character_length('chars')", "5"); + test_expression!("character_length('josé')", "4"); + test_expression!("character_length(NULL)", "NULL"); + test_expression!("concat('a','b','c')", "abc"); + test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); + test_expression!("concat(NULL)", ""); + test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); + test_expression!("concat_ws('|','a','b','c')", "a|b|c"); + test_expression!("concat_ws('|',NULL)", ""); + test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ')", "zzzytest "); + test_expression!("ltrim('zzzytest', 'xyz')", "test"); + test_expression!("ltrim(NULL, 'xyz')", "NULL"); + test_expression!("lower('')", ""); + test_expression!("lower('TOM')", "tom"); + test_expression!("lower(NULL)", "NULL"); + test_expression!("octet_length('')", "0"); + test_expression!("octet_length('chars')", "5"); + test_expression!("octet_length('josé')", "5"); + test_expression!("octet_length(NULL)", "NULL"); + test_expression!("rtrim(' testxxzx ')", " testxxzx"); + test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); + test_expression!("rtrim('testxxzx', 'xyz')", "test"); + test_expression!("rtrim(NULL, 'xyz')", "NULL"); + test_expression!("substr('alphabet', -3)", "alphabet"); + test_expression!("substr('alphabet', 0)", "alphabet"); + test_expression!("substr('alphabet', 1)", "alphabet"); + test_expression!("substr('alphabet', 2)", "lphabet"); + test_expression!("substr('alphabet', 3)", "phabet"); + test_expression!("substr('alphabet', 30)", ""); + test_expression!("substr('alphabet', CAST(NULL AS int))", "NULL"); + test_expression!("substr('alphabet', 3, 2)", "ph"); + test_expression!("substr('alphabet', 3, 20)", "phabet"); + test_expression!("substr('alphabet', CAST(NULL AS int), 20)", "NULL"); + test_expression!("substr('alphabet', 3, CAST(NULL AS int))", "NULL"); + test_expression!("trim(' tom ')", "tom"); + test_expression!("trim(' tom')", "tom"); + test_expression!("trim('')", ""); + test_expression!("trim('tom ')", "tom"); + test_expression!("upper('')", ""); + test_expression!("upper('tom')", "TOM"); + test_expression!("upper(NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_boolean_expressions() -> Result<()> { + test_expression!("true", "true"); + test_expression!("false", "false"); + Ok(()) +} + +#[tokio::test] +async fn test_interval_expressions() -> Result<()> { + test_expression!( + "interval '1'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '500 milliseconds'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '5 second'", + "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs" + ); + test_expression!( + "interval '0.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '5 minute'", + "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs" + ); + test_expression!( + "interval '5 minute 1 second'", + "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs" + ); + test_expression!( + "interval '1 hour'", + "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 hour'", + "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day'", + "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day 1'", + "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.5'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '0.5 day 1'", + "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.49 day'", + "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs" + ); + test_expression!( + "interval '0.499 day'", + "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs" + ); + test_expression!( + "interval '0.4999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs" + ); + test_expression!( + "interval '0.49999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs" + ); + test_expression!( + "interval '0.49999999999 day'", + "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day'", + "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", + "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs" + ); + test_expression!( + "interval '0.5 month'", + "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 month'", + "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 month'", + "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '13 month'", + "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '0.5 year'", + "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 year'", + "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '2 year'", + "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); Ok(()) } #[tokio::test] -async fn crypto_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - md5('tom') AS md5_tom, - md5('') AS md5_empty_str, - md5(null) AS md5_null, - sha224('tom') AS sha224_tom, - sha224('') AS sha224_empty_str, - sha224(null) AS sha224_null, - sha256('tom') AS sha256_tom, - sha256('') AS sha256_empty_str, - sha384('tom') AS sha348_tom, - sha384('') AS sha384_empty_str, - sha512('tom') AS sha512_tom, - sha512('') AS sha512_empty_str - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "34b7da764b21d298ef307d04d8152dc5", - "d41d8cd98f00b204e9800998ecf8427e", - "NULL", - "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f", - "NULL", - "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343", - "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", - "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e", - "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e" - ]]; - assert_eq!(expected, actual); +async fn test_crypto_expressions() -> Result<()> { + test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); + test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); + test_expression!("md5(NULL)", "NULL"); + test_expression!( + "sha224('tom')", + "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" + ); + test_expression!( + "sha224('')", + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + ); + test_expression!("sha224(NULL)", "NULL"); + test_expression!( + "sha256('tom')", + "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + ); + test_expression!( + "sha256('')", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + test_expression!("sha256(NULL)", "NULL"); + test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); + test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); + test_expression!("sha384(NULL)", "NULL"); + test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); + test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); + test_expression!("sha512(NULL)", "NULL"); + Ok(()) +} +#[tokio::test] +async fn test_extract_date_part() -> Result<()> { + test_expression!("date_part('hour', CAST('2020-01-01' AS DATE))", "0"); + test_expression!("EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE))", "0"); + test_expression!( + "EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "12" + ); + test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000"); + test_expression!( + "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "2020" + ); Ok(()) } #[tokio::test] -async fn extract_date_part() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - date_part('hour', CAST('2020-01-01' AS DATE)) AS hr1, - EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE)) AS hr2, - EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS hr3, - date_part('YEAR', CAST('2000-01-01' AS DATE)) AS year1, - EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS year2 - "; - - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec!["0", "0", "12", "2000", "2020"]]; - assert_eq!(expected, actual); +async fn test_in_list_scalar() -> Result<()> { + test_expression!("'a' IN ('a','b')", "true"); + test_expression!("'c' IN ('a','b')", "false"); + test_expression!("'c' NOT IN ('a','b')", "true"); + test_expression!("'a' NOT IN ('a','b')", "false"); + test_expression!("NULL IN ('a','b')", "NULL"); + test_expression!("NULL NOT IN ('a','b')", "NULL"); + test_expression!("'a' IN ('a','b',NULL)", "true"); + test_expression!("'c' IN ('a','b',NULL)", "NULL"); + test_expression!("'a' NOT IN ('a','b',NULL)", "false"); + test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); + test_expression!("0 IN (0,1,2)", "true"); + test_expression!("3 IN (0,1,2)", "false"); + test_expression!("3 NOT IN (0,1,2)", "true"); + test_expression!("0 NOT IN (0,1,2)", "false"); + test_expression!("NULL IN (0,1,2)", "NULL"); + test_expression!("NULL NOT IN (0,1,2)", "NULL"); + test_expression!("0 IN (0,1,2,NULL)", "true"); + test_expression!("3 IN (0,1,2,NULL)", "NULL"); + test_expression!("0 NOT IN (0,1,2,NULL)", "false"); + test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); + test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); + test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("'1' IN ('a','b',1)", "true"); + test_expression!("'2' IN ('a','b',1)", "false"); + test_expression!("'2' NOT IN ('a','b',1)", "true"); + test_expression!("'1' NOT IN ('a','b',1)", "false"); + test_expression!("NULL IN ('a','b',1)", "NULL"); + test_expression!("NULL NOT IN ('a','b',1)", "NULL"); + test_expression!("'1' IN ('a','b',NULL,1)", "true"); + test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); + test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); + test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); Ok(()) } @@ -2215,67 +2335,6 @@ async fn in_list_array() -> Result<()> { Ok(()) } -#[tokio::test] -async fn in_list_scalar() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - 'a' IN ('a','b') AS utf8_in_true - ,'c' IN ('a','b') AS utf8_in_false - ,'c' NOT IN ('a','b') AS utf8_not_in_true - ,'a' NOT IN ('a','b') AS utf8_not_in_false - ,NULL IN ('a','b') AS utf8_in_null - ,NULL NOT IN ('a','b') AS utf8_not_in_null - ,'a' IN ('a','b',NULL) AS utf8_in_null_true - ,'c' IN ('a','b',NULL) AS utf8_in_null_null - ,'a' NOT IN ('a','b',NULL) AS utf8_not_in_null_false - ,'c' NOT IN ('a','b',NULL) AS utf8_not_in_null_null - - ,0 IN (0,1,2) AS int64_in_true - ,3 IN (0,1,2) AS int64_in_false - ,3 NOT IN (0,1,2) AS int64_not_in_true - ,0 NOT IN (0,1,2) AS int64_not_in_false - ,NULL IN (0,1,2) AS int64_in_null - ,NULL NOT IN (0,1,2) AS int64_not_in_null - ,0 IN (0,1,2,NULL) AS int64_in_null_true - ,3 IN (0,1,2,NULL) AS int64_in_null_null - ,0 NOT IN (0,1,2,NULL) AS int64_not_in_null_false - ,3 NOT IN (0,1,2,NULL) AS int64_not_in_null_null - - ,0.0 IN (0.0,0.1,0.2) AS float64_in_true - ,0.3 IN (0.0,0.1,0.2) AS float64_in_false - ,0.3 NOT IN (0.0,0.1,0.2) AS float64_not_in_true - ,0.0 NOT IN (0.0,0.1,0.2) AS float64_not_in_false - ,NULL IN (0.0,0.1,0.2) AS float64_in_null - ,NULL NOT IN (0.0,0.1,0.2) AS float64_not_in_null - ,0.0 IN (0.0,0.1,0.2,NULL) AS float64_in_null_true - ,0.3 IN (0.0,0.1,0.2,NULL) AS float64_in_null_null - ,0.0 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_false - ,0.3 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_null - - ,'1' IN ('a','b',1) AS utf8_cast_in_true - ,'2' IN ('a','b',1) AS utf8_cast_in_false - ,'2' NOT IN ('a','b',1) AS utf8_cast_not_in_true - ,'1' NOT IN ('a','b',1) AS utf8_cast_not_in_false - ,NULL IN ('a','b',1) AS utf8_cast_in_null - ,NULL NOT IN ('a','b',1) AS utf8_cast_not_in_null - ,'1' IN ('a','b',NULL,1) AS utf8_cast_in_null_true - ,'2' IN ('a','b',NULL,1) AS utf8_cast_in_null_null - ,'1' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_false - ,'2' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_null - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", "false", - "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", - "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", - "NULL", "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", - "true", "NULL", "false", "NULL", - ]]; - assert_eq!(expected, actual); - Ok(()) -} - // TODO Tests to prove correct implementation of INNER JOIN's with qualified names. // https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11432. #[tokio::test] From 81e9417eb68171e03a304097ae86e1fd83307130 Mon Sep 17 00:00:00 2001 From: Diana Clarke Date: Wed, 24 Feb 2021 15:27:21 -0500 Subject: [PATCH 29/54] ARROW-11575: [Developer][Archery] Expose execution time in benchmark results See: https://issues.apache.org/jira/browse/ARROW-11575 Google Benchmark reports both cpu time & real time in each benchmark observation. For example: ``` {'cpu_time': 9718937.499999996, 'items_per_second': 26972495.707478322, 'iterations': 64, 'name': 'TakeStringRandomIndicesWithNulls/262144/0', 'null_percent': 0.0, 'real_time': 10297947.859726265, 'repetition_index': 2, 'repetitions': 0, 'run_name': 'TakeStringRandomIndicesWithNulls/262144/0', 'run_type': 'iteration', 'size': 262144.0, 'threads': 1, 'time_unit': 'ns'}, ``` Currently, Archery doesn't expose the execution time in its json results though. For example: ``` { "name": "TakeStringRandomIndicesWithNulls/262144/2", "unit": "items_per_second", "less_is_better": false, "values": [ 20900887.666890558, 21737551.30809738, 21872425.314689018 ] } ``` This pull request updates Archery to expose the real time as well. For example: ``` { "name": "TakeStringRandomIndicesWithNulls/262144/2", "unit": "items_per_second", "less_is_better": false, "values": [ 20900887.666890558, 21737551.30809738, 21872425.314689018 ], "time_unit": "ns", "times": [ 34939132.454438195, 44459594.18080747, 46606865.63566384 ] } ``` Motivation: I am persisting these results and would also like to store the execution time to debug slow benchmarks. Closes #9458 from dianaclarke/ARROW-11575 Authored-by: Diana Clarke Signed-off-by: Benjamin Kietzman --- .../archery/{utils => benchmark}/codec.py | 2 + dev/archery/archery/benchmark/core.py | 5 +- dev/archery/archery/benchmark/google.py | 4 +- dev/archery/archery/cli.py | 2 +- dev/archery/tests/test_benchmarks.py | 153 ++++++++++++++++-- 5 files changed, 150 insertions(+), 16 deletions(-) rename dev/archery/archery/{utils => benchmark}/codec.py (97%) diff --git a/dev/archery/archery/utils/codec.py b/dev/archery/archery/benchmark/codec.py similarity index 97% rename from dev/archery/archery/utils/codec.py rename to dev/archery/archery/benchmark/codec.py index 86bcbe1bfeede..359dea9b9f3d1 100644 --- a/dev/archery/archery/utils/codec.py +++ b/dev/archery/archery/benchmark/codec.py @@ -48,6 +48,8 @@ def encode(b): "unit": b.unit, "less_is_better": b.less_is_better, "values": b.values, + "time_unit": b.time_unit, + "times": b.times, } @staticmethod diff --git a/dev/archery/archery/benchmark/core.py b/dev/archery/archery/benchmark/core.py index 8246105758b00..5a92271a35391 100644 --- a/dev/archery/archery/benchmark/core.py +++ b/dev/archery/archery/benchmark/core.py @@ -27,11 +27,14 @@ def median(values): class Benchmark: - def __init__(self, name, unit, less_is_better, values, counters=None): + def __init__(self, name, unit, less_is_better, values, time_unit, + times, counters=None): self.name = name self.unit = unit self.less_is_better = less_is_better self.values = sorted(values) + self.time_unit = time_unit + self.times = sorted(times) self.median = median(self.values) self.counters = counters or {} diff --git a/dev/archery/archery/benchmark/google.py b/dev/archery/archery/benchmark/google.py index f5958b17864e2..c1644dcbd9cea 100644 --- a/dev/archery/archery/benchmark/google.py +++ b/dev/archery/archery/benchmark/google.py @@ -152,11 +152,13 @@ def __init__(self, name, runs): _, runs = partition(lambda b: b.is_aggregate, runs) self.runs = sorted(runs, key=lambda b: b.value) unit = self.runs[0].unit + time_unit = self.runs[0].time_unit less_is_better = not unit.endswith("per_second") values = [b.value for b in self.runs] + times = [b.real_time for b in self.runs] # Slight kludge to extract the UserCounters for each benchmark self.counters = self.runs[0].counters - super().__init__(name, unit, less_is_better, values) + super().__init__(name, unit, less_is_better, values, time_unit, times) def __repr__(self): return "GoogleBenchmark[name={},runs={}]".format(self.names, self.runs) diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index 564a22a8987cd..74e2373821c01 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -25,10 +25,10 @@ import pathlib import sys +from .benchmark.codec import JsonEncoder from .benchmark.compare import RunnerComparator, DEFAULT_THRESHOLD from .benchmark.runner import BenchmarkRunner, CppBenchmarkRunner from .lang.cpp import CppCMakeDefinition, CppConfiguration -from .utils.codec import JsonEncoder from .utils.lint import linter, python_numpydoc, LintValidationException from .utils.logger import logger, ctx as log_ctx from .utils.source import ArrowSources, InvalidArrowSource diff --git a/dev/archery/tests/test_benchmarks.py b/dev/archery/tests/test_benchmarks.py index 0566805842a06..b763ea3c86fba 100644 --- a/dev/archery/tests/test_benchmarks.py +++ b/dev/archery/tests/test_benchmarks.py @@ -17,32 +17,36 @@ import json +from archery.benchmark.codec import JsonEncoder from archery.benchmark.core import Benchmark, median from archery.benchmark.compare import BenchmarkComparator from archery.benchmark.google import ( GoogleBenchmark, GoogleBenchmarkObservation ) -from archery.utils.codec import JsonEncoder def test_benchmark_comparator(): unit = "micros" assert not BenchmarkComparator( - Benchmark("contender", unit, True, [10]), - Benchmark("baseline", unit, True, [20])).regression + Benchmark("contender", unit, True, [10], unit, [1]), + Benchmark("baseline", unit, True, [20], unit, [1]), + ).regression assert BenchmarkComparator( - Benchmark("contender", unit, False, [10]), - Benchmark("baseline", unit, False, [20])).regression + Benchmark("contender", unit, False, [10], unit, [1]), + Benchmark("baseline", unit, False, [20], unit, [1]), + ).regression assert BenchmarkComparator( - Benchmark("contender", unit, True, [20]), - Benchmark("baseline", unit, True, [10])).regression + Benchmark("contender", unit, True, [20], unit, [1]), + Benchmark("baseline", unit, True, [10], unit, [1]), + ).regression assert not BenchmarkComparator( - Benchmark("contender", unit, False, [20]), - Benchmark("baseline", unit, False, [10])).regression + Benchmark("contender", unit, False, [20], unit, [1]), + Benchmark("baseline", unit, False, [10], unit, [1]), + ).regression def test_benchmark_median(): @@ -65,6 +69,123 @@ def assert_benchmark(name, google_result, archery_result): assert json.loads(result) == archery_result +def test_items_per_second(): + name = "ArrayArrayKernel/32768/0" + google_result = { + "cpu_time": 116292.58886653671, + "items_per_second": 281772039.9844759, + "iterations": 5964, + "name": name, + "null_percent": 0.0, + "real_time": 119811.77313729875, + "repetition_index": 0, + "repetitions": 0, + "run_name": name, + "run_type": "iteration", + "size": 32768.0, + "threads": 1, + "time_unit": "ns", + } + archery_result = { + "name": name, + "unit": "items_per_second", + "less_is_better": False, + "values": [281772039.9844759], + "time_unit": "ns", + "times": [119811.77313729875], + } + assert "items_per_second" in google_result + assert "bytes_per_second" not in google_result + assert_benchmark(name, google_result, archery_result) + + +def test_bytes_per_second(): + name = "BufferOutputStreamLargeWrites/real_time" + google_result = { + "bytes_per_second": 1890209037.3405428, + "cpu_time": 17018127.659574457, + "iterations": 47, + "name": name, + "real_time": 17458386.53190963, + "repetition_index": 1, + "repetitions": 0, + "run_name": name, + "run_type": "iteration", + "threads": 1, + "time_unit": "ns", + } + archery_result = { + "name": name, + "unit": "bytes_per_second", + "less_is_better": False, + "values": [1890209037.3405428], + "time_unit": "ns", + "times": [17458386.53190963], + } + assert "items_per_second" not in google_result + assert "bytes_per_second" in google_result + assert_benchmark(name, google_result, archery_result) + + +def test_both_items_and_bytes_per_second(): + name = "ArrayArrayKernel/32768/0" + google_result = { + "bytes_per_second": 281772039.9844759, + "cpu_time": 116292.58886653671, + "items_per_second": 281772039.9844759, + "iterations": 5964, + "name": name, + "null_percent": 0.0, + "real_time": 119811.77313729875, + "repetition_index": 0, + "repetitions": 0, + "run_name": name, + "run_type": "iteration", + "size": 32768.0, + "threads": 1, + "time_unit": "ns", + } + # Note that bytes_per_second trumps items_per_second + archery_result = { + "name": name, + "unit": "bytes_per_second", + "less_is_better": False, + "values": [281772039.9844759], + "time_unit": "ns", + "times": [119811.77313729875], + } + assert "items_per_second" in google_result + assert "bytes_per_second" in google_result + assert_benchmark(name, google_result, archery_result) + + +def test_neither_items_nor_bytes_per_second(): + name = "AllocateDeallocate/size:1048576/real_time" + google_result = { + "cpu_time": 1778.6004847419827, + "iterations": 352765, + "name": name, + "real_time": 1835.3137357788837, + "repetition_index": 0, + "repetitions": 0, + "run_name": name, + "run_type": "iteration", + "threads": 1, + "time_unit": "ns", + } + archery_result = { + "name": name, + "unit": "ns", + "less_is_better": True, + "values": [1835.3137357788837], + "time_unit": "ns", + "times": [1835.3137357788837], + } + assert "items_per_second" not in google_result + assert "bytes_per_second" not in google_result + assert_benchmark(name, google_result, archery_result) + + def test_prefer_real_time(): name = "AllocateDeallocate/size:1048576/real_time" google_result = { @@ -74,7 +195,7 @@ def test_prefer_real_time(): "real_time": 1835.3137357788837, "repetition_index": 0, "repetitions": 0, - "run_name": "AllocateDeallocate/size:1048576/real_time", + "run_name": name, "run_type": "iteration", "threads": 1, "time_unit": "ns", @@ -84,6 +205,8 @@ def test_prefer_real_time(): "unit": "ns", "less_is_better": True, "values": [1835.3137357788837], + "time_unit": "ns", + "times": [1835.3137357788837], } assert name.endswith("/real_time") assert_benchmark(name, google_result, archery_result) @@ -98,7 +221,7 @@ def test_prefer_cpu_time(): "real_time": 1835.3137357788837, "repetition_index": 0, "repetitions": 0, - "run_name": "AllocateDeallocate/size:1048576", + "run_name": name, "run_type": "iteration", "threads": 1, "time_unit": "ns", @@ -108,6 +231,8 @@ def test_prefer_cpu_time(): "unit": "ns", "less_is_better": True, "values": [1778.6004847419827], + "time_unit": "ns", + "times": [1835.3137357788837], } assert not name.endswith("/real_time") assert_benchmark(name, google_result, archery_result) @@ -122,7 +247,7 @@ def test_omits_aggregates(): "name": "AllocateDeallocate/size:1048576/real_time_mean", "real_time": 1849.3869337041162, "repetitions": 0, - "run_name": "AllocateDeallocate/size:1048576/real_time", + "run_name": name, "run_type": "aggregate", "threads": 1, "time_unit": "ns", @@ -134,7 +259,7 @@ def test_omits_aggregates(): "real_time": 1835.3137357788837, "repetition_index": 0, "repetitions": 0, - "run_name": "AllocateDeallocate/size:1048576/real_time", + "run_name": name, "run_type": "iteration", "threads": 1, "time_unit": "ns", @@ -144,6 +269,8 @@ def test_omits_aggregates(): "unit": "ns", "less_is_better": True, "values": [1835.3137357788837], + "time_unit": "ns", + "times": [1835.3137357788837], } assert google_aggregate["run_type"] == "aggregate" assert google_result["run_type"] == "iteration" From 6d703c4c7b15be630af48d5e9ef61628751674b2 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 25 Feb 2021 06:48:05 +0900 Subject: [PATCH 30/54] ARROW-11768: [CI][C++] Make s390x job required We can consider big-endian support in the C++ implementation stable (except for Parquet where it isn't supported). Closes #9563 from pitrou/ARROW-11768-s390x-required Authored-by: Antoine Pitrou Signed-off-by: Sutou Kouhei --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 57646246c4a54..2cf70cca982ff 100644 --- a/.travis.yml +++ b/.travis.yml @@ -125,7 +125,8 @@ jobs: JDK: 11 allow_failures: - - arch: s390x + - name: "Go on s390x" + - name: "Java on s390x" before_install: - eval "$(python ci/detect-changes.py)" From 4beb514d071c9beec69b8917b5265e77ade22fb3 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 24 Feb 2021 17:12:11 -0500 Subject: [PATCH 31/54] ARROW-11767: [C++] Scalar::Hash may segfault Closes #9562 from bkietz/11767-Scalarhash-may-segfault-f Authored-by: Benjamin Kietzman Signed-off-by: Benjamin Kietzman --- cpp/src/arrow/scalar.cc | 14 +++++++++++--- cpp/src/arrow/scalar_test.cc | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index ee4d0ecad8fea..399eac675f4df 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -90,9 +90,13 @@ struct ScalarHashImpl { return Status::OK(); } + Status Visit(const DictionaryScalar& s) { + AccumulateHashFrom(*s.value.index); + return Status::OK(); + } + // TODO(bkietz) implement less wimpy hashing when these have ValueType Status Visit(const UnionScalar& s) { return Status::OK(); } - Status Visit(const DictionaryScalar& s) { return Status::OK(); } Status Visit(const ExtensionScalar& s) { return Status::OK(); } template @@ -127,14 +131,18 @@ struct ScalarHashImpl { return Status::OK(); } - explicit ScalarHashImpl(const Scalar& scalar) { AccumulateHashFrom(scalar); } + explicit ScalarHashImpl(const Scalar& scalar) : hash_(scalar.type->Hash()) { + if (scalar.is_valid) { + AccumulateHashFrom(scalar); + } + } void AccumulateHashFrom(const Scalar& scalar) { DCHECK_OK(StdHash(scalar.type->fingerprint())); DCHECK_OK(VisitScalarInline(scalar, this)); } - size_t hash_ = 0; + size_t hash_; }; size_t Scalar::Hash::hash(const Scalar& scalar) { return ScalarHashImpl(scalar).hash_; } diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 16c2f92d13b30..d99debb2ba945 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -111,10 +111,12 @@ TYPED_TEST(TestNumericScalar, Hashing) { using ScalarType = typename TypeTraits::ScalarType; std::unordered_set, Scalar::Hash, Scalar::PtrsEqual> set; + set.emplace(std::make_shared()); for (T i = 0; i < 10; ++i) { set.emplace(std::make_shared(i)); } + ASSERT_FALSE(set.emplace(std::make_shared()).second); for (T i = 0; i < 10; ++i) { ASSERT_FALSE(set.emplace(std::make_shared(i)).second); } @@ -406,6 +408,23 @@ TEST(TestBinaryScalar, Basics) { ASSERT_FALSE(two->Equals(BinaryScalar(Buffer::FromString("else")))); } +TEST(TestBinaryScalar, Hashing) { + auto FromInt = [](int i) { + return std::make_shared(Buffer::FromString(std::to_string(i))); + }; + + std::unordered_set, Scalar::Hash, Scalar::PtrsEqual> set; + set.emplace(std::make_shared()); + for (int i = 0; i < 10; ++i) { + set.emplace(FromInt(i)); + } + + ASSERT_FALSE(set.emplace(std::make_shared()).second); + for (int i = 0; i < 10; ++i) { + ASSERT_FALSE(set.emplace(FromInt(i)).second); + } +} + TEST(TestStringScalar, MakeScalar) { auto three = MakeScalar("three"); ASSERT_EQ(StringScalar("three"), *three); From 02addad336ba19a654f9c857ede546331be7b631 Mon Sep 17 00:00:00 2001 From: Diana Clarke Date: Thu, 25 Feb 2021 10:02:51 +0900 Subject: [PATCH 32/54] ARROW-11771: [Developer][Archery] Move benchmark tests (so CI runs them) I recently added additional tests to: https://github.com/apache/arrow/blob/master/dev/archery/tests/test_benchmarks.py But just noticed that they aren't actually being executed by CI (because they aren't in the expected location): ``` - name: Archery Unittests working-directory: dev/archery run: pytest -v archery ``` Here you can see them running now. archery_tests Closes #9564 from dianaclarke/ARROW-11771 Authored-by: Diana Clarke Signed-off-by: Sutou Kouhei --- dev/archery/{ => archery}/tests/test_benchmarks.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dev/archery/{ => archery}/tests/test_benchmarks.py (100%) diff --git a/dev/archery/tests/test_benchmarks.py b/dev/archery/archery/tests/test_benchmarks.py similarity index 100% rename from dev/archery/tests/test_benchmarks.py rename to dev/archery/archery/tests/test_benchmarks.py From b5ac048c75cc55f4039d279f554920be3112d7cd Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Thu, 25 Feb 2021 07:01:47 +0200 Subject: [PATCH 33/54] ARROW-11627: [Rust] Make allocator be a generic over type T The background and rational for this is described [here](https://github.com/jorgecarleitao/arrow2/tree/proposal); the idea is that this is groundwork to make our buffers typed, so that we can start introducing strong typing in the crate. This change is backward incompatible: 1. Our allocator is now a generic over type `T: NativeType`, which implies that we can now allocate certain types. 2. The allocator moved from `memory` to a new module `alloc` (inspired after `std::alloc`). Necessary steps to migrate existing code: 1. `use arrow::memory` -> `use arrow::alloc` 2. `memory::allocate_aligned(...)` -> `alloc::allocate_aligned::(...)` Note how `NativeType` contains `to_le_bytes`; we will use this method for IPC, where we need to serialize buffers with a specific endianess. This is ground work to enable multiple endianesses support Closes #9495 from jorgecarleitao/alloc_t Authored-by: Jorge C. Leitao Signed-off-by: Neville Dipale --- rust/arrow/src/alloc/alignment.rs | 119 ++++++++++++ rust/arrow/src/alloc/mod.rs | 136 ++++++++++++++ rust/arrow/src/alloc/types.rs | 175 ++++++++++++++++++ rust/arrow/src/array/array_list.rs | 6 +- rust/arrow/src/array/raw_pointer.rs | 7 +- rust/arrow/src/buffer/immutable.rs | 18 +- rust/arrow/src/buffer/mutable.rs | 16 +- rust/arrow/src/bytes.rs | 4 +- rust/arrow/src/lib.rs | 2 +- rust/arrow/src/memory.rs | 277 ---------------------------- 10 files changed, 456 insertions(+), 304 deletions(-) create mode 100644 rust/arrow/src/alloc/alignment.rs create mode 100644 rust/arrow/src/alloc/mod.rs create mode 100644 rust/arrow/src/alloc/types.rs delete mode 100644 rust/arrow/src/memory.rs diff --git a/rust/arrow/src/alloc/alignment.rs b/rust/arrow/src/alloc/alignment.rs new file mode 100644 index 0000000000000..dbf4602f83af9 --- /dev/null +++ b/rust/arrow/src/alloc/alignment.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: Below code is written for spatial/temporal prefetcher optimizations. Memory allocation +// should align well with usage pattern of cache access and block sizes on layers of storage levels from +// registers to non-volatile memory. These alignments are all cache aware alignments incorporated +// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimicks Intel TBB's +// cache_aligned_allocator which exploits cache locality and minimizes prefetch signals +// resulting in less round trip time between the layers of storage. +// For further info: https://software.intel.com/en-us/node/506094 + +// 32-bit architecture and things other than netburst microarchitecture are using 64 bytes. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "x86")] +pub const ALIGNMENT: usize = 1 << 6; + +// Intel x86_64: +// L2D streamer from L1: +// Loads data or instructions from memory to the second-level cache. To use the streamer, +// organize the data or instructions in blocks of 128 bytes, aligned on 128 bytes. +// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "x86_64")] +pub const ALIGNMENT: usize = 1 << 7; + +// 24Kc: +// Data Line Size +// - https://s3-eu-west-1.amazonaws.com/downloads-mips/documents/MD00346-2B-24K-DTS-04.00.pdf +// - https://gitlab.e.foundation/e/devices/samsung/n7100/stable_android_kernel_samsung_smdk4412/commit/2dbac10263b2f3c561de68b4c369bc679352ccee +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "mips")] +pub const ALIGNMENT: usize = 1 << 5; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "mips64")] +pub const ALIGNMENT: usize = 1 << 5; + +// Defaults for powerpc +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "powerpc")] +pub const ALIGNMENT: usize = 1 << 5; + +// Defaults for the ppc 64 +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "powerpc64")] +pub const ALIGNMENT: usize = 1 << 6; + +// e.g.: sifive +// - https://github.com/torvalds/linux/blob/master/Documentation/devicetree/bindings/riscv/sifive-l2-cache.txt#L41 +// in general all of them are the same. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "riscv")] +pub const ALIGNMENT: usize = 1 << 6; + +// This size is same across all hardware for this architecture. +// - https://docs.huihoo.com/doxygen/linux/kernel/3.7/arch_2s390_2include_2asm_2cache_8h.html +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "s390x")] +pub const ALIGNMENT: usize = 1 << 8; + +// This size is same across all hardware for this architecture. +// - https://docs.huihoo.com/doxygen/linux/kernel/3.7/arch_2sparc_2include_2asm_2cache_8h.html#a9400cc2ba37e33279bdbc510a6311fb4 +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "sparc")] +pub const ALIGNMENT: usize = 1 << 5; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "sparc64")] +pub const ALIGNMENT: usize = 1 << 6; + +// On ARM cache line sizes are fixed. both v6 and v7. +// Need to add board specific or platform specific things later. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "thumbv6")] +pub const ALIGNMENT: usize = 1 << 5; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "thumbv7")] +pub const ALIGNMENT: usize = 1 << 5; + +// Operating Systems cache size determines this. +// Currently no way to determine this without runtime inference. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "wasm32")] +pub const ALIGNMENT: usize = 1 << 6; + +// Same as v6 and v7. +// List goes like that: +// Cortex A, M, R, ARM v7, v7-M, Krait and NeoverseN uses this size. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "arm")] +pub const ALIGNMENT: usize = 1 << 5; + +// Combined from 4 sectors. Volta says 128. +// Prevent chunk optimizations better to go to the default size. +// If you have smaller data with less padded functionality then use 32 with force option. +// - https://devtalk.nvidia.com/default/topic/803600/variable-cache-line-width-/ +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "nvptx")] +pub const ALIGNMENT: usize = 1 << 7; +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "nvptx64")] +pub const ALIGNMENT: usize = 1 << 7; + +// This size is same across all hardware for this architecture. +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "aarch64")] +pub const ALIGNMENT: usize = 1 << 6; diff --git a/rust/arrow/src/alloc/mod.rs b/rust/arrow/src/alloc/mod.rs new file mode 100644 index 0000000000000..a225d32dd82d4 --- /dev/null +++ b/rust/arrow/src/alloc/mod.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines memory-related functions, such as allocate/deallocate/reallocate memory +//! regions, cache and allocation alignments. + +use std::mem::size_of; +use std::ptr::NonNull; +use std::{ + alloc::{handle_alloc_error, Layout}, + sync::atomic::AtomicIsize, +}; + +mod alignment; +mod types; + +pub use alignment::ALIGNMENT; +pub use types::NativeType; + +// If this number is not zero after all objects have been `drop`, there is a memory leak +pub static mut ALLOCATIONS: AtomicIsize = AtomicIsize::new(0); + +#[inline] +unsafe fn null_pointer() -> NonNull { + NonNull::new_unchecked(ALIGNMENT as *mut T) +} + +/// Allocates a cache-aligned memory region of `size` bytes with uninitialized values. +/// This is more performant than using [allocate_aligned_zeroed] when all bytes will have +/// an unknown or non-zero value and is semantically similar to `malloc`. +pub fn allocate_aligned(size: usize) -> NonNull { + unsafe { + if size == 0 { + null_pointer() + } else { + let size = size * size_of::(); + ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); + + let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); + let raw_ptr = std::alloc::alloc(layout) as *mut T; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + } +} + +/// Allocates a cache-aligned memory region of `size` bytes with `0` on all of them. +/// This is more performant than using [allocate_aligned] and setting all bytes to zero +/// and is semantically similar to `calloc`. +pub fn allocate_aligned_zeroed(size: usize) -> NonNull { + unsafe { + if size == 0 { + null_pointer() + } else { + let size = size * size_of::(); + ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); + + let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); + let raw_ptr = std::alloc::alloc_zeroed(layout) as *mut T; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + } +} + +/// # Safety +/// +/// This function is unsafe because undefined behavior can result if the caller does not ensure all +/// of the following: +/// +/// * ptr must denote a block of memory currently allocated via this allocator, +/// +/// * size must be the same size that was used to allocate that block of memory, +pub unsafe fn free_aligned(ptr: NonNull, size: usize) { + if ptr != null_pointer() { + let size = size * size_of::(); + ALLOCATIONS.fetch_sub(size as isize, std::sync::atomic::Ordering::SeqCst); + std::alloc::dealloc( + ptr.as_ptr() as *mut u8, + Layout::from_size_align_unchecked(size, ALIGNMENT), + ); + } +} + +/// # Safety +/// +/// This function is unsafe because undefined behavior can result if the caller does not ensure all +/// of the following: +/// +/// * ptr must be currently allocated via this allocator, +/// +/// * new_size must be greater than zero. +/// +/// * new_size, when rounded up to the nearest multiple of [ALIGNMENT], must not overflow (i.e., +/// the rounded value must be less than usize::MAX). +pub unsafe fn reallocate( + ptr: NonNull, + old_size: usize, + new_size: usize, +) -> NonNull { + let old_size = old_size * size_of::(); + let new_size = new_size * size_of::(); + if ptr == null_pointer() { + return allocate_aligned(new_size); + } + + if new_size == 0 { + free_aligned(ptr, old_size); + return null_pointer(); + } + + ALLOCATIONS.fetch_add( + new_size as isize - old_size as isize, + std::sync::atomic::Ordering::SeqCst, + ); + let raw_ptr = std::alloc::realloc( + ptr.as_ptr() as *mut u8, + Layout::from_size_align_unchecked(old_size, ALIGNMENT), + new_size, + ) as *mut T; + NonNull::new(raw_ptr).unwrap_or_else(|| { + handle_alloc_error(Layout::from_size_align_unchecked(new_size, ALIGNMENT)) + }) +} diff --git a/rust/arrow/src/alloc/types.rs b/rust/arrow/src/alloc/types.rs new file mode 100644 index 0000000000000..0e177da7db8d5 --- /dev/null +++ b/rust/arrow/src/alloc/types.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::datatypes::DataType; + +/// A type that Rust's custom allocator knows how to allocate and deallocate. +/// This is implemented for all Arrow's physical types whose in-memory representation +/// matches Rust's physical types. Consider this trait sealed. +/// # Safety +/// Do not implement this trait. +pub unsafe trait NativeType: + Sized + Copy + std::fmt::Debug + std::fmt::Display + PartialEq + Default + Sized + 'static +{ + type Bytes: AsRef<[u8]>; + + /// Whether a DataType is a valid type for this physical representation. + fn is_valid(data_type: &DataType) -> bool; + + /// How this type represents itself as bytes in little endianess. + /// This is used for IPC, where data is communicated with a specific endianess. + fn to_le_bytes(&self) -> Self::Bytes; +} + +unsafe impl NativeType for u8 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt8 + } +} + +unsafe impl NativeType for u16 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt16 + } +} + +unsafe impl NativeType for u32 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt32 + } +} + +unsafe impl NativeType for u64 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::UInt64 + } +} + +unsafe impl NativeType for i8 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Int8 + } +} + +unsafe impl NativeType for i16 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Int16 + } +} + +unsafe impl NativeType for i32 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Int32 | DataType::Date32 | DataType::Time32(_) + ) + } +} + +unsafe impl NativeType for i64 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + ) + } +} + +unsafe impl NativeType for f32 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Float32 + } +} + +unsafe impl NativeType for f64 { + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn is_valid(data_type: &DataType) -> bool { + data_type == &DataType::Float64 + } +} diff --git a/rust/arrow/src/array/array_list.rs b/rust/arrow/src/array/array_list.rs index 8458836bfd6cc..f2076b3e86dfe 100644 --- a/rust/arrow/src/array/array_list.rs +++ b/rust/arrow/src/array/array_list.rs @@ -378,12 +378,12 @@ impl fmt::Debug for FixedSizeListArray { #[cfg(test)] mod tests { use crate::{ + alloc, array::ArrayData, array::Int32Array, buffer::Buffer, datatypes::Field, datatypes::{Int32Type, ToByteSlice}, - memory, util::bit_util, }; @@ -993,7 +993,7 @@ mod tests { #[test] #[should_panic(expected = "memory is not aligned")] fn test_primitive_array_alignment() { - let ptr = memory::allocate_aligned(8); + let ptr = alloc::allocate_aligned::(8); let buf = unsafe { Buffer::from_raw_parts(ptr, 8, 8) }; let buf2 = buf.slice(1); let array_data = ArrayData::builder(DataType::Int32).add_buffer(buf2).build(); @@ -1003,7 +1003,7 @@ mod tests { #[test] #[should_panic(expected = "memory is not aligned")] fn test_list_array_alignment() { - let ptr = memory::allocate_aligned(8); + let ptr = alloc::allocate_aligned::(8); let buf = unsafe { Buffer::from_raw_parts(ptr, 8, 8) }; let buf2 = buf.slice(1); diff --git a/rust/arrow/src/array/raw_pointer.rs b/rust/arrow/src/array/raw_pointer.rs index 897dc5b591c38..185e1cbe98a7e 100644 --- a/rust/arrow/src/array/raw_pointer.rs +++ b/rust/arrow/src/array/raw_pointer.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::memory; use std::ptr::NonNull; /// This struct is highly `unsafe` and offers the possibility to self-reference a [arrow::Buffer] from [arrow::array::ArrayData]. @@ -36,7 +35,11 @@ impl RawPtrBox { /// * `ptr` is not aligned to a slice of type `T`. This is guaranteed if it was built from a slice of type `T`. pub(super) unsafe fn new(ptr: *const u8) -> Self { let ptr = NonNull::new(ptr as *mut u8).expect("Pointer cannot be null"); - assert!(memory::is_ptr_aligned::(ptr), "memory is not aligned"); + assert_eq!( + ptr.as_ptr().align_offset(std::mem::align_of::()), + 0, + "memory is not aligned" + ); Self { ptr: ptr.cast() } } diff --git a/rust/arrow/src/buffer/immutable.rs b/rust/arrow/src/buffer/immutable.rs index e96bc003c8b5e..c09e4ddc48a1e 100644 --- a/rust/arrow/src/buffer/immutable.rs +++ b/rust/arrow/src/buffer/immutable.rs @@ -21,9 +21,7 @@ use std::ptr::NonNull; use std::sync::Arc; use std::{convert::AsRef, usize}; -use crate::memory; use crate::util::bit_chunk_iterator::BitChunks; -use crate::util::bit_util; use crate::{ bytes::{Bytes, Deallocation}, datatypes::ArrowNativeType, @@ -56,19 +54,11 @@ impl Buffer { /// Initializes a [Buffer] from a slice of items. pub fn from_slice_ref>(items: &T) -> Self { - // allocate aligned memory buffer let slice = items.as_ref(); - let len = slice.len() * std::mem::size_of::(); - let capacity = bit_util::round_upto_multiple_of_64(len); - let buffer = memory::allocate_aligned(capacity); - unsafe { - memory::memcpy( - buffer, - NonNull::new_unchecked(slice.as_ptr() as *mut u8), - len, - ); - Buffer::build_with_arguments(buffer, len, Deallocation::Native(capacity)) - } + let len = slice.len(); + let mut buffer = MutableBuffer::with_capacity(len); + buffer.extend_from_slice(slice); + buffer.into() } /// Creates a buffer from an existing memory region (must already be byte-aligned), this diff --git a/rust/arrow/src/buffer/mutable.rs b/rust/arrow/src/buffer/mutable.rs index 9f0238f9d99be..ddc0501f466f4 100644 --- a/rust/arrow/src/buffer/mutable.rs +++ b/rust/arrow/src/buffer/mutable.rs @@ -1,9 +1,9 @@ use std::ptr::NonNull; use crate::{ + alloc, bytes::{Bytes, Deallocation}, datatypes::{ArrowNativeType, ToByteSlice}, - memory, util::bit_util, }; @@ -53,8 +53,14 @@ impl MutableBuffer { /// Allocate a new [MutableBuffer] with initial capacity to be at least `capacity`. #[inline] pub fn new(capacity: usize) -> Self { + Self::with_capacity(capacity) + } + + /// Allocate a new [MutableBuffer] with initial capacity to be at least `capacity`. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { let capacity = bit_util::round_upto_multiple_of_64(capacity); - let ptr = memory::allocate_aligned(capacity); + let ptr = alloc::allocate_aligned(capacity); Self { data: ptr, len: 0, @@ -75,7 +81,7 @@ impl MutableBuffer { /// ``` pub fn from_len_zeroed(len: usize) -> Self { let new_capacity = bit_util::round_upto_multiple_of_64(len); - let ptr = memory::allocate_aligned_zeroed(new_capacity); + let ptr = alloc::allocate_aligned_zeroed(new_capacity); Self { data: ptr, len, @@ -324,7 +330,7 @@ unsafe fn reallocate( ) -> (NonNull, usize) { let new_capacity = bit_util::round_upto_multiple_of_64(new_capacity); let new_capacity = std::cmp::max(new_capacity, old_capacity * 2); - let ptr = memory::reallocate(ptr, old_capacity, new_capacity); + let ptr = alloc::reallocate(ptr, old_capacity, new_capacity); (ptr, new_capacity) } @@ -460,7 +466,7 @@ impl std::ops::DerefMut for MutableBuffer { impl Drop for MutableBuffer { fn drop(&mut self) { - unsafe { memory::free_aligned(self.data, self.capacity) }; + unsafe { alloc::free_aligned(self.data, self.capacity) }; } } diff --git a/rust/arrow/src/bytes.rs b/rust/arrow/src/bytes.rs index 323654954f802..38fa4439b42d4 100644 --- a/rust/arrow/src/bytes.rs +++ b/rust/arrow/src/bytes.rs @@ -24,7 +24,7 @@ use std::ptr::NonNull; use std::sync::Arc; use std::{fmt::Debug, fmt::Formatter}; -use crate::{ffi, memory}; +use crate::{alloc, ffi}; /// Mode of deallocating memory regions pub enum Deallocation { @@ -126,7 +126,7 @@ impl Drop for Bytes { fn drop(&mut self) { match &self.deallocation { Deallocation::Native(capacity) => { - unsafe { memory::free_aligned(self.ptr, *capacity) }; + unsafe { alloc::free_aligned::(self.ptr, *capacity) }; } // foreign interface knows how to deallocate itself. Deallocation::Foreign(_) => (), diff --git a/rust/arrow/src/lib.rs b/rust/arrow/src/lib.rs index c082d6136e24e..9c2ca2723ce73 100644 --- a/rust/arrow/src/lib.rs +++ b/rust/arrow/src/lib.rs @@ -135,6 +135,7 @@ // introduced to ignore lint errors when upgrading from 2020-04-22 to 2020-11-14 #![allow(clippy::float_equality_without_abs, clippy::type_complexity)] +mod alloc; mod arch; pub mod array; pub mod bitmap; @@ -147,7 +148,6 @@ pub mod error; pub mod ffi; pub mod ipc; pub mod json; -pub mod memory; pub mod record_batch; pub mod temporal_conversions; pub mod tensor; diff --git a/rust/arrow/src/memory.rs b/rust/arrow/src/memory.rs deleted file mode 100644 index 0ea8845decce8..0000000000000 --- a/rust/arrow/src/memory.rs +++ /dev/null @@ -1,277 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines memory-related functions, such as allocate/deallocate/reallocate memory -//! regions, cache and allocation alignments. - -use std::mem::align_of; -use std::ptr::NonNull; -use std::{ - alloc::{handle_alloc_error, Layout}, - sync::atomic::AtomicIsize, -}; - -// NOTE: Below code is written for spatial/temporal prefetcher optimizations. Memory allocation -// should align well with usage pattern of cache access and block sizes on layers of storage levels from -// registers to non-volatile memory. These alignments are all cache aware alignments incorporated -// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimicks Intel TBB's -// cache_aligned_allocator which exploits cache locality and minimizes prefetch signals -// resulting in less round trip time between the layers of storage. -// For further info: https://software.intel.com/en-us/node/506094 - -// 32-bit architecture and things other than netburst microarchitecture are using 64 bytes. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "x86")] -pub const ALIGNMENT: usize = 1 << 6; - -// Intel x86_64: -// L2D streamer from L1: -// Loads data or instructions from memory to the second-level cache. To use the streamer, -// organize the data or instructions in blocks of 128 bytes, aligned on 128 bytes. -// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "x86_64")] -pub const ALIGNMENT: usize = 1 << 7; - -// 24Kc: -// Data Line Size -// - https://s3-eu-west-1.amazonaws.com/downloads-mips/documents/MD00346-2B-24K-DTS-04.00.pdf -// - https://gitlab.e.foundation/e/devices/samsung/n7100/stable_android_kernel_samsung_smdk4412/commit/2dbac10263b2f3c561de68b4c369bc679352ccee -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "mips")] -pub const ALIGNMENT: usize = 1 << 5; -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "mips64")] -pub const ALIGNMENT: usize = 1 << 5; - -// Defaults for powerpc -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "powerpc")] -pub const ALIGNMENT: usize = 1 << 5; - -// Defaults for the ppc 64 -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "powerpc64")] -pub const ALIGNMENT: usize = 1 << 6; - -// e.g.: sifive -// - https://github.com/torvalds/linux/blob/master/Documentation/devicetree/bindings/riscv/sifive-l2-cache.txt#L41 -// in general all of them are the same. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "riscv")] -pub const ALIGNMENT: usize = 1 << 6; - -// This size is same across all hardware for this architecture. -// - https://docs.huihoo.com/doxygen/linux/kernel/3.7/arch_2s390_2include_2asm_2cache_8h.html -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "s390x")] -pub const ALIGNMENT: usize = 1 << 8; - -// This size is same across all hardware for this architecture. -// - https://docs.huihoo.com/doxygen/linux/kernel/3.7/arch_2sparc_2include_2asm_2cache_8h.html#a9400cc2ba37e33279bdbc510a6311fb4 -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "sparc")] -pub const ALIGNMENT: usize = 1 << 5; -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "sparc64")] -pub const ALIGNMENT: usize = 1 << 6; - -// On ARM cache line sizes are fixed. both v6 and v7. -// Need to add board specific or platform specific things later. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "thumbv6")] -pub const ALIGNMENT: usize = 1 << 5; -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "thumbv7")] -pub const ALIGNMENT: usize = 1 << 5; - -// Operating Systems cache size determines this. -// Currently no way to determine this without runtime inference. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "wasm32")] -pub const ALIGNMENT: usize = FALLBACK_ALIGNMENT; - -// Same as v6 and v7. -// List goes like that: -// Cortex A, M, R, ARM v7, v7-M, Krait and NeoverseN uses this size. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "arm")] -pub const ALIGNMENT: usize = 1 << 5; - -// Combined from 4 sectors. Volta says 128. -// Prevent chunk optimizations better to go to the default size. -// If you have smaller data with less padded functionality then use 32 with force option. -// - https://devtalk.nvidia.com/default/topic/803600/variable-cache-line-width-/ -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "nvptx")] -pub const ALIGNMENT: usize = 1 << 7; -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "nvptx64")] -pub const ALIGNMENT: usize = 1 << 7; - -// This size is same across all hardware for this architecture. -/// Cache and allocation multiple alignment size -#[cfg(target_arch = "aarch64")] -pub const ALIGNMENT: usize = 1 << 6; - -#[doc(hidden)] -/// Fallback cache and allocation multiple alignment size -const FALLBACK_ALIGNMENT: usize = 1 << 6; - -/// -/// As you can see this is global and lives as long as the program lives. -/// Be careful to not write anything to this pointer in any scenario. -/// If you use allocation methods shown here you won't have any problems. -const BYPASS_PTR: NonNull = unsafe { NonNull::new_unchecked(ALIGNMENT as *mut u8) }; - -// If this number is not zero after all objects have been `drop`, there is a memory leak -pub static mut ALLOCATIONS: AtomicIsize = AtomicIsize::new(0); - -/// Allocates a cache-aligned memory region of `size` bytes with uninitialized values. -/// This is more performant than using [allocate_aligned_zeroed] when all bytes will have -/// an unknown or non-zero value and is semantically similar to `malloc`. -pub fn allocate_aligned(size: usize) -> NonNull { - unsafe { - if size == 0 { - // In a perfect world, there is no need to request zero size allocation. - // Currently, passing zero sized layout to alloc is UB. - // This will dodge allocator api for any type. - BYPASS_PTR - } else { - ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); - - let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); - let raw_ptr = std::alloc::alloc(layout); - NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) - } - } -} - -/// Allocates a cache-aligned memory region of `size` bytes with `0u8` on all of them. -/// This is more performant than using [allocate_aligned] and setting all bytes to zero -/// and is semantically similar to `calloc`. -pub fn allocate_aligned_zeroed(size: usize) -> NonNull { - unsafe { - if size == 0 { - // In a perfect world, there is no need to request zero size allocation. - // Currently, passing zero sized layout to alloc is UB. - // This will dodge allocator api for any type. - BYPASS_PTR - } else { - ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); - - let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); - let raw_ptr = std::alloc::alloc_zeroed(layout); - NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) - } - } -} - -/// # Safety -/// -/// This function is unsafe because undefined behavior can result if the caller does not ensure all -/// of the following: -/// -/// * ptr must denote a block of memory currently allocated via this allocator, -/// -/// * size must be the same size that was used to allocate that block of memory, -pub unsafe fn free_aligned(ptr: NonNull, size: usize) { - if ptr != BYPASS_PTR { - ALLOCATIONS.fetch_sub(size as isize, std::sync::atomic::Ordering::SeqCst); - std::alloc::dealloc( - ptr.as_ptr(), - Layout::from_size_align_unchecked(size, ALIGNMENT), - ); - } -} - -/// # Safety -/// -/// This function is unsafe because undefined behavior can result if the caller does not ensure all -/// of the following: -/// -/// * ptr must be currently allocated via this allocator, -/// -/// * new_size must be greater than zero. -/// -/// * new_size, when rounded up to the nearest multiple of [ALIGNMENT], must not overflow (i.e., -/// the rounded value must be less than usize::MAX). -pub unsafe fn reallocate( - ptr: NonNull, - old_size: usize, - new_size: usize, -) -> NonNull { - if ptr == BYPASS_PTR { - return allocate_aligned(new_size); - } - - if new_size == 0 { - free_aligned(ptr, old_size); - return BYPASS_PTR; - } - - ALLOCATIONS.fetch_add( - new_size as isize - old_size as isize, - std::sync::atomic::Ordering::SeqCst, - ); - let raw_ptr = std::alloc::realloc( - ptr.as_ptr(), - Layout::from_size_align_unchecked(old_size, ALIGNMENT), - new_size, - ); - NonNull::new(raw_ptr).unwrap_or_else(|| { - handle_alloc_error(Layout::from_size_align_unchecked(new_size, ALIGNMENT)) - }) -} - -/// # Safety -/// -/// Behavior is undefined if any of the following conditions are violated: -/// -/// * `src` must be valid for reads of `len * size_of::()` bytes. -/// -/// * `dst` must be valid for writes of `len * size_of::()` bytes. -/// -/// * Both `src` and `dst` must be properly aligned. -/// -/// `memcpy` creates a bitwise copy of `T`, regardless of whether `T` is [`Copy`]. If `T` is not -/// [`Copy`], using both the values in the region beginning at `*src` and the region beginning at -/// `*dst` can [violate memory safety][read-ownership]. -pub unsafe fn memcpy(dst: NonNull, src: NonNull, count: usize) { - if src != BYPASS_PTR { - std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_ptr(), count) - } -} - -pub fn is_ptr_aligned(p: NonNull) -> bool { - p.as_ptr().align_offset(align_of::()) == 0 -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_allocate() { - for _ in 0..10 { - let p = allocate_aligned(1024); - // make sure this is 64-byte aligned - assert_eq!(0, (p.as_ptr() as usize) % 64); - unsafe { free_aligned(p, 1024) }; - } - } -} From 6c3f9f04be7c255be4c773a26c9cf2c23063c90d Mon Sep 17 00:00:00 2001 From: Mauricio Vargas Date: Thu, 25 Feb 2021 08:20:54 -0800 Subject: [PATCH 34/54] ARROW-11756: [R] passing a partition as a schema leads to segfaults Closes #9566 from pachamaltese/master Authored-by: Mauricio Vargas Signed-off-by: Neal Richardson --- r/R/dataset-factory.R | 1 + r/R/dataset.R | 1 + r/tests/testthat/test-dataset.R | 6 ++++++ 3 files changed, 8 insertions(+) diff --git a/r/R/dataset-factory.R b/r/R/dataset-factory.R index 30622b8a6d09d..a772be544b0d2 100644 --- a/r/R/dataset-factory.R +++ b/r/R/dataset-factory.R @@ -27,6 +27,7 @@ DatasetFactory <- R6Class("DatasetFactory", inherit = ArrowObject, if (is.null(schema)) { dataset___DatasetFactory__Finish1(self, unify_schemas) } else { + assert_is(schema, "Schema") dataset___DatasetFactory__Finish2(self, schema) } }, diff --git a/r/R/dataset.R b/r/R/dataset.R index e990ff3cb8636..3f7d117d6f62f 100644 --- a/r/R/dataset.R +++ b/r/R/dataset.R @@ -75,6 +75,7 @@ open_dataset <- function(sources, } } # Enforce that all datasets have the same schema + assert_is(schema, "Schema") sources <- lapply(sources, function(x) { x$schema <- schema x diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index e84eb12b08ada..2dbf9c5cbbb6a 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -194,6 +194,12 @@ test_that("Hive partitioning", { ) }) +test_that("input validation", { + expect_error( + open_dataset(hive_dir, hive_partition(other = utf8(), group = uint8())) + ) +}) + test_that("Partitioning inference", { # These are the same tests as above, just using the *PartitioningFactory ds1 <- open_dataset(dataset_dir, partitioning = "part") From 6b09bb6e9ad5bbe126638f6daeda7e24c52f948a Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 25 Feb 2021 13:19:14 -0500 Subject: [PATCH 35/54] ARROW-11779: [Rust] make alloc module public Polars uses the `arrow::memory` module. With the backwards incompatible change of #9495, the API is refactored to `arrow::alloc`. By making `alloc` public, users can shift to the new changes. Closes #9572 from ritchie46/make_alloc_public Authored-by: Ritchie Vink Signed-off-by: Andrew Lamb --- rust/arrow/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/arrow/src/lib.rs b/rust/arrow/src/lib.rs index 9c2ca2723ce73..68a820bfc54b2 100644 --- a/rust/arrow/src/lib.rs +++ b/rust/arrow/src/lib.rs @@ -135,7 +135,7 @@ // introduced to ignore lint errors when upgrading from 2020-04-22 to 2020-11-14 #![allow(clippy::float_equality_without_abs, clippy::type_complexity)] -mod alloc; +pub mod alloc; mod arch; pub mod array; pub mod bitmap; From e10d2ea9bcaec54a3788c63c027b101c2dfab173 Mon Sep 17 00:00:00 2001 From: Max Meldrum Date: Thu, 25 Feb 2021 13:21:26 -0500 Subject: [PATCH 36/54] ARROW-11777: [Rust] impl AsRef for StringBuilder/BinaryBuilder This patch adds impl AsRef<[u8]> to the append_value/append of StringBuilder and BinaryBuilder. ```rust pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<()> ``` Non-breaking change that will enable data to be passed as value as well. Closes #9570 from Max-Meldrum/as_ref_patch Authored-by: Max Meldrum Signed-off-by: Andrew Lamb --- rust/arrow/src/array/builder.rs | 24 +++++++++++--------- rust/datafusion/src/physical_plan/explain.rs | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs index e7519aacac354..6979a9887cafa 100644 --- a/rust/arrow/src/array/builder.rs +++ b/rust/arrow/src/array/builder.rs @@ -1094,8 +1094,8 @@ impl GenericBinaryBuilder { /// /// Automatically calls the `append` method to delimit the slice appended in as a /// distinct array element. - pub fn append_value(&mut self, value: &[u8]) -> Result<()> { - self.builder.values().append_slice(value)?; + pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<()> { + self.builder.values().append_slice(value.as_ref())?; self.builder.append(true)?; Ok(()) } @@ -1140,8 +1140,10 @@ impl GenericStringBuilder { /// /// Automatically calls the `append` method to delimit the string appended in as a /// distinct array element. - pub fn append_value(&mut self, value: &str) -> Result<()> { - self.builder.values().append_slice(value.as_bytes())?; + pub fn append_value(&mut self, value: impl AsRef) -> Result<()> { + self.builder + .values() + .append_slice(value.as_ref().as_bytes())?; self.builder.append(true)?; Ok(()) } @@ -1176,13 +1178,13 @@ impl FixedSizeBinaryBuilder { /// /// Automatically calls the `append` method to delimit the slice appended in as a /// distinct array element. - pub fn append_value(&mut self, value: &[u8]) -> Result<()> { - if self.builder.value_length() != value.len() as i32 { + pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<()> { + if self.builder.value_length() != value.as_ref().len() as i32 { return Err(ArrowError::InvalidArgumentError( "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths".to_string() )); } - self.builder.values().append_slice(value)?; + self.builder.values().append_slice(value.as_ref())?; self.builder.append(true) } @@ -1999,8 +2001,8 @@ where /// Append a primitive value to the array. Return an existing index /// if already present in the values array or a new index if the /// value is appended to the values array. - pub fn append(&mut self, value: &str) -> Result { - if let Some(&key) = self.map.get(value.as_bytes()) { + pub fn append(&mut self, value: impl AsRef) -> Result { + if let Some(&key) = self.map.get(value.as_ref().as_bytes()) { // Append existing value. self.keys_builder.append_value(key)?; Ok(key) @@ -2008,9 +2010,9 @@ where // Append new value. let key = K::Native::from_usize(self.values_builder.len()) .ok_or(ArrowError::DictionaryKeyOverflowError)?; - self.values_builder.append_value(value)?; + self.values_builder.append_value(value.as_ref())?; self.keys_builder.append_value(key as K::Native)?; - self.map.insert(value.as_bytes().into(), key); + self.map.insert(value.as_ref().as_bytes().into(), key); Ok(key) } } diff --git a/rust/datafusion/src/physical_plan/explain.rs b/rust/datafusion/src/physical_plan/explain.rs index 56535917dde4e..26d2c94dc80a4 100644 --- a/rust/datafusion/src/physical_plan/explain.rs +++ b/rust/datafusion/src/physical_plan/explain.rs @@ -106,7 +106,7 @@ impl ExecutionPlan for ExplainExec { for p in &self.stringified_plans { type_builder.append_value(&String::from(&p.plan_type))?; - plan_builder.append_value(&p.plan)?; + plan_builder.append_value(&*p.plan)?; } let record_batch = RecordBatch::try_new( From 0b020a13a9eab8b3c2a4e97720bebe88c06ab4e8 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 25 Feb 2021 10:31:08 -0800 Subject: [PATCH 37/54] ARROW-11683: [R] Support dplyr::mutate() First steps: * Rework `selected_columns` to hold field_refs instead of string column names; add code to back out the string field names where needed (e.g. dataset `Project()`) * Create an `array_ref` pseudo-function to do the same as `field_ref` for `array_expressions` * Add a `data` argument to `eval_array_expression` in order to bind `array_ref`s to the actual Arrays before evaluating * Refactor `filter()` NSE code for reuse in `mutate()` * Split up dplyr tests because we're going to be adding lots more Then: * Basic `mutate()` and `transmute()` (done in https://github.com/apache/arrow/pull/9521/commits/578d4929264858916b94e8dc632123dfb85816d2) * Go through the examples in the dplyr::mutate() docs and add tests for all cases. Where possible they're implemented in arrow fully; where we don't support the functions, it falls back to the current behavior of pulling the data into R first. Followup JIRAs: * ARROW-11704: Wire up dplyr::mutate() for datasets * ARROW-16999: Implement dplyr::across() and autosplicing * ARROW-11700: Internationalize error handling in tidy eval * ARROW-11701: Implement dplyr::relocate() * ARROW-11702: Enable ungrouped aggregations in non-Dataset expressions * ARROW-11658: Handle mutate/rename inside group_by * ARROW-11705: Support scalar value recycling in RecordBatch/Table$create() * ARROW-11754: Support dplyr::compute() * ARROW-11752: Replace usage of testthat::expect_is() * ARROW-11755: Add tests from dplyr/test-mutate.r * ARROW-11785: Fallback when filtering Table with if_any() expression fails Closes #9521 from nealrichardson/mutate Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/NEWS.md | 7 + r/R/arrow-package.R | 2 +- r/R/arrowExports.R | 4 + r/R/dataset-scan.R | 10 + r/R/dataset-write.R | 6 +- r/R/dplyr.R | 287 ++++++++++++++++----- r/R/expression.R | 43 +++- r/src/arrowExports.cpp | 9 + r/src/expression.cpp | 7 + r/tests/testthat/helper-expectation.R | 63 +++++ r/tests/testthat/test-RecordBatch.R | 8 + r/tests/testthat/test-dplyr-filter.R | 287 +++++++++++++++++++++ r/tests/testthat/test-dplyr-mutate.R | 350 +++++++++++++++++++++++++ r/tests/testthat/test-dplyr.R | 356 +------------------------- r/tests/testthat/test-expression.R | 12 + 15 files changed, 1030 insertions(+), 421 deletions(-) create mode 100644 r/tests/testthat/test-dplyr-filter.R create mode 100644 r/tests/testthat/test-dplyr-mutate.R diff --git a/r/NEWS.md b/r/NEWS.md index 65c4e2205cca3..a008088ff8205 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -19,6 +19,13 @@ # arrow 3.0.0.9000 +## dplyr methods + +* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in Arrow for many applications. Where not yet supported, the implementation falls back to pulling data into an R `data.frame` first. +* String functions `nchar()`, `tolower()`, and `toupper()`, along with their `stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are supported in Arrow `dplyr` calls. `str_trim()` is also supported. + +## Other improvements + * `value_counts()` to tabulate values in an `Array` or `ChunkedArray`, similar to `base::table()`. * `StructArray` objects gain data.frame-like methods, including `names()`, `$`, `[[`, and `dim()`. * RecordBatch columns can now be added, replaced, or removed by assigning (`<-`) with either `$` or `[[` diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 66694a9786730..818d85c85802f 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -30,7 +30,7 @@ "dplyr::", c( "select", "filter", "collect", "summarise", "group_by", "groups", - "group_vars", "ungroup", "mutate", "arrange", "rename", "pull" + "group_vars", "ungroup", "mutate", "transmute", "arrange", "rename", "pull" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 3d0f31ce8f366..790232c8e219d 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -744,6 +744,10 @@ dataset___expr__field_ref <- function(name){ .Call(`_arrow_dataset___expr__field_ref`, name) } +dataset___expr__get_field_ref_name <- function(ref){ + .Call(`_arrow_dataset___expr__get_field_ref_name`, ref) +} + dataset___expr__scalar <- function(x){ .Call(`_arrow_dataset___expr__scalar`, x) } diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 45fc968ed08c5..ec6f85c4bab56 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -69,6 +69,10 @@ Scanner$create <- function(dataset, batch_size = NULL, ...) { if (inherits(dataset, "arrow_dplyr_query")) { + if (inherits(dataset$.data, "ArrowTabular")) { + # To handle mutate() on Table/RecordBatch, we need to collect(as_data_frame=FALSE) now + dataset <- dplyr::collect(dataset, as_data_frame = FALSE) + } return(Scanner$create( dataset$.data, dataset$selected_columns, @@ -152,6 +156,12 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) { ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject, public = list( Project = function(cols) { + # cols is either a character vector or a named list of Expressions + if (!is.character(cols)) { + # We don't yet support mutate() on datasets, so this is just a list + # of FieldRefs, and we need to back out the field names + cols <- get_field_names(cols) + } assert_is(cols, "character") dataset___ScannerBuilder__Project(self, cols) self diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R index c5c9292671537..5078bc3e371a6 100644 --- a/r/R/dataset-write.R +++ b/r/R/dataset-write.R @@ -62,8 +62,12 @@ write_dataset <- function(dataset, hive_style = TRUE, ...) { if (inherits(dataset, "arrow_dplyr_query")) { + if (inherits(dataset$.data, "ArrowTabular")) { + # collect() to materialize any mutate/rename + dataset <- dplyr::collect(dataset, as_data_frame = FALSE) + } # We can select a subset of columns but we can't rename them - if (!all(dataset$selected_columns == names(dataset$selected_columns))) { + if (!all(get_field_names(dataset) == names(dataset$selected_columns))) { stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE) } # partitioning vars need to be in the `select` schema diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 32713741b5358..2bd8170a1cb33 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -33,11 +33,11 @@ arrow_dplyr_query <- function(.data) { structure( list( .data = .data$clone(), - # selected_columns is a named character vector: - # * vector contents are the names of the columns in the data - # * vector names are the names they should be in the end (i.e. this + # selected_columns is a named list: + # * contents are references/expressions pointing to the data + # * names are the names they should be in the end (i.e. this # records any renaming) - selected_columns = set_names(names(.data)), + selected_columns = make_field_refs(names(.data), dataset = inherits(.data, "Dataset")), # filtered_rows will be an Expression filtered_rows = TRUE, # group_by_vars is a character vector of columns (as renamed) @@ -51,8 +51,15 @@ arrow_dplyr_query <- function(.data) { #' @export print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema - cols <- x$selected_columns - fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString()) + cols <- get_field_names(x) + # If cols are expressions, they won't be in the schema and will be "" in cols + fields <- map_chr(cols, function(name) { + if (nzchar(name)) { + schm$GetFieldByName(name)$ToString() + } else { + "expr" + } + }) # Strip off the field names as they are in the dataset and add the renamed ones fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") cat(class(x$.data)[1], " (query)\n", sep = "") @@ -73,6 +80,33 @@ print.arrow_dplyr_query <- function(x, ...) { invisible(x) } +get_field_names <- function(selected_cols) { + if (inherits(selected_cols, "arrow_dplyr_query")) { + selected_cols <- selected_cols$selected_columns + } + map_chr(selected_cols, function(x) { + if (inherits(x, "Expression")) { + out <- x$field_name + } else if (inherits(x, "array_expression")) { + out <- x$args$field_name + } else { + out <- NULL + } + # If x isn't some kind of field reference, out is NULL, + # but we always need to return a string + out %||% "" + }) +} + +make_field_refs <- function(field_names, dataset = TRUE) { + if (dataset) { + out <- lapply(field_names, Expression$field_ref) + } else { + out <- lapply(field_names, function(x) array_expression("array_ref", field_name = x)) + } + set_names(out, field_names) +} + # These are the names reflecting all select/rename, not what is in Arrow #' @export names.arrow_dplyr_query <- function(x) names(x$selected_columns) @@ -89,7 +123,7 @@ dim.arrow_dplyr_query <- function(x) { rows <- NA_integer_ } else { # Evaluate the filter expression to a BooleanArray and count - rows <- as.integer(sum(eval_array_expression(x$filtered_rows), na.rm = TRUE)) + rows <- as.integer(sum(eval_array_expression(x$filtered_rows, x$.data), na.rm = TRUE)) } c(rows, cols) } @@ -187,29 +221,8 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { } .data <- arrow_dplyr_query(.data) - # The filter() method works by evaluating the filters to generate Expressions - # with references to Arrays (if .data is Table/RecordBatch) or Fields (if - # .data is a Dataset). - dm <- filter_mask(.data) - filters <- lapply(filts, function (f) { - # This should yield an Expression as long as the filter function(s) are - # implemented in Arrow. - tryCatch(eval_tidy(f, dm), error = function(e) { - # Look for the cases where bad input was given, i.e. this would fail - # in regular dplyr anyway, and let those raise those as errors; - # else, for things not supported by Arrow return a "try-error", - # which we'll handle differently - msg <- conditionMessage(e) - # TODO: internationalization? - if (grepl("object '.*'.not.found", msg)) { - stop(e) - } - if (grepl('could not find function ".*"', msg)) { - stop(e) - } - invisible(structure(msg, class = "try-error", condition = e)) - }) - }) + # tidy-eval the filter expressions inside an Arrow data_mask + filters <- lapply(filts, arrow_eval, arrow_mask(.data)) bad_filters <- map_lgl(filters, ~inherits(., "try-error")) if (any(bad_filters)) { bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE) @@ -238,6 +251,30 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { } filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query +arrow_eval <- function (expr, mask) { + # filter(), mutate(), etc. work by evaluating the quoted `exprs` to generate Expressions + # with references to Arrays (if .data is Table/RecordBatch) or Fields (if + # .data is a Dataset). + + # This yields an Expression as long as the `exprs` are implemented in Arrow. + # Otherwise, it returns a try-error + tryCatch(eval_tidy(expr, mask), error = function(e) { + # Look for the cases where bad input was given, i.e. this would fail + # in regular dplyr anyway, and let those raise those as errors; + # else, for things not supported by Arrow return a "try-error", + # which we'll handle differently + msg <- conditionMessage(e) + # TODO(ARROW-11700): internationalization + if (grepl("object '.*'.not.found", msg)) { + stop(e) + } + if (grepl('could not find function ".*"', msg)) { + stop(e) + } + invisible(structure(msg, class = "try-error", condition = e)) + }) +} + # Helper to assemble the functions that go in the NSE data mask # The only difference between the Dataset and the Table/RecordBatch versions # is that they use a different wrapping function (FUN) to hold the unevaluated @@ -271,23 +308,32 @@ build_function_list <- function(FUN) { dataset_function_list <- build_function_list(build_dataset_expression) array_function_list <- build_function_list(build_array_expression) -# Create a data mask for evaluating a filter expression -filter_mask <- function(.data) { +# Create a data mask for evaluating a dplyr expression +arrow_mask <- function(.data) { if (query_on_dataset(.data)) { f_env <- new_environment(dataset_function_list) - var_binder <- function(x) Expression$field_ref(x) } else { f_env <- new_environment(array_function_list) - var_binder <- function(x) .data$.data[[x]] } - # Add the column references - # Renaming is handled automatically by the named list - data_pronoun <- lapply(.data$selected_columns, var_binder) - env_bind(f_env, !!!data_pronoun) - # Then bind the data pronoun - env_bind(f_env, .data = data_pronoun) - new_data_mask(f_env) + # Add functions that need to error hard and clear. + # Some R functions will still try to evaluate on an Expression + # and return NA with a warning + fail <- function(...) stop("Not implemented") + for (f in c("mean")) { + f_env[[f]] <- fail + } + + # Add the column references and make the mask + out <- new_data_mask( + new_environment(.data$selected_columns, parent = f_env), + f_env + ) + # Then insert the data pronoun + # TODO: figure out what rlang::as_data_pronoun does/why we should use it + # (because if we do we get `Error: Can't modify the data pronoun` in mutate()) + out$.data <- .data$selected_columns + out } set_filters <- function(.data, expressions) { @@ -309,8 +355,27 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { # See dataset.R for Dataset and Scanner(Builder) classes tab <- Scanner$create(x)$ToTable() } else { - # This is a Table/RecordBatch. See record-batch.R for the [ method - tab <- x$.data[x$filtered_rows, x$selected_columns, keep_na = FALSE] + # This is a Table or RecordBatch + + # Filter and select the data referenced in selected columns + if (isTRUE(x$filtered_rows)) { + filter <- TRUE + } else { + filter <- eval_array_expression(x$filtered_rows, x$.data) + } + # TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))? + tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] + # Now evaluate those expressions on the filtered table + cols <- lapply(x$selected_columns, eval_array_expression, data = tab) + if (length(cols) == 0) { + tab <- tab[, integer(0)] + } else { + if (inherits(x$.data, "Table")) { + tab <- Table$create(!!!cols) + } else { + tab <- RecordBatch$create(!!!cols) + } + } } if (as_data_frame) { df <- as.data.frame(tab) @@ -327,7 +392,13 @@ ensure_group_vars <- function(x) { if (inherits(x, "arrow_dplyr_query")) { # Before pulling data from Arrow, make sure all group vars are in the projection gv <- set_names(setdiff(dplyr::group_vars(x), names(x))) - x$selected_columns <- c(x$selected_columns, gv) + if (length(gv)) { + # Add them back + x$selected_columns <- c( + x$selected_columns, + make_field_refs(gv, dataset = query_on_dataset(.data)) + ) + } } x } @@ -337,21 +408,20 @@ restore_dplyr_features <- function(df, query) { # After calling collect(), make sure these features are carried over grouped <- length(query$group_by_vars) > 0 - renamed <- !identical(names(df), names(query)) - if (is.data.frame(df)) { + renamed <- ncol(df) && !identical(names(df), names(query)) + if (renamed) { # In case variables were renamed, apply those names - if (renamed && ncol(df)) { - names(df) <- names(query) - } + names(df) <- names(query) + } + if (grouped) { # Preserve groupings, if present - if (grouped) { + if (is.data.frame(df)) { df <- dplyr::grouped_df(df, dplyr::group_vars(query)) + } else { + # This is a Table, via collect(as_data_frame = FALSE) + df <- arrow_dplyr_query(df) + df$group_by_vars <- query$group_by_vars } - } else if (grouped || renamed) { - # This is a Table, via collect(as_data_frame = FALSE) - df <- arrow_dplyr_query(df) - names(df$selected_columns) <- names(query) - df$group_by_vars <- query$group_by_vars } df } @@ -423,26 +493,117 @@ ungroup.arrow_dplyr_query <- function(x, ...) { } ungroup.Dataset <- ungroup.ArrowTabular <- force -mutate.arrow_dplyr_query <- function(.data, ...) { +mutate.arrow_dplyr_query <- function(.data, + ..., + .keep = c("all", "used", "unused", "none"), + .before = NULL, + .after = NULL) { + call <- match.call() + exprs <- quos(...) + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } + .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("mutate()") } - # TODO: see if we can defer evaluating the expressions and not collect here. - # It's different from filters (as currently implemented) because the basic - # vector transformation functions aren't yet implemented in Arrow C++. - dplyr::mutate(dplyr::collect(.data), ...) + + .keep <- match.arg(.keep) + .before <- enquo(.before) + .after <- enquo(.after) + # Restrict the cases we support for now + if (!quo_is_null(.before) || !quo_is_null(.after)) { + # TODO(ARROW-11701) + return(abandon_ship(call, .data, '.before and .after arguments are not supported in Arrow')) + } else if (length(group_vars(.data)) > 0) { + # mutate() on a grouped dataset does calculations within groups + # This doesn't matter on scalar ops (arithmetic etc.) but it does + # for things with aggregations (e.g. subtracting the mean) + return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) + } + + # Check for unnamed expressions and fix if any + unnamed <- !nzchar(names(exprs)) + # Deparse and take the first element in case they're long expressions + names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label) + + mask <- arrow_mask(.data) + results <- list() + for (i in seq_along(exprs)) { + # Iterate over the indices and not the names because names may be repeated + # (which overwrites the previous name) + new_var <- names(exprs)[i] + results[[new_var]] <- arrow_eval(exprs[[i]], mask) + if (inherits(results[[new_var]], "try-error")) { + msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') + return(abandon_ship(call, .data, msg)) + } + # Put it in the data mask too + mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] + } + + # Assign the new columns into the .data$selected_columns, respecting the .keep param + if (.keep == "none") { + .data$selected_columns <- results + } else { + if (.keep != "all") { + # "used" or "unused" + used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) + old_vars <- names(.data$selected_columns) + if (.keep == "used") { + .data$selected_columns <- .data$selected_columns[intersect(old_vars, used_vars)] + } else { + # "unused" + .data$selected_columns <- .data$selected_columns[setdiff(old_vars, used_vars)] + } + } + # Note that this is names(exprs) not names(results): + # if results$new_var is NULL, that means we are supposed to remove it + for (new_var in names(exprs)) { + .data$selected_columns[[new_var]] <- results[[new_var]] + } + } + # Even if "none", we still keep group vars + ensure_group_vars(.data) } mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query -# TODO: add transmute() that does what summarise() does (select only the vars we need) + +transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ..., .keep = "none") +transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query + +# Helper to handle unsupported dplyr features +# * For Table/RecordBatch, we collect() and then call the dplyr method in R +# * For Dataset, we just error +abandon_ship <- function(call, .data, msg = NULL) { + dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]])) + if (query_on_dataset(.data)) { + if (is.null(msg)) { + # Default message: function not implemented + not_implemented_for_dataset(paste0(dplyr_fun_name, "()")) + } else { + stop(msg, call. = FALSE) + } + } + + # else, collect and call dplyr method + if (!is.null(msg)) { + warning(msg, "; pulling data into R", immediate. = TRUE, call. = FALSE) + } + call$.data <- dplyr::collect(.data) + call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr")) + eval.parent(call, 2) +} arrange.arrow_dplyr_query <- function(.data, ...) { .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("arrange()") } - - dplyr::arrange(dplyr::collect(.data), ...) + # TODO(ARROW-11703) move this to Arrow + call <- match.call() + abandon_ship(call, .data) } arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query diff --git a/r/R/expression.R b/r/R/expression.R index 878b800c652e3..74c1aefcae1c5 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -143,7 +143,14 @@ cast_array_expression <- function(x, to_type, safe = TRUE, ...) { .array_function_map <- c(.unary_function_map, .binary_function_map) -eval_array_expression <- function(x) { +eval_array_expression <- function(x, data = NULL) { + if (!is.null(data)) { + x <- bind_array_refs(x, data) + } + if (!inherits(x, "array_expression")) { + # Nothing to evaluate + return(x) + } x$args <- lapply(x$args, function (a) { if (inherits(a, "array_expression")) { eval_array_expression(a) @@ -154,6 +161,27 @@ eval_array_expression <- function(x) { call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } +find_array_refs <- function(x) { + if (identical(x$fun, "array_ref")) { + out <- x$args$field_name + } else { + out <- lapply(x$args, find_array_refs) + } + unlist(out) +} + +# Take an array_expression and replace array_refs with arrays/chunkedarrays from data +bind_array_refs <- function(x, data) { + if (inherits(x, "array_expression")) { + if (identical(x$fun, "array_ref")) { + x <- data[[x$args$field_name]] + } else { + x$args <- lapply(x$args, bind_array_refs, data) + } + } + x +} + #' @export is.na.array_expression <- function(x) array_expression("is.na", x) @@ -181,9 +209,13 @@ print.array_expression <- function(x, ...) { deparse(arg) } }) - # Prune this for readability - function_name <- sub("_kleene", "", x$fun) - paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") + if (identical(x$fun, "array_ref")) { + x$args$field_name + } else { + # Prune this for readability + function_name <- sub("_kleene", "", x$fun) + paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") + } } ########### @@ -217,6 +249,9 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ) Expression$create("cast", self, options = modifyList(opts, list(...))) } + ), + active = list( + field_name = function() dataset___expr__get_field_ref_name(self) ) ) Expression$create <- function(function_name, diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 839c9d6c17310..73ee64844a6da 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1569,6 +1569,14 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +std::string dataset___expr__get_field_ref_name(const std::shared_ptr& ref); +extern "C" SEXP _arrow_dataset___expr__get_field_ref_name(SEXP ref_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type ref(ref_sexp); + return cpp11::as_sexp(dataset___expr__get_field_ref_name(ref)); +END_CPP11 +} +// expression.cpp std::shared_ptr dataset___expr__scalar(const std::shared_ptr& x); extern "C" SEXP _arrow_dataset___expr__scalar(SEXP x_sexp){ BEGIN_CPP11 @@ -3702,6 +3710,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_FixedSizeListType__list_size", (DL_FUNC) &_arrow_FixedSizeListType__list_size, 1}, { "_arrow_dataset___expr__call", (DL_FUNC) &_arrow_dataset___expr__call, 3}, { "_arrow_dataset___expr__field_ref", (DL_FUNC) &_arrow_dataset___expr__field_ref, 1}, + { "_arrow_dataset___expr__get_field_ref_name", (DL_FUNC) &_arrow_dataset___expr__get_field_ref_name, 1}, { "_arrow_dataset___expr__scalar", (DL_FUNC) &_arrow_dataset___expr__scalar, 1}, { "_arrow_dataset___expr__ToString", (DL_FUNC) &_arrow_dataset___expr__ToString, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, diff --git a/r/src/expression.cpp b/r/src/expression.cpp index ddb1e72c30956..76d8222967b76 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -47,6 +47,13 @@ std::shared_ptr dataset___expr__field_ref(std::string name) { return std::make_shared(ds::field_ref(std::move(name))); } +// [[arrow::export]] +std::string dataset___expr__get_field_ref_name( + const std::shared_ptr& ref) { + auto refname = ref->field_ref()->name(); + return *refname; +} + // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index ce0f9de8a54c0..76edea61f5797 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -59,3 +59,66 @@ verify_output <- function(...) { } testthat::verify_output(...) } + +expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + skip_record_batch = NULL, # Msg, if should skip RB test + skip_table = NULL, # Msg, if should skip Table test + ...) { + expr <- rlang::enquo(expr) + expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + + skip_msg <- NULL + + if (is.null(skip_record_batch)) { + via_batch <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ) + expect_equivalent(via_batch, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_record_batch) + } + + if (is.null(skip_table)) { + via_table <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ) + expect_equivalent(via_table, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) + } +} + +expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + ...) { + expr <- rlang::enquo(expr) + msg <- tryCatch( + rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), + error = function (e) conditionMessage(e) + ) + expect_is(msg, "character", label = "dplyr on data.frame did not error") + + expect_error( + rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ), + msg, + ... + ) + expect_error( + rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ), + msg, + ... + ) +} \ No newline at end of file diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R index aeee66d87107a..a017823ce34df 100644 --- a/r/tests/testthat/test-RecordBatch.R +++ b/r/tests/testthat/test-RecordBatch.R @@ -416,6 +416,14 @@ test_that("record_batch() handles null type (ARROW-7064)", { expect_equivalent(batch$schema, schema(a = int32(), n = null())) }) +test_that("record_batch() scalar recycling", { + skip("Not implemented (ARROW-11705)") + expect_data_frame( + record_batch(a = 1:10, b = 5), + tibble::tibble(a = 1:10, b = 5) + ) +}) + test_that("RecordBatch$Equals", { df <- tibble::tibble(x = 1:10, y = letters[1:10]) a <- record_batch(df) diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R new file mode 100644 index 0000000000000..f73589496be4a --- /dev/null +++ b/r/tests/testthat/test-dplyr-filter.R @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("filter() on is.na()", { + expect_dplyr_equal( + input %>% + filter(is.na(lgl)) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("filter() with NAs in selection", { + expect_dplyr_equal( + input %>% + filter(lgl) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("Filter returning an empty Table should not segfault (ARROW-8354)", { + expect_dplyr_equal( + input %>% + filter(false) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression", { + char_sym <- "b" + expect_dplyr_equal( + input %>% + filter(chr == char_sym) %>% + select(string = chr, int) %>% + collect(), + tbl + ) +}) + +test_that("filtering with arithmetic", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl %/% 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression + autocasting", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("More complex select/filter", { + expect_dplyr_equal( + input %>% + filter(dbl > 2, chr == "d" | chr == "f") %>% + select(chr, int, lgl) %>% + filter(int < 5) %>% + select(int, chr) %>% + collect(), + tbl + ) +}) + +test_that("filter() with %in%", { + expect_dplyr_equal( + input %>% + filter(dbl > 2, chr %in% c("d", "f")) %>% + collect(), + tbl + ) +}) + +test_that("filter() with string ops", { + # Extra instrumentation to ensure that we're calling Arrow compute here + # because many base R string functions implicitly call as.character, + # which means they still work on Arrays but actually force data into R + # 1) wrapper that raises a warning if as.character is called. Can't wrap + # the whole test because as.character apparently gets called in other + # (presumably legitimate) places + # 2) Wrap the test in expect_warning(expr, NA) to catch the warning + + with_no_as_character <- function(expr) { + trace( + "as.character", + tracer = quote(warning("as.character was called")), + print = FALSE, + where = toupper + ) + on.exit(untrace("as.character", where = toupper)) + force(expr) + } + + expect_warning( + expect_dplyr_equal( + input %>% + filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% + collect(), + tbl + ), + NA) + + expect_dplyr_equal( + input %>% + filter(dbl > 2, str_length(verses) > 25) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>% + collect(), + tbl + ) +}) + +test_that("filter environment scope", { + # "object 'b_var' not found" + expect_dplyr_error(input %>% filter(batch, chr == b_var)) + + b_var <- "b" + expect_dplyr_equal( + input %>% + filter(chr == b_var) %>% + collect(), + tbl + ) + # Also for functions + # 'could not find function "isEqualTo"' because we haven't defined it yet + expect_dplyr_error(filter(batch, isEqualTo(int, 4))) + + skip("Need to substitute in user defined function too") + # TODO: fix this: this isEqualTo function is eagerly evaluating; it should + # instead yield array_expressions. Probably bc the parent env of the function + # has the Ops.Array methods defined; we need to move it so that the parent + # env is the data mask we use in the dplyr eval + isEqualTo <- function(x, y) x == y & !is.na(x) + expect_dplyr_equal( + input %>% + select(-fct) %>% # factor levels aren't identical + filter(isEqualTo(int, 4)) %>% + collect(), + tbl + ) +}) + +test_that("Filtering on a column that doesn't exist errors correctly", { + skip("Error handling in arrow_eval() needs to be internationalized (ARROW-11700)") + expect_error( + batch %>% filter(not_a_col == 42) %>% collect(), + "object 'not_a_col' not found" + ) +}) + +test_that("Filtering with a function that doesn't have an Array/expr method still works", { + expect_warning( + expect_dplyr_equal( + input %>% + filter(int > 2, pnorm(dbl) > .99) %>% + collect(), + tbl + ), + 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', + fixed = TRUE + ) +}) + +test_that("filter() with .data pronoun", { + expect_dplyr_equal( + input %>% + filter(.data$dbl > 4) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(is.na(.data$lgl)) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + # and the .env pronoun too! + chr <- 4 + expect_dplyr_equal( + input %>% + filter(.data$dbl > .env$chr) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + # but there is an error if we don't override the masking with `.env` + expect_dplyr_error( + tbl %>% + filter(.data$dbl > chr) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect() + ) +}) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R new file mode 100644 index 0000000000000..56d7e368520b4 --- /dev/null +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -0,0 +1,350 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("mutate() is lazy", { + expect_is( + tbl %>% record_batch() %>% mutate(int = int + 6L), + "arrow_dplyr_query" + ) +}) + +test_that("basic mutate", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + mutate(int = int + 6L) %>% + collect(), + tbl + ) +}) + +test_that("transmute", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + transmute(int = int + 6L) %>% + collect(), + tbl + ) +}) + +test_that("mutate and refer to previous mutants", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + line_lengths = nchar(padded_strings), + longer = line_lengths * 10 + ) %>% + filter(line_lengths > 15) %>% + collect(), + tbl + ) +}) + +test_that("mutate with .data pronoun", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + line_lengths = nchar(padded_strings), + longer = .data$line_lengths * 10 + ) %>% + filter(line_lengths > 15) %>% + collect(), + tbl + ) +}) + +test_that("mutate with unnamed expressions", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + int, # bare column name + nchar(padded_strings) # expression + ) %>% + filter(int > 5) %>% + collect(), + tbl + ) +}) + +test_that("mutate with reassigning same name", { + expect_dplyr_equal( + input %>% + transmute( + new = lgl, + new = chr + ) %>% + collect(), + tbl + ) +}) + +test_that("mutate with single value for recycling", { + skip("Not implemented (ARROW-11705") + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + dr_bronner = 1 # ALL ONE! + ) %>% + collect(), + tbl + ) +}) + +test_that("dplyr::mutate's examples", { + # Newly created variables are available immediately + expect_dplyr_equal( + input %>% + select(name, mass) %>% + mutate( + mass2 = mass * 2, + mass2_squared = mass2 * mass2 + ) %>% + collect(), + starwars # this is a test dataset that ships with dplyr + ) + + # As well as adding new variables, you can use mutate() to + # remove variables and modify existing variables. + expect_dplyr_equal( + input %>% + select(name, height, mass, homeworld) %>% + mutate( + mass = NULL, + height = height * 0.0328084 # convert to feet + ) %>% + collect(), + starwars + ) + + # Examples we don't support should succeed + # but warn that they're pulling data into R to do so + + # across + autosplicing: ARROW-11699 + expect_warning( + expect_dplyr_equal( + input %>% + select(name, homeworld, species) %>% + mutate(across(!name, as.factor)) %>% + collect(), + starwars + ), + "Expression across.*not supported in Arrow" + ) + + # group_by then mutate + expect_warning( + expect_dplyr_equal( + input %>% + select(name, mass, homeworld) %>% + group_by(homeworld) %>% + mutate(rank = min_rank(desc(mass))) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) + + # `.before` and `.after` experimental args: ARROW-11701 + df <- tibble(x = 1, y = 2) + expect_dplyr_equal( + input %>% mutate(z = x + y) %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> x y z + #> + #> 1 1 2 3 + expect_warning( + expect_dplyr_equal( + input %>% mutate(z = x + y, .before = 1) %>% collect(), + df + ), + "not supported in Arrow" + ) + #> # A tibble: 1 x 3 + #> z x y + #> + #> 1 3 1 2 + expect_warning( + expect_dplyr_equal( + input %>% mutate(z = x + y, .after = x) %>% collect(), + df + ), + "not supported in Arrow" + ) + #> # A tibble: 1 x 3 + #> x z y + #> + #> 1 1 3 2 + + # By default, mutate() keeps all columns from the input data. + # Experimental: You can override with `.keep` + df <- tibble(x = 1, y = 2, a = "a", b = "b") + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "all") %>% collect(), # the default + df + ) + #> # A tibble: 1 x 5 + #> x y a b z + #> + #> 1 1 2 a b 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "used") %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> x y z + #> + #> 1 1 2 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "unused") %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> a b z + #> + #> 1 a b 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "none") %>% collect(), # same as transmute() + df + ) + #> # A tibble: 1 x 1 + #> z + #> + #> 1 3 + + # Grouping ---------------------------------------- + # The mutate operation may yield different results on grouped + # tibbles because the expressions are computed within groups. + # The following normalises `mass` by the global average: + # TODO(ARROW-11702) + expect_warning( + expect_dplyr_equal( + input %>% + select(name, mass, species) %>% + mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) +}) + +test_that("handle bad expressions", { + # TODO: search for functions other than mean() (see above test) + # that need to be forced to fail because they error ambiguously + + skip("Error handling in arrow_eval() needs to be internationalized (ARROW-11700)") + expect_error( + Table$create(tbl) %>% mutate(newvar = NOTAVAR + 2), + "object 'NOTAVAR' not found" + ) +}) + +test_that("print a mutated dataset", { + expect_output( + Table$create(tbl) %>% + select(int) %>% + mutate(twice = int * 2) %>% + print(), +'Table (query) +int: int32 +twice: expr + +See $.data for the source Arrow object', + fixed = TRUE) + + # Handling non-expressions/edge cases + expect_output( + Table$create(tbl) %>% + select(int) %>% + mutate(again = 1:10) %>% + print(), +'Table (query) +int: int32 +again: expr + +See $.data for the source Arrow object', + fixed = TRUE) +}) + +test_that("mutate and write_dataset", { + # See related test in test-dataset.R + + skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651 + + first_date <- lubridate::ymd_hms("2015-04-29 03:12:39") + df1 <- tibble( + int = 1:10, + dbl = as.numeric(1:10), + lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2), + chr = letters[1:10], + fct = factor(LETTERS[1:10]), + ts = first_date + lubridate::days(1:10) + ) + + second_date <- lubridate::ymd_hms("2017-03-09 07:01:02") + df2 <- tibble( + int = 101:110, + dbl = c(as.numeric(51:59), NaN), + lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2), + chr = letters[10:1], + fct = factor(LETTERS[10:1]), + ts = second_date + lubridate::days(10:1) + ) + + dst_dir <- tempfile() + stacked <- record_batch(rbind(df1, df2)) + stacked %>% + mutate(twice = int * 2) %>% + group_by(int) %>% + write_dataset(dst_dir, format = "feather") + expect_true(dir.exists(dst_dir)) + expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "="))) + + new_ds <- open_dataset(dst_dir, format = "feather") + + expect_equivalent( + new_ds %>% + select(string = chr, integer = int, twice) %>% + filter(integer > 6 & integer < 11) %>% + collect() %>% + summarize(mean = mean(integer)), + df1 %>% + select(string = chr, integer = int) %>% + mutate(twice = integer * 2) %>% + filter(integer > 6) %>% + summarize(mean = mean(integer)) + ) +}) \ No newline at end of file diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 6d9945a115a45..13610f1c6f1b1 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -15,74 +15,9 @@ # specific language governing permissions and limitations # under the License. -context("dplyr verbs") - library(dplyr) library(stringr) -expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start - tbl, # A tbl/df as reference, will make RB/Table with - skip_record_batch = NULL, # Msg, if should skip RB test - skip_table = NULL, # Msg, if should skip Table test - ...) { - expr <- rlang::enquo(expr) - expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) - - skip_msg <- NULL - - if (is.null(skip_record_batch)) { - via_batch <- rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = record_batch(tbl))) - ) - expect_equivalent(via_batch, expected, ...) - } else { - skip_msg <- c(skip_msg, skip_record_batch) - } - - if (is.null(skip_table)) { - via_table <- rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = Table$create(tbl))) - ) - expect_equivalent(via_table, expected, ...) - } else { - skip_msg <- c(skip_msg, skip_table) - } - - if (!is.null(skip_msg)) { - skip(paste(skip_msg, collpase = "\n")) - } -} - -expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start - tbl, # A tbl/df as reference, will make RB/Table with - ...) { - expr <- rlang::enquo(expr) - msg <- tryCatch( - rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), - error = function (e) conditionMessage(e) - ) - expect_is(msg, "character", label = "dplyr on data.frame did not error") - - expect_error( - rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = record_batch(tbl))) - ), - msg, - ... - ) - expect_error( - rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = Table$create(tbl))) - ), - msg, - ... - ) -} - tbl <- example_data # Add some better string data tbl$verses <- verses[[1]] @@ -104,127 +39,6 @@ test_that("basic select/filter/collect", { expect_identical(collect(batch), tbl) }) -test_that("filter() on is.na()", { - expect_dplyr_equal( - input %>% - filter(is.na(lgl)) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("filter() with NAs in selection", { - expect_dplyr_equal( - input %>% - filter(lgl) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("Filter returning an empty Table should not segfault (ARROW-8354)", { - expect_dplyr_equal( - input %>% - filter(false) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("filtering with expression", { - char_sym <- "b" - expect_dplyr_equal( - input %>% - filter(chr == char_sym) %>% - select(string = chr, int) %>% - collect(), - tbl - ) -}) - -test_that("filtering with arithmetic", { - expect_dplyr_equal( - input %>% - filter(dbl + 1 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl / 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl / 2L > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int / 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int / 2L > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl %/% 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) -}) - -test_that("filtering with expression + autocasting", { - expect_dplyr_equal( - input %>% - filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int + 1 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) -}) - -test_that("More complex select/filter", { - expect_dplyr_equal( - input %>% - filter(dbl > 2, chr == "d" | chr == "f") %>% - select(chr, int, lgl) %>% - filter(int < 5) %>% - select(int, chr) %>% - collect(), - tbl - ) -}) - test_that("dim() on query", { expect_dplyr_equal( input %>% @@ -247,151 +61,12 @@ test_that("Print method", { int: int32 chr: string -* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5)) +* Filter: and(and(greater(dbl, 2), or(equal(chr, "d"), equal(chr, "f"))), less(int, 5)) See $.data for the source Arrow object', fixed = TRUE ) }) -test_that("filter() with %in%", { - expect_dplyr_equal( - input %>% - filter(dbl > 2, chr %in% c("d", "f")) %>% - collect(), - tbl - ) -}) - -test_that("filter() with string ops", { - # Extra instrumentation to ensure that we're calling Arrow compute here - # because many base R string functions implicitly call as.character, - # which means they still work on Arrays but actually force data into R - # 1) wrapper that raises a warning if as.character is called. Can't wrap - # the whole test because as.character apparently gets called in other - # (presumably legitimate) places - # 2) Wrap the test in expect_warning(expr, NA) to catch the warning - - with_no_as_character <- function(expr) { - trace( - "as.character", - tracer = quote(warning("as.character was called")), - print = FALSE, - where = toupper - ) - on.exit(untrace("as.character", where = toupper)) - force(expr) - } - - expect_warning( - expect_dplyr_equal( - input %>% - filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% - collect(), - tbl - ), - NA) - - expect_dplyr_equal( - input %>% - filter(dbl > 2, str_length(verses) > 25) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>% - collect(), - tbl - ) -}) - -test_that("filter environment scope", { - # "object 'b_var' not found" - expect_dplyr_error(input %>% filter(batch, chr == b_var)) - - b_var <- "b" - expect_dplyr_equal( - input %>% - filter(chr == b_var) %>% - collect(), - tbl - ) - # Also for functions - # 'could not find function "isEqualTo"' - expect_dplyr_error(filter(batch, isEqualTo(int, 4))) - - # TODO: fix this: this isEqualTo function is eagerly evaluating; it should - # instead yield array_expressions. Probably bc the parent env of the function - # has the Ops.Array methods defined; we need to move it so that the parent - # env is the data mask we use in the dplyr eval - isEqualTo <- function(x, y) x == y & !is.na(x) - expect_dplyr_equal( - input %>% - select(-fct) %>% # factor levels aren't identical - filter(isEqualTo(int, 4)) %>% - collect(), - tbl - ) -}) - -test_that("Filtering on a column that doesn't exist errors correctly", { - skip("Error handling in filter() needs to be internationalized") - expect_error( - batch %>% filter(not_a_col == 42) %>% collect(), - "object 'not_a_col' not found" - ) -}) - -test_that("Filtering with a function that doesn't have an Array/expr method still works", { - expect_warning( - expect_dplyr_equal( - input %>% - filter(int > 2, pnorm(dbl) > .99) %>% - collect(), - tbl - ), - 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', - fixed = TRUE - ) -}) - -test_that("filter() with .data pronoun", { - expect_dplyr_equal( - input %>% - filter(.data$dbl > 4) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(is.na(.data$lgl)) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - # and the .env pronoun too! - chr <- 4 - expect_dplyr_equal( - input %>% - filter(.data$dbl > .env$chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - # but there is an error if we don't override the masking with `.env` - expect_dplyr_error( - tbl %>% - filter(.data$dbl > chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect() - ) -}) - test_that("summarize", { expect_dplyr_equal( input %>% @@ -410,29 +85,6 @@ test_that("summarize", { ) }) -test_that("mutate", { - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - mutate(int = int + 6L) %>% - summarize(min_int = min(int)), - tbl - ) -}) - -test_that("transmute", { - skip("TODO: reimplement transmute (with dplyr 1.0, it no longer just works via mutate)") - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - transmute(int = int + 6L) %>% - summarize(min_int = min(int)), - tbl - ) -}) - test_that("group_by groupings are recorded", { expect_dplyr_equal( input %>% @@ -599,7 +251,7 @@ test_that("collect(as_data_frame=FALSE)", { select(int, strng = chr) %>% filter(int > 5) %>% collect(as_data_frame = FALSE) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -632,7 +284,7 @@ test_that("head", { select(int, strng = chr) %>% filter(int > 5) %>% head(2) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -665,7 +317,7 @@ test_that("tail", { select(int, strng = chr) %>% filter(int > 5) %>% tail(2) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 3c100812ff19e..3df7270f4c57a 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -34,8 +34,20 @@ test_that("array_expression print method", { ) }) +test_that("array_refs", { + tab <- Table$create(a = 1:5) + ex <- build_array_expression(">", array_expression("array_ref", field_name = "a"), 4) + expect_is(ex, "array_expression") + expect_identical(ex$args[[1]]$args$field_name, "a") + expect_identical(find_array_refs(ex), "a") + out <- eval_array_expression(ex, tab) + expect_is(out, "ChunkedArray") + expect_equal(as.vector(out), c(FALSE, FALSE, FALSE, FALSE, TRUE)) +}) + test_that("C++ expressions", { f <- Expression$field_ref("f") + expect_identical(f$field_name, "f") g <- Expression$field_ref("g") date <- Expression$scalar(as.Date("2020-01-15")) ts <- Expression$scalar(as.POSIXct("2020-01-17 11:11:11")) From d7da16e57258a5fa74cd5e7534a93010058b9a3c Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 25 Feb 2021 20:06:31 +0100 Subject: [PATCH 38/54] ARROW-11695: [C++][FlightRPC] fix option to disable TLS verification gRPC 1.34 and 1.36 both change up the API, so we have to detect both of those versions. The CMake config and C++ code was also refactored a bit so that there's less to copy-paste for each gRPC version change. Closes #9569 from lidavidm/arrow-11695-detection Authored-by: David Li Signed-off-by: Uwe L. Korn --- cpp/src/arrow/flight/CMakeLists.txt | 87 ++++++++++++------- cpp/src/arrow/flight/client.cc | 38 ++++++-- .../flight/try_compile/check_tls_opts_134.cc | 44 ++++++++++ .../flight/try_compile/check_tls_opts_136.cc | 38 ++++++++ 4 files changed, 166 insertions(+), 41 deletions(-) create mode 100644 cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc create mode 100644 cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 2fcb6ef077d10..b44bab290746c 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -73,44 +73,67 @@ string(REPLACE "-Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # Probe the version of gRPC being used to see if it supports disabling server # verification when using TLS. -if(NOT DEFINED HAS_GRPC_132) - message(STATUS "Checking support for TlsCredentialsOptions...") - get_property(CURRENT_INCLUDE_DIRECTORIES - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - PROPERTY INCLUDE_DIRECTORIES) - try_compile(HAS_GRPC_132 ${CMAKE_CURRENT_BINARY_DIR}/try_compile SOURCES - "${CMAKE_CURRENT_SOURCE_DIR}/try_compile/check_tls_opts_132.cc" - CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CURRENT_INCLUDE_DIRECTORIES}" - LINK_LIBRARIES gRPC::grpc - OUTPUT_VARIABLE TSL_CREDENTIALS_OPTIONS_CHECK_OUTPUT CXX_STANDARD 11) - - if(HAS_GRPC_132) - message(STATUS "TlsCredentialsOptions found in grpc::experimental.") - add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental) - else() - message(STATUS "TlsCredentialsOptions not found in grpc::experimental.") - message(DEBUG "Build output:") - list(APPEND CMAKE_MESSAGE_INDENT "check_tls_opts_132.cc: ") - message(DEBUG ${TSL_CREDENTIALS_OPTIONS_CHECK_OUTPUT}) - list(REMOVE_AT CMAKE_MESSAGE_INDENT -1) - - try_compile(HAS_GRPC_127 ${CMAKE_CURRENT_BINARY_DIR}/try_compile SOURCES - "${CMAKE_CURRENT_SOURCE_DIR}/try_compile/check_tls_opts_127.cc" +function(test_grpc_version DST_VAR DETECT_VERSION TEST_FILE) + if(NOT DEFINED ${DST_VAR}) + message( + STATUS "Checking support for TlsCredentialsOptions (gRPC >= ${DETECT_VERSION})...") + get_property(CURRENT_INCLUDE_DIRECTORIES + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + PROPERTY INCLUDE_DIRECTORIES) + try_compile(HAS_GRPC_VERSION ${CMAKE_CURRENT_BINARY_DIR}/try_compile SOURCES + "${CMAKE_CURRENT_SOURCE_DIR}/try_compile/${TEST_FILE}" CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CURRENT_INCLUDE_DIRECTORIES}" - OUTPUT_VARIABLE TSL_CREDENTIALS_OPTIONS_CHECK_OUTPUT CXX_STANDARD 11) - - if(HAS_GRPC_127) - message(STATUS "TlsCredentialsOptions found in grpc_impl::experimental.") - add_definitions( - -DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc_impl::experimental) + LINK_LIBRARIES gRPC::grpc gRPC::grpc++ + OUTPUT_VARIABLE TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT CXX_STANDARD 11) + message(STATUS "${TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT}") + if(HAS_GRPC_VERSION) + set(${DST_VAR} "${DETECT_VERSION}" PARENT_SCOPE) else() - message(STATUS "TlsCredentialsOptions not found in grpc_impl::experimental.") + message( + STATUS + "TlsCredentialsOptions (for gRPC ${DETECT_VERSION}) not found in grpc::experimental." + ) message(DEBUG "Build output:") - list(APPEND CMAKE_MESSAGE_INDENT "check_tls_opts_127.cc: ") - message(DEBUG ${TSL_CREDENTIALS_OPTIONS_CHECK_OUTPUT}) + list(APPEND CMAKE_MESSAGE_INDENT "${TEST_FILE}: ") + message(DEBUG ${TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT}) list(REMOVE_AT CMAKE_MESSAGE_INDENT -1) endif() endif() +endfunction() + +test_grpc_version(GRPC_VERSION "1.36" "check_tls_opts_136.cc") +test_grpc_version(GRPC_VERSION "1.34" "check_tls_opts_134.cc") +test_grpc_version(GRPC_VERSION "1.32" "check_tls_opts_132.cc") +test_grpc_version(GRPC_VERSION "1.27" "check_tls_opts_127.cc") +message( + STATUS + "Found approximate gRPC version: ${GRPC_VERSION} (ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS=${ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS})" + ) +if(GRPC_VERSION EQUAL "1.27") + add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc_impl::experimental) +elseif(GRPC_VERSION EQUAL "1.32") + add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental) +elseif(GRPC_VERSION EQUAL "1.34") + add_definitions(-DGRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS + -DGRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS + -DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental) +elseif(GRPC_VERSION EQUAL "1.36") + add_definitions(-DGRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS + -DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental) +else() + message( + STATUS + "A proper version of gRPC could not be found to support TlsCredentialsOptions in Arrow Flight." + ) + message( + STATUS + "You may need a newer version of gRPC (>= 1.27), or the gRPC API has changed and Flight must be updated to match." + ) + if(ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS) + message( + FATAL_ERROR "Halting build since ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS is set." + ) + endif() endif() # Restore the CXXFLAGS that were modified above diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 724f999fabedc..cd6f3a97e5874 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -860,7 +860,7 @@ namespace { // requires root CA certs, even if you are skipping server // verification. #if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) -constexpr char BLANK_ROOT_PEM[] = +constexpr char kDummyRootCert[] = "-----BEGIN CERTIFICATE-----\n" "MIICwzCCAaugAwIBAgIJAM12DOkcaqrhMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV\n" "BAMTCWxvY2FsaG9zdDAeFw0yMDEwMDcwODIyNDFaFw0zMDEwMDUwODIyNDFaMBQx\n" @@ -893,11 +893,7 @@ class FlightClient::FlightClientImpl { if (scheme == kSchemeGrpcTls) { if (options.disable_server_verification) { -#if !defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) - return Status::NotImplemented( - "Using encryption with server verification disabled is unsupported. " - "Please use a release of Arrow Flight built with gRPC 1.27 or higher."); -#else +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) namespace ge = GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS; // A callback to supply to TlsCredentialsOptions that accepts any server @@ -910,16 +906,40 @@ class FlightClient::FlightClientImpl { return 0; } }; - + auto server_authorization_check = std::make_shared(); noop_auth_check_ = std::make_shared( - std::make_shared()); + server_authorization_check); +#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) + auto certificate_provider = + std::make_shared( + kDummyRootCert); +#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) + grpc::experimental::TlsChannelCredentialsOptions tls_options( + certificate_provider); +#else + // While gRPC >= 1.36 does not require a root cert (it has a default) + // in practice the path it hardcodes is broken. See grpc/grpc#21655. + grpc::experimental::TlsChannelCredentialsOptions tls_options; + tls_options.set_certificate_provider(certificate_provider); +#endif + tls_options.watch_root_certs(); + tls_options.set_root_cert_name("dummy"); + tls_options.set_server_verification_option( + grpc_tls_server_verification_option::GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION); + tls_options.set_server_authorization_check_config(noop_auth_check_); +#elif defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) auto materials_config = std::make_shared(); - materials_config->set_pem_root_certs(BLANK_ROOT_PEM); + materials_config->set_pem_root_certs(kDummyRootCert); ge::TlsCredentialsOptions tls_options( GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION, materials_config, std::shared_ptr(), noop_auth_check_); +#endif creds = ge::TlsCredentials(tls_options); +#else + return Status::NotImplemented( + "Using encryption with server verification disabled is unsupported. " + "Please use a release of Arrow Flight built with gRPC 1.27 or higher."); #endif } else { grpc::SslCredentialsOptions ssl_options; diff --git a/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc b/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc new file mode 100644 index 0000000000000..4ee2122ef57e7 --- /dev/null +++ b/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Dummy file for checking if TlsCredentialsOptions exists in +// the grpc::experimental namespace. gRPC starting from 1.34 +// put it here. This is for supporting disabling server +// validation when using TLS. + +#include +#include +#include + +// Dummy file for checking if TlsCredentialsOptions exists in +// the grpc::experimental namespace. gRPC starting from 1.34 +// puts it here. This is for supporting disabling server +// validation when using TLS. + +static void check() { + // In 1.34, there's no parameterless constructor; in 1.36, there's + // only a parameterless constructor + auto options = + std::make_shared(nullptr); + options->set_server_verification_option( + grpc_tls_server_verification_option::GRPC_TLS_SERVER_VERIFICATION); +} + +int main(int argc, const char** argv) { + check(); + return 0; +} diff --git a/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc b/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc new file mode 100644 index 0000000000000..638eec67ba723 --- /dev/null +++ b/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Dummy file for checking if TlsCredentialsOptions exists in +// the grpc::experimental namespace. gRPC starting from 1.36 +// puts it here. This is for supporting disabling server +// validation when using TLS. + +#include +#include +#include + +static void check() { + // In 1.34, there's no parameterless constructor; in 1.36, there's + // only a parameterless constructor + auto options = std::make_shared(); + options->set_server_verification_option( + grpc_tls_server_verification_option::GRPC_TLS_SERVER_VERIFICATION); +} + +int main(int argc, const char** argv) { + check(); + return 0; +} From 9a9baf6824db91be2c0913367d4b151d9390a4e6 Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Thu, 25 Feb 2021 20:15:54 +0100 Subject: [PATCH 39/54] ARROW-2229: [C++][Python] Add WriteCsv functionality. This offers possibly performance naive CSV writer with limited options to keep the initial PR down. Obvious potential improvements to this approach are: - Smarter casts for dictionaries - Arena allocation for intermediate cast results The implementation also means that for all primitive type support we might have to fill in gaps in our cast function. Closes #9504 from emkornfield/csv Lead-authored-by: Micah Kornfield Co-authored-by: emkornfield Co-authored-by: Antoine Pitrou Co-authored-by: Micah Kornfield Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 3 + cpp/src/arrow/csv/CMakeLists.txt | 22 +- cpp/src/arrow/csv/api.h | 5 + cpp/src/arrow/csv/options.cc | 1 + cpp/src/arrow/csv/options.h | 15 + cpp/src/arrow/csv/writer.cc | 437 +++++++++++++++++++++++++++ cpp/src/arrow/csv/writer.h | 47 +++ cpp/src/arrow/csv/writer_test.cc | 129 ++++++++ cpp/src/arrow/ipc/json_simple.cc | 6 +- cpp/src/arrow/util/config.h.cmake | 2 + cpp/src/arrow/util/iterator.h | 3 +- python/pyarrow/_csv.pyx | 103 ++++++- python/pyarrow/csv.py | 3 +- python/pyarrow/includes/libarrow.pxd | 12 + python/pyarrow/tests/test_csv.py | 37 ++- 15 files changed, 811 insertions(+), 14 deletions(-) create mode 100644 cpp/src/arrow/csv/writer.cc create mode 100644 cpp/src/arrow/csv/writer.h create mode 100644 cpp/src/arrow/csv/writer_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 382a851c159a5..abd5428b3d775 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -349,6 +349,9 @@ if(ARROW_CSV) csv/options.cc csv/parser.cc csv/reader.cc) + if(ARROW_COMPUTE) + list(APPEND ARROW_SRCS csv/writer.cc) + endif() list(APPEND ARROW_TESTING_SRCS csv/test_common.cc) endif() diff --git a/cpp/src/arrow/csv/CMakeLists.txt b/cpp/src/arrow/csv/CMakeLists.txt index 2766cfd3bd2bd..561faf1b58480 100644 --- a/cpp/src/arrow/csv/CMakeLists.txt +++ b/cpp/src/arrow/csv/CMakeLists.txt @@ -15,14 +15,20 @@ # specific language governing permissions and limitations # under the License. -add_arrow_test(csv-test - SOURCES - chunker_test.cc - column_builder_test.cc - column_decoder_test.cc - converter_test.cc - parser_test.cc - reader_test.cc) +set(CSV_TEST_SRCS + chunker_test.cc + column_builder_test.cc + column_decoder_test.cc + converter_test.cc + parser_test.cc + reader_test.cc) + +# Writer depends on compute's cast functionality +if(ARROW_COMPUTE) + list(APPEND CSV_TEST_SRCS writer_test.cc) +endif() + +add_arrow_test(csv-test SOURCES ${CSV_TEST_SRCS}) add_arrow_benchmark(converter_benchmark PREFIX "arrow-csv") add_arrow_benchmark(parser_benchmark PREFIX "arrow-csv") diff --git a/cpp/src/arrow/csv/api.h b/cpp/src/arrow/csv/api.h index df88843f51b7b..7bf39315767e0 100644 --- a/cpp/src/arrow/csv/api.h +++ b/cpp/src/arrow/csv/api.h @@ -19,3 +19,8 @@ #include "arrow/csv/options.h" #include "arrow/csv/reader.h" + +// The writer depends on compute module for casting. +#ifdef ARROW_COMPUTE +#include "arrow/csv/writer.h" +#endif diff --git a/cpp/src/arrow/csv/options.cc b/cpp/src/arrow/csv/options.cc index b6f1346bcd34d..a515abf2cf41e 100644 --- a/cpp/src/arrow/csv/options.cc +++ b/cpp/src/arrow/csv/options.cc @@ -34,6 +34,7 @@ ConvertOptions ConvertOptions::Defaults() { } ReadOptions ReadOptions::Defaults() { return ReadOptions(); } +WriteOptions WriteOptions::Defaults() { return WriteOptions(); } } // namespace csv } // namespace arrow diff --git a/cpp/src/arrow/csv/options.h b/cpp/src/arrow/csv/options.h index 82153ed466a20..5c912e7fd8537 100644 --- a/cpp/src/arrow/csv/options.h +++ b/cpp/src/arrow/csv/options.h @@ -137,5 +137,20 @@ struct ARROW_EXPORT ReadOptions { static ReadOptions Defaults(); }; +/// Experimental +struct ARROW_EXPORT WriteOptions { + /// Whether to write an initial header line with column names + bool include_header = true; + + /// \brief Maximum number of rows processed at a time + /// + /// The CSV writer converts and writes data in batches of N rows. + /// This number can impact performance. + int32_t batch_size = 1024; + + /// Create write options with default values + static WriteOptions Defaults(); +}; + } // namespace csv } // namespace arrow diff --git a/cpp/src/arrow/csv/writer.cc b/cpp/src/arrow/csv/writer.cc new file mode 100644 index 0000000000000..ddd59b46fc1c9 --- /dev/null +++ b/cpp/src/arrow/csv/writer.cc @@ -0,0 +1,437 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/csv/writer.h" +#include "arrow/array.h" +#include "arrow/compute/cast.h" +#include "arrow/io/interfaces.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/result_internal.h" +#include "arrow/stl_allocator.h" +#include "arrow/util/iterator.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace csv { +// This implementation is intentionally light on configurability to minimize the size of +// the initial PR. Aditional features can be added as there is demand and interest to +// implement them. +// +// The algorithm used here at a high level is to break RecordBatches/Tables into slices +// and convert each slice independently. A slice is then converted to CSV by first +// scanning each column to determine the size of its contents when rendered as a string in +// CSV. For non-string types this requires casting the value to string (which is cached). +// This data is used to understand the precise length of each row and a single allocation +// for the final CSV data buffer. Once the final size is known each column is then +// iterated over again to place its contents into the CSV data buffer. The rationale for +// choosing this approach is it allows for reuse of the cast functionality in the compute +// module and inline data visiting functionality in the core library. A performance +// comparison has not been done using a naive single-pass approach. This approach might +// still be competitive due to reduction in the number of per row branches necessary with +// a single pass approach. Profiling would likely yield further opportunities for +// optimization with this approach. + +namespace { + +struct SliceIteratorFunctor { + Result> Next() { + if (current_offset < batch->num_rows()) { + std::shared_ptr next = batch->Slice(current_offset, slice_size); + current_offset += slice_size; + return next; + } + return IterationTraits>::End(); + } + const RecordBatch* const batch; + const int64_t slice_size; + int64_t current_offset; +}; + +RecordBatchIterator RecordBatchSliceIterator(const RecordBatch& batch, + int64_t slice_size) { + SliceIteratorFunctor functor = {&batch, slice_size, /*offset=*/static_cast(0)}; + return RecordBatchIterator(std::move(functor)); +} + +// Counts the number of characters that need escaping in s. +int64_t CountEscapes(util::string_view s) { + return static_cast(std::count(s.begin(), s.end(), '"')); +} + +// Matching quote pair character length. +constexpr int64_t kQuoteCount = 2; +constexpr int64_t kQuoteDelimiterCount = kQuoteCount + /*end_char*/ 1; + +// Interface for generating CSV data per column. +// The intended usage is to iteratively call UpdateRowLengths for a column and +// then PopulateColumns. PopulateColumns must be called in the reverse order of the +// populators (it populates data backwards). +class ColumnPopulator { + public: + ColumnPopulator(MemoryPool* pool, char end_char) : end_char_(end_char), pool_(pool) {} + + virtual ~ColumnPopulator() = default; + + // Adds the number of characters each entry in data will add to to elements + // in row_lengths. + Status UpdateRowLengths(const Array& data, int32_t* row_lengths) { + compute::ExecContext ctx(pool_); + // Populators are intented to be applied to reasonably small data. In most cases + // threading overhead would not be justified. + ctx.set_use_threads(false); + ASSIGN_OR_RAISE( + std::shared_ptr casted, + compute::Cast(data, /*to_type=*/utf8(), compute::CastOptions(), &ctx)); + casted_array_ = internal::checked_pointer_cast(casted); + return UpdateRowLengths(row_lengths); + } + + // Places string data onto each row in output and updates the corresponding row + // row pointers in preparation for calls to other (preceding) ColumnPopulators. + // Args: + // output: character buffer to write to. + // offsets: an array of end of row column within the the output buffer (values are + // one past the end of the position to write to). + virtual void PopulateColumns(char* output, int32_t* offsets) const = 0; + + protected: + virtual Status UpdateRowLengths(int32_t* row_lengths) = 0; + std::shared_ptr casted_array_; + const char end_char_; + + private: + MemoryPool* const pool_; +}; + +// Copies the contents of to out properly escaping any necessary characters. +// Returns the position prior to last copied character (out_end is decremented). +char* EscapeReverse(arrow::util::string_view s, char* out_end) { + for (const char* val = s.data() + s.length() - 1; val >= s.data(); val--, out_end--) { + if (*val == '"') { + *out_end = *val; + out_end--; + } + *out_end = *val; + } + return out_end; +} + +// Populator for non-string types. This populator relies on compute Cast functionality to +// String if it doesn't exist it will be an error. it also assumes the resulting string +// from a cast does not require quoting or escaping. +class UnquotedColumnPopulator : public ColumnPopulator { + public: + explicit UnquotedColumnPopulator(MemoryPool* memory_pool, char end_char) + : ColumnPopulator(memory_pool, end_char) {} + + Status UpdateRowLengths(int32_t* row_lengths) override { + for (int x = 0; x < casted_array_->length(); x++) { + row_lengths[x] += casted_array_->value_length(x); + } + return Status::OK(); + } + + void PopulateColumns(char* output, int32_t* offsets) const override { + VisitArrayDataInline( + *casted_array_->data(), + [&](arrow::util::string_view s) { + int64_t next_column_offset = s.length() + /*end_char*/ 1; + memcpy((output + *offsets - next_column_offset), s.data(), s.length()); + *(output + *offsets - 1) = end_char_; + *offsets -= static_cast(next_column_offset); + offsets++; + }, + [&]() { + // Nulls are empty (unquoted) to distinguish with empty string. + *(output + *offsets - 1) = end_char_; + *offsets -= 1; + offsets++; + }); + } +}; + +// Strings need special handling to ensure they are escaped properly. +// This class handles escaping assuming that all strings will be quoted +// and that the only character within the string that needs to escaped is +// a quote character (") and escaping is done my adding another quote. +class QuotedColumnPopulator : public ColumnPopulator { + public: + QuotedColumnPopulator(MemoryPool* pool, char end_char) + : ColumnPopulator(pool, end_char) {} + + Status UpdateRowLengths(int32_t* row_lengths) override { + const StringArray& input = *casted_array_; + int row_number = 0; + row_needs_escaping_.resize(casted_array_->length()); + VisitArrayDataInline( + *input.data(), + [&](arrow::util::string_view s) { + int64_t escaped_count = CountEscapes(s); + // TODO: Maybe use 64 bit row lengths or safe cast? + row_needs_escaping_[row_number] = escaped_count > 0; + row_lengths[row_number] += static_cast(s.length()) + + static_cast(escaped_count + kQuoteCount); + row_number++; + }, + [&]() { + row_needs_escaping_[row_number] = false; + row_number++; + }); + return Status::OK(); + } + + void PopulateColumns(char* output, int32_t* offsets) const override { + auto needs_escaping = row_needs_escaping_.begin(); + VisitArrayDataInline( + *(casted_array_->data()), + [&](arrow::util::string_view s) { + // still needs string content length to be added + char* row_end = output + *offsets; + int32_t next_column_offset = 0; + if (!*needs_escaping) { + next_column_offset = static_cast(s.length() + kQuoteDelimiterCount); + memcpy(row_end - next_column_offset + /*quote_offset=*/1, s.data(), + s.length()); + } else { + // Adjust row_end by 3: 1 quote char, 1 end char and 1 to position at the + // first position to write to. + next_column_offset = + static_cast(row_end - EscapeReverse(s, row_end - 3)); + } + *(row_end - next_column_offset) = '"'; + *(row_end - 2) = '"'; + *(row_end - 1) = end_char_; + *offsets -= next_column_offset; + offsets++; + needs_escaping++; + }, + [&]() { + // Nulls are empty (unquoted) to distinguish with empty string. + *(output + *offsets - 1) = end_char_; + *offsets -= 1; + offsets++; + needs_escaping++; + }); + } + + private: + // Older version of GCC don't support custom allocators + // at some point we should change this to use memory_pool + // backed allocator. + std::vector row_needs_escaping_; +}; + +struct PopulatorFactory { + template + enable_if_t::value || + std::is_same::value, + Status> + Visit(const TypeClass& type) { + populator = new QuotedColumnPopulator(pool, end_char); + return Status::OK(); + } + + template + enable_if_dictionary Visit(const TypeClass& type) { + return VisitTypeInline(*type.value_type(), this); + } + + template + enable_if_t::value || is_extension_type::value, + Status> + Visit(const TypeClass& type) { + return Status::Invalid("Unsupported Type:", type.ToString()); + } + + template + enable_if_t::value || is_decimal_type::value || + is_null_type::value || is_temporal_type::value, + Status> + Visit(const TypeClass& type) { + populator = new UnquotedColumnPopulator(pool, end_char); + return Status::OK(); + } + + char end_char; + MemoryPool* pool; + ColumnPopulator* populator; +}; + +Result> MakePopulator(const Field& field, char end_char, + MemoryPool* pool) { + PopulatorFactory factory{end_char, pool, nullptr}; + RETURN_NOT_OK(VisitTypeInline(*field.type(), &factory)); + return std::unique_ptr(factory.populator); +} + +class CSVConverter { + public: + static Result> Make(std::shared_ptr schema, + MemoryPool* pool) { + std::vector> populators(schema->num_fields()); + for (int col = 0; col < schema->num_fields(); col++) { + char end_char = col < schema->num_fields() - 1 ? ',' : '\n'; + ASSIGN_OR_RAISE(populators[col], + MakePopulator(*schema->field(col), end_char, pool)); + } + return std::unique_ptr( + new CSVConverter(std::move(schema), std::move(populators), pool)); + } + + Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, + io::OutputStream* out) { + RETURN_NOT_OK(PrepareForContentsWrite(options, out)); + RecordBatchIterator iterator = RecordBatchSliceIterator(batch, options.batch_size); + for (auto maybe_slice : iterator) { + ASSIGN_OR_RAISE(std::shared_ptr slice, maybe_slice); + RETURN_NOT_OK(TranslateMinimalBatch(*slice)); + RETURN_NOT_OK(out->Write(data_buffer_)); + } + return Status::OK(); + } + + Status WriteCSV(const Table& table, const WriteOptions& options, + io::OutputStream* out) { + TableBatchReader reader(table); + reader.set_chunksize(options.batch_size); + RETURN_NOT_OK(PrepareForContentsWrite(options, out)); + std::shared_ptr batch; + RETURN_NOT_OK(reader.ReadNext(&batch)); + while (batch != nullptr) { + RETURN_NOT_OK(TranslateMinimalBatch(*batch)); + RETURN_NOT_OK(out->Write(data_buffer_)); + RETURN_NOT_OK(reader.ReadNext(&batch)); + } + + return Status::OK(); + } + + private: + CSVConverter(std::shared_ptr schema, + std::vector> populators, MemoryPool* pool) + : column_populators_(std::move(populators)), + offsets_(0, 0, ::arrow::stl::allocator(pool)), + schema_(std::move(schema)), + pool_(pool) {} + + Status PrepareForContentsWrite(const WriteOptions& options, io::OutputStream* out) { + if (data_buffer_ == nullptr) { + ASSIGN_OR_RAISE( + data_buffer_, + AllocateResizableBuffer( + options.batch_size * schema_->num_fields() * kColumnSizeGuess, pool_)); + } + if (options.include_header) { + RETURN_NOT_OK(WriteHeader(out)); + } + return Status::OK(); + } + + int64_t CalculateHeaderSize() const { + int64_t header_length = 0; + for (int col = 0; col < schema_->num_fields(); col++) { + const std::string& col_name = schema_->field(col)->name(); + header_length += col_name.size(); + header_length += CountEscapes(col_name); + } + return header_length + (kQuoteDelimiterCount * schema_->num_fields()); + } + + Status WriteHeader(io::OutputStream* out) { + RETURN_NOT_OK(data_buffer_->Resize(CalculateHeaderSize(), /*shrink_to_fit=*/false)); + char* next = + reinterpret_cast(data_buffer_->mutable_data() + data_buffer_->size() - 1); + for (int col = schema_->num_fields() - 1; col >= 0; col--) { + *next-- = ','; + *next-- = '"'; + next = EscapeReverse(schema_->field(col)->name(), next); + *next-- = '"'; + } + *(data_buffer_->mutable_data() + data_buffer_->size() - 1) = '\n'; + DCHECK_EQ(reinterpret_cast(next + 1), data_buffer_->data()); + return out->Write(data_buffer_); + } + + Status TranslateMinimalBatch(const RecordBatch& batch) { + if (batch.num_rows() == 0) { + return Status::OK(); + } + offsets_.resize(batch.num_rows()); + std::fill(offsets_.begin(), offsets_.end(), 0); + + // Calculate relative offsets for each row (excluding delimiters) + for (int32_t col = 0; col < static_cast(column_populators_.size()); col++) { + RETURN_NOT_OK( + column_populators_[col]->UpdateRowLengths(*batch.column(col), offsets_.data())); + } + // Calculate cumulalative offsets for each row (including delimiters). + offsets_[0] += batch.num_columns(); + for (int64_t row = 1; row < batch.num_rows(); row++) { + offsets_[row] += offsets_[row - 1] + /*delimiter lengths*/ batch.num_columns(); + } + // Resize the target buffer to required size. We assume batch to batch sizes + // should be pretty close so don't shrink the buffer to avoid allocation churn. + RETURN_NOT_OK(data_buffer_->Resize(offsets_.back(), /*shrink_to_fit=*/false)); + + // Use the offsets to populate contents. + for (auto populator = column_populators_.rbegin(); + populator != column_populators_.rend(); populator++) { + (*populator) + ->PopulateColumns(reinterpret_cast(data_buffer_->mutable_data()), + offsets_.data()); + } + DCHECK_EQ(0, offsets_[0]); + return Status::OK(); + } + + static constexpr int64_t kColumnSizeGuess = 8; + std::vector> column_populators_; + std::vector> offsets_; + std::shared_ptr data_buffer_; + const std::shared_ptr schema_; + MemoryPool* pool_; +}; + +} // namespace + +Status WriteCSV(const Table& table, const WriteOptions& options, MemoryPool* pool, + arrow::io::OutputStream* output) { + if (pool == nullptr) { + pool = default_memory_pool(); + } + ASSIGN_OR_RAISE(std::unique_ptr converter, + CSVConverter::Make(table.schema(), pool)); + return converter->WriteCSV(table, options, output); +} + +Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, MemoryPool* pool, + arrow::io::OutputStream* output) { + if (pool == nullptr) { + pool = default_memory_pool(); + } + + ASSIGN_OR_RAISE(std::unique_ptr converter, + CSVConverter::Make(batch.schema(), pool)); + return converter->WriteCSV(batch, options, output); +} + +} // namespace csv +} // namespace arrow diff --git a/cpp/src/arrow/csv/writer.h b/cpp/src/arrow/csv/writer.h new file mode 100644 index 0000000000000..c009d7849f4cb --- /dev/null +++ b/cpp/src/arrow/csv/writer.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/csv/options.h" +#include "arrow/io/interfaces.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" + +namespace arrow { +namespace csv { +// Functionality for converting Arrow data to Comma separated value text. +// This library supports all primitive types that can be cast to a StringArrays. +// It applies to following formatting rules: +// - For non-binary types no quotes surround values. Nulls are represented as the empty +// string. +// - For binary types all non-null data is quoted (and quotes within data are escaped +// with an additional quote). +// Null values are empty and unquoted. +// - LF (\n) is always used as a line ending. + +/// \brief Converts table to a CSV and writes the results to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const Table& table, const WriteOptions& options, + MemoryPool* pool, arrow::io::OutputStream* output); +/// \brief Converts batch to CSV and writes the results to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, + MemoryPool* pool, arrow::io::OutputStream* output); + +} // namespace csv +} // namespace arrow diff --git a/cpp/src/arrow/csv/writer_test.cc b/cpp/src/arrow/csv/writer_test.cc new file mode 100644 index 0000000000000..dc59fefa8fe7a --- /dev/null +++ b/cpp/src/arrow/csv/writer_test.cc @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gtest/gtest.h" + +#include +#include + +#include "arrow/buffer.h" +#include "arrow/csv/writer.h" +#include "arrow/io/memory.h" +#include "arrow/record_batch.h" +#include "arrow/result_internal.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace csv { + +struct TestParams { + std::shared_ptr record_batch; + WriteOptions options; + std::string expected_output; +}; + +WriteOptions DefaultTestOptions(bool include_header) { + WriteOptions options; + options.batch_size = 5; + options.include_header = include_header; + return options; +} + +std::vector GenerateTestCases() { + auto abc_schema = schema({ + {field("a", uint64())}, + {field("b\"", utf8())}, + {field("c ", int32())}, + }); + auto empty_batch = + RecordBatch::Make(abc_schema, /*num_rows=*/0, + { + ArrayFromJSON(abc_schema->field(0)->type(), "[]"), + ArrayFromJSON(abc_schema->field(1)->type(), "[]"), + ArrayFromJSON(abc_schema->field(2)->type(), "[]"), + }); + auto populated_batch = RecordBatchFromJSON(abc_schema, R"([{"a": 1, "c ": -1}, + { "a": 1, "b\"": "abc\"efg", "c ": 2324}, + { "b\"": "abcd", "c ": 5467}, + { }, + { "a": 546, "b\"": "", "c ": 517 }, + { "a": 124, "b\"": "a\"\"b\"" }])"); + std::string expected_without_header = std::string("1,,-1") + "\n" + // line 1 + +R"(1,"abc""efg",2324)" + "\n" + // line 2 + R"(,"abcd",5467)" + "\n" + // line 3 + R"(,,)" + "\n" + // line 4 + R"(546,"",517)" + "\n" + // line 5 + R"(124,"a""""b""",)" + "\n"; // line 6 + std::string expected_header = std::string(R"("a","b""","c ")") + "\n"; + + return std::vector{ + {empty_batch, DefaultTestOptions(/*header=*/false), ""}, + {empty_batch, DefaultTestOptions(/*header=*/true), expected_header}, + {populated_batch, DefaultTestOptions(/*header=*/false), expected_without_header}, + {populated_batch, DefaultTestOptions(/*header=*/true), + expected_header + expected_without_header}}; +} + +class TestWriteCSV : public ::testing::TestWithParam { + protected: + template + Result ToCsvString(const Data& data, const WriteOptions& options) { + std::shared_ptr out; + ASSIGN_OR_RAISE(out, io::BufferOutputStream::Create()); + + RETURN_NOT_OK(WriteCSV(data, options, default_memory_pool(), out.get())); + ASSIGN_OR_RAISE(std::shared_ptr buffer, out->Finish()); + return std::string(reinterpret_cast(buffer->data()), buffer->size()); + } +}; + +TEST_P(TestWriteCSV, TestWrite) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr out, + io::BufferOutputStream::Create()); + WriteOptions options = GetParam().options; + std::string csv; + ASSERT_OK_AND_ASSIGN(csv, ToCsvString(*GetParam().record_batch, options)); + EXPECT_EQ(csv, GetParam().expected_output); + + // Batch size shouldn't matter. + options.batch_size /= 2; + ASSERT_OK_AND_ASSIGN(csv, ToCsvString(*GetParam().record_batch, options)); + EXPECT_EQ(csv, GetParam().expected_output); + + // Table and Record batch should work identically. + ASSERT_OK_AND_ASSIGN(std::shared_ptr
table, + Table::FromRecordBatches({GetParam().record_batch})); + ASSERT_OK_AND_ASSIGN(csv, ToCsvString(*table, options)); + EXPECT_EQ(csv, GetParam().expected_output); +} + +INSTANTIATE_TEST_SUITE_P(MultiColumnWriteCSVTest, TestWriteCSV, + ::testing::ValuesIn(GenerateTestCases())); + +INSTANTIATE_TEST_SUITE_P( + SingleColumnWriteCSVTest, TestWriteCSV, + ::testing::Values(TestParams{ + RecordBatchFromJSON(schema({field("int64", int64())}), + R"([{ "int64": 9999}, {}, { "int64": -15}])"), + WriteOptions(), + R"("int64")" + "\n9999\n\n-15\n"})); + +} // namespace csv +} // namespace arrow diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index fba8194aeb165..caf6fd06b9c31 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -43,6 +43,7 @@ #include #include #include +#include namespace rj = arrow::rapidjson; @@ -652,8 +653,11 @@ class StructConverter final : public ConcreteConverter { } } if (remaining > 0) { + rj::StringBuffer sb; + rj::Writer writer(sb); + json_obj.Accept(writer); return Status::Invalid("Unexpected members in JSON object for type ", - type_->ToString()); + type_->ToString(), " Object: ", sb.GetString()); } return builder_->Append(); } diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake index 8f8dea0c6c8d4..be6686f253e05 100644 --- a/cpp/src/arrow/util/config.h.cmake +++ b/cpp/src/arrow/util/config.h.cmake @@ -34,6 +34,8 @@ #define ARROW_PACKAGE_KIND "@ARROW_PACKAGE_KIND@" +#cmakedefine ARROW_COMPUTE + #cmakedefine ARROW_S3 #cmakedefine ARROW_USE_NATIVE_INT128 diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index 75ccf283aa5c3..771b209a40644 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -64,7 +64,8 @@ template class Iterator : public util::EqualityComparable> { public: /// \brief Iterator may be constructed from any type which has a member function - /// with signature Status Next(T*); + /// with signature Result Next(); + /// End of iterator is signalled by returning IteratorTraits::End(); /// /// The argument is moved or copied to the heap and kept in a unique_ptr. Only /// its destructor and its Next method (which are stored in function pointers) are diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index 4068a0b9141d1..f5b8e4d5fbac6 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -30,10 +30,13 @@ from pyarrow.includes.libarrow cimport * from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, RecordBatchReader, ensure_type, maybe_unbox_memory_pool, get_input_stream, - native_transcoding_input_stream, + get_writer, native_transcoding_input_stream, + pyarrow_unwrap_batch, pyarrow_unwrap_table, pyarrow_wrap_schema, pyarrow_wrap_table, - pyarrow_wrap_data_type, pyarrow_unwrap_data_type) + pyarrow_wrap_data_type, pyarrow_unwrap_data_type, + Table, RecordBatch) from pyarrow.lib import frombytes, tobytes +from pyarrow.util import _stringify_path cdef unsigned char _single_char(s) except 0: @@ -763,3 +766,99 @@ def open_csv(input_file, read_options=None, parse_options=None, move(c_convert_options), maybe_unbox_memory_pool(memory_pool)) return reader + + +cdef class WriteOptions(_Weakrefable): + """ + Options for writing CSV files. + + Parameters + ---------- + include_header : bool, optional (default True) + Whether to write an initial header line with column names + batch_size : int, optional (default 1024) + How many rows to process together when converting and writing + CSV data + """ + cdef: + CCSVWriteOptions options + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, *, include_header=None, batch_size=None): + self.options = CCSVWriteOptions.Defaults() + if include_header is not None: + self.include_header = include_header + if batch_size is not None: + self.batch_size = batch_size + + @property + def include_header(self): + """ + Whether to write an initial header line with column names. + """ + return self.options.include_header + + @include_header.setter + def include_header(self, value): + self.options.include_header = value + + @property + def batch_size(self): + """ + How many rows to process together when converting and writing + CSV data. + """ + return self.options.batch_size + + @batch_size.setter + def batch_size(self, value): + self.options.batch_size = value + + +cdef _get_write_options(WriteOptions write_options, CCSVWriteOptions* out): + if write_options is None: + out[0] = CCSVWriteOptions.Defaults() + else: + out[0] = write_options.options + + +def write_csv(data, output_file, write_options=None, + MemoryPool memory_pool=None): + """ + Write record batch or table to a CSV file. + + Parameters + ---------- + data: pyarrow.RecordBatch or pyarrow.Table + The data to write. + output_file: string, path, pyarrow.OutputStream or file-like object + The location where to write the CSV data. + write_options: pyarrow.csv.WriteOptions + Options to configure writing the CSV data. + memory_pool: MemoryPool, optional + Pool for temporary allocations. + """ + cdef: + shared_ptr[COutputStream] stream + CCSVWriteOptions c_write_options + CMemoryPool* c_memory_pool + CRecordBatch* batch + CTable* table + _get_write_options(write_options, &c_write_options) + + get_writer(output_file, &stream) + c_memory_pool = maybe_unbox_memory_pool(memory_pool) + if isinstance(data, RecordBatch): + batch = pyarrow_unwrap_batch(data).get() + with nogil: + check_status(WriteCSV(deref(batch), c_write_options, c_memory_pool, + stream.get())) + elif isinstance(data, Table): + table = pyarrow_unwrap_table(data).get() + with nogil: + check_status(WriteCSV(deref(table), c_write_options, c_memory_pool, + stream.get())) + else: + raise TypeError(f"Expected Table or RecordBatch, got '{type(data)}'") diff --git a/python/pyarrow/csv.py b/python/pyarrow/csv.py index b116ea11d83a3..fc1dcafba0b35 100644 --- a/python/pyarrow/csv.py +++ b/python/pyarrow/csv.py @@ -18,4 +18,5 @@ from pyarrow._csv import ( # noqa ReadOptions, ParseOptions, ConvertOptions, ISO8601, - open_csv, read_csv, CSVStreamingReader) + open_csv, read_csv, CSVStreamingReader, write_csv, + WriteOptions) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index ba3c3ad7d2b66..a4f6f18628402 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1618,6 +1618,13 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: @staticmethod CCSVReadOptions Defaults() + cdef cppclass CCSVWriteOptions" arrow::csv::WriteOptions": + c_bool include_header + int32_t batch_size + + @staticmethod + CCSVWriteOptions Defaults() + cdef cppclass CCSVReader" arrow::csv::TableReader": @staticmethod CResult[shared_ptr[CCSVReader]] Make( @@ -1633,6 +1640,11 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: CMemoryPool*, shared_ptr[CInputStream], CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions) + cdef CStatus WriteCSV( + CTable&, CCSVWriteOptions& options, CMemoryPool*, COutputStream*) + cdef CStatus WriteCSV( + CRecordBatch&, CCSVWriteOptions& options, CMemoryPool*, COutputStream*) + cdef extern from "arrow/json/options.h" nogil: diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index 462fe0114921f..5ca31aefebc0c 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -36,7 +36,8 @@ import pyarrow as pa from pyarrow.csv import ( - open_csv, read_csv, ReadOptions, ParseOptions, ConvertOptions, ISO8601) + open_csv, read_csv, ReadOptions, ParseOptions, ConvertOptions, ISO8601, + write_csv, WriteOptions) def generate_col_names(): @@ -203,6 +204,21 @@ def test_convert_options(): assert opts.timestamp_parsers == [ISO8601, '%Y-%m-%d'] +def test_write_options(): + cls = WriteOptions + opts = cls() + + check_options_class( + cls, include_header=[True, False]) + + assert opts.batch_size > 0 + opts.batch_size = 12345 + assert opts.batch_size == 12345 + + opts = cls(batch_size=9876) + assert opts.batch_size == 9876 + + class BaseTestCSVRead: def read_bytes(self, b, **kwargs): @@ -1257,3 +1273,22 @@ def test_read_csv_does_not_close_passed_file_handles(): buf = io.BytesIO(b"a,b,c\n1,2,3\n4,5,6") read_csv(buf) assert not buf.closed + + +def test_write_read_round_trip(): + t = pa.Table.from_arrays([[1, 2, 3], ["a", "b", "c"]], ["c1", "c2"]) + record_batch = t.to_batches(max_chunksize=4)[0] + for data in [t, record_batch]: + # Test with header + buf = io.BytesIO() + write_csv(data, buf, WriteOptions(include_header=True)) + buf.seek(0) + assert t == read_csv(buf) + + # Test without header + buf = io.BytesIO() + write_csv(data, buf, WriteOptions(include_header=False)) + buf.seek(0) + + read_options = ReadOptions(column_names=t.column_names) + assert t == read_csv(buf, read_options=read_options) From 59f9d2089b47ab541741cdb1c153fec37cacd109 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 25 Feb 2021 21:17:33 -0500 Subject: [PATCH 40/54] ARROW-10420: [C++] Refactor io and filesystem APIs to take an IOContext The `io::IOContext` class allows passing various settings such as the MemoryPool used for allocation and the Executor for async methods. Closes #9474 from pitrou/ARROW-10420-io-context Authored-by: Antoine Pitrou Signed-off-by: Benjamin Kietzman --- c_glib/arrow-glib/reader.cpp | 3 +- cpp/examples/minimal_build/example.cc | 2 +- cpp/src/arrow/csv/reader.cc | 54 +++++++++------ cpp/src/arrow/csv/reader.h | 8 ++- cpp/src/arrow/csv/reader_test.cc | 11 ++- cpp/src/arrow/dataset/discovery.h | 3 +- cpp/src/arrow/dataset/file_base.cc | 9 +++ cpp/src/arrow/dataset/file_base.h | 12 +++- cpp/src/arrow/dataset/file_ipc.cc | 10 +-- cpp/src/arrow/dataset/file_ipc.h | 15 ++-- cpp/src/arrow/dataset/file_parquet.cc | 22 +++--- cpp/src/arrow/dataset/file_parquet.h | 9 +-- cpp/src/arrow/filesystem/filesystem.cc | 83 +++++++++++++++-------- cpp/src/arrow/filesystem/filesystem.h | 36 +++++++++- cpp/src/arrow/filesystem/hdfs.cc | 19 +++--- cpp/src/arrow/filesystem/hdfs.h | 5 +- cpp/src/arrow/filesystem/localfs.cc | 19 +++--- cpp/src/arrow/filesystem/localfs.h | 5 +- cpp/src/arrow/filesystem/localfs_test.cc | 9 ++- cpp/src/arrow/filesystem/mockfs.cc | 67 ++++++++++++------ cpp/src/arrow/filesystem/mockfs.h | 8 ++- cpp/src/arrow/filesystem/s3fs.cc | 47 +++++++++---- cpp/src/arrow/filesystem/s3fs.h | 5 +- cpp/src/arrow/filesystem/util_internal.cc | 5 +- cpp/src/arrow/filesystem/util_internal.h | 3 +- cpp/src/arrow/io/caching.cc | 4 +- cpp/src/arrow/io/caching.h | 4 +- cpp/src/arrow/io/file.cc | 2 +- cpp/src/arrow/io/file.h | 2 +- cpp/src/arrow/io/hdfs.cc | 26 ++++--- cpp/src/arrow/io/hdfs.h | 11 ++- cpp/src/arrow/io/interfaces.cc | 34 +++++++--- cpp/src/arrow/io/interfaces.h | 57 +++++++++++++--- cpp/src/arrow/io/memory.cc | 2 +- cpp/src/arrow/io/memory.h | 2 +- cpp/src/arrow/io/type_fwd.h | 8 +++ cpp/src/arrow/ipc/type_fwd.h | 3 + cpp/src/arrow/util/parallel.h | 11 +-- cpp/src/parquet/arrow/reader.cc | 4 +- cpp/src/parquet/file_reader.cc | 4 +- cpp/src/parquet/file_reader.h | 2 +- cpp/src/parquet/properties.h | 6 +- docs/source/cpp/csv.rst | 6 +- python/pyarrow/_csv.pyx | 5 +- python/pyarrow/includes/libarrow.pxd | 9 ++- r/src/csv.cpp | 6 +- r/src/filesystem.cpp | 3 +- 47 files changed, 453 insertions(+), 227 deletions(-) diff --git a/c_glib/arrow-glib/reader.cpp b/c_glib/arrow-glib/reader.cpp index 17100e76a3c12..db6fa544069b1 100644 --- a/c_glib/arrow-glib/reader.cpp +++ b/c_glib/arrow-glib/reader.cpp @@ -1591,8 +1591,7 @@ garrow_csv_reader_new(GArrowInputStream *input, } auto arrow_reader = - arrow::csv::TableReader::Make(arrow::default_memory_pool(), - arrow::io::AsyncContext(), + arrow::csv::TableReader::Make(arrow::io::default_io_context(), arrow_input, read_options, parse_options, diff --git a/cpp/examples/minimal_build/example.cc b/cpp/examples/minimal_build/example.cc index 8f58de5777a49..e1b5c123a85fb 100644 --- a/cpp/examples/minimal_build/example.cc +++ b/cpp/examples/minimal_build/example.cc @@ -39,7 +39,7 @@ Status RunMain(int argc, char** argv) { ARROW_ASSIGN_OR_RAISE( auto csv_reader, arrow::csv::TableReader::Make(arrow::default_memory_pool(), - arrow::io::AsyncContext(), + arrow::io::default_io_context(), input_file, arrow::csv::ReadOptions::Defaults(), arrow::csv::ParseOptions::Defaults(), diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index f0fa1f206d344..bbba60c79c156 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -40,6 +40,7 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/type.h" +#include "arrow/type_fwd.h" #include "arrow/util/async_generator.h" #include "arrow/util/future.h" #include "arrow/util/iterator.h" @@ -51,19 +52,12 @@ #include "arrow/util/utf8.h" namespace arrow { - -class MemoryPool; - -namespace io { - -class InputStream; - -} // namespace io - namespace csv { using internal::Executor; +namespace { + struct ConversionSchema { struct Column { std::string name; @@ -154,6 +148,7 @@ struct CSVBlock { std::function consume_bytes; }; +} // namespace } // namespace csv template <> @@ -162,6 +157,7 @@ struct IterationTraits { }; namespace csv { +namespace { // The == operator must be defined to be used as T in Iterator bool operator==(const CSVBlock& left, const CSVBlock& right) { @@ -935,18 +931,17 @@ class AsyncThreadedTableReader AsyncGenerator> buffer_generator_; }; -///////////////////////////////////////////////////////////////////////// -// Factory functions - -Result> TableReader::Make( - MemoryPool* pool, io::AsyncContext async_context, - std::shared_ptr input, const ReadOptions& read_options, - const ParseOptions& parse_options, const ConvertOptions& convert_options) { +Result> MakeTableReader( + MemoryPool* pool, io::IOContext io_context, std::shared_ptr input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { std::shared_ptr reader; if (read_options.use_threads) { - reader = std::make_shared( - pool, input, read_options, parse_options, convert_options, async_context.executor, - internal::GetCpuThreadPool()); + auto cpu_executor = internal::GetCpuThreadPool(); + auto io_executor = io_context.executor(); + reader = std::make_shared(pool, input, read_options, + parse_options, convert_options, + cpu_executor, io_executor); } else { reader = std::make_shared(pool, input, read_options, parse_options, convert_options); @@ -955,6 +950,27 @@ Result> TableReader::Make( return reader; } +} // namespace + +///////////////////////////////////////////////////////////////////////// +// Factory functions + +Result> TableReader::Make( + io::IOContext io_context, std::shared_ptr input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + return MakeTableReader(io_context.pool(), io_context, std::move(input), read_options, + parse_options, convert_options); +} + +Result> TableReader::Make( + MemoryPool* pool, io::IOContext io_context, std::shared_ptr input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + return MakeTableReader(pool, io_context, std::move(input), read_options, parse_options, + convert_options); +} + Result> StreamingReader::Make( MemoryPool* pool, std::shared_ptr input, const ReadOptions& read_options, const ParseOptions& parse_options, diff --git a/cpp/src/arrow/csv/reader.h b/cpp/src/arrow/csv/reader.h index c361fbddce97c..b18dc04eb65b1 100644 --- a/cpp/src/arrow/csv/reader.h +++ b/cpp/src/arrow/csv/reader.h @@ -46,12 +46,16 @@ class ARROW_EXPORT TableReader { virtual Future> ReadAsync() = 0; /// Create a TableReader instance - static Result> Make(MemoryPool* pool, - io::AsyncContext async_context, + static Result> Make(io::IOContext io_context, std::shared_ptr input, const ReadOptions&, const ParseOptions&, const ConvertOptions&); + + ARROW_DEPRECATED("Use MemoryPool-less overload (the IOContext holds a pool already)") + static Result> Make( + MemoryPool* pool, io::IOContext io_context, std::shared_ptr input, + const ReadOptions&, const ParseOptions&, const ConvertOptions&); }; /// Experimental diff --git a/cpp/src/arrow/csv/reader_test.cc b/cpp/src/arrow/csv/reader_test.cc index 64010ae481ac4..602adf2f2a6c4 100644 --- a/cpp/src/arrow/csv/reader_test.cc +++ b/cpp/src/arrow/csv/reader_test.cc @@ -108,9 +108,8 @@ TableReaderFactory MakeSerialFactory() { auto read_options = ReadOptions::Defaults(); read_options.block_size = 1 << 10; read_options.use_threads = false; - return TableReader::Make(default_memory_pool(), io::AsyncContext(), input_stream, - read_options, ParseOptions::Defaults(), - ConvertOptions::Defaults()); + return TableReader::Make(io::default_io_context(), input_stream, read_options, + ParseOptions::Defaults(), ConvertOptions::Defaults()); }; } @@ -131,9 +130,9 @@ Result MakeAsyncFactory( ReadOptions read_options = ReadOptions::Defaults(); read_options.use_threads = true; read_options.block_size = 1 << 10; - auto table_reader = TableReader::Make( - default_memory_pool(), io::AsyncContext(thread_pool.get()), input_stream, - read_options, ParseOptions::Defaults(), ConvertOptions::Defaults()); + auto table_reader = + TableReader::Make(io::IOContext(thread_pool.get()), input_stream, read_options, + ParseOptions::Defaults(), ConvertOptions::Defaults()); return table_reader; }; } diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index ca3274cc0862a..94c49ff0b8557 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -30,8 +30,7 @@ #include "arrow/dataset/partition.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" -#include "arrow/filesystem/filesystem.h" -#include "arrow/filesystem/path_forest.h" +#include "arrow/filesystem/type_fwd.h" #include "arrow/result.h" #include "arrow/util/macros.h" #include "arrow/util/variant.h" diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 612c249861cfb..e468b686af5e8 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -160,6 +160,13 @@ Status FileWriter::Write(RecordBatchReader* batches) { return Status::OK(); } +Status FileWriter::Finish() { + RETURN_NOT_OK(FinishInternal()); + return destination_->Close(); +} + +namespace { + constexpr util::string_view kIntegerToken = "{i}"; Status ValidateBasenameTemplate(util::string_view basename_template) { @@ -257,6 +264,8 @@ class WriteQueue { std::shared_ptr schema_; }; +} // namespace + Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, std::shared_ptr scanner) { RETURN_NOT_OK(ValidateBasenameTemplate(write_options.basename_template)); diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 708f7e0205421..d058ac2c07788 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -268,18 +268,24 @@ class ARROW_DS_EXPORT FileWriter { Status Write(RecordBatchReader* batches); - virtual Status Finish() = 0; + Status Finish(); const std::shared_ptr& format() const { return options_->format(); } const std::shared_ptr& schema() const { return schema_; } const std::shared_ptr& options() const { return options_; } protected: - FileWriter(std::shared_ptr schema, std::shared_ptr options) - : schema_(std::move(schema)), options_(std::move(options)) {} + FileWriter(std::shared_ptr schema, std::shared_ptr options, + std::shared_ptr destination) + : schema_(std::move(schema)), + options_(std::move(options)), + destination_(destination) {} + + virtual Status FinishInternal() = 0; std::shared_ptr schema_; std::shared_ptr options_; + std::shared_ptr destination_; }; struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions { diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc index 8bd0121834483..b48b8c767cb3d 100644 --- a/cpp/src/arrow/dataset/file_ipc.cc +++ b/cpp/src/arrow/dataset/file_ipc.cc @@ -193,20 +193,22 @@ Result> IpcFileFormat::MakeWriter( ipc_options->metadata)); return std::shared_ptr( - new IpcFileWriter(std::move(writer), std::move(schema), std::move(ipc_options))); + new IpcFileWriter(std::move(destination), std::move(writer), std::move(schema), + std::move(ipc_options))); } -IpcFileWriter::IpcFileWriter(std::shared_ptr writer, +IpcFileWriter::IpcFileWriter(std::shared_ptr destination, + std::shared_ptr writer, std::shared_ptr schema, std::shared_ptr options) - : FileWriter(std::move(schema), std::move(options)), + : FileWriter(std::move(schema), std::move(options), std::move(destination)), batch_writer_(std::move(writer)) {} Status IpcFileWriter::Write(const std::shared_ptr& batch) { return batch_writer_->WriteRecordBatch(*batch); } -Status IpcFileWriter::Finish() { return batch_writer_->Close(); } +Status IpcFileWriter::FinishInternal() { return batch_writer_->Close(); } } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h index 2cdd837430e4f..35a760604088b 100644 --- a/cpp/src/arrow/dataset/file_ipc.h +++ b/cpp/src/arrow/dataset/file_ipc.h @@ -25,15 +25,10 @@ #include "arrow/dataset/file_base.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" +#include "arrow/ipc/type_fwd.h" #include "arrow/result.h" namespace arrow { -namespace ipc { - -class RecordBatchWriter; -struct IpcWriteOptions; - -} // namespace ipc namespace dataset { /// \brief A FileFormat implementation that reads from and writes to Ipc files @@ -82,13 +77,15 @@ class ARROW_DS_EXPORT IpcFileWriter : public FileWriter { public: Status Write(const std::shared_ptr& batch) override; - Status Finish() override; - private: - IpcFileWriter(std::shared_ptr writer, + IpcFileWriter(std::shared_ptr destination, + std::shared_ptr writer, std::shared_ptr schema, std::shared_ptr options); + Status FinishInternal() override; + + std::shared_ptr destination_; std::shared_ptr batch_writer_; friend class IpcFileFormat; diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index c26ad0490bad6..05bff2d1f5284 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -56,8 +56,7 @@ class ParquetScanTask : public ScanTask { ParquetScanTask(int row_group, std::vector column_projection, std::shared_ptr reader, std::shared_ptr pre_buffer_once, - std::vector pre_buffer_row_groups, - arrow::io::AsyncContext async_context, + std::vector pre_buffer_row_groups, arrow::io::IOContext io_context, arrow::io::CacheOptions cache_options, std::shared_ptr options, std::shared_ptr context) @@ -67,7 +66,7 @@ class ParquetScanTask : public ScanTask { reader_(std::move(reader)), pre_buffer_once_(std::move(pre_buffer_once)), pre_buffer_row_groups_(std::move(pre_buffer_row_groups)), - async_context_(async_context), + io_context_(io_context), cache_options_(cache_options) {} Result Execute() override { @@ -106,7 +105,7 @@ class ParquetScanTask : public ScanTask { BEGIN_PARQUET_CATCH_EXCEPTIONS std::call_once(*pre_buffer_once_, [this]() { reader_->parquet_reader()->PreBuffer(pre_buffer_row_groups_, column_projection_, - async_context_, cache_options_); + io_context_, cache_options_); }); END_PARQUET_CATCH_EXCEPTIONS } @@ -121,7 +120,7 @@ class ParquetScanTask : public ScanTask { // to be done. We assume all scan tasks have the same column projection. std::shared_ptr pre_buffer_once_; std::vector pre_buffer_row_groups_; - arrow::io::AsyncContext async_context_; + arrow::io::IOContext io_context_; arrow::io::CacheOptions cache_options_; }; @@ -362,7 +361,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr( row_groups[i], column_projection, reader, pre_buffer_once, row_groups, - reader_options.async_context, reader_options.cache_options, options, context); + reader_options.io_context, reader_options.cache_options, options, context); } return MakeVectorIterator(std::move(tasks)); @@ -410,13 +409,14 @@ Result> ParquetFileFormat::MakeWriter( *schema, default_memory_pool(), destination, parquet_options->writer_properties, parquet_options->arrow_writer_properties, &parquet_writer)); - return std::shared_ptr( - new ParquetFileWriter(std::move(parquet_writer), std::move(parquet_options))); + return std::shared_ptr(new ParquetFileWriter( + std::move(destination), std::move(parquet_writer), std::move(parquet_options))); } -ParquetFileWriter::ParquetFileWriter(std::shared_ptr writer, +ParquetFileWriter::ParquetFileWriter(std::shared_ptr destination, + std::shared_ptr writer, std::shared_ptr options) - : FileWriter(writer->schema(), std::move(options)), + : FileWriter(writer->schema(), std::move(options), std::move(destination)), parquet_writer_(std::move(writer)) {} Status ParquetFileWriter::Write(const std::shared_ptr& batch) { @@ -424,7 +424,7 @@ Status ParquetFileWriter::Write(const std::shared_ptr& batch) { return parquet_writer_->WriteTable(*table, batch->num_rows()); } -Status ParquetFileWriter::Finish() { return parquet_writer_->Close(); } +Status ParquetFileWriter::FinishInternal() { return parquet_writer_->Close(); } // // ParquetFileFragment diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 6967ab30669f3..ed0a6f949d627 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -97,7 +97,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { std::unordered_set dict_columns; bool pre_buffer = false; arrow::io::CacheOptions cache_options = arrow::io::CacheOptions::Defaults(); - arrow::io::AsyncContext async_context; + arrow::io::IOContext io_context; /// @} /// EXPERIMENTAL: Parallelize conversion across columns. This option is ignored if a @@ -226,12 +226,13 @@ class ARROW_DS_EXPORT ParquetFileWriter : public FileWriter { Status Write(const std::shared_ptr& batch) override; - Status Finish() override; - private: - ParquetFileWriter(std::shared_ptr writer, + ParquetFileWriter(std::shared_ptr destination, + std::shared_ptr writer, std::shared_ptr options); + Status FinishInternal() override; + std::shared_ptr parquet_writer_; friend class ParquetFileFormat; diff --git a/cpp/src/arrow/filesystem/filesystem.cc b/cpp/src/arrow/filesystem/filesystem.cc index 6945aa0646550..15441e7ae64b5 100644 --- a/cpp/src/arrow/filesystem/filesystem.cc +++ b/cpp/src/arrow/filesystem/filesystem.cc @@ -167,7 +167,9 @@ Result> FileSystem::OpenInputFile( SubTreeFileSystem::SubTreeFileSystem(const std::string& base_path, std::shared_ptr base_fs) - : base_path_(NormalizeBasePath(base_path, base_fs).ValueOrDie()), base_fs_(base_fs) {} + : FileSystem(base_fs->io_context()), + base_path_(NormalizeBasePath(base_path, base_fs).ValueOrDie()), + base_fs_(base_fs) {} SubTreeFileSystem::~SubTreeFileSystem() {} @@ -344,15 +346,19 @@ Result> SubTreeFileSystem::OpenAppendStream( SlowFileSystem::SlowFileSystem(std::shared_ptr base_fs, std::shared_ptr latencies) - : base_fs_(base_fs), latencies_(latencies) {} + : FileSystem(base_fs->io_context()), base_fs_(base_fs), latencies_(latencies) {} SlowFileSystem::SlowFileSystem(std::shared_ptr base_fs, double average_latency) - : base_fs_(base_fs), latencies_(io::LatencyGenerator::Make(average_latency)) {} + : FileSystem(base_fs->io_context()), + base_fs_(base_fs), + latencies_(io::LatencyGenerator::Make(average_latency)) {} SlowFileSystem::SlowFileSystem(std::shared_ptr base_fs, double average_latency, int32_t seed) - : base_fs_(base_fs), latencies_(io::LatencyGenerator::Make(average_latency, seed)) {} + : FileSystem(base_fs->io_context()), + base_fs_(base_fs), + latencies_(io::LatencyGenerator::Make(average_latency, seed)) {} bool SlowFileSystem::Equals(const FileSystem& other) const { return this == &other; } @@ -443,34 +449,37 @@ Result> SlowFileSystem::OpenAppendStream( } Status CopyFiles(const std::vector& sources, - const std::vector& destinations, int64_t chunk_size, - bool use_threads) { + const std::vector& destinations, + const io::IOContext& io_context, int64_t chunk_size, bool use_threads) { if (sources.size() != destinations.size()) { return Status::Invalid("Trying to copy ", sources.size(), " files into ", destinations.size(), " paths."); } - return ::arrow::internal::OptionalParallelFor( - use_threads, static_cast(sources.size()), [&](int i) { - if (sources[i].filesystem->Equals(destinations[i].filesystem)) { - return sources[i].filesystem->CopyFile(sources[i].path, destinations[i].path); - } + auto copy_one_file = [&](int i) { + if (sources[i].filesystem->Equals(destinations[i].filesystem)) { + return sources[i].filesystem->CopyFile(sources[i].path, destinations[i].path); + } + + ARROW_ASSIGN_OR_RAISE(auto source, + sources[i].filesystem->OpenInputStream(sources[i].path)); - ARROW_ASSIGN_OR_RAISE(auto source, - sources[i].filesystem->OpenInputStream(sources[i].path)); + ARROW_ASSIGN_OR_RAISE(auto destination, destinations[i].filesystem->OpenOutputStream( + destinations[i].path)); + RETURN_NOT_OK(internal::CopyStream(source, destination, chunk_size, io_context)); + return destination->Close(); + }; - ARROW_ASSIGN_OR_RAISE( - auto destination, - destinations[i].filesystem->OpenOutputStream(destinations[i].path)); - return internal::CopyStream(source, destination, chunk_size); - }); + return ::arrow::internal::OptionalParallelFor( + use_threads, static_cast(sources.size()), std::move(copy_one_file), + io_context.executor()); } Status CopyFiles(const std::shared_ptr& source_fs, const FileSelector& source_sel, const std::shared_ptr& destination_fs, - const std::string& destination_base_dir, int64_t chunk_size, - bool use_threads) { + const std::string& destination_base_dir, const io::IOContext& io_context, + int64_t chunk_size, bool use_threads) { ARROW_ASSIGN_OR_RAISE(auto source_infos, source_fs->GetFileInfo(source_sel)); if (source_infos.empty()) { return Status::OK(); @@ -497,12 +506,14 @@ Status CopyFiles(const std::shared_ptr& source_fs, } } + auto create_one_dir = [&](int i) { return destination_fs->CreateDir(dirs[i]); }; + dirs = internal::MinimalCreateDirSet(std::move(dirs)); RETURN_NOT_OK(::arrow::internal::OptionalParallelFor( - use_threads, static_cast(dirs.size()), - [&](int i) { return destination_fs->CreateDir(dirs[i]); })); + use_threads, static_cast(dirs.size()), std::move(create_one_dir), + io_context.executor())); - return CopyFiles(sources, destinations, chunk_size, use_threads); + return CopyFiles(sources, destinations, io_context, chunk_size, use_threads); } namespace { @@ -526,6 +537,7 @@ Result ParseFileSystemUri(const std::string& uri_string) { Result> FileSystemFromUriReal(const Uri& uri, const std::string& uri_string, + const io::IOContext& io_context, std::string* out_path) { const auto scheme = uri.scheme(); @@ -535,7 +547,7 @@ Result> FileSystemFromUriReal(const Uri& uri, if (out_path != nullptr) { *out_path = path; } - return std::make_shared(options); + return std::make_shared(options, io_context); } if (scheme == "hdfs" || scheme == "viewfs") { #ifdef ARROW_HDFS @@ -543,7 +555,7 @@ Result> FileSystemFromUriReal(const Uri& uri, if (out_path != nullptr) { *out_path = uri.path(); } - ARROW_ASSIGN_OR_RAISE(auto hdfs, HadoopFileSystem::Make(options)); + ARROW_ASSIGN_OR_RAISE(auto hdfs, HadoopFileSystem::Make(options, io_context)); return hdfs; #else return Status::NotImplemented("Got HDFS URI but Arrow compiled without HDFS support"); @@ -553,7 +565,7 @@ Result> FileSystemFromUriReal(const Uri& uri, #ifdef ARROW_S3 RETURN_NOT_OK(EnsureS3Initialized()); ARROW_ASSIGN_OR_RAISE(auto options, S3Options::FromUri(uri, out_path)); - ARROW_ASSIGN_OR_RAISE(auto s3fs, S3FileSystem::Make(options)); + ARROW_ASSIGN_OR_RAISE(auto s3fs, S3FileSystem::Make(options, io_context)); return s3fs; #else return Status::NotImplemented("Got S3 URI but Arrow compiled without S3 support"); @@ -566,7 +578,8 @@ Result> FileSystemFromUriReal(const Uri& uri, if (out_path != nullptr) { *out_path = std::string(RemoveLeadingSlash(uri.path())); } - return std::make_shared(internal::CurrentTimePoint()); + return std::make_shared(internal::CurrentTimePoint(), + io_context); } return Status::Invalid("Unrecognized filesystem type in URI: ", uri_string); @@ -576,12 +589,24 @@ Result> FileSystemFromUriReal(const Uri& uri, Result> FileSystemFromUri(const std::string& uri_string, std::string* out_path) { + return FileSystemFromUri(uri_string, io::default_io_context(), out_path); +} + +Result> FileSystemFromUri(const std::string& uri_string, + const io::IOContext& io_context, + std::string* out_path) { ARROW_ASSIGN_OR_RAISE(auto fsuri, ParseFileSystemUri(uri_string)); - return FileSystemFromUriReal(fsuri, uri_string, out_path); + return FileSystemFromUriReal(fsuri, uri_string, io_context, out_path); } Result> FileSystemFromUriOrPath(const std::string& uri_string, std::string* out_path) { + return FileSystemFromUriOrPath(uri_string, io::default_io_context(), out_path); +} + +Result> FileSystemFromUriOrPath( + const std::string& uri_string, const io::IOContext& io_context, + std::string* out_path) { if (internal::DetectAbsolutePath(uri_string)) { // Normalize path separators if (out_path != nullptr) { @@ -589,7 +614,7 @@ Result> FileSystemFromUriOrPath(const std::string& u } return std::make_shared(); } - return FileSystemFromUri(uri_string, out_path); + return FileSystemFromUri(uri_string, io_context, out_path); } Status FileSystemFromUri(const std::string& uri, std::shared_ptr* out_fs, diff --git a/cpp/src/arrow/filesystem/filesystem.h b/cpp/src/arrow/filesystem/filesystem.h index 9eeb3d86841a2..9d7aca9852964 100644 --- a/cpp/src/arrow/filesystem/filesystem.h +++ b/cpp/src/arrow/filesystem/filesystem.h @@ -26,7 +26,7 @@ #include #include "arrow/filesystem/type_fwd.h" -#include "arrow/io/type_fwd.h" +#include "arrow/io/interfaces.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" #include "arrow/util/macros.h" @@ -147,6 +147,9 @@ class ARROW_EXPORT FileSystem : public std::enable_shared_from_this virtual std::string type_name() const = 0; + /// EXPERIMENTAL: The IOContext associated with this filesystem. + const io::IOContext& io_context() const { return io_context_; } + /// Normalize path for the given filesystem /// /// The default implementation of this method is a no-op, but subclasses @@ -250,6 +253,12 @@ class ARROW_EXPORT FileSystem : public std::enable_shared_from_this /// If the target doesn't exist, a new empty file is created. virtual Result> OpenAppendStream( const std::string& path) = 0; + + protected: + explicit FileSystem(const io::IOContext& io_context = io::default_io_context()) + : io_context_(io_context) {} + + io::IOContext io_context_; }; /// \brief A FileSystem implementation that delegates to another @@ -382,6 +391,19 @@ ARROW_EXPORT Result> FileSystemFromUri(const std::string& uri, std::string* out_path = NULLPTR); +/// \brief Create a new FileSystem by URI with a custom IO context +/// +/// Recognized schemes are "file", "mock", "hdfs" and "s3fs". +/// +/// \param[in] uri a URI-based path, ex: file:///some/local/path +/// \param[in] io_context an IOContext which will be associated with the filesystem +/// \param[out] out_path (optional) Path inside the filesystem. +/// \return out_fs FileSystem instance. +ARROW_EXPORT +Result> FileSystemFromUri(const std::string& uri, + const io::IOContext& io_context, + std::string* out_path = NULLPTR); + /// \brief Create a new FileSystem by URI /// /// Same as FileSystemFromUri, but in addition also recognize non-URIs @@ -391,6 +413,16 @@ ARROW_EXPORT Result> FileSystemFromUriOrPath( const std::string& uri, std::string* out_path = NULLPTR); +/// \brief Create a new FileSystem by URI with a custom IO context +/// +/// Same as FileSystemFromUri, but in addition also recognize non-URIs +/// and treat them as local filesystem paths. Only absolute local filesystem +/// paths are allowed. +ARROW_EXPORT +Result> FileSystemFromUriOrPath( + const std::string& uri, const io::IOContext& io_context, + std::string* out_path = NULLPTR); + /// @} /// \brief Copy files, including from one FileSystem to another @@ -401,6 +433,7 @@ Result> FileSystemFromUriOrPath( ARROW_EXPORT Status CopyFiles(const std::vector& sources, const std::vector& destinations, + const io::IOContext& io_context = io::default_io_context(), int64_t chunk_size = 1024 * 1024, bool use_threads = true); /// \brief Copy selected files, including from one FileSystem to another @@ -411,6 +444,7 @@ Status CopyFiles(const std::shared_ptr& source_fs, const FileSelector& source_sel, const std::shared_ptr& destination_fs, const std::string& destination_base_dir, + const io::IOContext& io_context = io::default_io_context(), int64_t chunk_size = 1024 * 1024, bool use_threads = true); struct FileSystemGlobalOptions { diff --git a/cpp/src/arrow/filesystem/hdfs.cc b/cpp/src/arrow/filesystem/hdfs.cc index 1841bf1ff6c48..2fc549565cdce 100644 --- a/cpp/src/arrow/filesystem/hdfs.cc +++ b/cpp/src/arrow/filesystem/hdfs.cc @@ -43,7 +43,8 @@ using internal::RemoveLeadingSlash; class HadoopFileSystem::Impl { public: - explicit Impl(HdfsOptions options) : options_(std::move(options)) {} + Impl(HdfsOptions options, const io::IOContext& io_context) + : options_(std::move(options)), io_context_(io_context) {} ~Impl() { Status st = Close(); @@ -205,13 +206,13 @@ class HadoopFileSystem::Impl { Result> OpenInputStream(const std::string& path) { std::shared_ptr file; - RETURN_NOT_OK(client_->OpenReadable(path, &file)); + RETURN_NOT_OK(client_->OpenReadable(path, io_context_, &file)); return file; } Result> OpenInputFile(const std::string& path) { std::shared_ptr file; - RETURN_NOT_OK(client_->OpenReadable(path, &file)); + RETURN_NOT_OK(client_->OpenReadable(path, io_context_, &file)); return file; } @@ -226,7 +227,8 @@ class HadoopFileSystem::Impl { } protected: - HdfsOptions options_; + const HdfsOptions options_; + const io::IOContext io_context_; std::shared_ptr<::arrow::io::HadoopFileSystem> client_; void PathInfoToFileInfo(const io::HdfsPathInfo& info, FileInfo* out) { @@ -393,14 +395,15 @@ Result HdfsOptions::FromUri(const std::string& uri_string) { return FromUri(uri); } -HadoopFileSystem::HadoopFileSystem(const HdfsOptions& options) - : impl_(new Impl{options}) {} +HadoopFileSystem::HadoopFileSystem(const HdfsOptions& options, + const io::IOContext& io_context) + : FileSystem(io_context), impl_(new Impl{options, io_context_}) {} HadoopFileSystem::~HadoopFileSystem() {} Result> HadoopFileSystem::Make( - const HdfsOptions& options) { - std::shared_ptr ptr(new HadoopFileSystem(options)); + const HdfsOptions& options, const io::IOContext& io_context) { + std::shared_ptr ptr(new HadoopFileSystem(options, io_context)); RETURN_NOT_OK(ptr->impl_->Init()); return ptr; } diff --git a/cpp/src/arrow/filesystem/hdfs.h b/cpp/src/arrow/filesystem/hdfs.h index 5f6340e79b31b..72cb469b79d25 100644 --- a/cpp/src/arrow/filesystem/hdfs.h +++ b/cpp/src/arrow/filesystem/hdfs.h @@ -97,10 +97,11 @@ class ARROW_EXPORT HadoopFileSystem : public FileSystem { const std::string& path) override; /// Create a HdfsFileSystem instance from the given options. - static Result> Make(const HdfsOptions& options); + static Result> Make( + const HdfsOptions& options, const io::IOContext& = io::default_io_context()); protected: - explicit HadoopFileSystem(const HdfsOptions& options); + HadoopFileSystem(const HdfsOptions& options, const io::IOContext&); class Impl; std::unique_ptr impl_; diff --git a/cpp/src/arrow/filesystem/localfs.cc b/cpp/src/arrow/filesystem/localfs.cc index 88ce46137ec85..490bacea41359 100644 --- a/cpp/src/arrow/filesystem/localfs.cc +++ b/cpp/src/arrow/filesystem/localfs.cc @@ -264,10 +264,12 @@ Result LocalFileSystemOptions::FromUri( return LocalFileSystemOptions(); } -LocalFileSystem::LocalFileSystem() : options_(LocalFileSystemOptions::Defaults()) {} +LocalFileSystem::LocalFileSystem(const io::IOContext& io_context) + : FileSystem(io_context), options_(LocalFileSystemOptions::Defaults()) {} -LocalFileSystem::LocalFileSystem(const LocalFileSystemOptions& options) - : options_(options) {} +LocalFileSystem::LocalFileSystem(const LocalFileSystemOptions& options, + const io::IOContext& io_context) + : FileSystem(io_context), options_(options) {} LocalFileSystem::~LocalFileSystem() {} @@ -378,7 +380,7 @@ Status LocalFileSystem::CopyFile(const std::string& src, const std::string& dest #else ARROW_ASSIGN_OR_RAISE(auto is, OpenInputStream(src)); ARROW_ASSIGN_OR_RAISE(auto os, OpenOutputStream(dest)); - RETURN_NOT_OK(internal::CopyStream(is, os, 1024 * 1024 /* chunk_size */)); + RETURN_NOT_OK(internal::CopyStream(is, os, 1024 * 1024 /* chunk_size */, io_context())); RETURN_NOT_OK(os->Close()); return is->Close(); #endif @@ -388,11 +390,12 @@ namespace { template Result> OpenInputStreamGeneric( - const std::string& path, const LocalFileSystemOptions& options) { + const std::string& path, const LocalFileSystemOptions& options, + const io::IOContext& io_context) { if (options.use_mmap) { return io::MemoryMappedFile::Open(path, io::FileMode::READ); } else { - return io::ReadableFile::Open(path); + return io::ReadableFile::Open(path, io_context.pool()); } } @@ -400,12 +403,12 @@ Result> OpenInputStreamGeneric( Result> LocalFileSystem::OpenInputStream( const std::string& path) { - return OpenInputStreamGeneric(path, options_); + return OpenInputStreamGeneric(path, options_, io_context()); } Result> LocalFileSystem::OpenInputFile( const std::string& path) { - return OpenInputStreamGeneric(path, options_); + return OpenInputStreamGeneric(path, options_, io_context()); } namespace { diff --git a/cpp/src/arrow/filesystem/localfs.h b/cpp/src/arrow/filesystem/localfs.h index add57c6d266d8..d660dd36a5d16 100644 --- a/cpp/src/arrow/filesystem/localfs.h +++ b/cpp/src/arrow/filesystem/localfs.h @@ -55,8 +55,9 @@ struct ARROW_EXPORT LocalFileSystemOptions { /// followed, except when deleting an entry). class ARROW_EXPORT LocalFileSystem : public FileSystem { public: - LocalFileSystem(); - explicit LocalFileSystem(const LocalFileSystemOptions&); + explicit LocalFileSystem(const io::IOContext& = io::default_io_context()); + explicit LocalFileSystem(const LocalFileSystemOptions&, + const io::IOContext& = io::default_io_context()); ~LocalFileSystem() override; std::string type_name() const override { return "local"; } diff --git a/cpp/src/arrow/filesystem/localfs_test.cc b/cpp/src/arrow/filesystem/localfs_test.cc index dbe19a1f46fac..e338160951a7b 100644 --- a/cpp/src/arrow/filesystem/localfs_test.cc +++ b/cpp/src/arrow/filesystem/localfs_test.cc @@ -73,6 +73,11 @@ Result> FSFromUri(const std::string& uri, return FileSystemFromUri(uri, out_path); } +Result> FSFromUriOrPath(const std::string& uri, + std::string* out_path = NULLPTR) { + return FileSystemFromUriOrPath(uri, out_path); +} + //////////////////////////////////////////////////////////////////////////// // Misc tests @@ -192,7 +197,7 @@ class TestLocalFS : public LocalFSTestMixin { } void TestFileSystemFromUriOrPath(const std::string& uri) { - CheckFileSystemFromUriFunc(uri, FileSystemFromUriOrPath); + CheckFileSystemFromUriFunc(uri, FSFromUriOrPath); } template @@ -213,7 +218,7 @@ class TestLocalFS : public LocalFSTestMixin { } void TestLocalUriOrPath(const std::string& uri, const std::string& expected_path) { - CheckLocalUri(uri, expected_path, FileSystemFromUriOrPath); + CheckLocalUri(uri, expected_path, FSFromUriOrPath); } void TestInvalidUri(const std::string& uri) { diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index 7cef8ac2b5468..294cc85531a99 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -25,12 +25,14 @@ #include #include "arrow/buffer.h" +#include "arrow/buffer_builder.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/util_internal.h" #include "arrow/io/interfaces.h" #include "arrow/io/memory.h" #include "arrow/util/logging.h" +#include "arrow/util/string_view.h" #include "arrow/util/variant.h" #include "arrow/util/windows_fixup.h" @@ -48,11 +50,19 @@ class Entry; struct File { TimePoint mtime; std::string name; - std::string data; + std::shared_ptr data; File(TimePoint mtime, std::string name) : mtime(mtime), name(std::move(name)) {} - int64_t size() const { return static_cast(data.length()); } + int64_t size() const { return data ? data->size() : 0; } + + explicit operator util::string_view() const { + if (data) { + return util::string_view(*data); + } else { + return ""; + } + } }; struct Directory { @@ -172,13 +182,17 @@ class Entry : public EntryBase { class MockFSOutputStream : public io::OutputStream { public: - explicit MockFSOutputStream(File* file) : file_(file), closed_(false) {} + MockFSOutputStream(File* file, MemoryPool* pool) + : file_(file), builder_(pool), closed_(false) {} ~MockFSOutputStream() override = default; // Implement the OutputStream interface Status Close() override { - closed_ = true; + if (!closed_) { + RETURN_NOT_OK(builder_.Finish(&file_->data)); + closed_ = true; + } return Status::OK(); } @@ -187,8 +201,8 @@ class MockFSOutputStream : public io::OutputStream { // MockFSOutputStream is mainly used for debugging and testing, so // mark an aborted file's contents explicitly. std::stringstream ss; - ss << "MockFSOutputStream aborted after " << file_->data.size() << " bytes written"; - file_->data = ss.str(); + ss << "MockFSOutputStream aborted after " << file_->size() << " bytes written"; + file_->data = Buffer::FromString(ss.str()); closed_ = true; } return Status::OK(); @@ -200,19 +214,19 @@ class MockFSOutputStream : public io::OutputStream { if (closed_) { return Status::Invalid("Invalid operation on closed stream"); } - return file_->size(); + return builder_.length(); } Status Write(const void* data, int64_t nbytes) override { if (closed_) { return Status::Invalid("Invalid operation on closed stream"); } - file_->data.append(reinterpret_cast(data), static_cast(nbytes)); - return Status::OK(); + return builder_.Append(data, nbytes); } protected: File* file_; + BufferBuilder builder_; bool closed_; }; @@ -234,12 +248,14 @@ std::ostream& operator<<(std::ostream& os, const MockFileInfo& di) { class MockFileSystem::Impl { public: TimePoint current_time; + MemoryPool* pool; + // The root directory Entry root; std::mutex mutex; - explicit Impl(TimePoint current_time) - : current_time(current_time), root(Directory("", current_time)) {} + Impl(TimePoint current_time, MemoryPool* pool) + : current_time(current_time), pool(pool), root(Directory("", current_time)) {} std::unique_lock lock_guard() { return std::unique_lock(mutex); @@ -333,7 +349,7 @@ class MockFileSystem::Impl { Entry* child = pair.second.get(); if (child->is_file()) { auto& file = child->as_file(); - out->push_back({path + file.name, file.mtime, file.data}); + out->push_back({path + file.name, file.mtime, util::string_view(file)}); } else if (child->is_dir()) { DumpFiles(path, child->as_dir(), out); } @@ -352,18 +368,22 @@ class MockFileSystem::Impl { // Find the file in the parent dir, or create it const auto& name = parts.back(); Entry* child = parent->as_dir().Find(name); + File* file; if (child == nullptr) { child = new Entry(File(current_time, name)); parent->as_dir().AssignEntry(name, std::unique_ptr(child)); + file = &child->as_file(); } else if (child->is_file()) { - child->as_file().mtime = current_time; - if (!append) { - child->as_file().data.clear(); - } + file = &child->as_file(); + file->mtime = current_time; } else { return NotAFile(path); } - return std::make_shared(&child->as_file()); + auto ptr = std::make_shared(file, pool); + if (append && file->data) { + RETURN_NOT_OK(ptr->Write(file->data->data(), file->data->size())); + } + return ptr; } Result> OpenInputReader(const std::string& path) { @@ -377,14 +397,19 @@ class MockFileSystem::Impl { if (!entry->is_file()) { return NotAFile(path); } - return std::make_shared(Buffer::FromString(entry->as_file().data)); + const auto& file = entry->as_file(); + if (file.data) { + return std::make_shared(file.data); + } else { + return std::make_shared(""); + } } }; MockFileSystem::~MockFileSystem() = default; -MockFileSystem::MockFileSystem(TimePoint current_time) { - impl_ = std::unique_ptr(new Impl(current_time)); +MockFileSystem::MockFileSystem(TimePoint current_time, const io::IOContext& io_context) { + impl_ = std::unique_ptr(new Impl(current_time, io_context.pool())); } bool MockFileSystem::Equals(const FileSystem& other) const { return this == &other; } @@ -689,7 +714,7 @@ std::vector MockFileSystem::AllFiles() { return result; } -Status MockFileSystem::CreateFile(const std::string& path, const std::string& contents, +Status MockFileSystem::CreateFile(const std::string& path, util::string_view contents, bool recursive) { auto parent = fs::internal::GetAbstractPathParent(path).first; diff --git a/cpp/src/arrow/filesystem/mockfs.h b/cpp/src/arrow/filesystem/mockfs.h index 847b4898ec71a..212caf6d7fed1 100644 --- a/cpp/src/arrow/filesystem/mockfs.h +++ b/cpp/src/arrow/filesystem/mockfs.h @@ -23,6 +23,7 @@ #include #include "arrow/filesystem/filesystem.h" +#include "arrow/util/string_view.h" #include "arrow/util/windows_fixup.h" namespace arrow { @@ -43,7 +44,7 @@ struct MockDirInfo { struct MockFileInfo { std::string full_path; TimePoint mtime; - std::string data; + util::string_view data; bool operator==(const MockFileInfo& other) const { return mtime == other.mtime && full_path == other.full_path && data == other.data; @@ -58,7 +59,8 @@ struct MockFileInfo { /// and bootstrapping FileSystem-based APIs. class ARROW_EXPORT MockFileSystem : public FileSystem { public: - explicit MockFileSystem(TimePoint current_time); + explicit MockFileSystem(TimePoint current_time, + const io::IOContext& = io::default_io_context()); ~MockFileSystem() override; std::string type_name() const override { return "mock"; } @@ -98,7 +100,7 @@ class ARROW_EXPORT MockFileSystem : public FileSystem { std::vector AllFiles(); // Create a File with a content from a string. - Status CreateFile(const std::string& path, const std::string& content, + Status CreateFile(const std::string& path, util::string_view content, bool recursive = true); // Create a MockFileSystem out of (empty) FileInfo. The content of every diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index cc8ae1148e07b..d71e8537250b9 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -631,8 +631,13 @@ Result GetObjectRange(Aws::S3::S3Client* client, class ObjectInputFile final : public io::RandomAccessFile { public: ObjectInputFile(std::shared_ptr fs, Aws::S3::S3Client* client, - const S3Path& path, int64_t size = kNoSize) - : fs_(std::move(fs)), client_(client), path_(path), content_length_(size) {} + const io::IOContext& io_context, const S3Path& path, + int64_t size = kNoSize) + : fs_(std::move(fs)), + client_(client), + io_context_(io_context), + path_(path), + content_length_(size) {} Status Init() { // Issue a HEAD Object to get the content-length and ensure any @@ -735,7 +740,7 @@ class ObjectInputFile final : public io::RandomAccessFile { // No need to allocate more than the remaining number of bytes nbytes = std::min(nbytes, content_length_ - position); - ARROW_ASSIGN_OR_RAISE(auto buf, AllocateResizableBuffer(nbytes)); + ARROW_ASSIGN_OR_RAISE(auto buf, AllocateResizableBuffer(nbytes, io_context_.pool())); if (nbytes > 0) { ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, ReadAt(position, nbytes, buf->mutable_data())); @@ -760,7 +765,9 @@ class ObjectInputFile final : public io::RandomAccessFile { protected: std::shared_ptr fs_; // Owner of S3Client Aws::S3::S3Client* client_; + const io::IOContext io_context_; S3Path path_; + bool closed_ = false; int64_t pos_ = 0; int64_t content_length_ = kNoSize; @@ -779,8 +786,13 @@ class ObjectOutputStream final : public io::OutputStream { public: ObjectOutputStream(std::shared_ptr fs, Aws::S3::S3Client* client, - const S3Path& path, const S3Options& options) - : fs_(std::move(fs)), client_(client), path_(path), options_(options) {} + const io::IOContext& io_context, const S3Path& path, + const S3Options& options) + : fs_(std::move(fs)), + client_(client), + io_context_(io_context), + path_(path), + options_(options) {} ~ObjectOutputStream() override { // For compliance with the rest of the IO stack, Close rather than Abort, @@ -910,8 +922,9 @@ class ObjectOutputStream final : public io::OutputStream { } // Can't upload data on its own, need to buffer it if (!current_part_) { - ARROW_ASSIGN_OR_RAISE(current_part_, - io::BufferOutputStream::Create(part_upload_threshold_)); + ARROW_ASSIGN_OR_RAISE( + current_part_, + io::BufferOutputStream::Create(part_upload_threshold_, io_context_.pool())); current_part_size_ = 0; } RETURN_NOT_OK(current_part_->Write(data, nbytes)); @@ -974,7 +987,7 @@ class ObjectOutputStream final : public io::OutputStream { // If the data isn't owned, make an immutable copy for the lifetime of the closure if (owned_buffer == nullptr) { - ARROW_ASSIGN_OR_RAISE(owned_buffer, AllocateBuffer(nbytes)); + ARROW_ASSIGN_OR_RAISE(owned_buffer, AllocateBuffer(nbytes, io_context_.pool())); memcpy(owned_buffer->mutable_data(), data, nbytes); } else { DCHECK_EQ(data, owned_buffer->data()); @@ -1048,8 +1061,10 @@ class ObjectOutputStream final : public io::OutputStream { protected: std::shared_ptr fs_; // Owner of S3Client Aws::S3::S3Client* client_; + const io::IOContext io_context_; S3Path path_; const S3Options& options_; + Aws::String upload_id_; bool closed_ = true; int64_t pos_ = 0; @@ -1586,8 +1601,8 @@ class S3FileSystem::Impl { ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s)); RETURN_NOT_OK(ValidateFilePath(path)); - auto ptr = - std::make_shared(fs->shared_from_this(), client_.get(), path); + auto ptr = std::make_shared(fs->shared_from_this(), client_.get(), + fs->io_context(), path); RETURN_NOT_OK(ptr->Init()); return ptr; } @@ -1605,20 +1620,22 @@ class S3FileSystem::Impl { RETURN_NOT_OK(ValidateFilePath(path)); auto ptr = std::make_shared(fs->shared_from_this(), client_.get(), - path, info.size()); + fs->io_context(), path, info.size()); RETURN_NOT_OK(ptr->Init()); return ptr; } }; -S3FileSystem::S3FileSystem(const S3Options& options) : impl_(new Impl{options}) {} +S3FileSystem::S3FileSystem(const S3Options& options, const io::IOContext& io_context) + : FileSystem(io_context), impl_(new Impl{options}) {} S3FileSystem::~S3FileSystem() {} -Result> S3FileSystem::Make(const S3Options& options) { +Result> S3FileSystem::Make( + const S3Options& options, const io::IOContext& io_context) { RETURN_NOT_OK(CheckS3Initialized()); - std::shared_ptr ptr(new S3FileSystem(options)); + std::shared_ptr ptr(new S3FileSystem(options, io_context)); RETURN_NOT_OK(ptr->impl_->Init()); return ptr; } @@ -1890,7 +1907,7 @@ Result> S3FileSystem::OpenOutputStream( RETURN_NOT_OK(ValidateFilePath(path)); auto ptr = std::make_shared( - shared_from_this(), impl_->client_.get(), path, impl_->options()); + shared_from_this(), impl_->client_.get(), io_context(), path, impl_->options()); RETURN_NOT_OK(ptr->Init()); return ptr; } diff --git a/cpp/src/arrow/filesystem/s3fs.h b/cpp/src/arrow/filesystem/s3fs.h index bd8f1fadbb9e7..ac384fcba71f3 100644 --- a/cpp/src/arrow/filesystem/s3fs.h +++ b/cpp/src/arrow/filesystem/s3fs.h @@ -199,10 +199,11 @@ class ARROW_EXPORT S3FileSystem : public FileSystem { const std::string& path) override; /// Create a S3FileSystem instance from the given options. - static Result> Make(const S3Options& options); + static Result> Make( + const S3Options& options, const io::IOContext& = io::default_io_context()); protected: - explicit S3FileSystem(const S3Options& options); + explicit S3FileSystem(const S3Options& options, const io::IOContext&); class Impl; std::unique_ptr impl_; diff --git a/cpp/src/arrow/filesystem/util_internal.cc b/cpp/src/arrow/filesystem/util_internal.cc index a9c6a1c212088..8f86707375d43 100644 --- a/cpp/src/arrow/filesystem/util_internal.cc +++ b/cpp/src/arrow/filesystem/util_internal.cc @@ -31,8 +31,9 @@ TimePoint CurrentTimePoint() { } Status CopyStream(const std::shared_ptr& src, - const std::shared_ptr& dest, int64_t chunk_size) { - ARROW_ASSIGN_OR_RAISE(auto chunk, AllocateBuffer(chunk_size)); + const std::shared_ptr& dest, int64_t chunk_size, + const io::IOContext& io_context) { + ARROW_ASSIGN_OR_RAISE(auto chunk, AllocateBuffer(chunk_size, io_context.pool())); while (true) { ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, diff --git a/cpp/src/arrow/filesystem/util_internal.h b/cpp/src/arrow/filesystem/util_internal.h index ffcb2b2a82e7e..915c8d03d4644 100644 --- a/cpp/src/arrow/filesystem/util_internal.h +++ b/cpp/src/arrow/filesystem/util_internal.h @@ -34,7 +34,8 @@ TimePoint CurrentTimePoint(); ARROW_EXPORT Status CopyStream(const std::shared_ptr& src, - const std::shared_ptr& dest, int64_t chunk_size); + const std::shared_ptr& dest, int64_t chunk_size, + const io::IOContext& io_context); ARROW_EXPORT Status PathNotFound(const std::string& path); diff --git a/cpp/src/arrow/io/caching.cc b/cpp/src/arrow/io/caching.cc index a306ca7d28603..31a426dffd7d1 100644 --- a/cpp/src/arrow/io/caching.cc +++ b/cpp/src/arrow/io/caching.cc @@ -131,7 +131,7 @@ struct RangeCacheEntry { struct ReadRangeCache::Impl { std::shared_ptr file; - AsyncContext ctx; + IOContext ctx; CacheOptions options; // Ordered by offset (so as to find a matching region by binary search) @@ -150,7 +150,7 @@ struct ReadRangeCache::Impl { } }; -ReadRangeCache::ReadRangeCache(std::shared_ptr file, AsyncContext ctx, +ReadRangeCache::ReadRangeCache(std::shared_ptr file, IOContext ctx, CacheOptions options) : impl_(new Impl()) { impl_->file = std::move(file); diff --git a/cpp/src/arrow/io/caching.h b/cpp/src/arrow/io/caching.h index fd2a652369eec..089c1c554dc3d 100644 --- a/cpp/src/arrow/io/caching.h +++ b/cpp/src/arrow/io/caching.h @@ -82,11 +82,11 @@ class ARROW_EXPORT ReadRangeCache { static constexpr int64_t kDefaultRangeSizeLimit = 32 * 1024 * 1024; /// Construct a read cache with default - explicit ReadRangeCache(std::shared_ptr file, AsyncContext ctx) + explicit ReadRangeCache(std::shared_ptr file, IOContext ctx) : ReadRangeCache(file, std::move(ctx), CacheOptions::Defaults()) {} /// Construct a read cache with given options - explicit ReadRangeCache(std::shared_ptr file, AsyncContext ctx, + explicit ReadRangeCache(std::shared_ptr file, IOContext ctx, CacheOptions options); ~ReadRangeCache(); diff --git a/cpp/src/arrow/io/file.cc b/cpp/src/arrow/io/file.cc index dbc2af89ea01a..8a4976db4aa61 100644 --- a/cpp/src/arrow/io/file.cc +++ b/cpp/src/arrow/io/file.cc @@ -699,7 +699,7 @@ Result> MemoryMappedFile::Read(int64_t nbytes) { return buffer; } -Future> MemoryMappedFile::ReadAsync(const AsyncContext&, +Future> MemoryMappedFile::ReadAsync(const IOContext&, int64_t position, int64_t nbytes) { return Future>::MakeFinished(ReadAt(position, nbytes)); diff --git a/cpp/src/arrow/io/file.h b/cpp/src/arrow/io/file.h index 87bb8b9f81c27..50d4f2c4dfc90 100644 --- a/cpp/src/arrow/io/file.h +++ b/cpp/src/arrow/io/file.h @@ -185,7 +185,7 @@ class ARROW_EXPORT MemoryMappedFile : public ReadWriteFileInterface { Result ReadAt(int64_t position, int64_t nbytes, void* out) override; // Synchronous ReadAsync override - Future> ReadAsync(const AsyncContext&, int64_t position, + Future> ReadAsync(const IOContext&, int64_t position, int64_t nbytes) override; Status WillNeed(const std::vector& ranges) override; diff --git a/cpp/src/arrow/io/hdfs.cc b/cpp/src/arrow/io/hdfs.cc index f4c391d3de10e..af91b35ed3c5c 100644 --- a/cpp/src/arrow/io/hdfs.cc +++ b/cpp/src/arrow/io/hdfs.cc @@ -222,11 +222,8 @@ class HdfsReadableFile::HdfsReadableFileImpl : public HdfsAnyFileImpl { int32_t buffer_size_; }; -HdfsReadableFile::HdfsReadableFile(MemoryPool* pool) { - if (pool == nullptr) { - pool = default_memory_pool(); - } - impl_.reset(new HdfsReadableFileImpl(pool)); +HdfsReadableFile::HdfsReadableFile(const io::IOContext& io_context) { + impl_.reset(new HdfsReadableFileImpl(io_context.pool())); } HdfsReadableFile::~HdfsReadableFile() { DCHECK_OK(impl_->Close()); } @@ -498,6 +495,7 @@ class HadoopFileSystem::HadoopFileSystemImpl { } Status OpenReadable(const std::string& path, int32_t buffer_size, + const io::IOContext& io_context, std::shared_ptr* file) { hdfsFile handle = driver_->OpenFile(fs_, path.c_str(), O_RDONLY, buffer_size, 0, 0); @@ -508,7 +506,7 @@ class HadoopFileSystem::HadoopFileSystemImpl { } // std::make_shared does not work with private ctors - *file = std::shared_ptr(new HdfsReadableFile()); + *file = std::shared_ptr(new HdfsReadableFile(io_context)); (*file)->impl_->set_members(path, driver_, fs_, handle); (*file)->impl_->set_buffer_size(buffer_size); @@ -627,12 +625,24 @@ Status HadoopFileSystem::ListDirectory(const std::string& path, Status HadoopFileSystem::OpenReadable(const std::string& path, int32_t buffer_size, std::shared_ptr* file) { - return impl_->OpenReadable(path, buffer_size, file); + return impl_->OpenReadable(path, buffer_size, io::default_io_context(), file); +} + +Status HadoopFileSystem::OpenReadable(const std::string& path, + std::shared_ptr* file) { + return OpenReadable(path, kDefaultHdfsBufferSize, io::default_io_context(), file); +} + +Status HadoopFileSystem::OpenReadable(const std::string& path, int32_t buffer_size, + const io::IOContext& io_context, + std::shared_ptr* file) { + return impl_->OpenReadable(path, buffer_size, io_context, file); } Status HadoopFileSystem::OpenReadable(const std::string& path, + const io::IOContext& io_context, std::shared_ptr* file) { - return OpenReadable(path, kDefaultHdfsBufferSize, file); + return OpenReadable(path, kDefaultHdfsBufferSize, io_context, file); } Status HadoopFileSystem::OpenWritable(const std::string& path, bool append, diff --git a/cpp/src/arrow/io/hdfs.h b/cpp/src/arrow/io/hdfs.h index f91dfb618e6c0..21b0cd8a282f6 100644 --- a/cpp/src/arrow/io/hdfs.h +++ b/cpp/src/arrow/io/hdfs.h @@ -184,8 +184,15 @@ class ARROW_EXPORT HadoopFileSystem : public FileSystem { Status OpenReadable(const std::string& path, int32_t buffer_size, std::shared_ptr* file); + Status OpenReadable(const std::string& path, int32_t buffer_size, + const io::IOContext& io_context, + std::shared_ptr* file); + Status OpenReadable(const std::string& path, std::shared_ptr* file); + Status OpenReadable(const std::string& path, const io::IOContext& io_context, + std::shared_ptr* file); + // FileMode::WRITE options // @param path complete file path // @param buffer_size 0 by default @@ -228,10 +235,8 @@ class ARROW_EXPORT HdfsReadableFile : public RandomAccessFile { Result Tell() const override; Result GetSize() override; - void set_memory_pool(MemoryPool* pool); - private: - explicit HdfsReadableFile(MemoryPool* pool = NULLPTR); + explicit HdfsReadableFile(const io::IOContext&); class ARROW_NO_EXPORT HdfsReadableFileImpl; std::unique_ptr impl_; diff --git a/cpp/src/arrow/io/interfaces.cc b/cpp/src/arrow/io/interfaces.cc index 309d487c52c76..22abbb27bce06 100644 --- a/cpp/src/arrow/io/interfaces.cc +++ b/cpp/src/arrow/io/interfaces.cc @@ -46,18 +46,22 @@ using internal::ThreadPool; namespace io { -AsyncContext::AsyncContext() : AsyncContext(internal::GetIOThreadPool()) {} +static IOContext g_default_io_context{}; -AsyncContext::AsyncContext(Executor* executor) : executor(executor) {} +IOContext::IOContext(MemoryPool* pool) : IOContext(pool, internal::GetIOThreadPool()) {} + +const IOContext& default_io_context() { return g_default_io_context; } FileInterface::~FileInterface() = default; Status FileInterface::Abort() { return Close(); } +namespace { + class InputStreamBlockIterator { public: InputStreamBlockIterator(std::shared_ptr stream, int64_t block_size) - : stream_(stream), block_size_(block_size) {} + : stream_(std::move(stream)), block_size_(block_size) {} Result> Next() { if (done_) { @@ -81,6 +85,10 @@ class InputStreamBlockIterator { bool done_ = false; }; +} // namespace + +const IOContext& Readable::io_context() const { return g_default_io_context; } + Status InputStream::Advance(int64_t nbytes) { return Read(nbytes).status(); } Result InputStream::Peek(int64_t ARROW_ARG_UNUSED(nbytes)) { @@ -98,14 +106,13 @@ Result>> MakeInputStreamIterator( return Iterator>(InputStreamBlockIterator(stream, block_size)); } -struct RandomAccessFile::RandomAccessFileImpl { +struct RandomAccessFile::Impl { std::mutex lock_; }; RandomAccessFile::~RandomAccessFile() = default; -RandomAccessFile::RandomAccessFile() - : interface_impl_(new RandomAccessFile::RandomAccessFileImpl()) {} +RandomAccessFile::RandomAccessFile() : interface_impl_(new Impl()) {} Result RandomAccessFile::ReadAt(int64_t position, int64_t nbytes, void* out) { std::lock_guard lock(interface_impl_->lock_); @@ -121,25 +128,30 @@ Result> RandomAccessFile::ReadAt(int64_t position, } // Default ReadAsync() implementation: simply issue the read on the context's executor -Future> RandomAccessFile::ReadAsync(const AsyncContext& ctx, +Future> RandomAccessFile::ReadAsync(const IOContext& ctx, int64_t position, int64_t nbytes) { auto self = shared_from_this(); TaskHints hints; hints.io_size = nbytes; - hints.external_id = ctx.external_id; - return DeferNotOk(ctx.executor->Submit(std::move(hints), [self, position, nbytes] { + hints.external_id = ctx.external_id(); + return DeferNotOk(ctx.executor()->Submit(std::move(hints), [self, position, nbytes] { return self->ReadAt(position, nbytes); })); } +Future> RandomAccessFile::ReadAsync(int64_t position, + int64_t nbytes) { + return ReadAsync(io_context(), position, nbytes); +} + // Default WillNeed() implementation: no-op Status RandomAccessFile::WillNeed(const std::vector& ranges) { return Status::OK(); } -Status Writable::Write(const std::string& data) { - return Write(data.c_str(), static_cast(data.size())); +Status Writable::Write(util::string_view data) { + return Write(data.data(), static_cast(data.size())); } Status Writable::Write(const std::shared_ptr& data) { diff --git a/cpp/src/arrow/io/interfaces.h b/cpp/src/arrow/io/interfaces.h index b5a1f1220f60e..07c01324ea1fb 100644 --- a/cpp/src/arrow/io/interfaces.h +++ b/cpp/src/arrow/io/interfaces.h @@ -48,15 +48,41 @@ struct ReadRange { } }; -// EXPERIMENTAL -struct ARROW_EXPORT AsyncContext { - ::arrow::internal::Executor* executor; +/// EXPERIMENTAL: options provider for IO tasks +/// +/// Includes an Executor (which will be used to execute asynchronous reads), +/// a MemoryPool (which will be used to allocate buffers when zero copy reads +/// are not possible), and an external id (in case the executor receives tasks from +/// multiple sources and must distinguish tasks associated with this IOContext). +struct ARROW_EXPORT IOContext { + // No specified executor: will use a global IO thread pool + IOContext() : IOContext(default_memory_pool()) {} + + // No specified executor: will use a global IO thread pool + explicit IOContext(MemoryPool* pool); + + explicit IOContext(MemoryPool* pool, ::arrow::internal::Executor* executor, + int64_t external_id = -1) + : pool_(pool), executor_(executor), external_id_(external_id) {} + + explicit IOContext(::arrow::internal::Executor* executor, int64_t external_id = -1) + : pool_(default_memory_pool()), executor_(executor), external_id_(external_id) {} + + MemoryPool* pool() const { return pool_; } + + ::arrow::internal::Executor* executor() const { return executor_; } + // An application-specific ID, forwarded to executor task submissions - int64_t external_id = -1; + int64_t external_id() const { return external_id_; } - // Set `executor` to a global IO-specific thread pool. - AsyncContext(); - explicit AsyncContext(::arrow::internal::Executor* executor); + private: + MemoryPool* pool_; + ::arrow::internal::Executor* executor_; + int64_t external_id_; +}; + +struct ARROW_DEPRECATED("renamed to IOContext in 4.0.0") AsyncContext : public IOContext { + using IOContext::IOContext; }; class ARROW_EXPORT FileInterface { @@ -127,7 +153,7 @@ class ARROW_EXPORT Writable { /// \brief Flush buffered bytes, if any virtual Status Flush(); - Status Write(const std::string& data); + Status Write(util::string_view data); }; class ARROW_EXPORT Readable { @@ -148,6 +174,12 @@ class ARROW_EXPORT Readable { /// In some cases (e.g. a memory-mapped file), this method may avoid a /// memory copy. virtual Result> Read(int64_t nbytes) = 0; + + /// EXPERIMENTAL: The IOContext associated with this file. + /// + /// By default, this is the same as default_io_context(), but it may be + /// overriden by subclasses. + virtual const IOContext& io_context() const; }; class ARROW_EXPORT OutputStream : virtual public FileInterface, public Writable { @@ -234,9 +266,12 @@ class ARROW_EXPORT RandomAccessFile virtual Result> ReadAt(int64_t position, int64_t nbytes); /// EXPERIMENTAL: Read data asynchronously. - virtual Future> ReadAsync(const AsyncContext&, int64_t position, + virtual Future> ReadAsync(const IOContext&, int64_t position, int64_t nbytes); + /// EXPERIMENTAL: Read data asynchronously, using the file's IOContext. + Future> ReadAsync(int64_t position, int64_t nbytes); + /// EXPERIMENTAL: Inform that the given ranges may be read soon. /// /// Some implementations might arrange to prefetch some of the data. @@ -248,8 +283,8 @@ class ARROW_EXPORT RandomAccessFile RandomAccessFile(); private: - struct ARROW_NO_EXPORT RandomAccessFileImpl; - std::unique_ptr interface_impl_; + struct ARROW_NO_EXPORT Impl; + std::unique_ptr interface_impl_; }; class ARROW_EXPORT WritableFile : public OutputStream, public Seekable { diff --git a/cpp/src/arrow/io/memory.cc b/cpp/src/arrow/io/memory.cc index 1ac435ab64297..a953c8f28a7a2 100644 --- a/cpp/src/arrow/io/memory.cc +++ b/cpp/src/arrow/io/memory.cc @@ -320,7 +320,7 @@ Status BufferReader::WillNeed(const std::vector& ranges) { return st; } -Future> BufferReader::ReadAsync(const AsyncContext&, +Future> BufferReader::ReadAsync(const IOContext&, int64_t position, int64_t nbytes) { return Future>::MakeFinished(DoReadAt(position, nbytes)); diff --git a/cpp/src/arrow/io/memory.h b/cpp/src/arrow/io/memory.h index 075398a180bbd..bfebe9945f83f 100644 --- a/cpp/src/arrow/io/memory.h +++ b/cpp/src/arrow/io/memory.h @@ -160,7 +160,7 @@ class ARROW_EXPORT BufferReader std::shared_ptr buffer() const { return buffer_; } // Synchronous ReadAsync override - Future> ReadAsync(const AsyncContext&, int64_t position, + Future> ReadAsync(const IOContext&, int64_t position, int64_t nbytes) override; Status WillNeed(const std::vector& ranges) override; diff --git a/cpp/src/arrow/io/type_fwd.h b/cpp/src/arrow/io/type_fwd.h index 130ced9db673f..041b825c98834 100644 --- a/cpp/src/arrow/io/type_fwd.h +++ b/cpp/src/arrow/io/type_fwd.h @@ -17,6 +17,8 @@ #pragma once +#include "arrow/util/visibility.h" + namespace arrow { namespace io { @@ -24,6 +26,12 @@ struct FileMode { enum type { READ, WRITE, READWRITE }; }; +struct IOContext; + +/// EXPERIMENTAL: convenience global singleton for default IOContext settings +ARROW_EXPORT +const IOContext& default_io_context(); + class FileInterface; class Seekable; class Writable; diff --git a/cpp/src/arrow/ipc/type_fwd.h b/cpp/src/arrow/ipc/type_fwd.h index bef9776c6a0c4..d3f5c5b82e4ad 100644 --- a/cpp/src/arrow/ipc/type_fwd.h +++ b/cpp/src/arrow/ipc/type_fwd.h @@ -47,6 +47,9 @@ enum class MessageType { SPARSE_TENSOR }; +struct IpcReadOptions; +struct IpcWriteOptions; + class MessageReader; class RecordBatchStreamReader; diff --git a/cpp/src/arrow/util/parallel.h b/cpp/src/arrow/util/parallel.h index e2c87a534a6d2..e56a71b91af72 100644 --- a/cpp/src/arrow/util/parallel.h +++ b/cpp/src/arrow/util/parallel.h @@ -30,12 +30,12 @@ namespace internal { // arguments between 0 and `num_tasks - 1`, on an arbitrary number of threads. template -Status ParallelFor(int num_tasks, FUNCTION&& func) { - auto pool = internal::GetCpuThreadPool(); +Status ParallelFor(int num_tasks, FUNCTION&& func, + Executor* executor = internal::GetCpuThreadPool()) { std::vector> futures(num_tasks); for (int i = 0; i < num_tasks; ++i) { - ARROW_ASSIGN_OR_RAISE(futures[i], pool->Submit(func, i)); + ARROW_ASSIGN_OR_RAISE(futures[i], executor->Submit(func, i)); } auto st = Status::OK(); for (auto& fut : futures) { @@ -49,9 +49,10 @@ Status ParallelFor(int num_tasks, FUNCTION&& func) { // depending on the input boolean. template -Status OptionalParallelFor(bool use_threads, int num_tasks, FUNCTION&& func) { +Status OptionalParallelFor(bool use_threads, int num_tasks, FUNCTION&& func, + Executor* executor = internal::GetCpuThreadPool()) { if (use_threads) { - return ParallelFor(num_tasks, std::forward(func)); + return ParallelFor(num_tasks, std::forward(func), executor); } else { for (int i = 0; i < num_tasks; ++i) { RETURN_NOT_OK(func(i)); diff --git a/cpp/src/parquet/arrow/reader.cc b/cpp/src/parquet/arrow/reader.cc index e784d39101619..1e66d5c52c011 100644 --- a/cpp/src/parquet/arrow/reader.cc +++ b/cpp/src/parquet/arrow/reader.cc @@ -890,7 +890,7 @@ Status FileReaderImpl::GetRecordBatchReader(const std::vector& row_groups, if (reader_properties_.pre_buffer()) { // PARQUET-1698/PARQUET-1820: pre-buffer row groups/column chunks if enabled BEGIN_PARQUET_CATCH_EXCEPTIONS - reader_->PreBuffer(row_groups, column_indices, reader_properties_.async_context(), + reader_->PreBuffer(row_groups, column_indices, reader_properties_.io_context(), reader_properties_.cache_options()); END_PARQUET_CATCH_EXCEPTIONS } @@ -990,7 +990,7 @@ Status FileReaderImpl::ReadRowGroups(const std::vector& row_groups, if (reader_properties_.pre_buffer()) { BEGIN_PARQUET_CATCH_EXCEPTIONS parquet_reader()->PreBuffer(row_groups, column_indices, - reader_properties_.async_context(), + reader_properties_.io_context(), reader_properties_.cache_options()); END_PARQUET_CATCH_EXCEPTIONS } diff --git a/cpp/src/parquet/file_reader.cc b/cpp/src/parquet/file_reader.cc index 39ef337c3eb18..730f5b9fb9b92 100644 --- a/cpp/src/parquet/file_reader.cc +++ b/cpp/src/parquet/file_reader.cc @@ -252,7 +252,7 @@ class SerializedFile : public ParquetFileReader::Contents { void PreBuffer(const std::vector& row_groups, const std::vector& column_indices, - const ::arrow::io::AsyncContext& ctx, + const ::arrow::io::IOContext& ctx, const ::arrow::io::CacheOptions& options) { cached_source_ = std::make_shared<::arrow::io::internal::ReadRangeCache>(source_, ctx, options); @@ -595,7 +595,7 @@ std::shared_ptr ParquetFileReader::RowGroup(int i) { void ParquetFileReader::PreBuffer(const std::vector& row_groups, const std::vector& column_indices, - const ::arrow::io::AsyncContext& ctx, + const ::arrow::io::IOContext& ctx, const ::arrow::io::CacheOptions& options) { // Access private methods here SerializedFile* file = diff --git a/cpp/src/parquet/file_reader.h b/cpp/src/parquet/file_reader.h index 79af3cd2b35ae..12c2878329184 100644 --- a/cpp/src/parquet/file_reader.h +++ b/cpp/src/parquet/file_reader.h @@ -138,7 +138,7 @@ class PARQUET_EXPORT ParquetFileReader { /// only one row group at a time may be useful. void PreBuffer(const std::vector& row_groups, const std::vector& column_indices, - const ::arrow::io::AsyncContext& ctx, + const ::arrow::io::IOContext& ctx, const ::arrow::io::CacheOptions& options); private: diff --git a/cpp/src/parquet/properties.h b/cpp/src/parquet/properties.h index f042201212217..be17f447a381a 100644 --- a/cpp/src/parquet/properties.h +++ b/cpp/src/parquet/properties.h @@ -616,16 +616,16 @@ class PARQUET_EXPORT ArrowReaderProperties { ::arrow::io::CacheOptions cache_options() const { return cache_options_; } /// Set execution context for read coalescing. - void set_async_context(::arrow::io::AsyncContext ctx) { async_context_ = ctx; } + void set_io_context(const ::arrow::io::IOContext& ctx) { io_context_ = ctx; } - ::arrow::io::AsyncContext async_context() const { return async_context_; } + const ::arrow::io::IOContext& io_context() const { return io_context_; } private: bool use_threads_; std::unordered_set read_dict_indices_; int64_t batch_size_; bool pre_buffer_; - ::arrow::io::AsyncContext async_context_; + ::arrow::io::IOContext io_context_; ::arrow::io::CacheOptions cache_options_; }; diff --git a/docs/source/cpp/csv.rst b/docs/source/cpp/csv.rst index 44dc1498f18a2..f8a508c3f946f 100644 --- a/docs/source/cpp/csv.rst +++ b/docs/source/cpp/csv.rst @@ -41,8 +41,7 @@ A CSV file is read from a :class:`~arrow::io::InputStream`. { // ... - arrow::MemoryPool* pool = default_memory_pool(); - arrow::io::AsyncContext async_context; + arrow::io::IOContext io_context = arrow::io::default_io_context(); std::shared_ptr input = ...; auto read_options = arrow::csv::ReadOptions::Defaults(); @@ -51,8 +50,7 @@ A CSV file is read from a :class:`~arrow::io::InputStream`. // Instantiate TableReader from input stream and options auto maybe_reader = - arrow::csv::TableReader::Make(pool, - async_context, + arrow::csv::TableReader::Make(io_context, input, read_options, parse_options, diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index f5b8e4d5fbac6..2f09a82c20ef7 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -703,7 +703,6 @@ def read_csv(input_file, read_options=None, parse_options=None, CCSVConvertOptions c_convert_options shared_ptr[CCSVReader] reader shared_ptr[CTable] table - CAsyncContext c_async_ctx = CAsyncContext() _get_reader(input_file, read_options, &stream) _get_read_options(read_options, &c_read_options) @@ -711,8 +710,8 @@ def read_csv(input_file, read_options=None, parse_options=None, _get_convert_options(convert_options, &c_convert_options) reader = GetResultValue(CCSVReader.Make( - maybe_unbox_memory_pool(memory_pool), c_async_ctx, stream, - c_read_options, c_parse_options, c_convert_options)) + CIOContext(maybe_unbox_memory_pool(memory_pool)), + stream, c_read_options, c_parse_options, c_convert_options)) with nogil: table = GetResultValue(reader.get().Read()) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index a4f6f18628402..046d19892025a 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1140,8 +1140,11 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: ObjectType_FILE" arrow::io::ObjectType::FILE" ObjectType_DIRECTORY" arrow::io::ObjectType::DIRECTORY" - cdef cppclass CAsyncContext" arrow::io::AsyncContext": - CAsyncContext() + cdef cppclass CIOContext" arrow::io::IOContext": + CIOContext() + CIOContext(CMemoryPool*) + + CIOContext c_default_io_context "arrow::io::default_io_context"() cdef cppclass FileStatistics: int64_t size @@ -1628,7 +1631,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: cdef cppclass CCSVReader" arrow::csv::TableReader": @staticmethod CResult[shared_ptr[CCSVReader]] Make( - CMemoryPool*, CAsyncContext, shared_ptr[CInputStream], + CIOContext, shared_ptr[CInputStream], CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions) CResult[shared_ptr[CTable]] Read() diff --git a/r/src/csv.cpp b/r/src/csv.cpp index 69b834a6be003..0ce4cd699f853 100644 --- a/r/src/csv.cpp +++ b/r/src/csv.cpp @@ -141,9 +141,9 @@ std::shared_ptr csv___TableReader__Make( const std::shared_ptr& read_options, const std::shared_ptr& parse_options, const std::shared_ptr& convert_options) { - return ValueOrStop( - arrow::csv::TableReader::Make(gc_memory_pool(), arrow::io::AsyncContext(), input, - *read_options, *parse_options, *convert_options)); + return ValueOrStop(arrow::csv::TableReader::Make(arrow::io::IOContext(gc_memory_pool()), + input, *read_options, *parse_options, + *convert_options)); } // [[arrow::export]] diff --git a/r/src/filesystem.cpp b/r/src/filesystem.cpp index 066e5b540f278..fced8abdd07b1 100644 --- a/r/src/filesystem.cpp +++ b/r/src/filesystem.cpp @@ -23,6 +23,7 @@ #include namespace fs = ::arrow::fs; +namespace io = ::arrow::io; namespace cpp11 { @@ -268,7 +269,7 @@ void fs___CopyFiles(const std::shared_ptr& source_fs, const std::string& destination_base_dir, int64_t chunk_size = 1024 * 1024, bool use_threads = true) { StopIfNotOk(fs::CopyFiles(source_fs, *source_sel, destination_fs, destination_base_dir, - chunk_size, use_threads)); + io::default_io_context(), chunk_size, use_threads)); } #endif From fd22dd93d9835090a350c859f6e75dd2615d0bd0 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 26 Feb 2021 12:10:09 +0900 Subject: [PATCH 41/54] ARROW-11786: [C++] Remove noisy CMake message Closes #9578 from lidavidm/arrow-11786 Authored-by: David Li Signed-off-by: Sutou Kouhei --- cpp/src/arrow/flight/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index b44bab290746c..77789d4d2643a 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -85,7 +85,6 @@ function(test_grpc_version DST_VAR DETECT_VERSION TEST_FILE) CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CURRENT_INCLUDE_DIRECTORIES}" LINK_LIBRARIES gRPC::grpc gRPC::grpc++ OUTPUT_VARIABLE TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT CXX_STANDARD 11) - message(STATUS "${TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT}") if(HAS_GRPC_VERSION) set(${DST_VAR} "${DETECT_VERSION}" PARENT_SCOPE) else() From 6170c5190feab2f585e65e989957bfc67ad6bd52 Mon Sep 17 00:00:00 2001 From: Ximo Guanter Date: Fri, 26 Feb 2021 06:09:47 -0500 Subject: [PATCH 42/54] ARROW-11784: [Rust][DataFusion] CoalesceBatchesStream doesn't honor Stream interface Unit tests now cover the bug to avoid regressions. Closes #9574 from edrevo/fix-coalescebatchesstream Authored-by: Ximo Guanter Signed-off-by: Andrew Lamb --- .../src/physical_plan/coalesce_batches.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/src/physical_plan/coalesce_batches.rs b/rust/datafusion/src/physical_plan/coalesce_batches.rs index 9f36fd8f794a7..b91e0b672eb58 100644 --- a/rust/datafusion/src/physical_plan/coalesce_batches.rs +++ b/rust/datafusion/src/physical_plan/coalesce_batches.rs @@ -111,6 +111,7 @@ impl ExecutionPlan for CoalesceBatchesExec { target_batch_size: self.target_batch_size, buffer: Vec::new(), buffered_rows: 0, + is_closed: false, })) } } @@ -126,6 +127,8 @@ struct CoalesceBatchesStream { buffer: Vec, /// Buffered row count buffered_rows: usize, + /// Whether the stream has finished returning all of its data or not + is_closed: bool, } impl Stream for CoalesceBatchesStream { @@ -135,6 +138,9 @@ impl Stream for CoalesceBatchesStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + if self.is_closed { + return Poll::Ready(None); + } loop { let input_batch = self.input.poll_next_unpin(cx); match input_batch { @@ -167,6 +173,7 @@ impl Stream for CoalesceBatchesStream { } } None => { + self.is_closed = true; // we have reached the end of the input stream but there could still // be buffered batches if self.buffer.is_empty() { @@ -234,7 +241,7 @@ pub fn concat_batches( #[cfg(test)] mod tests { use super::*; - use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; use arrow::array::UInt32Array; use arrow::datatypes::{DataType, Field, Schema}; @@ -244,7 +251,7 @@ mod tests { let partition = create_vec_batches(&schema, 10); let partitions = vec![partition]; - let output_partitions = coalesce_batches(&schema, partitions, 20).await?; + let output_partitions = coalesce_batches(&schema, partitions, 21).await?; assert_eq!(1, output_partitions.len()); // input is 10 batches x 8 rows (80 rows) @@ -287,6 +294,8 @@ mod tests { ) -> Result>> { // create physical plan let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = + RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; let exec: Arc = Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size)); From 4da5822cbb04562f9ced5304cfbbe308c950d040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 26 Feb 2021 07:50:29 -0800 Subject: [PATCH 43/54] ARROW-11794: [Go] Add concurrent-safe ipc.FileReader.RecordAt(i) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Arrow IPC files are safe to load concurrently. The implementation of `ipc.FileReader.Record(i)` is not safe due to stashing the current record internally. This adds a backward-compatible function `RecordAt` that behaves like ReadAt. Closes #9584 from fsaintjacques/go-concurrent-file-reader Authored-by: François Saint-Jacques Signed-off-by: Neal Richardson --- go/arrow/internal/arrdata/ioutil.go | 50 +++++++++++++++++++++++++++++ go/arrow/ipc/file_reader.go | 24 ++++++++++---- go/arrow/ipc/file_test.go | 1 + 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/go/arrow/internal/arrdata/ioutil.go b/go/arrow/internal/arrdata/ioutil.go index 7065f64b503cd..33aab24bb3c39 100644 --- a/go/arrow/internal/arrdata/ioutil.go +++ b/go/arrow/internal/arrdata/ioutil.go @@ -17,8 +17,10 @@ package arrdata // import "github.com/apache/arrow/go/arrow/internal/arrdata" import ( + "fmt" "io" "os" + "sync" "testing" "github.com/apache/arrow/go/arrow" @@ -59,6 +61,54 @@ func CheckArrowFile(t *testing.T, f *os.File, mem memory.Allocator, schema *arro } +func CheckArrowConcurrentFile(t *testing.T, f *os.File, mem memory.Allocator, schema *arrow.Schema, recs []array.Record) { + t.Helper() + + _, err := f.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + + r, err := ipc.NewFileReader(f, ipc.WithSchema(schema), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + var g sync.WaitGroup + errs := make(chan error, r.NumRecords()) + checkRecord := func(i int) { + defer g.Done() + rec, err := r.RecordAt(i) + if err != nil { + errs <- fmt.Errorf("could not read record %d: %v", i, err) + return + } + if !array.RecordEqual(rec, recs[i]) { + errs <- fmt.Errorf("records[%d] differ", i) + } + } + + for i := 0; i < r.NumRecords(); i++ { + g.Add(1) + go checkRecord(i) + } + + g.Wait() + close(errs) + + for err := range errs { + if err != nil { + t.Fatal(err) + } + } + + err = r.Close() + if err != nil { + t.Fatal(err) + } +} + // CheckArrowStream checks whether a given ARROW stream contains the expected list of records. func CheckArrowStream(t *testing.T, f *os.File, mem memory.Allocator, schema *arrow.Schema, recs []array.Record) { t.Helper() diff --git a/go/arrow/ipc/file_reader.go b/go/arrow/ipc/file_reader.go index 961803b33ef1d..cf32448201878 100644 --- a/go/arrow/ipc/file_reader.go +++ b/go/arrow/ipc/file_reader.go @@ -244,6 +244,23 @@ func (f *FileReader) Close() error { // The returned value is valid until the next call to Record. // Users need to call Retain on that Record to keep it valid for longer. func (f *FileReader) Record(i int) (array.Record, error) { + record, err := f.RecordAt(i) + if err != nil { + return nil, err + } + + if f.record != nil { + f.record.Release() + } + + f.record = record + return record, nil +} + +// Record returns the i-th record from the file. Ownership is transferred to the +// caller and must call Release() to free the memory. This method is safe to +// call concurrently. +func (f *FileReader) RecordAt(i int) (array.Record, error) { if i < 0 || i > f.NumRecords() { panic("arrow/ipc: record index out of bounds") } @@ -271,12 +288,7 @@ func (f *FileReader) Record(i int) (array.Record, error) { return nil, xerrors.Errorf("arrow/ipc: message %d is not a Record", i) } - if f.record != nil { - f.record.Release() - } - - f.record = newRecord(f.schema, msg.meta, bytes.NewReader(msg.body.Bytes())) - return f.record, nil + return newRecord(f.schema, msg.meta, bytes.NewReader(msg.body.Bytes())), nil } // Read reads the current record from the underlying stream and an error, if any. diff --git a/go/arrow/ipc/file_test.go b/go/arrow/ipc/file_test.go index 8c5d515ba5e57..d0ef9605e61fa 100644 --- a/go/arrow/ipc/file_test.go +++ b/go/arrow/ipc/file_test.go @@ -45,6 +45,7 @@ func TestFile(t *testing.T) { arrdata.WriteFile(t, f, mem, recs[0].Schema(), recs) arrdata.CheckArrowFile(t, f, mem, recs[0].Schema(), recs) + arrdata.CheckArrowConcurrentFile(t, f, mem, recs[0].Schema(), recs) }) } } From dfd232313e1538b81a38db1e59cf4a109b61a467 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Fri, 26 Feb 2021 12:37:52 -0500 Subject: [PATCH 44/54] ARROW-11662: [C++] Support sorting decimal and fixed size binary data Also enable nth_to_indices on decimal and fixed size binary data. Closes #9577 from pitrou/ARROW-11662-sort-decimal Authored-by: Antoine Pitrou Signed-off-by: Benjamin Kietzman --- c_glib/test/test-decimal128-data-type.rb | 4 +- .../arrow/compute/kernels/codegen_internal.cc | 9 + .../arrow/compute/kernels/codegen_internal.h | 32 ++ cpp/src/arrow/compute/kernels/vector_sort.cc | 79 ++-- .../arrow/compute/kernels/vector_sort_test.cc | 385 ++++++++++++------ cpp/src/arrow/testing/gtest_util.h | 2 + cpp/src/arrow/testing/random.cc | 42 +- cpp/src/arrow/testing/random.h | 11 + cpp/src/arrow/type.cc | 2 +- cpp/src/arrow/type.h | 4 +- cpp/src/arrow/type_test.cc | 6 +- docs/source/cpp/compute.rst | 3 +- .../test/test-decimal128-data-type.rb | 4 +- 13 files changed, 432 insertions(+), 151 deletions(-) diff --git a/c_glib/test/test-decimal128-data-type.rb b/c_glib/test/test-decimal128-data-type.rb index a02e3badca051..b27e1cad1ea3f 100644 --- a/c_glib/test/test-decimal128-data-type.rb +++ b/c_glib/test/test-decimal128-data-type.rb @@ -23,12 +23,12 @@ def test_type def test_name data_type = Arrow::Decimal128DataType.new(2, 0) - assert_equal("decimal", data_type.name) + assert_equal("decimal128", data_type.name) end def test_to_s data_type = Arrow::Decimal128DataType.new(2, 0) - assert_equal("decimal(2, 0)", data_type.to_s) + assert_equal("decimal128(2, 0)", data_type.to_s) end def test_precision diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index b321ff3fc8b33..ad43b7a3aa981 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -48,6 +48,7 @@ std::vector> g_numeric_types; std::vector> g_base_binary_types; std::vector> g_temporal_types; std::vector> g_primitive_types; +std::vector g_decimal_type_ids; static std::once_flag codegen_static_initialized; template @@ -71,6 +72,9 @@ static void InitStaticData() { // Floating point types g_floating_types = {float32(), float64()}; + // Decimal types + g_decimal_type_ids = {Type::DECIMAL128, Type::DECIMAL256}; + // Numeric types Extend(g_int_types, &g_numeric_types); Extend(g_floating_types, &g_numeric_types); @@ -132,6 +136,11 @@ const std::vector>& FloatingPointTypes() { return g_floating_types; } +const std::vector& DecimalTypeIds() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_decimal_type_ids; +} + const std::vector& AllTimeUnits() { static std::vector units = {TimeUnit::SECOND, TimeUnit::MILLI, TimeUnit::MICRO, TimeUnit::NANO}; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 8c49e796623e7..9e2ed82a4267a 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -188,6 +188,16 @@ struct GetViewType { } }; +template <> +struct GetViewType { + using T = Decimal256; + using PhysicalType = util::string_view; + + static T LogicalValue(PhysicalType value) { + return Decimal256(reinterpret_cast(value.data())); + } +}; + template struct GetOutputType; @@ -206,6 +216,11 @@ struct GetOutputType { using T = Decimal128; }; +template <> +struct GetOutputType { + using T = Decimal256; +}; + // ---------------------------------------------------------------------- // Iteration / value access utilities @@ -396,6 +411,7 @@ const std::vector>& SignedIntTypes(); const std::vector>& UnsignedIntTypes(); const std::vector>& IntTypes(); const std::vector>& FloatingPointTypes(); +const std::vector& DecimalTypeIds(); ARROW_EXPORT const std::vector& AllTimeUnits(); @@ -1185,6 +1201,22 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) { } } +// Generate a kernel given a templated functor for decimal types +// +// See "Numeric" above for description of the generator functor +template