Skip to content

Commit

Permalink
[FEAT] connect: add df.{intersection,union}
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent a58fecf commit c57e270
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, project::project, range::range, with_columns::with_columns,
aggregate::aggregate, project::project, range::range, set_op::set_op,
with_columns::with_columns,
};

mod aggregate;
mod project;
mod range;
mod set_op;
mod with_columns;

pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
Expand All @@ -31,6 +33,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::WithColumns(w) => {
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
}
RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
57 changes: 57 additions & 0 deletions src/daft-connect/src/translation/logical_plan/set_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use eyre::{bail, Context};
use spark_connect::set_operation::SetOpType;
use tracing::warn;

use crate::translation::to_logical_plan;

pub fn set_op(
set_op: spark_connect::SetOperation,
) -> eyre::Result<daft_logical_plan::LogicalPlanBuilder> {
let spark_connect::SetOperation {
left_input,
right_input,
set_op_type,
is_all,
by_name,
allow_missing_columns,
} = set_op;

let Some(left_input) = left_input else {
bail!("Left input is required");
};

let Some(right_input) = right_input else {
bail!("Right input is required");
};

let set_op = SetOpType::try_from(set_op_type)
.wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?;

if let Some(by_name) = by_name {
warn!("Ignoring by_name: {by_name}");
}

if let Some(allow_missing_columns) = allow_missing_columns {
warn!("Ignoring allow_missing_columns: {allow_missing_columns}");
}

let left = to_logical_plan(*left_input)?;
let right = to_logical_plan(*right_input)?;

let is_all = is_all.unwrap_or(false);

match set_op {
SetOpType::Unspecified => {
bail!("Unspecified set operation is not supported");
}
SetOpType::Intersect => left
.intersect(&right, is_all)
.wrap_err("Failed to apply intersect to logical plan"),
SetOpType::Union => left
.union(&right, is_all)
.wrap_err("Failed to apply union to logical plan"),
SetOpType::Except => {
bail!("Except set operation is not supported");
}
}
}
21 changes: 21 additions & 0 deletions tests/connect/test_intersection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_intersection(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Intersect the two ranges
intersected = range1.intersect(range2)

# Collect results
results = intersected.collect()

# Verify the DataFrame has expected values
# Intersection should only include overlapping values once
assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)"

# Check that all expected values are present
values = [row.id for row in results]
assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence"
21 changes: 21 additions & 0 deletions tests/connect/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_union(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges
unioned = range1.union(range2)

# Collect results
results = unioned.collect()

# Verify the DataFrame has expected values
# Union includes duplicates, so length should be sum of both ranges
assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)"

# Check that all expected values are present, including duplicates
values = [row.id for row in results]
assert sorted(values) == [0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9], "Values should match expected sequence with duplicates"

0 comments on commit c57e270

Please sign in to comment.