Skip to content

Commit

Permalink
fix: implement consistent formatting for constraint expressions (#1985)
Browse files Browse the repository at this point in the history
# Description
Implements consistent formatting for constraint expressions so something
like `value < 1000` is normalized to `value < 1000`


Also includes drive by improvements.
1. Test & Fix that Datafusion expressions can actually be used when
adding a constraint
2. Test & Fix that constraints can be added to column with
capitalization
 
# Related Issue(s)
- closes #1971
  • Loading branch information
Blajda authored Dec 19, 2023
1 parent 4ece26d commit f6d2061
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 42 deletions.
15 changes: 12 additions & 3 deletions crates/deltalake-core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> {
mod test {
use arrow_schema::DataType as ArrowDataType;
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable};

use crate::delta_datafusion::DeltaSessionContext;
use crate::kernel::{DataType, PrimitiveType, StructField, StructType};
use crate::{DeltaOps, DeltaTable};

Expand Down Expand Up @@ -388,6 +389,11 @@ mod test {
DataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"Value3".to_string(),
DataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"modified".to_string(),
DataType::Primitive(PrimitiveType::String),
Expand Down Expand Up @@ -442,7 +448,10 @@ mod test {
}),
"arrow_cast(1, 'Int32')".to_string()
),
simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()),
simple!(
Expr::Column(Column::from_qualified_name_ignore_case("Value3")).eq(lit(3_i64)),
"Value3 = 3".to_string()
),
simple!(col("active").is_true(), "active IS TRUE".to_string()),
simple!(col("active"), "active".to_string()),
simple!(col("active").eq(lit(true)), "active = true".to_string()),
Expand Down Expand Up @@ -536,7 +545,7 @@ mod test {
),
];

let session = SessionContext::new();
let session: SessionContext = DeltaSessionContext::default().into();

for test in tests {
let actual = fmt_expr_to_sql(&test.expr).unwrap();
Expand Down
12 changes: 9 additions & 3 deletions crates/deltalake-core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints: vec![],
ctx: SessionContext::new(),
ctx: DeltaSessionContext::default().into(),
}
}

Expand All @@ -1042,10 +1042,16 @@ impl DeltaDataChecker {
Self {
constraints,
invariants: vec![],
ctx: SessionContext::new(),
ctx: DeltaSessionContext::default().into(),
}
}

/// Specify the Datafusion context
pub fn with_session_context(mut self, context: SessionContext) -> Self {
self.ctx = context;
self
}

/// Create a new DeltaDataChecker
pub fn new(snapshot: &DeltaTableState) -> Self {
let metadata = snapshot.metadata();
Expand All @@ -1059,7 +1065,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints,
ctx: SessionContext::new(),
ctx: DeltaSessionContext::default().into(),
}
}

Expand Down
172 changes: 136 additions & 36 deletions crates/deltalake-core/src/operations/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ use datafusion::execution::context::SessionState;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::ToDFSchema;
use futures::future::BoxFuture;
use futures::StreamExt;
use serde_json::json;

use crate::delta_datafusion::{register_store, DeltaDataChecker, DeltaScanBuilder};
use crate::delta_datafusion::expr::fmt_expr_to_sql;
use crate::delta_datafusion::{
register_store, DeltaDataChecker, DeltaScanBuilder, DeltaSessionContext,
};
use crate::kernel::{Action, CommitInfo, IsolationLevel, Metadata, Protocol};
use crate::logstore::LogStoreRef;
use crate::operations::datafusion_utils::Expression;
Expand All @@ -23,6 +27,8 @@ use crate::table::Constraint;
use crate::DeltaTable;
use crate::{DeltaResult, DeltaTableError};

use super::datafusion_utils::into_expr;

/// Build a constraint to add to a table
pub struct ConstraintBuilder {
snapshot: DeltaTableState,
Expand All @@ -47,10 +53,10 @@ impl ConstraintBuilder {
/// Specify the constraint to be added
pub fn with_constraint<S: Into<String>, E: Into<Expression>>(
mut self,
column: S,
name: S,
expression: E,
) -> Self {
self.name = Some(column.into());
self.name = Some(name.into());
self.expr = Some(expression.into());
self
}
Expand All @@ -75,15 +81,10 @@ impl std::future::IntoFuture for ConstraintBuilder {
Some(v) => v,
None => return Err(DeltaTableError::Generic("No name provided".to_string())),
};
let expr = match this.expr {
Some(Expression::String(s)) => s,
Some(Expression::DataFusion(e)) => e.to_string(),
None => {
return Err(DeltaTableError::Generic(
"No expression provided".to_string(),
))
}
};

let expr = this
.expr
.ok_or_else(|| DeltaTableError::Generic("No Expresion provided".to_string()))?;

let mut metadata = this
.snapshot
Expand All @@ -94,23 +95,29 @@ impl std::future::IntoFuture for ConstraintBuilder {

if metadata.configuration.contains_key(&configuration_key) {
return Err(DeltaTableError::Generic(format!(
"Constraint with name: {} already exists, expr: {}",
name, expr
"Constraint with name: {} already exists",
name
)));
}

let state = this.state.unwrap_or_else(|| {
let session = SessionContext::new();
let session: SessionContext = DeltaSessionContext::default().into();
register_store(this.log_store.clone(), session.runtime_env());
session.state()
});

// Checker built here with the one time constraint to check.
let checker = DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr)]);
let scan = DeltaScanBuilder::new(&this.snapshot, this.log_store.clone(), &state)
.build()
.await?;

