Skip to content

Commit

Permalink
Update inference logic
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Feb 22, 2023
1 parent 5338306 commit 724ce4a
Showing 1 changed file with 128 additions and 68 deletions.
196 changes: 128 additions & 68 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ use arrow_cast::parse::Parser;
use arrow_schema::*;
use lazy_static::lazy_static;
use regex::{Regex, RegexSet};
use std::collections::HashSet;
use std::fmt;
use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
Expand All @@ -62,48 +61,70 @@ use crate::reader::records::{RecordDecoder, StringRecords};
use csv::StringRecord;

lazy_static! {
/// Order should match [`InferredDataType`]
static ref REGEX_SET: RegexSet = RegexSet::new([
r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN
r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL
r"^-?(\d+)$", //INTEGER
r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL
r"^\d{4}-\d\d-\d\d$", //DATE32
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d$", //Timestamp(Second)
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,3}$", //Timestamp(Millisecond)
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,6}$", //Timestamp(Microsecond)
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,9}$", //Timestamp(Nanosecond)
]).unwrap();
//The order should match with REGEX_SET
static ref MATCH_DATA_TYPE: Vec<DataType> = vec![
DataType::Boolean,
DataType::Float64,
DataType::Int64,
DataType::Date32,
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
];
static ref PARSE_DECIMAL_RE: Regex =
Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap();
}

/// Infer the data type of a record
fn infer_field_schema(string: &str, datetime_re: Option<Regex>) -> DataType {
// when quoting is enabled in the reader, these quotes aren't escaped, we default to
// Utf8 for them
if string.starts_with('"') {
return DataType::Utf8;
}
let matches = REGEX_SET.matches(string).into_iter().next();
// match regex in a particular order
match matches {
Some(ix) => MATCH_DATA_TYPE[ix].clone(),
None => match datetime_re {
Some(datetime_re) if datetime_re.is_match(string) => {
DataType::Timestamp(TimeUnit::Nanosecond, None)
}
#[derive(Default, Copy, Clone)]
struct InferredDataType {
/// Packed booleans indicating type
///
/// 0 - Boolean
/// 1 - Integer
/// 2 - Float64
/// 3 - Date32
/// 4 - Timestamp(Second)
/// 5 - Timestamp(Millisecond)
/// 6 - Timestamp(Microsecond)
/// 7 - Timestamp(Nanosecond)
/// 8 - Utf8
packed: u16,
}

impl InferredDataType {
/// Returns the inferred data type
fn get(&self) -> DataType {
match self.packed {
1 => DataType::Boolean,
2 => DataType::Int64,
4 | 6 => DataType::Float64, // Promote Int64 to Float64
b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
// Promote to highest precision timestamp
8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
9 => DataType::Timestamp(TimeUnit::Microsecond, None),
10 => DataType::Timestamp(TimeUnit::Millisecond, None),
11 => DataType::Timestamp(TimeUnit::Second, None),
12 => DataType::Date32,
_ => unreachable!(),
},
_ => DataType::Utf8,
},
}
}

/// Updates the [`InferredDataType`] with the given string
fn update(&mut self, string: &str, datetime_re: Option<&Regex>) {
self.packed |= if string.starts_with('"') {
1 << 8 // Utf8
} else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
1 << m
} else {
match datetime_re {
// Timestamp(Nanosecond)
Some(d) if d.is_match(string) => 1 << 7,
_ => 1 << 8, // Utf8
}
}
}
}

Expand Down Expand Up @@ -232,10 +253,9 @@ fn infer_reader_schema_with_csv_options<R: Read>(

let header_length = headers.len();
// keep track of inferred field types
let mut column_types: Vec<HashSet<DataType>> = vec![HashSet::new(); header_length];
let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];

let mut records_count = 0;
let mut fields = vec![];

let mut record = StringRecord::new();
let max_records = roptions.max_read_records.unwrap_or(usize::MAX);
Expand All @@ -250,40 +270,18 @@ fn infer_reader_schema_with_csv_options<R: Read>(
for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
if let Some(string) = record.get(i) {
if !string.is_empty() {
column_type
.insert(infer_field_schema(string, roptions.datetime_re.clone()));
column_type.update(string, roptions.datetime_re.as_ref())
}
}
}
}

