Skip to content

Commit

Permalink
Support Union types in ScalarValue (#9683)
Browse files Browse the repository at this point in the history
Support Union types in `ScalarValue`  (#9683)
  • Loading branch information
avantgardnerio authored Mar 19, 2024
1 parent 0974759 commit 8074ca1
Show file tree
Hide file tree
Showing 9 changed files with 496 additions and 3 deletions.
4 changes: 2 additions & 2 deletions datafusion/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub enum DataFusionError {
IoError(io::Error),
/// Error when SQL is syntactically incorrect.
///
/// 2nd argument is for optional backtrace
/// 2nd argument is for optional backtrace
SQL(ParserError, Option<String>),
/// Error when a feature is not yet implemented.
///
Expand Down Expand Up @@ -101,7 +101,7 @@ pub enum DataFusionError {
/// This error can be returned in cases such as when schema inference is not
/// possible and when column names are not unique.
///
/// 2nd argument is for optional backtrace
/// 2nd argument is for optional backtrace
/// Boxing the optional backtrace to prevent <https://rust-lang.github.io/rust-clippy/master/index.html#/result_large_err>
SchemaError(SchemaError, Box<Option<String>>),
/// Error during execution of the query.
Expand Down
82 changes: 82 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ use arrow::{
},
};
use arrow_array::{ArrowNativeTypeOp, Scalar};
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

pub use struct_builder::ScalarStructBuilder;

Expand Down Expand Up @@ -275,6 +277,11 @@ pub enum ScalarValue {
DurationMicrosecond(Option<i64>),
/// Duration in nanoseconds
DurationNanosecond(Option<i64>),
/// A nested datatype that can represent slots of differing types. Components:
/// `.0`: a tuple of union `type_id` and the single value held by this Scalar
/// `.1`: the list of fields, zero-to-one of which will by set in `.0`
/// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came
Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode),
/// Dictionary type: index type and value
Dictionary(Box<DataType>, Box<ScalarValue>),
}
Expand Down Expand Up @@ -375,6 +382,10 @@ impl PartialEq for ScalarValue {
(IntervalDayTime(_), _) => false,
(IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2),
(IntervalMonthDayNano(_), _) => false,
(Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => {
val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2)
}
(Union(_, _, _), _) => false,
(Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2),
(Dictionary(_, _), _) => false,
(Null, Null) => true,
Expand Down Expand Up @@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue {
(DurationMicrosecond(_), _) => None,
(DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2),
(DurationNanosecond(_), _) => None,
(Union(v1, t1, m1), Union(v2, t2, m2)) => {
if t1.eq(t2) && m1.eq(m2) {
v1.partial_cmp(v2)
} else {
None
}
}
(Union(_, _, _), _) => None,
(Dictionary(k1, v1), Dictionary(k2, v2)) => {
// Don't compare if the key types don't match (it is effectively a different datatype)
if k1 == k2 {
Expand Down Expand Up @@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue {
IntervalYearMonth(v) => v.hash(state),
IntervalDayTime(v) => v.hash(state),
IntervalMonthDayNano(v) => v.hash(state),
Union(v, t, m) => {
v.hash(state);
t.hash(state);
m.hash(state);
}
Dictionary(k, v) => {
k.hash(state);
v.hash(state);
Expand Down Expand Up @@ -1093,6 +1117,7 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(_) => {
DataType::Duration(TimeUnit::Nanosecond)
}
ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode),
ScalarValue::Dictionary(k, v) => {
DataType::Dictionary(k.clone(), Box::new(v.data_type()))
}
Expand Down Expand Up @@ -1292,6 +1317,7 @@ impl ScalarValue {
ScalarValue::DurationMillisecond(v) => v.is_none(),
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
Expand Down Expand Up @@ -2087,6 +2113,39 @@ impl ScalarValue {
e,
size
),
ScalarValue::Union(value, fields, _mode) => match value {
Some((v_id, value)) => {
let mut field_type_ids = Vec::<i8>::with_capacity(fields.len());
let mut child_arrays =
Vec::<(Field, ArrayRef)>::with_capacity(fields.len());
for (f_id, field) in fields.iter() {
let ar = if f_id == *v_id {
value.to_array_of_size(size)?
} else {
let dt = field.data_type();
new_null_array(dt, size)
};
let field = (**field).clone();
child_arrays.push((field, ar));
field_type_ids.push(f_id);
}
let type_ids = repeat(*v_id).take(size).collect::<Vec<_>>();
let type_ids = Buffer::from_slice_ref(type_ids);
let value_offsets: Option<Buffer> = None;
let ar = UnionArray::try_new(
field_type_ids.as_slice(),
type_ids,
value_offsets,
child_arrays,
)
.map_err(|e| DataFusionError::ArrowError(e, None))?;
Arc::new(ar)
}
None => {
let dt = self.data_type();
new_null_array(&dt, size)
}
},
ScalarValue::Dictionary(key_type, v) => {
// values array is one element long (the value)
match key_type.as_ref() {
Expand Down Expand Up @@ -2626,6 +2685,9 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(val) => {
eq_array_primitive!(array, index, DurationNanosecondArray, val)?
}
ScalarValue::Union(_, _, _) => {
return _not_impl_err!("Union is not supported yet")
}
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
Expand Down Expand Up @@ -2703,6 +2765,15 @@ impl ScalarValue {
ScalarValue::LargeList(arr) => arr.get_array_memory_size(),
ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(arr) => arr.get_array_memory_size(),
ScalarValue::Union(vals, fields, _mode) => {
vals.as_ref()
.map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
.unwrap_or_default()
// `fields` is boxed, so it is NOT already included in `self`
+ std::mem::size_of_val(fields)
+ (std::mem::size_of::<Field>() * fields.len())
+ fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::<usize>()
}
ScalarValue::Dictionary(dt, sv) => {
// `dt` and `sv` are boxed, so they are NOT already included in `self`
dt.size() + sv.size()
Expand Down Expand Up @@ -3048,6 +3119,9 @@ impl TryFrom<&DataType> for ScalarValue {
.to_owned()
.into(),
),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
Expand Down Expand Up @@ -3164,6 +3238,10 @@ impl fmt::Display for ScalarValue {
.join(",")
)?
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "{}:{}", id, val)?,
None => write!(f, "NULL")?,
},
ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?,
ScalarValue::Null => write!(f, "NULL")?,
};
Expand Down Expand Up @@ -3279,6 +3357,10 @@ impl fmt::Debug for ScalarValue {
ScalarValue::DurationNanosecond(_) => {
write!(f, "DurationNanosecond(\"{self}\")")
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "Union {}:{}", id, val),
None => write!(f, "Union(NULL)"),
},
ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"),
ScalarValue::Null => write!(f, "NULL"),
}
Expand Down
35 changes: 35 additions & 0 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@ mod tests {
use crate::test::exec::StatisticsExec;
use crate::ExecutionPlan;

use crate::empty::EmptyExec;
use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::{UnionFields, UnionMode};
use datafusion_common::{ColumnStatistics, ScalarValue};
use datafusion_expr::Operator;

Expand Down Expand Up @@ -1090,4 +1092,37 @@ mod tests {
assert_eq!(statistics.total_byte_size, Precision::Inexact(1600));
Ok(())
}

#[test]
fn test_equivalence_properties_union_type() -> Result<()> {
let union_type = DataType::Union(
UnionFields::new(
vec![0, 1],
vec![
Field::new("f1", DataType::Int32, true),
Field::new("f2", DataType::Utf8, true),
],
),
UnionMode::Sparse,
);

let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", union_type, true),
]));

let exec = FilterExec::try_new(
binary(
binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?,
Operator::And,
binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?,
&schema,
)?,
Arc::new(EmptyExec::new(schema.clone())),
)?;

exec.statistics().unwrap();

Ok(())
}
}
15 changes: 15 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue {
int64 nanos = 3;
}

message UnionField {
int32 field_id = 1;
Field field = 2;
}

message UnionValue {
// Note that a null union value must have one or more fields, so we
// encode a null UnionValue as one with value_id == 128
int32 value_id = 1;
ScalarValue value = 2;
repeated UnionField fields = 3;
UnionMode mode = 4;
}

message ScalarFixedSizeBinary{
bytes values = 1;
int32 length = 2;
Expand Down Expand Up @@ -1042,6 +1056,7 @@ message ScalarValue{
ScalarTime64Value time64_value = 30;
IntervalMonthDayNanoValue interval_month_day_nano = 31;
ScalarFixedSizeBinary fixed_size_binary_value = 34;
UnionValue union_value = 42;
}
}

Expand Down
Loading

0 comments on commit 8074ca1

Please sign in to comment.