Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: made generalize_filter less permissive, also added more cases #2149

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 121 additions & 22 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
use datafusion_expr::expr::Placeholder;
use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType};
use datafusion_expr::{
BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection,
BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Operator, Projection,
UserDefinedLogicalNode, UNNAMED_TABLE,
};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -697,34 +697,57 @@ fn generalize_filter(
target_name: &TableReference,
placeholders: &mut HashMap<String, Expr>,
) -> Option<Expr> {
fn references_table(expr: &Expr, table: &TableReference) -> Option<String> {
match expr {
#[derive(Debug)]
enum ReferenceTableCheck {
HasReference(String),
NoReference,
Unknown,
}
impl ReferenceTableCheck {
fn has_reference(&self) -> bool {
match self {
ReferenceTableCheck::HasReference(_) => true,
_ => false,
}
}
}
fn references_table(expr: &Expr, table: &TableReference) -> ReferenceTableCheck {
let res = match expr {
Expr::Alias(alias) => references_table(&alias.expr, table),
Expr::Column(col) => col.relation.as_ref().and_then(|rel| {
if rel == table {
Some(col.name.to_owned())
} else {
None
}
}),
Expr::Column(col) => col
.relation
.as_ref()
.map(|rel| {
if rel == table {
ReferenceTableCheck::HasReference(col.name.to_owned())
} else {
ReferenceTableCheck::NoReference
}
})
.unwrap_or(ReferenceTableCheck::NoReference),
Expr::Negative(neg) => references_table(neg, table),
Expr::Cast(cast) => references_table(&cast.expr, table),
Expr::TryCast(try_cast) => references_table(&try_cast.expr, table),
Expr::ScalarFunction(func) => {
if func.args.len() == 1 {
references_table(&func.args[0], table)
} else {
None
ReferenceTableCheck::Unknown
}
}
_ => None,
}
Expr::IsNull(inner) => references_table(&inner, table),
Expr::Literal(_) => ReferenceTableCheck::NoReference,
_ => ReferenceTableCheck::Unknown,
};
res
}

match predicate {
Expr::BinaryExpr(binary) => {
if references_table(&binary.right, source_name).is_some() {
if let Some(left_target) = references_table(&binary.left, target_name) {
if references_table(&binary.right, source_name).has_reference() {
if let ReferenceTableCheck::HasReference(left_target) =
references_table(&binary.left, target_name)
{
if partition_columns.contains(&left_target) {
let placeholder_name = format!("{left_target}_{}", placeholders.len());

Expand All @@ -745,8 +768,10 @@ fn generalize_filter(
}
return None;
}
if references_table(&binary.left, source_name).is_some() {
if let Some(right_target) = references_table(&binary.right, target_name) {
if references_table(&binary.left, source_name).has_reference() {
if let ReferenceTableCheck::HasReference(right_target) =
references_table(&binary.right, target_name)
{
if partition_columns.contains(&right_target) {
let placeholder_name = format!("{right_target}_{}", placeholders.len());

Expand Down Expand Up @@ -783,19 +808,45 @@ fn generalize_filter(
placeholders,
);

match (left, right) {
let res = match (left, right) {
(None, None) => None,
(None, Some(r)) => Some(r),
(Some(l), None) => Some(l),
(None, Some(one_side)) | (Some(one_side), None) => {
// in the case of an AND clause, it's safe to generalize the filter down to just one side of the AND.
// this is because this filter will be more permissive than the actual predicate, so we know that
// we will catch all data that could be matched by the predicate. For OR this is not the case - we
// could potentially eliminate one side of the predicate and the filter would only match half the
// cases that would have satisfied the match predicate.
match binary.op {
Operator::And => Some(one_side),
Operator::Or => None,
_ => None,
}
}
(Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr {
left: l.into(),
op: binary.op,
right: r.into(),
})
.into(),
}
};
res
}
other => Some(other),
other => match references_table(&other, source_name) {
ReferenceTableCheck::HasReference(col) => {
let placeholder_name = format!("{col}_{}", placeholders.len());

let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder {
id: placeholder_name.clone(),
data_type: None,
});

placeholders.insert(placeholder_name, other);

Some(placeholder)
}
ReferenceTableCheck::NoReference => Some(other),
ReferenceTableCheck::Unknown => None,
},
}
}

Expand Down Expand Up @@ -1484,6 +1535,7 @@ mod tests {
use datafusion_expr::Expr;
use datafusion_expr::LogicalPlanBuilder;
use datafusion_expr::Operator;
use itertools::Itertools;
use serde_json::json;
use std::collections::HashMap;
use std::ops::Neg;
Expand Down Expand Up @@ -2430,6 +2482,51 @@ mod tests {
assert_eq!(generalized, expected_filter);
}

#[tokio::test]
async fn test_generalize_filter_with_partitions_nulls() {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");

let source_id = col(Column::new(source.clone().into(), "id"));
let target_id = col(Column::new(target.clone().into(), "id"));

// source.id = target.id OR (source.id is null and target.id is null)
let parsed_filter = (source_id.clone().eq(target_id.clone()))
.or(source_id.clone().is_null().and(target_id.clone().is_null()));

let mut placeholders = HashMap::default();

let generalized = generalize_filter(
parsed_filter,
&vec!["id".to_owned()],
&source,
&target,
&mut placeholders,
)
.unwrap();

// id_1 = target.id OR (id_2 and target.id is null)
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
data_type: None,
})
.eq(target_id.clone())
.or(Expr::Placeholder(Placeholder {
id: "id_1".to_owned(),
data_type: None,
})
.and(target_id.clone().is_null()));

assert!(placeholders.len() == 2);

let captured_expressions = placeholders.values().collect_vec();

assert!(captured_expressions.contains(&&source_id));
assert!(captured_expressions.contains(&&source_id.is_null()));

assert_eq!(generalized, expected_filter);
}

#[tokio::test]
async fn test_generalize_filter_with_partitions_captures_expression() {
// Check that when generalizing the filter, the placeholder map captures the expression needed to make the statement the same
Expand Down Expand Up @@ -2474,6 +2571,7 @@ mod tests {
let source = TableReference::parse_str("source");
let target = TableReference::parse_str("target");

// source.id = target.id and target.id = 'C'
let parsed_filter = col(Column::new(source.clone().into(), "id"))
.eq(col(Column::new(target.clone().into(), "id")))
.and(col(Column::new(target.clone().into(), "id")).eq(lit("C")));
Expand All @@ -2489,6 +2587,7 @@ mod tests {
)
.unwrap();

// id_0 = target.id and target.id = 'C'
let expected_filter = Expr::Placeholder(Placeholder {
id: "id_0".to_owned(),
data_type: None,
Expand Down
Loading