Skip to content

Commit

Permalink
change to accept multi args
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Oct 30, 2023
1 parent 0d4dc36 commit a9a0c2f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 29 deletions.
32 changes: 16 additions & 16 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1840,12 +1840,12 @@ impl ScalarValue {
let arr = Decimal128Array::from(vals)
.with_precision_and_scale(*precision, *scale)
.unwrap();
wrap_into_list_array(Arc::new(arr))
wrap_into_list_array(&[&arr]).unwrap()
}

DataType::Null => {
let arr = new_null_array(&DataType::Null, values.len());
wrap_into_list_array(arr)
wrap_into_list_array(&[&arr]).unwrap()
}
_ => panic!(
"Unsupported data type {:?} for ScalarValue::list_to_array",
Expand Down Expand Up @@ -2242,18 +2242,14 @@ impl ScalarValue {
let list_array = as_list_array(array);
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));

ScalarValue::List(arr)
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nested_array])?))
}
// TODO: There is no test for FixedSizeList now, add it later
DataType::FixedSizeList(_, _) => {
let list_array = as_fixed_size_list_array(array)?;
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));

ScalarValue::List(arr)
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nested_array])?))
}
DataType::Date32 => {
typed_cast!(array, index, Date32Array, Date32)
Expand Down Expand Up @@ -3236,11 +3232,12 @@ mod tests {

let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8);

let expected = wrap_into_list_array(Arc::new(StringArray::from(vec![
let expected = wrap_into_list_array(&[&StringArray::from(vec![
"rust",
"arrow",
"data-fusion",
])));
])])
.unwrap();
let result = as_list_array(&array);
assert_eq!(result, &expected);
}
Expand Down Expand Up @@ -3274,10 +3271,10 @@ mod tests {

#[test]
fn iter_to_array_string_test() {
let arr1 =
wrap_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
let arr1 = wrap_into_list_array(&[&StringArray::from(vec!["foo", "bar", "baz"])])
.unwrap();
let arr2 =
wrap_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));
wrap_into_list_array(&[&StringArray::from(vec!["rust", "world"])]).unwrap();

let scalars = vec![
ScalarValue::List(Arc::new(arr1)),
Expand Down Expand Up @@ -4519,13 +4516,16 @@ mod tests {
// Define list-of-structs scalars

let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap();
let nl0 = ScalarValue::List(Arc::new(wrap_into_list_array(nl0_array)));
let nl0 =
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nl0_array]).unwrap()));

let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap();
let nl1 = ScalarValue::List(Arc::new(wrap_into_list_array(nl1_array)));
let nl1 =
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nl1_array]).unwrap()));

let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap();
let nl2 = ScalarValue::List(Arc::new(wrap_into_list_array(nl2_array)));
let nl2 =
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nl2_array]).unwrap()));

// iter_to_array for list-of-struct
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
Expand Down
26 changes: 17 additions & 9 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::compute;
use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_array::ListArray;
use arrow_array::{Array, ListArray};
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -336,16 +336,24 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
count
}

/// Wrap an array into a single element `ListArray`.
/// Wrap arrays into a single element `ListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
ListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
offsets,
arr,
pub fn wrap_into_list_array(arr: &[&dyn Array]) -> Result<ListArray> {
if arr.is_empty() {
return Err(DataFusionError::Internal(
"Cannot wrap empty array into list array".to_owned(),
));
}

let lens = arr.iter().map(|x| x.len()).collect::<Vec<_>>();
// Assume data type is consistent
let data_type = arr[0].data_type().to_owned();
Ok(ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(lens),
arrow::compute::concat(arr)?,
None,
)
))
}

/// An extension trait for smart pointers. Provides an interface to get a
Expand Down
3 changes: 1 addition & 2 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ impl Accumulator for ArrayAggAccumulator {
}

let concated_array = arrow::compute::concat(&element_arrays)?;
let list_array = wrap_into_list_array(concated_array);

let list_array = wrap_into_list_array(&[&concated_array])?;
Ok(ScalarValue::List(Arc::new(list_array)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ mod tests {
};

let arr = arrow::compute::sort(&arr, None).unwrap();
let list_arr = wrap_into_list_array(arr);
let list_arr = wrap_into_list_array(&[&arr]).unwrap();
ScalarValue::List(Arc::new(list_arr))
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
// Either an empty array or all nulls:
DataType::Null => {
let array = new_null_array(&DataType::Null, arrays.len());
Ok(Arc::new(wrap_into_list_array(array)))
Ok(Arc::new(wrap_into_list_array(&[&array])?))
}
data_type => array_array(arrays, data_type),
}
Expand Down

0 comments on commit a9a0c2f

Please sign in to comment.