Skip to content

Commit

Permalink
Change schema_infer_max_rec config to use Option<usize> rather t…
Browse files Browse the repository at this point in the history
…han `usize` (#13250)

* Make schema_infer_max_rec an Option

* Add lifetime parameter to CSV and compression BoxStreams
  • Loading branch information
alihan-synnada authored Nov 6, 2024
1 parent 6612d7c commit 39aa15e
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 55 deletions.
6 changes: 3 additions & 3 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ config_namespace! {
/// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting.
pub newlines_in_values: Option<bool>, default = None
pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED
pub schema_infer_max_rec: usize, default = 100
pub schema_infer_max_rec: Option<usize>, default = None
pub date_format: Option<String>, default = None
pub datetime_format: Option<String>, default = None
pub timestamp_format: Option<String>, default = None
Expand All @@ -1673,7 +1673,7 @@ impl CsvOptions {
/// Set a limit in terms of records to scan to infer the schema
/// - default to `DEFAULT_SCHEMA_INFER_MAX_RECORD`
pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
self.schema_infer_max_rec = max_rec;
self.schema_infer_max_rec = Some(max_rec);
self
}

Expand Down Expand Up @@ -1773,7 +1773,7 @@ config_namespace! {
/// Options controlling JSON format
pub struct JsonOptions {
pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED
pub schema_infer_max_rec: usize, default = 100
pub schema_infer_max_rec: Option<usize>, default = None
}
}

Expand Down
24 changes: 15 additions & 9 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::fmt::{self, Debug};
use std::sync::Arc;

use super::write::orchestration::stateless_multipart_put;
use super::{FileFormat, FileFormatFactory};
use super::{FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD};
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::file_format::write::BatchSerializer;
use crate::datasource::physical_plan::{
Expand Down Expand Up @@ -137,11 +137,11 @@ impl CsvFormat {
/// Return a newline delimited stream from the specified file on
/// Stream, decompressing if necessary
/// Each returned `Bytes` has a whole number of newline delimited rows
async fn read_to_delimited_chunks(
async fn read_to_delimited_chunks<'a>(
&self,
store: &Arc<dyn ObjectStore>,
object: &ObjectMeta,
) -> BoxStream<'static, Result<Bytes>> {
) -> BoxStream<'a, Result<Bytes>> {
// stream to only read as many rows as needed into memory
let stream = store
.get(&object.location)
Expand All @@ -165,10 +165,10 @@ impl CsvFormat {
stream.boxed()
}

async fn read_to_delimited_chunks_from_stream(
async fn read_to_delimited_chunks_from_stream<'a>(
&self,
stream: BoxStream<'static, Result<Bytes>>,
) -> BoxStream<'static, Result<Bytes>> {
stream: BoxStream<'a, Result<Bytes>>,
) -> BoxStream<'a, Result<Bytes>> {
let file_compression_type: FileCompressionType = self.options.compression.into();
let decoder = file_compression_type.convert_stream(stream);
let steam = match decoder {
Expand Down Expand Up @@ -204,7 +204,7 @@ impl CsvFormat {
/// Set a limit in terms of records to scan to infer the schema
/// - default to `DEFAULT_SCHEMA_INFER_MAX_RECORD`
pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
self.options.schema_infer_max_rec = max_rec;
self.options.schema_infer_max_rec = Some(max_rec);
self
}

Expand Down Expand Up @@ -319,7 +319,10 @@ impl FileFormat for CsvFormat {
) -> Result<SchemaRef> {
let mut schemas = vec![];

let mut records_to_read = self.options.schema_infer_max_rec;
let mut records_to_read = self
.options
.schema_infer_max_rec
.unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);

for object in objects {
let stream = self.read_to_delimited_chunks(store, object).await;
Expand Down Expand Up @@ -945,7 +948,10 @@ mod tests {
let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap();
let path = Path::from("csv/aggregate_test_100.csv");
let csv = CsvFormat::default().with_has_header(true);
let records_to_read = csv.options().schema_infer_max_rec;
let records_to_read = csv
.options()
.schema_infer_max_rec
.unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
let store = Arc::new(integration) as Arc<dyn ObjectStore>;
let original_stream = store.get(&path).await?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ impl FileCompressionType {
}

/// Given a `Stream`, create a `Stream` which data are compressed with `FileCompressionType`.
pub fn convert_to_compress_stream(
pub fn convert_to_compress_stream<'a>(
&self,
s: BoxStream<'static, Result<Bytes>>,
) -> Result<BoxStream<'static, Result<Bytes>>> {
s: BoxStream<'a, Result<Bytes>>,
) -> Result<BoxStream<'a, Result<Bytes>>> {
Ok(match self.variant {
#[cfg(feature = "compression")]
GZIP => ReaderStream::new(AsyncGzEncoder::new(StreamReader::new(s)))
Expand Down Expand Up @@ -180,10 +180,10 @@ impl FileCompressionType {
}

/// Given a `Stream`, create a `Stream` which data are decompressed with `FileCompressionType`.
pub fn convert_stream(
pub fn convert_stream<'a>(
&self,
s: BoxStream<'static, Result<Bytes>>,
) -> Result<BoxStream<'static, Result<Bytes>>> {
s: BoxStream<'a, Result<Bytes>>,
) -> Result<BoxStream<'a, Result<Bytes>>> {
Ok(match self.variant {
#[cfg(feature = "compression")]
GZIP => {
Expand Down
11 changes: 8 additions & 3 deletions datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use std::io::BufReader;
use std::sync::Arc;

use super::write::orchestration::stateless_multipart_put;
use super::{FileFormat, FileFormatFactory, FileScanConfig};
use super::{
FileFormat, FileFormatFactory, FileScanConfig, DEFAULT_SCHEMA_INFER_MAX_RECORD,
};
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::file_format::write::BatchSerializer;
use crate::datasource::physical_plan::FileGroupDisplay;
Expand Down Expand Up @@ -147,7 +149,7 @@ impl JsonFormat {
/// Set a limit in terms of records to scan to infer the schema
/// - defaults to `DEFAULT_SCHEMA_INFER_MAX_RECORD`
pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
self.options.schema_infer_max_rec = max_rec;
self.options.schema_infer_max_rec = Some(max_rec);
self
}

Expand Down Expand Up @@ -187,7 +189,10 @@ impl FileFormat for JsonFormat {
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
let mut schemas = Vec::new();
let mut records_to_read = self.options.schema_infer_max_rec;
let mut records_to_read = self
.options
.schema_infer_max_rec
.unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
let file_compression_type = FileCompressionType::from(self.options.compression);
for object in objects {
let mut take_while = || {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing_table_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ mod tests {
let format = listing_table.options().format.clone();
let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
let csv_options = csv_format.options().clone();
assert_eq!(csv_options.schema_infer_max_rec, 1000);
assert_eq!(csv_options.schema_infer_max_rec, Some(1000));
let listing_options = listing_table.options();
assert_eq!(".tbl", listing_options.file_extension);
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ message CsvOptions {
bytes quote = 3; // Quote character as a byte
bytes escape = 4; // Optional escape character as a byte
CompressionTypeVariant compression = 5; // Compression type
uint64 schema_infer_max_rec = 6; // Max records for schema inference
optional uint64 schema_infer_max_rec = 6; // Optional max records for schema inference
string date_format = 7; // Optional date format
string datetime_format = 8; // Optional datetime format
string timestamp_format = 9; // Optional timestamp format
Expand All @@ -430,7 +430,7 @@ message CsvOptions {
// Options controlling CSV format
message JsonOptions {
CompressionTypeVariant compression = 1; // Compression type
uint64 schema_infer_max_rec = 2; // Max records for schema inference
optional uint64 schema_infer_max_rec = 2; // Optional max records for schema inference
}

message TableParquetOptions {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions {
double_quote: proto_opts.has_header.first().map(|h| *h != 0),
newlines_in_values: proto_opts.newlines_in_values.first().map(|h| *h != 0),
compression: proto_opts.compression().into(),
schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize,
schema_infer_max_rec: proto_opts.schema_infer_max_rec.map(|h| h as usize),
date_format: (!proto_opts.date_format.is_empty())
.then(|| proto_opts.date_format.clone()),
datetime_format: (!proto_opts.datetime_format.is_empty())
Expand Down Expand Up @@ -1050,7 +1050,7 @@ impl TryFrom<&protobuf::JsonOptions> for JsonOptions {
let compression: protobuf::CompressionTypeVariant = proto_opts.compression();
Ok(JsonOptions {
compression: compression.into(),
schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize,
schema_infer_max_rec: proto_opts.schema_infer_max_rec.map(|h| h as usize),
})
}
}
Expand Down
20 changes: 10 additions & 10 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,7 +1512,7 @@ impl serde::Serialize for CsvOptions {
if self.compression != 0 {
len += 1;
}
if self.schema_infer_max_rec != 0 {
if self.schema_infer_max_rec.is_some() {
len += 1;
}
if !self.date_format.is_empty() {
Expand Down Expand Up @@ -1571,10 +1571,10 @@ impl serde::Serialize for CsvOptions {
.map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?;
struct_ser.serialize_field("compression", &v)?;
}
if self.schema_infer_max_rec != 0 {
if let Some(v) = self.schema_infer_max_rec.as_ref() {
#[allow(clippy::needless_borrow)]
#[allow(clippy::needless_borrows_for_generic_args)]
struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?;
struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&v).as_str())?;
}
if !self.date_format.is_empty() {
struct_ser.serialize_field("dateFormat", &self.date_format)?;
Expand Down Expand Up @@ -1787,7 +1787,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
return Err(serde::de::Error::duplicate_field("schemaInferMaxRec"));
}
schema_infer_max_rec__ =
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0)
;
}
GeneratedField::DateFormat => {
Expand Down Expand Up @@ -1866,7 +1866,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions {
quote: quote__.unwrap_or_default(),
escape: escape__.unwrap_or_default(),
compression: compression__.unwrap_or_default(),
schema_infer_max_rec: schema_infer_max_rec__.unwrap_or_default(),
schema_infer_max_rec: schema_infer_max_rec__,
date_format: date_format__.unwrap_or_default(),
datetime_format: datetime_format__.unwrap_or_default(),
timestamp_format: timestamp_format__.unwrap_or_default(),
Expand Down Expand Up @@ -3929,7 +3929,7 @@ impl serde::Serialize for JsonOptions {
if self.compression != 0 {
len += 1;
}
if self.schema_infer_max_rec != 0 {
if self.schema_infer_max_rec.is_some() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.JsonOptions", len)?;
Expand All @@ -3938,10 +3938,10 @@ impl serde::Serialize for JsonOptions {
.map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?;
struct_ser.serialize_field("compression", &v)?;
}
if self.schema_infer_max_rec != 0 {
if let Some(v) = self.schema_infer_max_rec.as_ref() {
#[allow(clippy::needless_borrow)]
#[allow(clippy::needless_borrows_for_generic_args)]
struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?;
struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&v).as_str())?;
}
struct_ser.end()
}
Expand Down Expand Up @@ -4019,14 +4019,14 @@ impl<'de> serde::Deserialize<'de> for JsonOptions {
return Err(serde::de::Error::duplicate_field("schemaInferMaxRec"));
}
schema_infer_max_rec__ =
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0)
;
}
}
}
Ok(JsonOptions {
compression: compression__.unwrap_or_default(),
schema_infer_max_rec: schema_infer_max_rec__.unwrap_or_default(),
schema_infer_max_rec: schema_infer_max_rec__,
})
}
}
Expand Down
12 changes: 6 additions & 6 deletions datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ pub struct CsvOptions {
/// Compression type
#[prost(enumeration = "CompressionTypeVariant", tag = "5")]
pub compression: i32,
/// Max records for schema inference
#[prost(uint64, tag = "6")]
pub schema_infer_max_rec: u64,
/// Optional max records for schema inference
#[prost(uint64, optional, tag = "6")]
pub schema_infer_max_rec: ::core::option::Option<u64>,
/// Optional date format
#[prost(string, tag = "7")]
pub date_format: ::prost::alloc::string::String,
Expand Down Expand Up @@ -612,9 +612,9 @@ pub struct JsonOptions {
/// Compression type
#[prost(enumeration = "CompressionTypeVariant", tag = "1")]
pub compression: i32,
/// Max records for schema inference
#[prost(uint64, tag = "2")]
pub schema_infer_max_rec: u64,
/// Optional max records for schema inference
#[prost(uint64, optional, tag = "2")]
pub schema_infer_max_rec: ::core::option::Option<u64>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TableParquetOptions {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions {
.newlines_in_values
.map_or_else(Vec::new, |h| vec![h as u8]),
compression: compression.into(),
schema_infer_max_rec: opts.schema_infer_max_rec as u64,
schema_infer_max_rec: opts.schema_infer_max_rec.map(|h| h as u64),
date_format: opts.date_format.clone().unwrap_or_default(),
datetime_format: opts.datetime_format.clone().unwrap_or_default(),
timestamp_format: opts.timestamp_format.clone().unwrap_or_default(),
Expand All @@ -940,7 +940,7 @@ impl TryFrom<&JsonOptions> for protobuf::JsonOptions {
let compression: protobuf::CompressionTypeVariant = opts.compression.into();
Ok(protobuf::JsonOptions {
compression: compression.into(),
schema_infer_max_rec: opts.schema_infer_max_rec as u64,
schema_infer_max_rec: opts.schema_infer_max_rec.map(|h| h as u64),
})
}
}
Expand Down
12 changes: 6 additions & 6 deletions datafusion/proto/src/generated/datafusion_proto_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ pub struct CsvOptions {
/// Compression type
#[prost(enumeration = "CompressionTypeVariant", tag = "5")]
pub compression: i32,
/// Max records for schema inference
#[prost(uint64, tag = "6")]
pub schema_infer_max_rec: u64,
/// Optional max records for schema inference
#[prost(uint64, optional, tag = "6")]
pub schema_infer_max_rec: ::core::option::Option<u64>,
/// Optional date format
#[prost(string, tag = "7")]
pub date_format: ::prost::alloc::string::String,
Expand Down Expand Up @@ -612,9 +612,9 @@ pub struct JsonOptions {
/// Compression type
#[prost(enumeration = "CompressionTypeVariant", tag = "1")]
pub compression: i32,
/// Max records for schema inference
#[prost(uint64, tag = "2")]
pub schema_infer_max_rec: u64,
/// Optional max records for schema inference
#[prost(uint64, optional, tag = "2")]
pub schema_infer_max_rec: ::core::option::Option<u64>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TableParquetOptions {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/proto/src/logical_plan/file_formats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl CsvOptionsProto {
escape: options.escape.map_or(vec![], |v| vec![v]),
double_quote: options.double_quote.map_or(vec![], |v| vec![v as u8]),
compression: options.compression as i32,
schema_infer_max_rec: options.schema_infer_max_rec as u64,
schema_infer_max_rec: options.schema_infer_max_rec.map(|v| v as u64),
date_format: options.date_format.clone().unwrap_or_default(),
datetime_format: options.datetime_format.clone().unwrap_or_default(),
timestamp_format: options.timestamp_format.clone().unwrap_or_default(),
Expand Down Expand Up @@ -110,7 +110,7 @@ impl From<&CsvOptionsProto> for CsvOptions {
3 => CompressionTypeVariant::ZSTD,
_ => CompressionTypeVariant::UNCOMPRESSED,
},
schema_infer_max_rec: proto.schema_infer_max_rec as usize,
schema_infer_max_rec: proto.schema_infer_max_rec.map(|v| v as usize),
date_format: if proto.date_format.is_empty() {
None
} else {
Expand Down Expand Up @@ -239,7 +239,7 @@ impl JsonOptionsProto {
if let Some(options) = &factory.options {
JsonOptionsProto {
compression: options.compression as i32,
schema_infer_max_rec: options.schema_infer_max_rec as u64,
schema_infer_max_rec: options.schema_infer_max_rec.map(|v| v as u64),
}
} else {
JsonOptionsProto::default()
Expand All @@ -257,7 +257,7 @@ impl From<&JsonOptionsProto> for JsonOptions {
3 => CompressionTypeVariant::ZSTD,
_ => CompressionTypeVariant::UNCOMPRESSED,
},
schema_infer_max_rec: proto.schema_infer_max_rec as usize,
schema_infer_max_rec: proto.schema_infer_max_rec.map(|v| v as usize),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> {

// Set specific JSON format options
json_format.compression = CompressionTypeVariant::GZIP;
json_format.schema_infer_max_rec = 1000;
json_format.schema_infer_max_rec = Some(1000);

let file_type = format_as_file_type(Arc::new(JsonFormatFactory::new_with_options(
json_format.clone(),
Expand Down

0 comments on commit 39aa15e

Please sign in to comment.