Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Partially qualified joins join a.x = y and join x = b.y #3290

Merged
2 changes: 1 addition & 1 deletion benchmarking/tpcds/queries/72.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk
AND cr_order_number = cs_order_number)
WHERE d1.d_week_seq = d2.d_week_seq
AND inv_quantity_on_hand < cs_quantity
AND d3.d_date > d1.d_date + 5 -- SQL Server: DATEADD(day, 5, d1.d_date)
AND d3.d_date > d1.d_date + interval '5 days' -- SQL Server: DATEADD(day, 5, d1.d_date)
AND hd_buy_potential = '>10000'
AND d1.d_year = 1999
AND cd_marital_status = 'D'
Expand Down
126 changes: 85 additions & 41 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,25 +665,70 @@ impl SQLPlanner {

let from = from.iter().next().unwrap();

fn collect_compound_identifiers(
fn collect_idents(
left: &[Ident],
right: &[Ident],
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>)> {
if left.len() == 2 && right.len() == 2 {
let (tbl_a, col_a) = (&left[0].value, &left[1].value);
let (tbl_b, col_b) = (&right[0].value, &right[1].value);
let (left, right) = match (left, right) {
// both are fully qualified: `join on a.x = b.y`
([tbl_a, Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => {
if left_rel.get_name() == tbl_b.value && right_rel.get_name() == tbl_a.value {
(col_b.clone(), col_a.clone())
} else {
(col_a.clone(), col_b.clone())
}
}
// only one is fully qualified: `join on x = b.y`
([Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => {
if tbl_b.value == right_rel.get_name() {
(col_a.clone(), col_b.clone())
} else if tbl_b.value == left_rel.get_name() {
(col_b.clone(), col_a.clone())
} else {
unsupported_sql_err!("Could not determine which table the identifiers belong to")
}
}
// only one is fully qualified: `join on a.x = y`
([tbl_a, Ident{value: col_a, ..}], [Ident{value: col_b, ..}]) => {
// find out which one the qualified identifier belongs to
// we assume the other identifier belongs to the other table
if tbl_a.value == left_rel.get_name() {
(col_a.clone(), col_b.clone())
} else if tbl_a.value == right_rel.get_name() {
(col_b.clone(), col_a.clone())
} else {
unsupported_sql_err!("Could not determine which table the identifiers belong to")
}
}
// neither are fully qualified: `join on x = y`
([left], [right]) => {
let left = ident_to_str(left);
let right = ident_to_str(right);

// we don't know which table the identifiers belong to, so we need to check both
let left_schema = left_rel.schema();
let right_schema = right_rel.schema();

// if the left side is in the left schema, then we assume the right side is in the right schema
if left_schema.get_field(&left).is_ok() {
(left, right)
// if the right side is in the left schema, then we assume the left side is in the right schema
} else if right_schema.get_field(&left).is_ok() {
(right, left)
} else {
unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left);
}

// switch left/right operands if the caller has them in reverse
if &left_rel.get_name() == tbl_b || &right_rel.get_name() == tbl_a {
Ok((vec![col(col_b.as_ref())], vec![col(col_a.as_ref())]))
} else {
Ok((vec![col(col_a.as_ref())], vec![col(col_b.as_ref())]))
}
} else {
unsupported_sql_err!("collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len());
}
_ => unsupported_sql_err!(
"collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}",
left.len(),
right.len()
),
};
Ok((vec![col(left)], vec![col(right)]))
}

fn process_join_on(
Expand All @@ -694,48 +739,47 @@ impl SQLPlanner {
if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression {
match *op {
BinaryOperator::Eq | BinaryOperator::Spaceship => {
if let (
let null_equals_null = *op == BinaryOperator::Spaceship;

match (left.as_ref(), right.as_ref()) {
(
sqlparser::ast::Expr::CompoundIdentifier(left),
sqlparser::ast::Expr::CompoundIdentifier(right),
) = (left.as_ref(), right.as_ref())
{
let null_equals_null = *op == BinaryOperator::Spaceship;
collect_compound_identifiers(left, right, left_rel, right_rel)
) => {
collect_idents(left, right, left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))
} else if let (
}
(
sqlparser::ast::Expr::Identifier(left),
sqlparser::ast::Expr::Identifier(right),
) = (left.as_ref(), right.as_ref())
{
let left = ident_to_str(left);
let right = ident_to_str(right);

// we don't know which table the identifiers belong to, so we need to check both
let left_schema = left_rel.schema();
let right_schema = right_rel.schema();

// if the left side is in the left schema, then we assume the right side is in the right schema
let (left_on, right_on) = if left_schema.get_field(&left).is_ok() {
(col(left), col(right))
// if the right side is in the left schema, then we assume the left side is in the right schema
} else if right_schema.get_field(&left).is_ok() {
(col(right), col(left))
} else {
unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left);
};
) =>{
collect_idents(&[left.clone()], &[right.clone()], left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))

let null_equals_null = *op == BinaryOperator::Spaceship;
}
(
sqlparser::ast::Expr::CompoundIdentifier(left),
sqlparser::ast::Expr::Identifier(right)
) => {
collect_idents(left, &[right.clone()], left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))

Ok((vec![left_on], vec![right_on], vec![null_equals_null]))
} else {
unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found `{left} {op} {right}`");
}
(
sqlparser::ast::Expr::Identifier(left),
sqlparser::ast::Expr::CompoundIdentifier(right)
) => {
collect_idents(&[left.clone()], right, left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))
}
_ => unsupported_sql_err!("process_join_on: Expected CompoundIdentifier, but found left: {:?}, right: {:?}", left, right),
}
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
}
BinaryOperator::And => {
let (mut left_i, mut right_i, mut null_equals_nulls_i) =
process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j, mut null_equals_nulls_j) =
process_join_on(left, left_rel, right_rel)?;
process_join_on(right, left_rel, right_rel)?;
left_i.append(&mut left_j);
right_i.append(&mut right_j);
null_equals_nulls_i.append(&mut null_equals_nulls_j);
Expand Down
32 changes: 22 additions & 10 deletions tests/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,27 @@ def test_joins_with_duplicate_columns():
assert actual.to_pydict() == expected


@pytest.mark.parametrize("join_condition", ["idx=idax", "idax=idx"])
def test_joins_without_compound_ident(join_condition):
df1 = daft.from_pydict({"idx": [1, None], "val": [10, 20]})
df2 = daft.from_pydict({"idax": [1, None], "score": [0.1, 0.2]})

catalog = SQLCatalog({"df1": df1, "df2": df2})

df_sql = daft.sql(f"select * from df1 join df2 on {join_condition}", catalog).to_pydict()

expected = {"idx": [1], "val": [10], "idax": [1], "score": [0.1]}
@pytest.mark.parametrize(
"join_condition",
[
"x = y",
"x = b.y",
"y = x",
"y = a.x",
"a.x = y",
"a.x = b.y",
"b.y = x",
"b.y = a.x",
],
)
def test_join_qualifiers(join_condition):
a = daft.from_pydict({"x": [1, None], "val": [10, 20]})
b = daft.from_pydict({"y": [1, None], "score": [0.1, 0.2]})

catalog = SQLCatalog({"a": a, "b": b})

df_sql = daft.sql(f"select * from a join b on {join_condition}", catalog).to_pydict()

expected = {"x": [1], "val": [10], "y": [1], "score": [0.1]}

assert df_sql == expected
Loading