diff --git a/datafusion/common/src/file_options/arrow_writer.rs b/datafusion/common/src/file_options/arrow_writer.rs index a30e6d800e20..cb921535aba5 100644 --- a/datafusion/common/src/file_options/arrow_writer.rs +++ b/datafusion/common/src/file_options/arrow_writer.rs @@ -27,6 +27,18 @@ use super::StatementOptions; #[derive(Clone, Debug)] pub struct ArrowWriterOptions {} +impl ArrowWriterOptions { + pub fn new() -> Self { + Self {} + } +} + +impl Default for ArrowWriterOptions { + fn default() -> Self { + Self::new() + } +} + impl TryFrom<(&ConfigOptions, &StatementOptions)> for ArrowWriterOptions { type Error = DataFusionError; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8bde0da133eb..d79879e57a7d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1213,6 +1213,7 @@ message FileTypeWriterOptions { JsonWriterOptions json_options = 1; ParquetWriterOptions parquet_options = 2; CsvWriterOptions csv_options = 3; + ArrowWriterOptions arrow_options = 4; } } @@ -1243,6 +1244,8 @@ message CsvWriterOptions { string null_value = 8; } +message ArrowWriterOptions {} + message WriterProperties { uint64 data_page_size_limit = 1; uint64 dictionary_page_size_limit = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 528761136ca3..d7ad6fb03c92 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1929,6 +1929,77 @@ impl<'de> serde::Deserialize<'de> for ArrowType { deserializer.deserialize_struct("datafusion.ArrowType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ArrowWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion.ArrowWriterOptions", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ArrowWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(ArrowWriterOptions { + }) + } + } + deserializer.deserialize_struct("datafusion.ArrowWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AvroFormat { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -8354,6 +8425,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::CsvOptions(v) => { struct_ser.serialize_field("csvOptions", v)?; } + file_type_writer_options::FileType::ArrowOptions(v) => { + struct_ser.serialize_field("arrowOptions", v)?; + } } } struct_ser.end() @@ -8372,6 +8446,8 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "parquetOptions", "csv_options", "csvOptions", + "arrow_options", + "arrowOptions", ]; #[allow(clippy::enum_variant_names)] @@ -8379,6 +8455,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { JsonOptions, ParquetOptions, CsvOptions, + ArrowOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8403,6 +8480,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), + "arrowOptions" | "arrow_options" => Ok(GeneratedField::ArrowOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8444,6 +8522,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("csvOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) +; + } + GeneratedField::ArrowOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ArrowOptions) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9a0b7ab332a6..d594da90879c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1646,7 +1646,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3, 4")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1660,6 +1660,8 @@ pub mod file_type_writer_options { ParquetOptions(super::ParquetWriterOptions), #[prost(message, tag = "3")] CsvOptions(super::CsvWriterOptions), + #[prost(message, tag = "4")] + ArrowOptions(super::ArrowWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1704,6 +1706,9 @@ pub struct CsvWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowWriterOptions {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WriterProperties { #[prost(uint64, tag = "1")] pub data_page_size_limit: u64, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 6ca95519a9b1..f10f11c1c093 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -16,6 +16,7 @@ // under the License. use arrow::csv::WriterBuilder; +use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -858,6 +859,13 @@ impl AsLogicalPlan for LogicalPlanNode { Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { match &opt.file_type { Some(ft) => match ft { + file_type_writer_options::FileType::ArrowOptions(_) => { + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Arrow( + ArrowWriterOptions::new(), + ), + )) + } file_type_writer_options::FileType::CsvOptions( writer_options, ) => { @@ -1659,6 +1667,17 @@ impl AsLogicalPlan for LogicalPlanNode { } CopyOptions::WriterOptions(opt) => { match opt.as_ref() { + FileTypeWriterOptions::Arrow(_) => { + let arrow_writer_options = + file_type_writer_options::FileType::ArrowOptions( + protobuf::ArrowWriterOptions {}, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(arrow_writer_options), + }, + )) + } FileTypeWriterOptions::CSV(csv_opts) => { let csv_options = &csv_opts.writer_options; let csv_writer_options = csv_writer_options_to_proto( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index ea28eeee8810..dc827d02bf25 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -42,6 +42,7 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; @@ -834,6 +835,10 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; match file_type { + protobuf::file_type_writer_options::FileType::ArrowOptions(_) => { + Ok(Self::Arrow(ArrowWriterOptions::new())) + } + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { let compression: CompressionTypeVariant = opts.compression().into(); Ok(Self::JSON(JsonWriterOptions::new(compression))) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ed21124a9e22..2d38cfd400ad 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{ IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_common::file_options::arrow_writer::ArrowWriterOptions; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -394,6 +395,45 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.arrow".to_string(), + file_format: FileType::ARROW, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::Arrow( + ArrowWriterOptions::new(), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.arrow", copy_to.output_url); + assert_eq!(FileType::ARROW, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::Arrow(_) => {} + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} + #[tokio::test] async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { let ctx = SessionContext::new();