From 70574aa58cfc86f8eeff8159ee04e2f1268c11e6 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 02:42:33 -0800 Subject: [PATCH] [FEAT] connect: support basic column operations --- src/daft-connect/src/translation.rs | 2 +- src/daft-connect/src/translation/datatype.rs | 154 +++++++++++++++++- src/daft-connect/src/translation/expr.rs | 68 +++++++- .../translation/expr/unresolved_function.rs | 28 ++++ tests/connect/test_basic_column.py | 36 ++++ 5 files changed, 282 insertions(+), 6 deletions(-) create mode 100644 tests/connect/test_basic_column.py diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index bb2d73b507..a03fe113a7 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -6,7 +6,7 @@ mod literal; mod logical_plan; mod schema; -pub use datatype::to_spark_datatype; +pub use datatype::{to_daft_datatype, to_spark_datatype}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; pub use logical_plan::to_logical_plan; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs index 9a40844464..d5f186c659 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/translation/datatype.rs @@ -1,4 +1,5 @@ -use daft_schema::dtype::DataType; +use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit}; +use eyre::{bail, ensure, WrapErr}; use spark_connect::data_type::Kind; use tracing::warn; @@ -112,3 +113,154 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { _ => unimplemented!("Unsupported datatype: {datatype:?}"), } } + +// todo(test): add tests for this esp in Python +pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result { + let Some(kind) = &datatype.kind else { + bail!("Datatype is required"); + }; + + let type_variation_err = "Custom type variation reference not supported"; + + match kind { + Kind::Null(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Null) + } + Kind::Binary(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Binary) + } + Kind::Boolean(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Boolean) + } + Kind::Byte(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int8) + } + Kind::Short(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int16) + } + Kind::Integer(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int32) + } + Kind::Long(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int64) + } + Kind::Float(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Float32) + } + Kind::Double(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Float64) + } + Kind::Decimal(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + + let Some(precision) = value.precision else { + bail!("Decimal precision is required"); + }; + + let Some(scale) = value.scale else { + bail!("Decimal scale is required"); + }; + + let precision = usize::try_from(precision) + .wrap_err("Decimal precision must be a non-negative integer")?; + + let scale = + usize::try_from(scale).wrap_err("Decimal scale must be a non-negative integer")?; + + Ok(DataType::Decimal128(precision, scale)) + } + Kind::String(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Utf8) + } + Kind::Char(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Utf8) + } + Kind::VarChar(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Utf8) + } + Kind::Date(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Date) + } + Kind::Timestamp(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + + // todo(?): is this correct? + + Ok(DataType::Timestamp(TimeUnit::Microseconds, None)) + } + Kind::TimestampNtz(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + + // todo(?): is this correct? + + Ok(DataType::Timestamp(TimeUnit::Microseconds, None)) + } + Kind::CalendarInterval(_) => bail!("Calendar interval type not supported"), + Kind::YearMonthInterval(_) => bail!("Year-month interval type not supported"), + Kind::DayTimeInterval(_) => bail!("Day-time interval type not supported"), + Kind::Array(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + let element_type = to_daft_datatype( + value + .element_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Array element type is required"))?, + )?; + Ok(DataType::List(Box::new(element_type))) + } + Kind::Struct(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + let fields = value + .fields + .iter() + .map(|f| { + let field_type = to_daft_datatype( + f.data_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Struct field type is required"))?, + )?; + Ok(Field::new(&f.name, field_type)) + }) + .collect::>>()?; + Ok(DataType::Struct(fields)) + } + Kind::Map(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + let key_type = to_daft_datatype( + value + .key_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Map key type is required"))?, + )?; + let value_type = to_daft_datatype( + value + .value_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Map value type is required"))?, + )?; + + let map = DataType::Map { + key: Box::new(key_type), + value: Box::new(value_type), + }; + + Ok(map) + } + Kind::Variant(_) => bail!("Variant type not supported"), + Kind::Udt(_) => bail!("User-defined type not supported"), + Kind::Unparsed(_) => bail!("Unparsed type not supported"), + } +} diff --git a/src/daft-connect/src/translation/expr.rs b/src/daft-connect/src/translation/expr.rs index bcbadf9737..f5307fae9d 100644 --- a/src/daft-connect/src/translation/expr.rs +++ b/src/daft-connect/src/translation/expr.rs @@ -1,11 +1,18 @@ use std::sync::Arc; use eyre::{bail, Context}; -use spark_connect::{expression as spark_expr, Expression}; +use spark_connect::{ + expression as spark_expr, + expression::{ + cast::{CastToType, EvalMode}, + sort_order::{NullOrdering, SortDirection}, + }, + Expression, +}; use tracing::warn; use unresolved_function::unresolved_to_daft_expr; -use crate::translation::to_daft_literal; +use crate::translation::{to_daft_datatype, to_daft_literal}; mod unresolved_function; @@ -69,11 +76,64 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result Ok(child.alias(name)) } - spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"), + spark_expr::ExprType::Cast(c) => { + // Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) } + // thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans + println!("got cast {c:?}"); + let spark_expr::Cast { + expr, + eval_mode, + cast_to_type, + } = &**c; + + let Some(expr) = expr else { + bail!("Cast expression is required"); + }; + + let expr = to_daft_expr(expr)?; + + let Some(cast_to_type) = cast_to_type else { + bail!("Cast to type is required"); + }; + + let data_type = match cast_to_type { + CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| { + format!("Failed to convert spark datatype to daft datatype: {kind:?}") + })?, + CastToType::TypeStr(s) => { + bail!("Cast to type string not yet supported; tried to cast to {s}"); + } + }; + + let eval_mode = EvalMode::try_from(*eval_mode) + .wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?; + + warn!("Ignoring cast eval mode: {eval_mode:?}"); + + Ok(expr.cast(&data_type)) + } spark_expr::ExprType::UnresolvedRegex(_) => { bail!("Unresolved regex expressions not yet supported") } - spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"), + spark_expr::ExprType::SortOrder(s) => { + let spark_expr::SortOrder { + child, + direction, + null_ordering, + } = &**s; + + let Some(_child) = child else { + bail!("Sort order child is required"); + }; + + let _sort_direction = SortDirection::try_from(*direction) + .wrap_err_with(|| format!("Invalid sort direction: {direction}"))?; + + let _sort_nulls = NullOrdering::try_from(*null_ordering) + .wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?; + + bail!("Sort order expressions not yet supported"); + } spark_expr::ExprType::LambdaFunction(_) => { bail!("Lambda function expressions not yet supported") } diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index ffb8c802ce..230924c5de 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -24,6 +24,8 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result handle_count(arguments).wrap_err("Failed to handle count function"), + "isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"), + "isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"), n => bail!("Unresolved function {n} not yet supported"), } } @@ -42,3 +44,29 @@ pub fn handle_count(arguments: Vec) -> eyre::Result) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + + Ok(arg.is_null()) +} + +pub fn handle_isnotnull(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + + Ok(arg.not_null()) +} diff --git a/tests/connect/test_basic_column.py b/tests/connect/test_basic_column.py new file mode 100644 index 0000000000..fefb41eb98 --- /dev/null +++ b/tests/connect/test_basic_column.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from pyspark.sql.functions import col +from pyspark.sql.types import StringType + + +def test_column_operations(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Test __getattr__ + df_attr = df.select(col("id").desc()) # Fix: call desc() as method + assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order" + + # Test __getitem__ + # df_item = df.select(col("id")[0]) + # assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element" + + # Test alias + df_alias = df.select(col("id").alias("my_number")) + assert "my_number" in df_alias.columns, "alias should rename column" + assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged" + + # Test cast + df_cast = df.select(col("id").cast(StringType())) + assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type" + + # Test isNotNull/isNull + df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null")) + assert df_null.toPandas()["not_null"].iloc[0] == True, "isNotNull should be True for non-null values" + assert df_null.toPandas()["is_null"].iloc[0] == False, "isNull should be False for non-null values" + + # Test name + df_name = df.select(col("id").name("renamed_id")) + assert "renamed_id" in df_name.columns, "name should rename column" + assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged"