diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index c9343dd601..53a0cfc923 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -4,12 +4,14 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, project::project, range::range, with_columns::with_columns, + aggregate::aggregate, project::project, range::range, set_op::set_op, + with_columns::with_columns, }; mod aggregate; mod project; mod range; +mod set_op; mod with_columns; pub fn to_logical_plan(relation: Relation) -> eyre::Result { @@ -31,6 +33,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::WithColumns(w) => { with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") } + RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/set_op.rs b/src/daft-connect/src/translation/logical_plan/set_op.rs new file mode 100644 index 0000000000..7dfeff9650 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/set_op.rs @@ -0,0 +1,57 @@ +use eyre::{bail, Context}; +use spark_connect::set_operation::SetOpType; +use tracing::warn; + +use crate::translation::to_logical_plan; + +pub fn set_op( + set_op: spark_connect::SetOperation, +) -> eyre::Result { + let spark_connect::SetOperation { + left_input, + right_input, + set_op_type, + is_all, + by_name, + allow_missing_columns, + } = set_op; + + let Some(left_input) = left_input else { + bail!("Left input is required"); + }; + + let Some(right_input) = right_input else { + bail!("Right input is required"); + }; + + let set_op = SetOpType::try_from(set_op_type) + .wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?; + + if let Some(by_name) = by_name { + warn!("Ignoring by_name: {by_name}"); + } + + if let Some(allow_missing_columns) = allow_missing_columns { + warn!("Ignoring allow_missing_columns: {allow_missing_columns}"); + } + + let left = to_logical_plan(*left_input)?; + let right = to_logical_plan(*right_input)?; + + let is_all = is_all.unwrap_or(false); + + match set_op { + SetOpType::Unspecified => { + bail!("Unspecified set operation is not supported"); + } + SetOpType::Intersect => left + .intersect(&right, is_all) + .wrap_err("Failed to apply intersect to logical plan"), + SetOpType::Union => left + .union(&right, is_all) + .wrap_err("Failed to apply union to logical plan"), + SetOpType::Except => { + bail!("Except set operation is not supported"); + } + } +} diff --git a/tests/connect/test_intersection.py b/tests/connect/test_intersection.py new file mode 100644 index 0000000000..7944de5cae --- /dev/null +++ b/tests/connect/test_intersection.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_intersection(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Intersect the two ranges + intersected = range1.intersect(range2) + + # Collect results + results = intersected.collect() + + # Verify the DataFrame has expected values + # Intersection should only include overlapping values once + assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)" + + # Check that all expected values are present + values = [row.id for row in results] + assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence" diff --git a/tests/connect/test_union.py b/tests/connect/test_union.py new file mode 100644 index 0000000000..9ac235d9e5 --- /dev/null +++ b/tests/connect/test_union.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_union(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Collect results + results = unioned.collect() + + # Verify the DataFrame has expected values + # Union includes duplicates, so length should be sum of both ranges + assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)" + + # Check that all expected values are present, including duplicates + values = [row.id for row in results] + assert sorted(values) == [0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9], "Values should match expected sequence with duplicates"