// build schema from inference results
for i in 0..header_length {
let possibilities = &column_types[i];
let field_name = &headers[i];

// determine data type based on possible types
// if there are incompatible types, use DataType::Utf8
match possibilities.len() {
1 => {
for dtype in possibilities.iter() {
fields.push(Field::new(field_name, dtype.clone(), true));
}
}
2 => {
if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64)
{
// we have an integer and double, fall down to double
fields.push(Field::new(field_name, DataType::Float64, true));
} else {
// default to Utf8 for conflicting datatypes (e.g bool and int)
fields.push(Field::new(field_name, DataType::Utf8, true));
}
}
_ => fields.push(Field::new(field_name, DataType::Utf8, true)),
}
}
let fields = column_types
.iter()
.zip(&headers)
.map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
.collect();

Ok((Schema::new(fields), records_count))
}
Expand Down Expand Up @@ -683,14 +681,11 @@ fn parse(
>(
line_number, rows, i, None
),
DataType::Timestamp(TimeUnit::Second, _) => {
build_primitive_array::<TimestampSecondType>(
line_number,
rows,
i,
None,
)
}
DataType::Timestamp(TimeUnit::Second, _) => build_primitive_array::<
TimestampSecondType,
>(
line_number, rows, i, None
),
DataType::Timestamp(TimeUnit::Millisecond, _) => {
build_primitive_array::<TimestampMillisecondType>(
line_number,
Expand Down Expand Up @@ -1655,7 +1650,10 @@ mod tests {
assert_eq!(&DataType::Float64, schema.field(2).data_type());
assert_eq!(&DataType::Boolean, schema.field(3).data_type());
assert_eq!(&DataType::Date32, schema.field(4).data_type());
assert_eq!(&DataType::Timestamp(TimeUnit::Second, None), schema.field(5).data_type());
assert_eq!(
&DataType::Timestamp(TimeUnit::Second, None),
schema.field(5).data_type()
);

let names: Vec<&str> =
schema.fields().iter().map(|x| x.name().as_str()).collect();
Expand Down Expand Up @@ -1716,6 +1714,13 @@ mod tests {
}
}

/// Infer the data type of a record
fn infer_field_schema(string: &str, datetime_re: Option<Regex>) -> DataType {
let mut v = InferredDataType::default();
v.update(string, datetime_re.as_ref());
v.get()
}

#[test]
fn test_infer_field_schema() {
assert_eq!(infer_field_schema("A", None), DataType::Utf8);
Expand Down Expand Up @@ -2425,4 +2430,59 @@ mod tests {
assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
assert_eq!(read.fill_count, 4);
}

#[test]
fn test_inference() {
let cases: &[(&[&str], DataType)] = &[
(&[], DataType::Utf8),
(&["false", "12"], DataType::Utf8),
(&["12", "cupcakes"], DataType::Utf8),
(&["12", "12.4"], DataType::Float64),
(&["14050", "24332"], DataType::Int64),
(&["14050.0", "true"], DataType::Utf8),
(&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
(&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
(
&["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
DataType::Timestamp(TimeUnit::Second, None),
),
(&["2020-03-19", "2020-03-20"], DataType::Date32),
(
&["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
DataType::Timestamp(TimeUnit::Second, None),
),
(
&[
"2020-03-19",
"2020-03-19 02:00:00",
"2020-03-19 00:00:00.000",
],
DataType::Timestamp(TimeUnit::Millisecond, None),
),
(
&[
"2020-03-19",
"2020-03-19 02:00:00",
"2020-03-19 00:00:00.000000",
],
DataType::Timestamp(TimeUnit::Microsecond, None),
),
(
&[
"2020-03-19",
"2020-03-19 02:00:00.000000000",
"2020-03-19 00:00:00.000000",
],
DataType::Timestamp(TimeUnit::Nanosecond, None),
),
];

for (values, expected) in cases {
let mut t = InferredDataType::default();
for v in *values {
t.update(v, None)
}
assert_eq!(&t.get(), expected, "{:?}", values)
}
}
}

0 comments on commit 724ce4a

Please sign in to comment.