diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 439f5bd551..df0ae6c2a4 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -18,6 +18,8 @@ use futures::TryStreamExt; use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; +use crate::translation::logical_plan::with_columns_renamed::with_columns_renamed; + mod aggregate; mod drop; mod filter; @@ -27,6 +29,7 @@ mod range; mod read; mod to_df; mod with_columns; +mod with_columns_renamed; pub struct SparkAnalyzer<'a> { pub psets: &'a InMemoryPartitionSetCache, @@ -110,6 +113,9 @@ impl SparkAnalyzer<'_> { self.local_relation(plan_id, l) .wrap_err("Failed to apply local_relation to logical plan") } + RelType::WithColumnsRenamed(w) => with_columns_renamed(*w) + .await + .wrap_err("Failed to apply with_columns_renamed to logical plan"), RelType::Read(r) => read::read(r) .await .wrap_err("Failed to apply read to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs new file mode 100644 index 0000000000..01c6493974 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs @@ -0,0 +1,45 @@ +use daft_dsl::col; +use eyre::{bail, Context}; + +use crate::translation::Plan; + +pub async fn with_columns_renamed( + with_columns_renamed: spark_connect::WithColumnsRenamed, +) -> eyre::Result { + let spark_connect::WithColumnsRenamed { + input, + rename_columns_map, + renames, + } = with_columns_renamed; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = Box::pin(crate::translation::to_logical_plan(*input)).await?; + + // todo: let's implement this directly into daft + + // Convert the rename mappings into expressions + let rename_exprs = if !rename_columns_map.is_empty() { + // Use rename_columns_map if provided (legacy format) + rename_columns_map + .into_iter() + .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) + .collect() + } else { + // Use renames if provided (new format) + renames + .into_iter() + .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) + .collect() + }; + + // Apply the rename expressions to the plan + plan.builder = plan + .builder + .select(rename_exprs) + .wrap_err("Failed to apply rename expressions to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_with_columns_renamed.py b/tests/connect/test_with_columns_renamed.py new file mode 100644 index 0000000000..124f142ca2 --- /dev/null +++ b/tests/connect/test_with_columns_renamed.py @@ -0,0 +1,24 @@ +from __future__ import annotations + + +def test_with_columns_renamed(spark_session): + # Test withColumnRenamed + df = spark_session.range(5) + renamed_df = df.withColumnRenamed("id", "number") + + collected = renamed_df.collect() + assert len(collected) == 5 + assert "number" in renamed_df.columns + assert "id" not in renamed_df.columns + assert [row["number"] for row in collected] == list(range(5)) + + # todo: this edge case is a spark connect bug; it will only send rename of id -> character over protobuf + # # Test withColumnsRenamed + # df = spark_session.range(2) + # renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"}) + # + # collected = renamed_df.collect() + # assert len(collected) == 2 + # assert set(renamed_df.columns) == {"number", "character"} + # assert "id" not in renamed_df.columns + # assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)]