Skip to content

Commit

Permalink
feat: roundtrip FixedSizeList Scalar to protobuf
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Nov 16, 2023
1 parent b013087 commit 5b90427
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 18 deletions.
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ message ScalarValue{
int32 date_32_value = 14;
ScalarTime32Value time32_value = 15;
ScalarListValue list_value = 17;
ScalarListValue fixed_size_list_value = 18;

Decimal128 decimal128_value = 20;
Decimal256 decimal256_value = 39;
Expand Down
14 changes: 14 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::Float64Value(v) => Self::Float64(Some(*v)),
Value::Date32Value(v) => Self::Date32(Some(*v)),
// ScalarValue::List is serialized using arrow IPC format
Value::ListValue(scalar_list) => {
Value::ListValue(scalar_list) | Value::FixedSizeListValue(scalar_list) => {
let protobuf::ScalarListValue {
ipc_message,
arrow_data,
Expand Down Expand Up @@ -698,7 +698,11 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
.map_err(DataFusionError::ArrowError)
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
let arr = record_batch.column(0);
Self::List(arr.to_owned())
match value {
Value::ListValue(_) => Self::List(arr.to_owned()),
Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()),
_ => unreachable!(),
}
}
Value::NullValue(v) => {
let null_type: DataType = v.try_into()?;
Expand Down
28 changes: 16 additions & 12 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,13 +1134,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
Value::LargeUtf8Value(s.to_owned())
})
}
ScalarValue::FixedSizeList(..) => Err(Error::General(
"Proto serialization error: ScalarValue::Fixedsizelist not supported"
.to_string(),
)),
// ScalarValue::List is serialized using Arrow IPC messages.
// as a single column RecordBatch
ScalarValue::List(arr) => {
// ScalarValue::List and ScalarValue::FixedSizeList are serialized using
// Arrow IPC messages as a single column RecordBatch
ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => {
// Wrap in a "field_name" column
let batch = RecordBatch::try_from_iter(vec![(
"field_name",
Expand Down Expand Up @@ -1168,11 +1164,19 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
schema: Some(schema),
};

Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
scalar_list_value,
)),
})
match val {
ScalarValue::List(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
scalar_list_value,
)),
}),
ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::FixedSizeListValue(
scalar_list_value,
)),
}),
_ => unreachable!(),
}
}
ScalarValue::Date32(val) => {
create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s))
Expand Down
14 changes: 11 additions & 3 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use std::collections::HashMap;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::array::{ArrayRef, FixedSizeListArray};
use arrow::datatypes::{
DataType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
};

use prost::Message;
Expand Down Expand Up @@ -690,6 +690,14 @@ fn round_trip_scalar_values() {
],
&DataType::List(new_arc_field("item", DataType::Float32, true)),
)),
ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::<
Int32Type,
_,
_,
>(
vec![Some(vec![Some(1), Some(2), Some(3)])],
3,
))),
ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::Utf8(Some("foo".into()))),
Expand Down

0 comments on commit 5b90427

Please sign in to comment.