From 25cf67df0ed2b183ddcb1ba592ee72b91627bcdd Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 7 Dec 2023 08:01:35 -0700 Subject: [PATCH 1/2] Failing test --- .../tests/cases/roundtrip_logical_plan.rs | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 5e36a838f311..bb201ccf23dc 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -972,6 +972,37 @@ fn round_trip_datatype() { } } +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + let decoded_field = decoded.fields().iter().last().unwrap(); + assert_eq!(decoded_field.dict_id(), Some(dict_id), "dict_id should be retained"); + + Ok(()) +} + #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ From 24ad54eb2aa10ffe6fa003a161b6ae81fef2efe0 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 7 Dec 2023 09:50:11 -0700 Subject: [PATCH 2/2] Passing test --- datafusion/proto/proto/datafusion.proto | 2 + datafusion/proto/src/generated/pbjson.rs | 39 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 4 ++ .../proto/src/logical_plan/from_proto.rs | 16 +++++++- datafusion/proto/src/logical_plan/to_proto.rs | 2 + .../tests/cases/roundtrip_logical_plan.rs | 16 ++++++-- 6 files changed, 73 insertions(+), 6 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 863e3c315c82..de7afd5c7bb2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -840,6 +840,8 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 74798ee8e94c..3001a9c09503 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6910,6 +6910,12 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -6926,6 +6932,13 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } struct_ser.end() } } @@ -6942,6 +6955,10 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -6951,6 +6968,8 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6977,6 +6996,8 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7001,6 +7022,8 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -7035,6 +7058,20 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } } } Ok(Field { @@ -7043,6 +7080,8 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ae20913e3dd7..cfa424e444f8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1027,6 +1027,10 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 13576aaa089a..a8c9b92ee887 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -377,8 +377,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - Ok(Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone())) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 0af8d9f3e719..4c6fdaa894ae 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -108,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bb201ccf23dc..8e15b5d0d480 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -990,15 +990,23 @@ fn roundtrip_dict_id() -> Result<()> { // encode let mut buf: Vec = vec![]; - let schema_proto: datafusion_proto::generated::datafusion::Schema = schema.try_into().unwrap(); + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); schema_proto.encode(&mut buf).unwrap(); // decode - let schema_proto = datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); let decoded: Schema = (&schema_proto).try_into()?; - let decoded_field = decoded.fields().iter().last().unwrap(); - assert_eq!(decoded_field.dict_id(), Some(dict_id), "dict_id should be retained"); + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } Ok(()) }