From 61729171dc6b83f62256a66fc7880568c760f0ee Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Feb 2023 22:05:29 +1100 Subject: [PATCH 1/3] Dataframe join_on method --- datafusion/common/src/table_reference.rs | 1 + datafusion/core/src/dataframe.rs | 76 +++++++++++++++++++ datafusion/expr/src/logical_plan/builder.rs | 24 +++++- datafusion/expr/src/utils.rs | 61 ++++++++++++++- datafusion/sql/src/relation/join.rs | 84 +-------------------- docs/source/user-guide/dataframe.md | 1 + 6 files changed, 165 insertions(+), 82 deletions(-) diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 370f5e46ee80..1e6292b292a3 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -194,6 +194,7 @@ impl<'a> TableReference<'a> { /// failing that then taking the entire unnormalized input as the identifier itself. /// /// Will normalize (convert to lowercase) any unquoted identifiers. + /// /// e.g. `Foo` will be parsed as `foo`, and `"Foo"".bar"` will be parsed as /// `Foo".bar` (note the preserved case and requiring two double quotes to represent /// a single double quote in the identifier) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 557e04a3b66f..ab799a276582 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -363,6 +363,55 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } + /// Join this DataFrame with another DataFrame using the specified expressions. + /// + /// Simply a thin wrapper over [`join`](Self::join) where the join keys are not provided, + /// and the provided expressions are AND'ed together to form the filter expression. + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let left = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let right = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await? + /// .select(vec![ + /// col("a").alias("a2"), + /// col("b").alias("b2"), + /// col("c").alias("c2"), + /// ])?; + /// let join_on = left.join_on( + /// right, + /// JoinType::Inner, + /// [col("a").not_eq(col("a2")), col("b").not_eq(col("b2"))], + /// )?; + /// let batches = join_on.collect().await?; + /// # Ok(()) + /// # } + /// ``` + pub fn join_on( + self, + right: DataFrame, + join_type: JoinType, + on_exprs: impl IntoIterator, + ) -> Result { + let expr = on_exprs.into_iter().reduce(Expr::and); + let plan = LogicalPlanBuilder::from(self.plan) + .join( + right.plan, + join_type, + (Vec::::new(), Vec::::new()), + expr, + )? + .build()?; + Ok(DataFrame::new(self.session_state, plan)) + } + /// Repartition a DataFrame based on a logical partitioning scheme. /// /// ``` @@ -1039,6 +1088,33 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_on() -> Result<()> { + let left = test_table_with_name("a") + .await? + .select_columns(&["c1", "c2"])?; + let right = test_table_with_name("b") + .await? + .select_columns(&["c1", "c2"])?; + let join = left.join_on( + right, + JoinType::Inner, + [ + col("a.c1").not_eq(col("b.c1")), + col("a.c2").not_eq(col("b.c2")), + ], + )?; + + let expected_plan = "Inner Join: Filter: a.c1 != b.c1 AND a.c2 != b.c2\ + \n Projection: a.c1, a.c2\ + \n TableScan: a\ + \n Projection: b.c1, b.c2\ + \n TableScan: b"; + assert_eq!(expected_plan, format!("{:?}", join.logical_plan())); + + Ok(()) + } + #[tokio::test] async fn limit() -> Result<()> { // build query using Table API diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 23256662f7fd..4bbb83bb74ce 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -22,7 +22,10 @@ use crate::expr_rewriter::{ normalize_cols, rewrite_sort_cols_by_aggs, }; use crate::type_coercion::binary::comparison_coercion; -use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan}; +use crate::utils::{ + columnize_expr, compare_sort_expr, ensure_any_column_reference_is_unambiguous, + exprlist_to_fields, from_plan, +}; use crate::{and, binary_expr, Operator}; use crate::{ logical_plan::{ @@ -502,6 +505,25 @@ impl LogicalPlanBuilder { )); } + let filter = if let Some(expr) = filter { + // ambiguous check + ensure_any_column_reference_is_unambiguous( + &expr, + &[self.schema(), right.schema()], + )?; + + // normalize all columns in expression + let using_columns = expr.to_columns()?; + let filter = normalize_col_with_schemas( + expr, + &[self.schema(), right.schema()], + &[using_columns], + )?; + Some(filter) + } else { + None + }; + let (left_keys, right_keys): (Vec>, Vec>) = join_keys .0 diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 6f64bc14f812..8ce959e793f0 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -37,7 +37,7 @@ use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use std::cmp::Ordering; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; /// The value to which `COUNT(*)` is expanded to in @@ -1023,6 +1023,65 @@ pub fn find_valid_equijoin_key_pair( Ok(join_key_pair) } +/// Ensure any column reference of the expression is unambiguous. +/// Assume we have two schema: +/// schema1: a, b ,c +/// schema2: a, d, e +/// +/// `schema1.a + schema2.a` is unambiguous. +/// `a + d` is ambiguous, because `a` may come from schema1 or schema2. +pub fn ensure_any_column_reference_is_unambiguous( + expr: &Expr, + schemas: &[&DFSchema], +) -> Result<()> { + if schemas.len() == 1 { + return Ok(()); + } + // all referenced columns in the expression that don't have relation + let referenced_cols = expr.to_columns()?; + let mut no_relation_cols = referenced_cols + .iter() + .filter_map(|col| { + if col.relation.is_none() { + Some((col.name.as_str(), 0)) + } else { + None + } + }) + .collect::>(); + // find the name of the column existing in multi schemas. + let ambiguous_col_name = schemas + .iter() + .flat_map(|schema| schema.fields()) + .map(|field| field.name()) + .find(|col_name| { + no_relation_cols.entry(col_name).and_modify(|v| *v += 1); + matches!( + no_relation_cols.get_key_value(col_name.as_str()), + Some((_, 2..)) + ) + }); + + if let Some(col_name) = ambiguous_col_name { + let maybe_field = schemas + .iter() + .flat_map(|schema| { + schema + .field_with_unqualified_name(col_name) + .map(|f| f.qualified_name()) + .ok() + }) + .collect::>(); + Err(DataFusionError::Plan(format!( + "reference \'{}\' is ambiguous, could be {};", + col_name, + maybe_field.join(","), + ))) + } else { + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 6f2233f3949a..591194136286 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -17,11 +17,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; -use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result}; -use datafusion_expr::expr_rewriter::normalize_col_with_schemas; -use datafusion_expr::{Expr, JoinType, LogicalPlan, LogicalPlanBuilder}; +use datafusion_common::{Column, DataFusionError, Result}; +use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn plan_table_with_joins( @@ -133,30 +132,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match constraint { JoinConstraint::On(sql_expr) => { let join_schema = left.schema().join(right.schema())?; - // parse ON expression let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?; - - // ambiguous check - ensure_any_column_reference_is_unambiguous( - &expr, - &[left.schema().clone(), right.schema().clone()], - )?; - - // normalize all columns in expression - let using_columns = expr.to_columns()?; - let filter = normalize_col_with_schemas( - expr, - &[left.schema(), right.schema()], - &[using_columns], - )?; - LogicalPlanBuilder::from(left) .join( right, join_type, (Vec::::new(), Vec::::new()), - Some(filter), + Some(expr), )? .build() } @@ -198,62 +181,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } } - -/// Ensure any column reference of the expression is unambiguous. -/// Assume we have two schema: -/// schema1: a, b ,c -/// schema2: a, d, e -/// -/// `schema1.a + schema2.a` is unambiguous. -/// `a + d` is ambiguous, because `a` may come from schema1 or schema2. -fn ensure_any_column_reference_is_unambiguous( - expr: &Expr, - schemas: &[DFSchemaRef], -) -> Result<()> { - if schemas.len() == 1 { - return Ok(()); - } - // all referenced columns in the expression that don't have relation - let referenced_cols = expr.to_columns()?; - let mut no_relation_cols = referenced_cols - .iter() - .filter_map(|col| { - if col.relation.is_none() { - Some((col.name.as_str(), 0)) - } else { - None - } - }) - .collect::>(); - // find the name of the column existing in multi schemas. - let ambiguous_col_name = schemas - .iter() - .flat_map(|schema| schema.fields()) - .map(|field| field.name()) - .find(|col_name| { - no_relation_cols.entry(col_name).and_modify(|v| *v += 1); - matches!( - no_relation_cols.get_key_value(col_name.as_str()), - Some((_, 2..)) - ) - }); - - if let Some(col_name) = ambiguous_col_name { - let maybe_field = schemas - .iter() - .flat_map(|schema| { - schema - .field_with_unqualified_name(col_name) - .map(|f| f.qualified_name()) - .ok() - }) - .collect::>(); - Err(DataFusionError::Plan(format!( - "reference \'{}\' is ambiguous, could be {};", - col_name, - maybe_field.join(","), - ))) - } else { - Ok(()) - } -} diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 5ba803fce7ef..cc831f5ea5f5 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -68,6 +68,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su | filter | Filter a DataFrame to only include rows that match the specified filter expression. | | intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema | | join | Join this DataFrame with another DataFrame using the specified columns as join keys. | +| join_on | Join this DataFrame with another DataFrame using arbitrary expressions. | | limit | Limit the number of rows returned from this DataFrame. | | repartition | Repartition a DataFrame based on a logical partitioning scheme. | | sort | Sort the DataFrame by the specified sorting expressions. Any expression can be turned into a sort expression by calling its `sort` method. | From 4043a53f629cd6851b7f8e2ed1e3d56946eb0f8b Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Feb 2023 22:11:12 +1100 Subject: [PATCH 2/3] Fix formatting --- docs/source/user-guide/dataframe.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index cc831f5ea5f5..c7d490e40484 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -68,7 +68,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su | filter | Filter a DataFrame to only include rows that match the specified filter expression. | | intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema | | join | Join this DataFrame with another DataFrame using the specified columns as join keys. | -| join_on | Join this DataFrame with another DataFrame using arbitrary expressions. | +| join_on | Join this DataFrame with another DataFrame using arbitrary expressions. | | limit | Limit the number of rows returned from this DataFrame. | | repartition | Repartition a DataFrame based on a logical partitioning scheme. | | sort | Sort the DataFrame by the specified sorting expressions. Any expression can be turned into a sort expression by calling its `sort` method. | From 52e78c6ced4ba878f7a5c1a32ddbf4d6bbe71128 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Wed, 8 Feb 2023 20:05:14 +1100 Subject: [PATCH 3/3] Add tests --- datafusion/core/src/dataframe.rs | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index ab799a276582..26fe5c051204 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1099,13 +1099,10 @@ mod tests { let join = left.join_on( right, JoinType::Inner, - [ - col("a.c1").not_eq(col("b.c1")), - col("a.c2").not_eq(col("b.c2")), - ], + [col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))], )?; - let expected_plan = "Inner Join: Filter: a.c1 != b.c1 AND a.c2 != b.c2\ + let expected_plan = "Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2\ \n Projection: a.c1, a.c2\ \n TableScan: a\ \n Projection: b.c1, b.c2\ @@ -1115,6 +1112,25 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_ambiguous_filter() -> Result<()> { + let left = test_table_with_name("a") + .await? + .select_columns(&["c1", "c2"])?; + let right = test_table_with_name("b") + .await? + .select_columns(&["c1", "c2"])?; + + let join = left + .join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))]) + .expect_err("join didn't fail check"); + let expected = + "Error during planning: reference 'c1' is ambiguous, could be a.c1,b.c1;"; + assert_eq!(join.to_string(), expected); + + Ok(()) + } + #[tokio::test] async fn limit() -> Result<()> { // build query using Table API