-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] connect: add
df.{intersection,union}
- Loading branch information
1 parent
6c483e4
commit c0bac90
Showing
5 changed files
with
103 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
use eyre::{bail, Context}; | ||
use spark_connect::set_operation::SetOpType; | ||
use tracing::warn; | ||
|
||
use crate::translation::to_logical_plan; | ||
|
||
pub fn set_op( | ||
set_op: spark_connect::SetOperation, | ||
) -> eyre::Result<daft_logical_plan::LogicalPlanBuilder> { | ||
let spark_connect::SetOperation { | ||
left_input, | ||
right_input, | ||
set_op_type, | ||
is_all, | ||
by_name, | ||
allow_missing_columns, | ||
} = set_op; | ||
|
||
let Some(left_input) = left_input else { | ||
bail!("Left input is required"); | ||
}; | ||
|
||
let Some(right_input) = right_input else { | ||
bail!("Right input is required"); | ||
}; | ||
|
||
let set_op = SetOpType::try_from(set_op_type) | ||
.wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?; | ||
|
||
if let Some(by_name) = by_name { | ||
warn!("Ignoring by_name: {by_name}"); | ||
} | ||
|
||
if let Some(allow_missing_columns) = allow_missing_columns { | ||
warn!("Ignoring allow_missing_columns: {allow_missing_columns}"); | ||
} | ||
|
||
let left = to_logical_plan(*left_input)?; | ||
let right = to_logical_plan(*right_input)?; | ||
|
||
let is_all = is_all.unwrap_or(false); | ||
|
||
match set_op { | ||
SetOpType::Unspecified => { | ||
bail!("Unspecified set operation is not supported"); | ||
} | ||
SetOpType::Intersect => left | ||
.intersect(&right, is_all) | ||
.wrap_err("Failed to apply intersect to logical plan"), | ||
SetOpType::Union => left | ||
.union(&right, is_all) | ||
.wrap_err("Failed to apply union to logical plan"), | ||
SetOpType::Except => { | ||
bail!("Except set operation is not supported"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
|
||
def test_intersection(spark_session): | ||
# Create ranges using Spark - with overlap | ||
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 | ||
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 | ||
|
||
# Intersect the two ranges | ||
intersected = range1.intersect(range2) | ||
|
||
# Collect results | ||
results = intersected.collect() | ||
|
||
# Verify the DataFrame has expected values | ||
# Intersection should only include overlapping values once | ||
assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)" | ||
|
||
# Check that all expected values are present | ||
values = [row.id for row in results] | ||
assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
|
||
def test_union(spark_session): | ||
# Create ranges using Spark - with overlap | ||
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 | ||
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 | ||
|
||
# Union the two ranges | ||
unioned = range1.union(range2) | ||
|
||
# Collect results | ||
results = unioned.collect() | ||
|
||
# Verify the DataFrame has expected values | ||
# Union includes duplicates, so length should be sum of both ranges | ||
assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)" | ||
|
||
# Check that all expected values are present, including duplicates | ||
values = [row.id for row in results] | ||
assert sorted(values) == [0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9], "Values should match expected sequence with duplicates" |