From 62c976f50ecbe721827ca2cdf94eda99cca6f116 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 19 Nov 2024 20:05:28 -0800 Subject: [PATCH] [FEAT] connect: add alias support --- Cargo.lock | 3 + Cargo.toml | 3 + src/daft-connect/Cargo.toml | 5 +- src/daft-connect/src/lib.rs | 12 +- src/daft-connect/src/translation.rs | 6 + src/daft-connect/src/translation/datatype.rs | 114 ++++++++++++++++++ src/daft-connect/src/translation/expr.rs | 105 ++++++++++++++++ .../translation/expr/unresolved_function.rs | 44 +++++++ src/daft-connect/src/translation/literal.rs | 52 ++++++++ .../src/translation/logical_plan.rs | 63 ++-------- .../src/translation/logical_plan/aggregate.rs | 72 +++++++++++ .../src/translation/logical_plan/project.rs | 26 ++++ .../src/translation/logical_plan/range.rs | 55 +++++++++ src/daft-connect/src/translation/schema.rs | 69 +++++------ src/daft-dsl/src/lit.rs | 6 + tests/connect/test_alias.py | 21 ++++ 16 files changed, 556 insertions(+), 100 deletions(-) create mode 100644 src/daft-connect/src/translation/datatype.rs create mode 100644 src/daft-connect/src/translation/expr.rs create mode 100644 src/daft-connect/src/translation/expr/unresolved_function.rs create mode 100644 src/daft-connect/src/translation/literal.rs create mode 100644 src/daft-connect/src/translation/logical_plan/aggregate.rs create mode 100644 src/daft-connect/src/translation/logical_plan/project.rs create mode 100644 src/daft-connect/src/translation/logical_plan/range.rs create mode 100644 tests/connect/test_alias.py diff --git a/Cargo.lock b/Cargo.lock index 36e3598748..87fae1cd72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1934,10 +1934,13 @@ dependencies = [ "arrow2", "async-stream", "common-daft-config", + "daft-core", + "daft-dsl", "daft-local-execution", "daft-local-plan", "daft-logical-plan", "daft-scan", + "daft-schema", "daft-table", "dashmap", "eyre", diff --git a/Cargo.toml b/Cargo.toml index 79f933dad9..be1146166a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,11 +194,14 @@ chrono-tz = "0.8.4" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +daft-core = {path = "src/daft-core"} +daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} daft-local-plan = {path = "src/daft-local-plan"} daft-logical-plan = {path = "src/daft-logical-plan"} daft-scan = {path = "src/daft-scan"} +daft-schema = {path = "src/daft-schema"} daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index f94cb284be..47a718465f 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -2,10 +2,13 @@ arrow2 = {workspace = true} async-stream = "0.3.6" common-daft-config = {workspace = true} +daft-core = {workspace = true} +daft-dsl = {workspace = true} daft-local-execution = {workspace = true} daft-local-plan = {workspace = true} daft-logical-plan = {workspace = true} daft-scan = {workspace = true} +daft-schema = {workspace = true} daft-table = {workspace = true} dashmap = "6.1.0" eyre = "0.6.12" @@ -19,7 +22,7 @@ tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} [features] -python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-local-plan/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python"] +python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-local-plan/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python"] [lints] workspace = true diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 882b2af1af..70171ad0d4 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -22,7 +22,7 @@ use spark_connect::{ ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, }; use tonic::{transport::Server, Request, Response, Status}; -use tracing::info; +use tracing::{debug, info}; use uuid::Uuid; use crate::session::Session; @@ -309,22 +309,22 @@ impl SparkConnectService for DaftSparkConnectService { Ok(schema) => schema, Err(e) => { return invalid_argument_err!( - "Failed to translate relation to schema: {e}" + "Failed to translate relation to schema: {e:?}" ); } }; - let schema = analyze_plan_response::DdlParse { - parsed: Some(result), + let schema = analyze_plan_response::Schema { + schema: Some(result), }; let response = AnalyzePlanResponse { session_id, server_side_session_id: String::new(), - result: Some(analyze_plan_response::Result::DdlParse(schema)), + result: Some(analyze_plan_response::Result::Schema(schema)), }; - println!("response: {response:#?}"); + debug!("response: {response:#?}"); Ok(Response::new(response)) } diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index 125aa6e884..bb2d73b507 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -1,7 +1,13 @@ //! Translation between Spark Connect and Daft +mod datatype; +mod expr; +mod literal; mod logical_plan; mod schema; +pub use datatype::to_spark_datatype; +pub use expr::to_daft_expr; +pub use literal::to_daft_literal; pub use logical_plan::to_logical_plan; pub use schema::relation_to_schema; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs new file mode 100644 index 0000000000..9a40844464 --- /dev/null +++ b/src/daft-connect/src/translation/datatype.rs @@ -0,0 +1,114 @@ +use daft_schema::dtype::DataType; +use spark_connect::data_type::Kind; +use tracing::warn; + +pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { + match datatype { + DataType::Null => spark_connect::DataType { + kind: Some(Kind::Null(spark_connect::data_type::Null { + type_variation_reference: 0, + })), + }, + DataType::Boolean => spark_connect::DataType { + kind: Some(Kind::Boolean(spark_connect::data_type::Boolean { + type_variation_reference: 0, + })), + }, + DataType::Int8 => spark_connect::DataType { + kind: Some(Kind::Byte(spark_connect::data_type::Byte { + type_variation_reference: 0, + })), + }, + DataType::Int16 => spark_connect::DataType { + kind: Some(Kind::Short(spark_connect::data_type::Short { + type_variation_reference: 0, + })), + }, + DataType::Int32 => spark_connect::DataType { + kind: Some(Kind::Integer(spark_connect::data_type::Integer { + type_variation_reference: 0, + })), + }, + DataType::Int64 => spark_connect::DataType { + kind: Some(Kind::Long(spark_connect::data_type::Long { + type_variation_reference: 0, + })), + }, + DataType::UInt8 => spark_connect::DataType { + kind: Some(Kind::Byte(spark_connect::data_type::Byte { + type_variation_reference: 0, + })), + }, + DataType::UInt16 => spark_connect::DataType { + kind: Some(Kind::Short(spark_connect::data_type::Short { + type_variation_reference: 0, + })), + }, + DataType::UInt32 => spark_connect::DataType { + kind: Some(Kind::Integer(spark_connect::data_type::Integer { + type_variation_reference: 0, + })), + }, + DataType::UInt64 => spark_connect::DataType { + kind: Some(Kind::Long(spark_connect::data_type::Long { + type_variation_reference: 0, + })), + }, + DataType::Float32 => spark_connect::DataType { + kind: Some(Kind::Float(spark_connect::data_type::Float { + type_variation_reference: 0, + })), + }, + DataType::Float64 => spark_connect::DataType { + kind: Some(Kind::Double(spark_connect::data_type::Double { + type_variation_reference: 0, + })), + }, + DataType::Decimal128(precision, scale) => spark_connect::DataType { + kind: Some(Kind::Decimal(spark_connect::data_type::Decimal { + scale: Some(*scale as i32), + precision: Some(*precision as i32), + type_variation_reference: 0, + })), + }, + DataType::Timestamp(unit, _) => { + warn!("Ignoring time unit {unit:?} for timestamp type"); + spark_connect::DataType { + kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp { + type_variation_reference: 0, + })), + } + } + DataType::Date => spark_connect::DataType { + kind: Some(Kind::Date(spark_connect::data_type::Date { + type_variation_reference: 0, + })), + }, + DataType::Binary => spark_connect::DataType { + kind: Some(Kind::Binary(spark_connect::data_type::Binary { + type_variation_reference: 0, + })), + }, + DataType::Utf8 => spark_connect::DataType { + kind: Some(Kind::String(spark_connect::data_type::String { + type_variation_reference: 0, + collation: String::new(), // todo(correctness): is this correct? + })), + }, + DataType::Struct(fields) => spark_connect::DataType { + kind: Some(Kind::Struct(spark_connect::data_type::Struct { + fields: fields + .iter() + .map(|f| spark_connect::data_type::StructField { + name: f.name.clone(), + data_type: Some(to_spark_datatype(&f.dtype)), + nullable: true, // todo(correctness): is this correct? + metadata: None, + }) + .collect(), + type_variation_reference: 0, + })), + }, + _ => unimplemented!("Unsupported datatype: {datatype:?}"), + } +} diff --git a/src/daft-connect/src/translation/expr.rs b/src/daft-connect/src/translation/expr.rs new file mode 100644 index 0000000000..299a770d3c --- /dev/null +++ b/src/daft-connect/src/translation/expr.rs @@ -0,0 +1,105 @@ +use std::sync::Arc; + +use eyre::{bail, Context}; +use spark_connect::{expression as spark_expr, Expression}; +use tracing::warn; +use unresolved_function::unresolved_to_daft_expr; + +use crate::translation::to_daft_literal; + +mod unresolved_function; + +pub fn to_daft_expr(expression: Expression) -> eyre::Result { + if let Some(common) = expression.common { + warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + }; + + let Some(expr) = expression.expr_type else { + bail!("Expression is required"); + }; + + match expr { + spark_expr::ExprType::Literal(l) => to_daft_literal(l), + spark_expr::ExprType::UnresolvedAttribute(attr) => { + let spark_expr::UnresolvedAttribute { + unparsed_identifier, + plan_id, + is_metadata_column, + } = attr; + + if let Some(plan_id) = plan_id { + warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); + } + + if let Some(is_metadata_column) = is_metadata_column { + warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); + } + + Ok(daft_dsl::col(unparsed_identifier)) + } + spark_expr::ExprType::UnresolvedFunction(f) => { + unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function") + } + spark_expr::ExprType::ExpressionString(_) => bail!("Expression string not yet supported"), + spark_expr::ExprType::UnresolvedStar(_) => { + bail!("Unresolved star expressions not yet supported") + } + spark_expr::ExprType::Alias(alias) => { + let spark_expr::Alias { + expr, + name, + metadata, + } = *alias; + + let Some(expr) = expr else { + bail!("Alias expr is required"); + }; + + let [name] = name.as_slice() else { + bail!("Alias name is required and currently only works with a single string; got {name:?}"); + }; + + if let Some(metadata) = metadata { + bail!("Alias metadata is not yet supported; got {metadata:?}"); + } + + let child = to_daft_expr(*expr)?; + + let name = Arc::from(name.as_str()); + + Ok(child.alias(name)) + } + spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"), + spark_expr::ExprType::UnresolvedRegex(_) => { + bail!("Unresolved regex expressions not yet supported") + } + spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"), + spark_expr::ExprType::LambdaFunction(_) => { + bail!("Lambda function expressions not yet supported") + } + spark_expr::ExprType::Window(_) => bail!("Window expressions not yet supported"), + spark_expr::ExprType::UnresolvedExtractValue(_) => { + bail!("Unresolved extract value expressions not yet supported") + } + spark_expr::ExprType::UpdateFields(_) => { + bail!("Update fields expressions not yet supported") + } + spark_expr::ExprType::UnresolvedNamedLambdaVariable(_) => { + bail!("Unresolved named lambda variable expressions not yet supported") + } + spark_expr::ExprType::CommonInlineUserDefinedFunction(_) => { + bail!("Common inline user defined function expressions not yet supported") + } + spark_expr::ExprType::CallFunction(_) => { + bail!("Call function expressions not yet supported") + } + spark_expr::ExprType::NamedArgumentExpression(_) => { + bail!("Named argument expressions not yet supported") + } + spark_expr::ExprType::MergeAction(_) => bail!("Merge action expressions not yet supported"), + spark_expr::ExprType::TypedAggregateExpression(_) => { + bail!("Typed aggregate expressions not yet supported") + } + spark_expr::ExprType::Extension(_) => bail!("Extension expressions not yet supported"), + } +} diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs new file mode 100644 index 0000000000..e81b13057e --- /dev/null +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -0,0 +1,44 @@ +use daft_core::count_mode::CountMode; +use eyre::{bail, Context}; +use spark_connect::expression::UnresolvedFunction; + +use crate::translation::to_daft_expr; + +pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result { + let UnresolvedFunction { + function_name, + arguments, + is_distinct, + is_user_defined_function, + } = f; + + let arguments: Vec<_> = arguments.into_iter().map(to_daft_expr).try_collect()?; + + if is_distinct { + bail!("Distinct not yet supported"); + } + + if is_user_defined_function { + bail!("User-defined functions not yet supported"); + } + + match function_name.as_str() { + "count" => handle_count(arguments).wrap_err("Failed to handle count function"), + n => bail!("Unresolved function {n} not yet supported"), + } +} + +pub fn handle_count(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + + let count = arg.count(CountMode::All); + + Ok(count) +} diff --git a/src/daft-connect/src/translation/literal.rs b/src/daft-connect/src/translation/literal.rs new file mode 100644 index 0000000000..19895b047c --- /dev/null +++ b/src/daft-connect/src/translation/literal.rs @@ -0,0 +1,52 @@ +use daft_core::datatypes::IntervalValue; +use eyre::bail; +use spark_connect::expression::{literal::LiteralType, Literal}; + +// todo(test): add tests for this esp in Python +pub fn to_daft_literal(literal: Literal) -> eyre::Result { + let Some(literal) = literal.literal_type else { + bail!("Literal is required"); + }; + + match literal { + LiteralType::Array(_) => bail!("Array literals not yet supported"), + LiteralType::Binary(bytes) => Ok(daft_dsl::lit(bytes.as_slice())), + LiteralType::Boolean(b) => Ok(daft_dsl::lit(b)), + LiteralType::Byte(_) => bail!("Byte literals not yet supported"), + LiteralType::CalendarInterval(_) => { + bail!("Calendar interval literals not yet supported") + } + LiteralType::Date(d) => Ok(daft_dsl::lit(d)), + LiteralType::DayTimeInterval(_) => { + bail!("Day-time interval literals not yet supported") + } + LiteralType::Decimal(_) => bail!("Decimal literals not yet supported"), + LiteralType::Double(d) => Ok(daft_dsl::lit(d)), + LiteralType::Float(f) => { + let f = f64::from(f); + Ok(daft_dsl::lit(f)) + } + LiteralType::Integer(i) => Ok(daft_dsl::lit(i)), + LiteralType::Long(l) => Ok(daft_dsl::lit(l)), + LiteralType::Map(_) => bail!("Map literals not yet supported"), + LiteralType::Null(_) => { + // todo(correctness): is it ok to assume type is i32 here? + Ok(daft_dsl::null_lit()) + } + LiteralType::Short(_) => bail!("Short literals not yet supported"), + LiteralType::String(s) => Ok(daft_dsl::lit(s)), + LiteralType::Struct(_) => bail!("Struct literals not yet supported"), + LiteralType::Timestamp(ts) => { + // todo(correctness): is it ok that the type is different logically? + Ok(daft_dsl::lit(ts)) + } + LiteralType::TimestampNtz(ts) => { + // todo(correctness): is it ok that the type is different logically? + Ok(daft_dsl::lit(ts)) + } + LiteralType::YearMonthInterval(value) => { + let interval = IntervalValue::new(value, 0, 0); + Ok(daft_dsl::lit(interval)) + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 58255e2ef9..947e0cd0d3 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,8 +1,14 @@ use daft_logical_plan::LogicalPlanBuilder; -use eyre::{bail, ensure, Context}; -use spark_connect::{relation::RelType, Range, Relation}; +use eyre::{bail, Context}; +use spark_connect::{relation::RelType, Relation}; use tracing::warn; +use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; + +mod aggregate; +mod project; +mod range; + pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); @@ -14,55 +20,10 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { match rel_type { RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), + RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"), + RelType::Aggregate(a) => { + aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") + } plan => bail!("Unsupported relation type: {plan:?}"), } } - -fn range(range: Range) -> eyre::Result { - #[cfg(not(feature = "python"))] - bail!("Range operations require Python feature to be enabled"); - - #[cfg(feature = "python")] - { - use daft_scan::python::pylib::ScanOperatorHandle; - use pyo3::prelude::*; - let Range { - start, - end, - step, - num_partitions, - } = range; - - let partitions = num_partitions.unwrap_or(1); - - ensure!(partitions > 0, "num_partitions must be greater than 0"); - - let start = start.unwrap_or(0); - - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); - - let plan = Python::with_gil(|py| { - let range_module = PyModule::import_bound(py, "daft.io._range") - .wrap_err("Failed to import range module")?; - - let range = range_module - .getattr(pyo3::intern!(py, "RangeScanOperator")) - .wrap_err("Failed to get range function")?; - - let range = range - .call1((start, end, step, partitions)) - .wrap_err("Failed to create range scan operator")? - .to_object(py); - - let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; - - let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; - - eyre::Result::<_>::Ok(plan) - }) - .wrap_err("Failed to create range scan")?; - - Ok(plan) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs new file mode 100644 index 0000000000..a355449f20 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs @@ -0,0 +1,72 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, WrapErr}; +use spark_connect::aggregate::GroupType; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { + let spark_connect::Aggregate { + input, + group_type, + grouping_expressions, + aggregate_expressions, + pivot, + grouping_sets, + } = aggregate; + + let Some(input) = input else { + bail!("input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let group_type = GroupType::try_from(group_type) + .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?; + + assert_groupby(group_type)?; + + if let Some(pivot) = pivot { + bail!("Pivot not yet supported; got {pivot:?}"); + } + + if !grouping_sets.is_empty() { + bail!("Grouping sets not yet supported; got {grouping_sets:?}"); + } + + let grouping_expressions: Vec<_> = grouping_expressions + .into_iter() + .map(to_daft_expr) + .try_collect()?; + + let aggregate_expressions: Vec<_> = aggregate_expressions + .into_iter() + .map(to_daft_expr) + .try_collect()?; + + let plan = plan + .aggregate(aggregate_expressions.clone(), grouping_expressions.clone()) + .wrap_err_with(|| format!("Failed to apply aggregate to logical plan aggregate_expressions={aggregate_expressions:?} grouping_expressions={grouping_expressions:?}"))?; + + Ok(plan) +} + +fn assert_groupby(plan: GroupType) -> eyre::Result<()> { + match plan { + GroupType::Unspecified => { + bail!("GroupType must be specified; got Unspecified") + } + GroupType::Groupby => Ok(()), + GroupType::Rollup => { + bail!("Rollup not yet supported") + } + GroupType::Cube => { + bail!("Cube not yet supported") + } + GroupType::Pivot => { + bail!("Pivot not yet supported") + } + GroupType::GroupingSets => { + bail!("GroupingSets not yet supported") + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs new file mode 100644 index 0000000000..1a5614cbc2 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/project.rs @@ -0,0 +1,26 @@ +//! Project operation for selecting and manipulating columns from a dataset +//! +//! TL;DR: Project is Spark's equivalent of SQL SELECT - it selects columns, renames them via aliases, +//! and creates new columns from expressions. Example: `df.select(col("id").alias("my_number"))` + +use daft_logical_plan::LogicalPlanBuilder; +use eyre::bail; +use spark_connect::Project; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn project(project: Project) -> eyre::Result { + let Project { input, expressions } = project; + + let Some(input) = input else { + bail!("Project input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let daft_exprs: Vec<_> = expressions.into_iter().map(to_daft_expr).try_collect()?; + + let plan = plan.select(daft_exprs)?; + + Ok(plan) +} diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs new file mode 100644 index 0000000000..e11fef26cb --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/range.rs @@ -0,0 +1,55 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{ensure, Context}; +use spark_connect::Range; + +pub fn range(range: Range) -> eyre::Result { + #[cfg(not(feature = "python"))] + { + use eyre::bail; + bail!("Range operations require Python feature to be enabled"); + } + + #[cfg(feature = "python")] + { + use daft_scan::python::pylib::ScanOperatorHandle; + use pyo3::prelude::*; + let Range { + start, + end, + step, + num_partitions, + } = range; + + let partitions = num_partitions.unwrap_or(1); + + ensure!(partitions > 0, "num_partitions must be greater than 0"); + + let start = start.unwrap_or(0); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let plan = Python::with_gil(|py| { + let range_module = PyModule::import_bound(py, "daft.io._range") + .wrap_err("Failed to import range module")?; + + let range = range_module + .getattr(pyo3::intern!(py, "RangeScanOperator")) + .wrap_err("Failed to get range function")?; + + let range = range + .call1((start, end, step, partitions)) + .wrap_err("Failed to create range scan operator")? + .to_object(py); + + let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; + + let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; + + eyre::Result::<_>::Ok(plan) + }) + .wrap_err("Failed to create range scan")?; + + Ok(plan) + } +} diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index de28a587fc..1b242428d2 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -1,54 +1,39 @@ -use eyre::bail; use spark_connect::{ - data_type::{Kind, Long, Struct, StructField}, - relation::RelType, + data_type::{Kind, Struct, StructField}, DataType, Relation, }; use tracing::warn; +use crate::translation::{to_logical_plan, to_spark_datatype}; + #[tracing::instrument(skip_all)] pub fn relation_to_schema(input: Relation) -> eyre::Result { if input.common.is_some() { warn!("We do not currently look at common fields"); } - let result = match input - .rel_type - .ok_or_else(|| tonic::Status::internal("rel_type is None"))? - { - RelType::Range(spark_connect::Range { num_partitions, .. }) => { - if num_partitions.is_some() { - warn!("We do not currently support num_partitions"); - } - - let long = Long { - type_variation_reference: 0, - }; - - let id_field = StructField { - name: "id".to_string(), - data_type: Some(DataType { - kind: Some(Kind::Long(long)), - }), - nullable: false, - metadata: None, - }; - - let fields = vec![id_field]; - - let strct = Struct { - fields, - type_variation_reference: 0, - }; - - DataType { - kind: Some(Kind::Struct(strct)), - } - } - other => { - bail!("Unsupported relation type: {other:?}"); - } - }; - - Ok(result) + let plan = to_logical_plan(input)?; + + let result = plan.schema(); + + let fields: eyre::Result> = result + .fields + .iter() + .map(|(name, field)| { + let field_type = to_spark_datatype(&field.dtype); + Ok(StructField { + name: name.clone(), // todo(correctness): name vs field.name... will they always be the same? + data_type: Some(field_type), + nullable: true, // todo(correctness): is this correct? + metadata: None, // todo(completeness): might want to add metadata here + }) + }) + .collect(); + + Ok(DataType { + kind: Some(Kind::Struct(Struct { + fields: fields?, + type_variation_reference: 0, + })), + }) } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 8f0fba1fec..1d86442aef 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -366,6 +366,12 @@ pub trait Literal: Sized { fn literal_value(self) -> LiteralValue; } +impl Literal for IntervalValue { + fn literal_value(self) -> LiteralValue { + LiteralValue::Interval(self) + } +} + impl Literal for String { fn literal_value(self) -> LiteralValue { LiteralValue::Utf8(self) diff --git a/tests/connect/test_alias.py b/tests/connect/test_alias.py new file mode 100644 index 0000000000..94efb35fc2 --- /dev/null +++ b/tests/connect/test_alias.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_alias(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Simply rename the 'id' column to 'my_number' + df_renamed = df.select(col("id").alias("my_number")) + + # Verify the alias was set correctly + assert df_renamed.schema != df.schema, "Schema should be changed after alias" + + # Verify the data is unchanged but column name is different + df_rows = df.collect() + df_renamed_rows = df_renamed.collect() + assert [row.id for row in df_rows] == [ + row.my_number for row in df_renamed_rows + ], "Data should be unchanged after alias"