From 21fa8e4ac5ace3a75d8ed152b2fac4bf6287670f Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 18 Mar 2024 11:07:45 -0600 Subject: [PATCH 1/6] Impl ScalarValue --- datafusion/common/src/error.rs | 4 +- datafusion/common/src/scalar/mod.rs | 78 +++++ datafusion/physical-plan/src/filter.rs | 35 +++ datafusion/proto/proto/datafusion.proto | 15 + datafusion/proto/src/generated/pbjson.rs | 272 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 26 +- .../proto/src/logical_plan/from_proto.rs | 35 +++ datafusion/proto/src/logical_plan/to_proto.rs | 29 ++ 8 files changed, 491 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 1ecd5b62bee8..d1e47b473499 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -63,7 +63,7 @@ pub enum DataFusionError { IoError(io::Error), /// Error when SQL is syntactically incorrect. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace SQL(ParserError, Option), /// Error when a feature is not yet implemented. /// @@ -101,7 +101,7 @@ pub enum DataFusionError { /// This error can be returned in cases such as when schema inference is not /// possible and when column names are not unique. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace /// Boxing the optional backtrace to prevent SchemaError(SchemaError, Box>), /// Error during execution of the query. diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5ace44f24b69..81d66b1d4ed8 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -53,6 +53,8 @@ use arrow::{ }, }; use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_buffer::Buffer; +use arrow_schema::{UnionFields, UnionMode}; pub use struct_builder::ScalarStructBuilder; @@ -275,6 +277,8 @@ pub enum ScalarValue { DurationMicrosecond(Option), /// Duration in nanoseconds DurationNanosecond(Option), + /// A nested datatype that can represent slots of differing types. Components: + Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), } @@ -375,6 +379,7 @@ impl PartialEq for ScalarValue { (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), (IntervalMonthDayNano(_), _) => false, + (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, (Null, Null) => true, @@ -500,6 +505,14 @@ impl PartialOrd for ScalarValue { (DurationMicrosecond(_), _) => None, (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), (DurationNanosecond(_), _) => None, + (Union(v1, t1, _m1), Union(v2, t2, _m2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) if k1 == k2 { @@ -663,6 +676,11 @@ impl std::hash::Hash for ScalarValue { IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), + Union(v, t, m) => { + v.hash(state); + t.hash(state); + m.hash(state); + } Dictionary(k, v) => { k.hash(state); v.hash(state); @@ -1093,6 +1111,9 @@ impl ScalarValue { ScalarValue::DurationNanosecond(_) => { DataType::Duration(TimeUnit::Nanosecond) } + ScalarValue::Union(_, fields, mode) => { + DataType::Union(fields.clone(), mode.clone()) + } ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -1292,6 +1313,7 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), + ScalarValue::Union(v, _, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -2083,6 +2105,39 @@ impl ScalarValue { e, size ), + ScalarValue::Union(value, fields, _mode) => match value { + Some((v_id, value)) => { + let mut field_type_ids = Vec::::with_capacity(fields.len()); + let mut child_arrays = + Vec::<(Field, ArrayRef)>::with_capacity(fields.len()); + for (f_id, field) in fields.iter() { + let ar = if f_id == *v_id { + value.to_array_of_size(size)? + } else { + let dt = field.data_type(); + new_null_array(&dt, size) + }; + let field = (**field).clone(); + child_arrays.push((field, ar)); + field_type_ids.push(f_id); + } + let type_ids = repeat(*v_id).take(size).collect::>(); + let type_ids = Buffer::from_slice_ref(type_ids); + let value_offsets: Option = None; + let ar = UnionArray::try_new( + field_type_ids.as_slice(), + type_ids, + value_offsets, + child_arrays, + ) + .map_err(|e| DataFusionError::ArrowError(e))?; + Arc::new(ar) + } + None => { + let dt = self.data_type(); + new_null_array(&dt, size) + } + }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { @@ -2622,6 +2677,9 @@ impl ScalarValue { ScalarValue::DurationNanosecond(val) => { eq_array_primitive!(array, index, DurationNanosecondArray, val)? } + ScalarValue::Union(_, _, _) => { + return _not_impl_err!("Union is not supported yet") + } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { DataType::Int8 => get_dict_value::(array, index)?, @@ -2699,6 +2757,15 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Union(vals, fields, _mode) => { + vals.as_ref() + .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .unwrap_or_default() + // `fields` is boxed, so it is NOT already included in `self` + + std::mem::size_of_val(fields) + + (std::mem::size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() @@ -3044,6 +3111,9 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), mode.clone()) + } DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( @@ -3160,6 +3230,10 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "{{{}}}", format!("{}:{}", id, val))?, + None => write!(f, "NULL")?, + }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; @@ -3275,6 +3349,10 @@ impl fmt::Debug for ScalarValue { ScalarValue::DurationNanosecond(_) => { write!(f, "DurationNanosecond(\"{self}\")") } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "{{{}}}", format!("{}:{}", id, val)), + None => write!(f, "Union(NULL)"), + }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 72f885a93962..f44ade7106df 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -441,7 +441,9 @@ mod tests { use crate::test::exec::StatisticsExec; use crate::ExecutionPlan; + use crate::empty::EmptyExec; use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_expr::Operator; @@ -1090,4 +1092,37 @@ mod tests { assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) } + + #[test] + fn test_equivalence_properties_union_type() -> Result<()> { + let union_type = DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", union_type, true), + ])); + + let exec = FilterExec::try_new( + binary( + binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?, + Operator::And, + binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, + &schema, + )?, + Arc::new(EmptyExec::new(schema.clone())), + )?; + + exec.statistics().unwrap(); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 597094758584..e456f5333f70 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue { int64 nanos = 3; } +message UnionField { + int32 field_id = 1; + Field field = 2; +} + +message UnionValue { + // Note that a null union value must have one or more fields, so we + // encode a null UnionValue as one with value_id == 128 + int32 value_id = 1; + ScalarValue value = 2; + repeated UnionField fields = 3; + UnionMode mode = 4; +} + message ScalarFixedSizeBinary{ bytes values = 1; int32 length = 2; @@ -1042,6 +1056,7 @@ message ScalarValue{ ScalarTime64Value time64_value = 30; IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; + UnionValue union_value = 42; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cb9633338e8f..3a1d871337b0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -24041,6 +24041,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::FixedSizeBinaryValue(v) => { struct_ser.serialize_field("fixedSizeBinaryValue", v)?; } + scalar_value::Value::UnionValue(v) => { + struct_ser.serialize_field("unionValue", v)?; + } } } struct_ser.end() @@ -24125,6 +24128,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano", "fixed_size_binary_value", "fixedSizeBinaryValue", + "union_value", + "unionValue", ]; #[allow(clippy::enum_variant_names)] @@ -24165,6 +24170,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Time64Value, IntervalMonthDayNano, FixedSizeBinaryValue, + UnionValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24222,6 +24228,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), + "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24471,6 +24478,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) +; + } + GeneratedField::UnionValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("unionValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) ; } } @@ -26930,6 +26944,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_id != 0 { + len += 1; + } + if self.field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionField", len)?; + if self.field_id != 0 { + struct_ser.serialize_field("fieldId", &self.field_id)?; + } + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_id", + "fieldId", + "field", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldId, + Field, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldId" | "field_id" => Ok(GeneratedField::FieldId), + "field" => Ok(GeneratedField::Field), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_id__ = None; + let mut field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldId => { + if field_id__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldId")); + } + field_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + } + } + Ok(UnionField { + field_id: field_id__.unwrap_or_default(), + field: field__, + }) + } + } + deserializer.deserialize_struct("datafusion.UnionField", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UnionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -27092,6 +27217,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode { deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value_id != 0 { + len += 1; + } + if self.value.is_some() { + len += 1; + } + if !self.fields.is_empty() { + len += 1; + } + if self.mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionValue", len)?; + if self.value_id != 0 { + struct_ser.serialize_field("valueId", &self.value_id)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } + if self.mode != 0 { + let v = UnionMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + struct_ser.serialize_field("mode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value_id", + "valueId", + "value", + "fields", + "mode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ValueId, + Value, + Fields, + Mode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "valueId" | "value_id" => Ok(GeneratedField::ValueId), + "value" => Ok(GeneratedField::Value), + "fields" => Ok(GeneratedField::Fields), + "mode" => Ok(GeneratedField::Mode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value_id__ = None; + let mut value__ = None; + let mut fields__ = None; + let mut mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ValueId => { + if value_id__.is_some() { + return Err(serde::de::Error::duplicate_field("valueId")); + } + value_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + GeneratedField::Mode => { + if mode__.is_some() { + return Err(serde::de::Error::duplicate_field("mode")); + } + mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(UnionValue { + value_id: value_id__.unwrap_or_default(), + value: value__, + fields: fields__.unwrap_or_default(), + mode: mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UniqueConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f5ef6c1f74f0..c715cceb1ca6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1225,6 +1225,28 @@ pub struct IntervalMonthDayNanoValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -1236,7 +1258,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -1320,6 +1342,8 @@ pub mod scalar_value { IntervalMonthDayNano(super::IntervalMonthDayNanoValue), #[prost(message, tag = "34")] FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3822b74bc18c..cc28df71c0fb 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -771,6 +771,41 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), )), + Value::UnionValue(val) => { + let mode = match val.mode { + 0 => UnionMode::Sparse, + 1 => UnionMode::Dense, + id => Err(Error::unknown("UnionMode", id))?, + }; + let ids = val + .fields + .iter() + .map(|f| f.field_id as i8) + .collect::>(); + let fields = val + .fields + .iter() + .map(|f| f.field.clone()) + .collect::>>(); + let fields = fields.ok_or_else(|| Error::required("UnionField"))?; + let fields = fields + .iter() + .map(Field::try_from) + .collect::, _>>()?; + let fields = UnionFields::new(ids, fields); + let v_id = val.value_id as i8; + let val = match &val.value { + None => None, + Some(val) => { + let val: ScalarValue = val + .as_ref() + .try_into() + .map_err(|_| Error::General("Invalid Scalar".to_string()))?; + Some((v_id, Box::new(val))) + } + }; + Self::Union(val, fields, mode) + } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7a17d2a2b405..93d7835993fc 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use crate::protobuf::{ }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, + UnionField, UnionValue, }; use arrow::{ @@ -1402,6 +1403,34 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }; Ok(protobuf::ScalarValue { value: Some(value) }) } + + ScalarValue::Union(val, df_fields, mode) => { + let mut fields = Vec::::with_capacity(df_fields.len()); + for (id, field) in df_fields.iter() { + let field_id = id as i32; + let field = Some(field.as_ref().try_into()?); + let field = UnionField { field_id, field }; + fields.push(field); + } + let mode = match mode { + UnionMode::Sparse => 0, + UnionMode::Dense => 1, + }; + let value = match val { + None => None, + Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), + }; + let val = UnionValue { + value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), + value, + fields, + mode, + }; + let val = Value::UnionValue(Box::new(val)); + let val = protobuf::ScalarValue { value: Some(val) }; + Ok(val) + } + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { From b80b47b6a7e69ad67b9603fe87fb493ba818caea Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 18 Mar 2024 14:08:12 -0600 Subject: [PATCH 2/6] cargo check --- datafusion/common/src/scalar/mod.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 81d66b1d4ed8..14aea44997de 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -2130,7 +2130,7 @@ impl ScalarValue { value_offsets, child_arrays, ) - .map_err(|e| DataFusionError::ArrowError(e))?; + .map_err(|e| DataFusionError::ArrowError(e, None))?; Arc::new(ar) } None => { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9680177d736f..6e20b739886d 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -456,6 +456,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } From 807218d7433d2360f10eb0d269c2a111d9b5da20 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 18 Mar 2024 14:48:48 -0600 Subject: [PATCH 3/6] clippy --- datafusion/common/src/scalar/mod.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 14aea44997de..2677b6ffdbef 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1111,9 +1111,7 @@ impl ScalarValue { ScalarValue::DurationNanosecond(_) => { DataType::Duration(TimeUnit::Nanosecond) } - ScalarValue::Union(_, fields, mode) => { - DataType::Union(fields.clone(), mode.clone()) - } + ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -2115,7 +2113,7 @@ impl ScalarValue { value.to_array_of_size(size)? } else { let dt = field.data_type(); - new_null_array(&dt, size) + new_null_array(dt, size) }; let field = (**field).clone(); child_arrays.push((field, ar)); @@ -3112,7 +3110,7 @@ impl TryFrom<&DataType> for ScalarValue { .into(), ), DataType::Union(fields, mode) => { - ScalarValue::Union(None, fields.clone(), mode.clone()) + ScalarValue::Union(None, fields.clone(), *mode) } DataType::Null => ScalarValue::Null, _ => { @@ -3231,7 +3229,7 @@ impl fmt::Display for ScalarValue { )? } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "{{{}}}", format!("{}:{}", id, val))?, + Some((id, val)) => write!(f, "{}:{}", id, val)?, None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, @@ -3350,7 +3348,7 @@ impl fmt::Debug for ScalarValue { write!(f, "DurationNanosecond(\"{self}\")") } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "{{{}}}", format!("{}:{}", id, val)), + Some((id, val)) => write!(f, "Union {}:{}", id, val), None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), From 26e153e3aeb64bcdbd76fae8ef1a4a6c5c3264bd Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 18 Mar 2024 16:29:43 -0600 Subject: [PATCH 4/6] eq --- datafusion/common/src/scalar/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 2677b6ffdbef..595c5b30e31b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -379,6 +379,9 @@ impl PartialEq for ScalarValue { (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), (IntervalMonthDayNano(_), _) => false, + (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => { + val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) + } (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, From 1329c8328dc2bd3266dd1caf8ae7471ba06ba2db Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 19 Mar 2024 10:20:52 -0600 Subject: [PATCH 5/6] comments --- datafusion/common/src/scalar/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 595c5b30e31b..33eff21585c7 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -278,6 +278,9 @@ pub enum ScalarValue { /// Duration in nanoseconds DurationNanosecond(Option), /// A nested datatype that can represent slots of differing types. Components: + /// `.0`: a tuple of union `type_id` and the single value held by this Scalar + /// `.1`: the list of fields, zero-to-one of which will by set in `.0` + /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), From 267dcaea54a99265baf9a3d23e4f3dad07898fb0 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 19 Mar 2024 11:21:49 -0600 Subject: [PATCH 6/6] PR feedback --- datafusion/common/src/scalar/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 33eff21585c7..cc44e715186f 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -511,8 +511,8 @@ impl PartialOrd for ScalarValue { (DurationMicrosecond(_), _) => None, (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), (DurationNanosecond(_), _) => None, - (Union(v1, t1, _m1), Union(v2, t2, _m2)) => { - if t1.eq(t2) { + (Union(v1, t1, m1), Union(v2, t2, m2)) => { + if t1.eq(t2) && m1.eq(m2) { v1.partial_cmp(v2) } else { None