diff --git a/rust/src/delta_arrow.rs b/rust/src/delta_arrow.rs index 5014910b22..0057175df2 100644 --- a/rust/src/delta_arrow.rs +++ b/rust/src/delta_arrow.rs @@ -28,21 +28,20 @@ impl TryFrom<&schema::SchemaField> for ArrowField { type Error = ArrowError; fn try_from(f: &schema::SchemaField) -> Result { - let mut field = ArrowField::new( - f.get_name(), - ArrowDataType::try_from(f.get_type())?, - f.is_nullable(), - ); - let metadata = f .get_metadata() - .to_owned() .iter() .map(|(key, val)| Ok((key.clone(), serde_json::to_string(val)?))) .collect::>() .map_err(|err| ArrowError::JsonError(err.to_string()))?; - field.set_metadata(metadata); + let field = ArrowField::new( + f.get_name(), + ArrowDataType::try_from(f.get_type())?, + f.is_nullable(), + ) + .with_metadata(metadata); + Ok(field) } } @@ -200,7 +199,6 @@ impl TryFrom<&ArrowField> for schema::SchemaField { arrow_field.is_nullable(), arrow_field .metadata() - .to_owned() .iter() .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone()))) .collect(), diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index 78c4f8a84e..e756a9b2fc 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -29,6 +29,7 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, SchemaRef, TimeUnit}; +use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use chrono::{DateTime, NaiveDateTime, Utc}; @@ -42,7 +43,7 @@ use datafusion::execution::FunctionRegistry; use datafusion::optimizer::utils::conjunction; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; -use datafusion::physical_plan::file_format::FileScanConfig; +use datafusion::physical_plan::file_format::{partition_type_wrap, FileScanConfig}; use datafusion::physical_plan::{ ColumnStatistics, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; @@ -375,26 +376,12 @@ impl TableProvider for DeltaTable { }); }; - let table_partition_cols = self - .get_metadata()? - .partition_columns - .clone() - .into_iter() - .flat_map(|c| { - let f_types = schema - .fields() - .iter() - .map(|e| (e.name(), e.data_type())) - .filter(|&(col_name, _)| c.eq_ignore_ascii_case(col_name)) - .collect::>(); - f_types.first().map(|o| o.to_owned()) - }) - .collect::>(); + let table_partition_cols = self.get_metadata()?.partition_columns.clone(); let file_schema = Arc::new(ArrowSchema::new( schema .fields() .iter() - .filter(|f| !table_partition_cols.contains_key(f.name())) + .filter(|f| !table_partition_cols.contains(f.name())) .cloned() .collect(), )); @@ -406,12 +393,17 @@ impl TableProvider for DeltaTable { file_schema, file_groups: file_groups.into_values().collect(), statistics: self.datafusion_table_statistics(), - projection: projection.map(|o| o.to_owned()), + projection: projection.cloned(), limit, table_partition_cols: table_partition_cols - .into_iter() - .map(|(a, b)| (a.to_owned(), b.to_owned())) - .collect::>(), + .iter() + .map(|c| { + Ok(( + c.to_owned(), + partition_type_wrap(schema.field_with_name(c)?.data_type().clone()), + )) + }) + .collect::, ArrowError>>()?, output_ordering: None, config_options: Default::default(), }, diff --git a/rust/src/operations/write.rs b/rust/src/operations/write.rs index 51f2374f71..9ce9bf5b0a 100644 --- a/rust/src/operations/write.rs +++ b/rust/src/operations/write.rs @@ -33,7 +33,7 @@ use crate::storage::DeltaObjectStore; use crate::writer::record_batch::divide_by_partition_values; use crate::writer::utils::PartitionPath; -use arrow::datatypes::SchemaRef as ArrowSchemaRef; +use arrow::datatypes::{DataType, SchemaRef as ArrowSchemaRef}; use arrow::record_batch::RecordBatch; use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan}; @@ -196,6 +196,18 @@ impl std::future::IntoFuture for WriteBuilder { fn into_future(self) -> Self::IntoFuture { let this = self; + fn schema_to_vec_name_type(schema: ArrowSchemaRef) -> Vec<(String, DataType)> { + schema + .fields() + .iter() + .map(|f| (f.name().to_owned(), f.data_type().clone())) + .collect::>() + } + + fn schema_eq(l: ArrowSchemaRef, r: ArrowSchemaRef) -> bool { + schema_to_vec_name_type(l) == schema_to_vec_name_type(r) + } + Box::pin(async move { let object_store = if let Some(store) = this.object_store { Ok(store) @@ -274,7 +286,8 @@ impl std::future::IntoFuture for WriteBuilder { if let Ok(meta) = table.get_metadata() { let curr_schema: ArrowSchemaRef = Arc::new((&meta.schema).try_into()?); - if schema != curr_schema { + + if !schema_eq(curr_schema, schema.clone()) { return Err(DeltaTableError::Generic( "Updating table schema not yet implemented".to_string(), )); diff --git a/rust/tests/datafusion_test.rs b/rust/tests/datafusion_test.rs index c521f91ca0..8684fe29fe 100644 --- a/rust/tests/datafusion_test.rs +++ b/rust/tests/datafusion_test.rs @@ -86,7 +86,7 @@ async fn prepare_table( #[tokio::test] async fn test_datafusion_sql_registration() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); - table_factories.insert("deltatable".to_string(), Arc::new(DeltaTableFactory {})); + table_factories.insert("DELTATABLE".to_string(), Arc::new(DeltaTableFactory {})); let cfg = RuntimeConfig::new().with_table_factories(table_factories); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new();