Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataframe join_on method #5210

Merged
merged 3 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/common/src/table_reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ impl<'a> TableReference<'a> {
/// failing that then taking the entire unnormalized input as the identifier itself.
///
/// Will normalize (convert to lowercase) any unquoted identifiers.
///
/// e.g. `Foo` will be parsed as `foo`, and `"Foo"".bar"` will be parsed as
/// `Foo".bar` (note the preserved case and requiring two double quotes to represent
/// a single double quote in the identifier)
Expand Down
76 changes: 76 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,55 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, plan))
}

/// Join this DataFrame with another DataFrame using the specified expressions.
///
/// Simply a thin wrapper over [`join`](Self::join) where the join keys are not provided,
/// and the provided expressions are AND'ed together to form the filter expression.
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let left = ctx
/// .read_csv("tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// let right = ctx
/// .read_csv("tests/data/example.csv", CsvReadOptions::new())
/// .await?
/// .select(vec![
/// col("a").alias("a2"),
/// col("b").alias("b2"),
/// col("c").alias("c2"),
/// ])?;
/// let join_on = left.join_on(
/// right,
/// JoinType::Inner,
/// [col("a").not_eq(col("a2")), col("b").not_eq(col("b2"))],
/// )?;
/// let batches = join_on.collect().await?;
/// # Ok(())
/// # }
/// ```
pub fn join_on(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 LGTM

self,
right: DataFrame,
join_type: JoinType,
on_exprs: impl IntoIterator<Item = Expr>,
) -> Result<DataFrame> {
let expr = on_exprs.into_iter().reduce(Expr::and);
let plan = LogicalPlanBuilder::from(self.plan)
.join(
right.plan,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
expr,
)?
.build()?;
Ok(DataFrame::new(self.session_state, plan))
}

