Skip to content

Commit

Permalink
Integrated feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-ionescu committed Dec 16, 2022
1 parent c8c0a03 commit 8ff24a8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
16 changes: 7 additions & 9 deletions rust/src/delta_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@ impl TryFrom<&schema::SchemaField> for ArrowField {
type Error = ArrowError;

fn try_from(f: &schema::SchemaField) -> Result<Self, ArrowError> {
let mut field = ArrowField::new(
f.get_name(),
ArrowDataType::try_from(f.get_type())?,
f.is_nullable(),
);

let metadata = f
.get_metadata()
.to_owned()
.iter()
.map(|(key, val)| Ok((key.clone(), serde_json::to_string(val)?)))
.collect::<Result<_, serde_json::Error>>()
.map_err(|err| ArrowError::JsonError(err.to_string()))?;

field.set_metadata(metadata);
let field = ArrowField::new(
f.get_name(),
ArrowDataType::try_from(f.get_type())?,
f.is_nullable(),
)
.with_metadata(metadata);

Ok(field)
}
}
Expand Down Expand Up @@ -200,7 +199,6 @@ impl TryFrom<&ArrowField> for schema::SchemaField {
arrow_field.is_nullable(),
arrow_field
.metadata()
.to_owned()
.iter()
.map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
.collect(),
Expand Down
34 changes: 13 additions & 21 deletions rust/src/delta_datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use chrono::{DateTime, NaiveDateTime, Utc};
Expand All @@ -42,7 +43,7 @@ use datafusion::execution::FunctionRegistry;
use datafusion::optimizer::utils::conjunction;
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
use datafusion::physical_plan::file_format::FileScanConfig;
use datafusion::physical_plan::file_format::{partition_type_wrap, FileScanConfig};
use datafusion::physical_plan::{
ColumnStatistics, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics,
};
Expand Down Expand Up @@ -375,26 +376,12 @@ impl TableProvider for DeltaTable {
});
};

let table_partition_cols = self
.get_metadata()?
.partition_columns
.clone()
.into_iter()
.flat_map(|c| {
let f_types = schema
.fields()
.iter()
.map(|e| (e.name(), e.data_type()))
.filter(|&(col_name, _)| c.eq_ignore_ascii_case(col_name))
.collect::<Vec<_>>();
f_types.first().map(|o| o.to_owned())
})
.collect::<HashMap<_, _>>();
let table_partition_cols = self.get_metadata()?.partition_columns.clone();
let file_schema = Arc::new(ArrowSchema::new(
schema
.fields()
.iter()
.filter(|f| !table_partition_cols.contains_key(f.name()))
.filter(|f| !table_partition_cols.contains(f.name()))
.cloned()
.collect(),
));
Expand All @@ -406,12 +393,17 @@ impl TableProvider for DeltaTable {
file_schema,
file_groups: file_groups.into_values().collect(),
statistics: self.datafusion_table_statistics(),
projection: projection.map(|o| o.to_owned()),
projection: projection.cloned(),
limit,
table_partition_cols: table_partition_cols
.into_iter()
.map(|(a, b)| (a.to_owned(), b.to_owned()))
.collect::<Vec<(_, _)>>(),
.iter()
.map(|c| {
Ok((
c.to_owned(),
partition_type_wrap(schema.field_with_name(c)?.data_type().clone()),
))
})
.collect::<Result<Vec<_>, ArrowError>>()?,
output_ordering: None,
config_options: Default::default(),
},
Expand Down
17 changes: 15 additions & 2 deletions rust/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::storage::DeltaObjectStore;
use crate::writer::record_batch::divide_by_partition_values;
use crate::writer::utils::PartitionPath;

use arrow::datatypes::SchemaRef as ArrowSchemaRef;
use arrow::datatypes::{DataType, SchemaRef as ArrowSchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan};
Expand Down Expand Up @@ -196,6 +196,18 @@ impl std::future::IntoFuture for WriteBuilder {
fn into_future(self) -> Self::IntoFuture {
let this = self;

fn schema_to_vec_name_type(schema: ArrowSchemaRef) -> Vec<(String, DataType)> {
schema
.fields()
.iter()
.map(|f| (f.name().to_owned(), f.data_type().clone()))
.collect::<Vec<_>>()
}

fn schema_eq(l: ArrowSchemaRef, r: ArrowSchemaRef) -> bool {
schema_to_vec_name_type(l) == schema_to_vec_name_type(r)
}

Box::pin(async move {
let object_store = if let Some(store) = this.object_store {
Ok(store)
Expand Down Expand Up @@ -274,7 +286,8 @@ impl std::future::IntoFuture for WriteBuilder {

if let Ok(meta) = table.get_metadata() {
let curr_schema: ArrowSchemaRef = Arc::new((&meta.schema).try_into()?);
if schema != curr_schema {

if !schema_eq(curr_schema, schema.clone()) {
return Err(DeltaTableError::Generic(
"Updating table schema not yet implemented".to_string(),
));
Expand Down
2 changes: 1 addition & 1 deletion rust/tests/datafusion_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async fn prepare_table(
#[tokio::test]
async fn test_datafusion_sql_registration() -> Result<()> {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> = HashMap::new();
table_factories.insert("deltatable".to_string(), Arc::new(DeltaTableFactory {}));
table_factories.insert("DELTATABLE".to_string(), Arc::new(DeltaTableFactory {}));
let cfg = RuntimeConfig::new().with_table_factories(table_factories);
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
Expand Down

0 comments on commit 8ff24a8

Please sign in to comment.