diff --git a/src/io/json/read/infer_schema.rs b/src/io/json/read/infer_schema.rs index da5ce613d92..c54e05195dd 100644 --- a/src/io/json/read/infer_schema.rs +++ b/src/io/json/read/infer_schema.rs @@ -31,132 +31,47 @@ use crate::error::{ArrowError, Result}; /// * `Int64` and `Float64` should be `Float64` /// * Lists and scalars are coerced to a list of a compatible scalar /// * All other types are coerced to `Utf8` -fn coerce_data_type(dt: Vec<&DataType>) -> Result { - match dt.len() { - 1 => Ok(dt[0].clone()), - 2 => { - // there can be a case where a list and scalar both exist - if dt.contains(&&DataType::List(Box::new(Field::new( - "item", - DataType::Float64, - true, - )))) || dt.contains(&&DataType::List(Box::new(Field::new( - "item", - DataType::Int64, - true, - )))) || dt.contains(&&DataType::List(Box::new(Field::new( - "item", - DataType::Boolean, - true, - )))) || dt.contains(&&DataType::List(Box::new(Field::new( - "item", - DataType::Utf8, - true, - )))) { - // we have a list and scalars, so we should get the values and coerce them - let mut dt = dt; - // sorting guarantees that the list will be the second value - dt.sort(); - match (dt[0], dt[1]) { - (t1, DataType::List(e)) if e.data_type() == &DataType::Float64 => { - if t1 == &DataType::Float64 { - Ok(DataType::List(Box::new(Field::new( - "item", - DataType::Float64, - true, - )))) - } else { - Ok(DataType::List(Box::new(Field::new( - "item", - coerce_data_type(vec![t1, &DataType::Float64])?, - true, - )))) - } - } - (t1, DataType::List(e)) if e.data_type() == &DataType::Int64 => { - if t1 == &DataType::Int64 { - Ok(DataType::List(Box::new(Field::new( - "item", - DataType::Int64, - true, - )))) - } else { - Ok(DataType::List(Box::new(Field::new( - "item", - coerce_data_type(vec![t1, &DataType::Int64])?, - true, - )))) - } - } - (t1, DataType::List(e)) if e.data_type() == &DataType::Boolean => { - if t1 == &DataType::Boolean { - Ok(DataType::List(Box::new(Field::new( - "item", - DataType::Boolean, - true, - )))) - } else { - Ok(DataType::List(Box::new(Field::new( - "item", - coerce_data_type(vec![t1, &DataType::Boolean])?, - true, - )))) - } - } - (t1, DataType::List(e)) if e.data_type() == &DataType::Utf8 => { - if t1 == &DataType::Utf8 { - Ok(DataType::List(Box::new(Field::new( - "item", - DataType::Utf8, - true, - )))) - } else { - Ok(DataType::List(Box::new(Field::new( - "item", - coerce_data_type(vec![t1, &DataType::Utf8])?, - true, - )))) - } - } - (t1, t2) => Err(ArrowError::Schema(format!( - "Cannot coerce data types for {:?} and {:?}", - t1, t2 - ))), - } - } else if dt.contains(&&DataType::Float64) && dt.contains(&&DataType::Int64) { - Ok(DataType::Float64) - } else { - Ok(DataType::Utf8) - } +fn coerce_data_type(dt: &[&DataType]) -> DataType { + use DataType::*; + if dt.len() == 1 { + return dt[0].clone(); + } else if dt.len() > 2 { + return List(Box::new(Field::new("item", Utf8, true))); + } + let (lhs, rhs) = (dt[0], dt[1]); + + return match (lhs, rhs) { + (lhs, rhs) if lhs == rhs => lhs.clone(), + (List(lhs), List(rhs)) => { + let inner = coerce_data_type(&[lhs.data_type(), rhs.data_type()]); + List(Box::new(Field::new("item", inner, true))) } - _ => { - // TODO(nevi_me) It's possible to have [float, int, list(float)], which should - // return list(float). Will hash this out later - Ok(DataType::List(Box::new(Field::new( - "item", - DataType::Utf8, - true, - )))) + (scalar, List(list)) => { + let inner = coerce_data_type(&[scalar, list.data_type()]); + List(Box::new(Field::new("item", inner, true))) } - } + (List(list), scalar) => { + let inner = coerce_data_type(&[scalar, list.data_type()]); + List(Box::new(Field::new("item", inner, true))) + } + (Float64, Int64) => Float64, + (Int64, Float64) => Float64, + (Int64, Boolean) => Int64, + (Boolean, Int64) => Int64, + (_, _) => Utf8, + }; } /// Generate schema from JSON field names and inferred data types -fn generate_schema(spec: HashMap>) -> Result { - let fields: Result> = spec +fn generate_schema(spec: HashMap>) -> Schema { + let fields: Vec = spec .iter() .map(|(k, hs)| { let v: Vec<&DataType> = hs.iter().collect(); - coerce_data_type(v).map(|t| Field::new(k, t, true)) + Field::new(k, coerce_data_type(&v), true) }) .collect(); - match fields { - Ok(fields) => { - let schema = Schema::new(fields); - Ok(schema) - } - Err(e) => Err(e), - } + Schema::new(fields) } /// Infer the fields of a JSON file by reading the first n records of the buffer, with @@ -234,7 +149,7 @@ where // if a record contains only nulls, it is not // added to values if !types.is_empty() { - let dt = coerce_data_type(types)?; + let dt = coerce_data_type(&types); if values.contains_key(k) { let x = values.get_mut(k).unwrap(); @@ -329,7 +244,7 @@ where }; } - generate_schema(values) + Ok(generate_schema(values)) } /// Infer the fields of a JSON file by reading the first n records of the file, with @@ -376,36 +291,20 @@ mod test { assert_eq!( List(Box::new(Field::new("item", Float64, true))), - coerce_data_type(vec![ - &Float64, - &List(Box::new(Field::new("item", Float64, true))) - ]) - .unwrap() + coerce_data_type(&[&Float64, &List(Box::new(Field::new("item", Float64, true)))]) ); assert_eq!( List(Box::new(Field::new("item", Float64, true))), - coerce_data_type(vec![ - &Float64, - &List(Box::new(Field::new("item", Int64, true))) - ]) - .unwrap() + coerce_data_type(&[&Float64, &List(Box::new(Field::new("item", Int64, true)))]) ); assert_eq!( List(Box::new(Field::new("item", Int64, true))), - coerce_data_type(vec![ - &Int64, - &List(Box::new(Field::new("item", Int64, true))) - ]) - .unwrap() + coerce_data_type(&[&Int64, &List(Box::new(Field::new("item", Int64, true)))]) ); // boolean and number are incompatible, return utf8 assert_eq!( List(Box::new(Field::new("item", Utf8, true))), - coerce_data_type(vec![ - &Boolean, - &List(Box::new(Field::new("item", Float64, true))) - ]) - .unwrap() + coerce_data_type(&[&Boolean, &List(Box::new(Field::new("item", Float64, true)))]) ); } }