/// Repartition a DataFrame based on a logical partitioning scheme.
///
/// ```
Expand Down Expand Up @@ -1039,6 +1088,33 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_on() -> Result<()> {
let left = test_table_with_name("a")
.await?
.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("b")
.await?
.select_columns(&["c1", "c2"])?;
let join = left.join_on(
right,
JoinType::Inner,
[
col("a.c1").not_eq(col("b.c1")),
col("a.c2").not_eq(col("b.c2")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible here to also add an equality predicate to demonstrate they are automatically recognized as equi preds?

Perhaps something like

Suggested change
col("a.c2").not_eq(col("b.c2")),
col("a.c2").eq(col("b.c2")),

Copy link
Contributor Author

@Jefffrey Jefffrey Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done as you suggested. it seems they still are considered as part of the filter, though this seems to track with the explicit SQL version too:

https://github.com/apache/arrow-datafusion/blob/f0c67193a3d18ff1d94f9dd55bfb1715e5473bf1/datafusion/sql/tests/integration_test.rs#L1661-L1672

edit: nvm there's the extract_equijoin_predicate logical optimization which extracts it into an equijoin predicate indeed

],
)?;

let expected_plan = "Inner Join: Filter: a.c1 != b.c1 AND a.c2 != b.c2\
\n Projection: a.c1, a.c2\
\n TableScan: a\
\n Projection: b.c1, b.c2\
\n TableScan: b";
assert_eq!(expected_plan, format!("{:?}", join.logical_plan()));

Ok(())
}

#[tokio::test]
async fn limit() -> Result<()> {
// build query using Table API
Expand Down
24 changes: 23 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ use crate::expr_rewriter::{
normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::type_coercion::binary::comparison_coercion;
use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan};
use crate::utils::{
columnize_expr, compare_sort_expr, ensure_any_column_reference_is_unambiguous,
exprlist_to_fields, from_plan,
};
use crate::{and, binary_expr, Operator};
use crate::{
logical_plan::{
Expand Down Expand Up @@ -502,6 +505,25 @@ impl LogicalPlanBuilder {
));
}

let filter = if let Some(expr) = filter {
// ambiguous check
ensure_any_column_reference_is_unambiguous(
&expr,
&[self.schema(), right.schema()],
)?;

// normalize all columns in expression
let using_columns = expr.to_columns()?;
let filter = normalize_col_with_schemas(
expr,
&[self.schema(), right.schema()],
&[using_columns],
)?;
Some(filter)
} else {
None
};

Comment on lines +508 to +526
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to #4196

fix bug where you could do dataframe join with ambiguous column for the filter expr

instead of having the check done in both DataFrame join api and SQL planner join mod, unify by having check done inside the logical plan builder

this is technically an unrelated fix to the actual issue, so i can extract into separate issue if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine to include in this PR as long as it also has a test (for ambiguity check using the DataFrame API)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test added

let (left_keys, right_keys): (Vec<Result<Column>>, Vec<Result<Column>>) =
join_keys
.0
Expand Down
61 changes: 60 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

/// The value to which `COUNT(*)` is expanded to in
Expand Down Expand Up @@ -1023,6 +1023,65 @@ pub fn find_valid_equijoin_key_pair(
Ok(join_key_pair)
}

/// Ensure any column reference of the expression is unambiguous.
/// Assume we have two schema:
/// schema1: a, b ,c
/// schema2: a, d, e
///
/// `schema1.a + schema2.a` is unambiguous.
/// `a + d` is ambiguous, because `a` may come from schema1 or schema2.
pub fn ensure_any_column_reference_is_unambiguous(
expr: &Expr,
schemas: &[&DFSchema],
) -> Result<()> {
if schemas.len() == 1 {
return Ok(());
}
// all referenced columns in the expression that don't have relation
let referenced_cols = expr.to_columns()?;
let mut no_relation_cols = referenced_cols
.iter()
.filter_map(|col| {
if col.relation.is_none() {
Some((col.name.as_str(), 0))
} else {
None
}
})
.collect::<HashMap<&str, u8>>();
// find the name of the column existing in multi schemas.
let ambiguous_col_name = schemas
.iter()
.flat_map(|schema| schema.fields())
.map(|field| field.name())
.find(|col_name| {
no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
matches!(
no_relation_cols.get_key_value(col_name.as_str()),
Some((_, 2..))
)
});

if let Some(col_name) = ambiguous_col_name {
let maybe_field = schemas
.iter()
.flat_map(|schema| {
schema
.field_with_unqualified_name(col_name)
.map(|f| f.qualified_name())
.ok()
})
.collect::<Vec<_>>();
Err(DataFusionError::Plan(format!(
"reference \'{}\' is ambiguous, could be {};",
col_name,
maybe_field.join(","),
)))
} else {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
84 changes: 4 additions & 80 deletions datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::utils::normalize_ident;
use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
use datafusion_expr::{Expr, JoinType, LogicalPlan, LogicalPlanBuilder};
use datafusion_common::{Column, DataFusionError, Result};
use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn plan_table_with_joins(
Expand Down Expand Up @@ -133,30 +132,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
match constraint {
JoinConstraint::On(sql_expr) => {
let join_schema = left.schema().join(right.schema())?;

// parse ON expression
let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?;

// ambiguous check
ensure_any_column_reference_is_unambiguous(
&expr,
&[left.schema().clone(), right.schema().clone()],
)?;

// normalize all columns in expression
let using_columns = expr.to_columns()?;
let filter = normalize_col_with_schemas(
expr,
&[left.schema(), right.schema()],
&[using_columns],
)?;

LogicalPlanBuilder::from(left)
.join(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
Some(filter),
Some(expr),
)?
.build()
}
Expand Down Expand Up @@ -198,62 +181,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
}

/// Ensure any column reference of the expression is unambiguous.
/// Assume we have two schema:
/// schema1: a, b ,c
/// schema2: a, d, e
///
/// `schema1.a + schema2.a` is unambiguous.
/// `a + d` is ambiguous, because `a` may come from schema1 or schema2.
fn ensure_any_column_reference_is_unambiguous(
expr: &Expr,
schemas: &[DFSchemaRef],
) -> Result<()> {
if schemas.len() == 1 {
return Ok(());
}
// all referenced columns in the expression that don't have relation
let referenced_cols = expr.to_columns()?;
let mut no_relation_cols = referenced_cols
.iter()
.filter_map(|col| {
if col.relation.is_none() {
Some((col.name.as_str(), 0))
} else {
None
}
})
.collect::<HashMap<&str, u8>>();
// find the name of the column existing in multi schemas.
let ambiguous_col_name = schemas
.iter()
.flat_map(|schema| schema.fields())
.map(|field| field.name())
.find(|col_name| {
no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
matches!(
no_relation_cols.get_key_value(col_name.as_str()),
Some((_, 2..))
)
});

if let Some(col_name) = ambiguous_col_name {
let maybe_field = schemas
.iter()
.flat_map(|schema| {
schema
.field_with_unqualified_name(col_name)
.map(|f| f.qualified_name())
.ok()
})
.collect::<Vec<_>>();
Err(DataFusionError::Plan(format!(
"reference \'{}\' is ambiguous, could be {};",
col_name,
maybe_field.join(","),
)))
} else {
Ok(())
}
}
1 change: 1 addition & 0 deletions docs/source/user-guide/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su
| filter | Filter a DataFrame to only include rows that match the specified filter expression. |
| intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema |
| join | Join this DataFrame with another DataFrame using the specified columns as join keys. |
| join_on | Join this DataFrame with another DataFrame using arbitrary expressions. |
| limit | Limit the number of rows returned from this DataFrame. |
| repartition | Repartition a DataFrame based on a logical partitioning scheme. |
| sort | Sort the DataFrame by the specified sorting expressions. Any expression can be turned into a sort expression by calling its `sort` method. |
Expand Down