let schema = scan.schema().to_dfschema()?;
let expr = into_expr(expr, &schema, &state)?;
let expr_str = fmt_expr_to_sql(&expr)?;

// Checker built here with the one time constraint to check.
let checker =
DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr_str)]);

let plan: Arc<dyn ExecutionPlan> = Arc::new(scan);
let mut tasks = vec![];
for p in 0..plan.output_partitioning().partition_count() {
Expand Down Expand Up @@ -140,9 +147,10 @@ impl std::future::IntoFuture for ConstraintBuilder {
// We have validated the table passes it's constraints, now to add the constraint to
// the table.

metadata
.configuration
.insert(format!("delta.constraints.{}", name), Some(expr.clone()));
metadata.configuration.insert(
format!("delta.constraints.{}", name),
Some(expr_str.clone()),
);

let old_protocol = this.snapshot.protocol();
let protocol = Protocol {
Expand All @@ -162,12 +170,12 @@ impl std::future::IntoFuture for ConstraintBuilder {

let operational_parameters = HashMap::from_iter([
("name".to_string(), json!(&name)),
("expr".to_string(), json!(&expr)),
("expr".to_string(), json!(&expr_str)),
]);

let operations = DeltaOperation::AddConstraint {
name: name.clone(),
expr: expr.clone(),
expr: expr_str.clone(),
};

let commit_info = CommitInfo {
Expand Down Expand Up @@ -208,11 +216,37 @@ mod tests {
use std::sync::Arc;

use arrow_array::{Array, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
use datafusion_expr::{col, lit};

use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch};
use crate::{DeltaOps, DeltaResult};
use crate::{DeltaOps, DeltaResult, DeltaTable};

fn get_constraint(table: &DeltaTable, name: &str) -> String {
table
.metadata()
.unwrap()
.configuration
.get(name)
.unwrap()
.clone()
.unwrap()
}

async fn get_constraint_op_params(table: &mut DeltaTable) -> String {
let commit_info = table.history(None).await.unwrap();
let last_commit = &commit_info[commit_info.len() - 1];
last_commit
.operation_parameters
.as_ref()
.unwrap()
.get("expr")
.unwrap()
.as_str()
.unwrap()
.to_owned()
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn add_constraint_with_invalid_data() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
Expand All @@ -225,12 +259,10 @@ mod tests {
.add_constraint()
.with_constraint("id", "value > 5")
.await;
dbg!(&constraint);
assert!(constraint.is_err());
Ok(())
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn add_valid_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
Expand All @@ -239,18 +271,89 @@ mod tests {
.await?;
let table = DeltaOps(write);

let constraint = table
let mut table = table
.add_constraint()
.with_constraint("id", "value < 1000")
.await;
dbg!(&constraint);
assert!(constraint.is_ok());
let version = constraint?.version();
.with_constraint("id", "value < 1000")
.await?;
let version = table.version();
assert_eq!(version, 1);

let expected_expr = "value < 1000";
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
assert_eq!(
get_constraint(&table, "delta.constraints.id"),
expected_expr
);
Ok(())
}

#[tokio::test]
async fn add_constraint_datafusion() -> DeltaResult<()> {
// Add constraint by providing a datafusion expression.
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;
let table = DeltaOps(write);

let mut table = table
.add_constraint()
.with_constraint("valid_values", col("value").lt(lit(1000)))
.await?;
let version = table.version();
assert_eq!(version, 1);

let expected_expr = "value < 1000";
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
expected_expr
);

Ok(())
}

#[tokio::test]
async fn test_constraint_case_sensitive() -> DeltaResult<()> {
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("Id", ArrowDataType::Utf8, true),
Field::new("vAlue", ArrowDataType::Int32, true),
Field::new("mOdifieD", ArrowDataType::Utf8, true),
]));

let batch = RecordBatch::try_new(
Arc::clone(&arrow_schema.clone()),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
],
)
.unwrap();

let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap();

let mut table = DeltaOps(table)
.add_constraint()
.with_constraint("valid_values", "vAlue < 1000")
.await?;
let version = table.version();
assert_eq!(version, 1);

let expected_expr = "vAlue < 1000";
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
expected_expr
);

Ok(())
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn add_conflicting_named_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
Expand All @@ -269,12 +372,10 @@ mod tests {
.add_constraint()
.with_constraint("id", "value < 10")
.await;
dbg!(&second_constraint);
assert!(second_constraint.is_err());
Ok(())
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn write_data_that_violates_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
Expand All @@ -294,7 +395,6 @@ mod tests {
];
let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?;
let err = table.write(vec![batch]).await;
dbg!(&err);
assert!(err.is_err());
Ok(())
}
Expand Down

0 comments on commit f6d2061

Please sign in to comment.