From 1cc57cae079c429583e5b1e264b17abd96a4d158 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 16 Dec 2024 10:57:36 -0800 Subject: [PATCH 1/5] feat(connect): support `DdlParse` --- Cargo.lock | 1 + Cargo.toml | 1 + src/daft-connect/Cargo.toml | 1 + src/daft-connect/src/lib.rs | 25 +++++- src/daft-schema/src/schema.rs | 7 +- src/daft-sql/src/lib.rs | 3 + src/daft-sql/src/planner.rs | 140 ++++++++++++++++++++++++++++- tests/connect/test_analyze_plan.py | 7 ++ 8 files changed, 179 insertions(+), 6 deletions(-) create mode 100644 tests/connect/test_analyze_plan.py diff --git a/Cargo.lock b/Cargo.lock index fd01681f0b..3011a56b24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1988,6 +1988,7 @@ dependencies = [ "daft-micropartition", "daft-scan", "daft-schema", + "daft-sql", "daft-table", "dashmap", "eyre", diff --git a/Cargo.toml b/Cargo.toml index b6f5284a60..d5a5cf218d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -204,6 +204,7 @@ daft-logical-plan = {path = "src/daft-logical-plan"} daft-micropartition = {path = "src/daft-micropartition"} daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} +daft-sql = {path = "src/daft-sql"} daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index b1d1f63052..22ddfe04bc 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -11,6 +11,7 @@ daft-logical-plan = {workspace = true} daft-micropartition = {workspace = true} daft-scan = {workspace = true} daft-schema = {workspace = true} +daft-sql = {workspace = true} daft-table = {workspace = true} dashmap = "6.1.0" eyre = "0.6.12" diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 439a74dc57..7b8432652e 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -6,6 +6,7 @@ #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] +use daft_sql::error::SQLPlannerResult; use dashmap::DashMap; use eyre::Context; #[cfg(feature = "python")] @@ -323,7 +324,29 @@ impl SparkConnectService for DaftSparkConnectService { Ok(Response::new(response)) } - _ => unimplemented_err!("Analyze plan operation is not yet implemented"), + Analyze::DdlParse(DdlParse { ddl_string }) => { + let daft_schema = match daft_sql::sql_schema(&ddl_string) { + Ok(daft_schema) => daft_schema, + Err(e) => return invalid_argument_err!("{e}"), + }; + + let daft_schema = daft_schema.to_struct(); + + let schema = translation::to_spark_datatype(&daft_schema); + + let schema = analyze_plan_response::Schema { + schema: Some(schema), + }; + + let response = AnalyzePlanResponse { + session_id, + server_side_session_id: String::new(), + result: Some(analyze_plan_response::Result::Schema(schema)), + }; + + Ok(Response::new(response)) + } + other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"), } } diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index af8eb77e96..a1fc464e96 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -13,7 +13,7 @@ use derive_more::Display; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use crate::field::Field; +use crate::{field::Field, prelude::DataType}; pub type SchemaRef = Arc; @@ -48,6 +48,11 @@ impl Schema { Ok(Self { fields: map }) } + pub fn to_struct(&self) -> DataType { + let fields = self.fields.values().cloned().collect(); + DataType::Struct(fields) + } + pub fn exclude>(&self, names: &[S]) -> DaftResult { let mut fields = IndexMap::new(); let names = names.iter().map(|s| s.as_ref()).collect::>(); diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index bcb71494b6..75a819c204 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -4,7 +4,10 @@ pub mod catalog; pub mod error; pub mod functions; mod modules; + mod planner; +pub use planner::*; + #[cfg(feature = "python")] pub mod python; mod table_provider; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index e7a1fa381c..7d5296aa60 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -21,10 +21,10 @@ use daft_functions::{ use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo, - ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField, - Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, - WildcardAdditionalOptions, With, + ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct, + ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, + Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, + Value, WildcardAdditionalOptions, With, }, dialect::GenericDialect, parser::{Parser, ParserOptions}, @@ -202,6 +202,35 @@ impl<'a> SQLPlanner<'a> { Ok(()) } + pub fn parse_column_definitions(&self, column_defs: &str) { + let tokens = Tokenizer::new(&GenericDialect, column_defs) + .tokenize() + .unwrap(); + + let mut parser = Parser::new(&GenericDialect) + .with_options(ParserOptions { + trailing_commas: true, + ..Default::default() + }) + .with_tokens(tokens); + + let o = parser.parse_comma_separated(Parser::parse_column_def); + + // Vec<(String, String)> + let outputs = o + .unwrap() + .into_iter() + .map(|cd| { + let name = cd.name.to_string(); + let data_type = cd.data_type.to_string(); + + (name, data_type) + }) + .collect::>(); + + println!("{:?}", outputs); + } + pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult { let tokens = Tokenizer::new(&GenericDialect {}, sql).tokenize()?; @@ -1262,6 +1291,28 @@ impl<'a> SQLPlanner<'a> { } } + fn column_to_field(&self, column_def: &ColumnDef) -> SQLPlannerResult { + let ColumnDef { + name, + data_type, + collation, + options, + } = column_def; + + if let Some(collation) = collation { + unsupported_sql_err!("collation operation ({collation:?}) is not supported") + } + + if !options.is_empty() { + unsupported_sql_err!("unsupported options: {options:?}") + } + + let name = ident_to_str(name); + let data_type = self.sql_dtype_to_dtype(data_type)?; + + Ok(Field::new(name, data_type)) + } + fn value_to_lit(&self, value: &Value) -> SQLPlannerResult { Ok(match value { Value::SingleQuotedString(s) => LiteralValue::Utf8(s.clone()), @@ -2114,6 +2165,32 @@ fn check_wildcard_options( Ok(()) } + +pub fn sql_schema>(s: S) -> SQLPlannerResult { + let planner = SQLPlanner::default(); + + let tokens = Tokenizer::new(&GenericDialect, s.as_ref()).tokenize()?; + + let mut parser = Parser::new(&GenericDialect) + .with_options(ParserOptions { + trailing_commas: true, + ..Default::default() + }) + .with_tokens(tokens); + + let column_defs = parser.parse_comma_separated(Parser::parse_column_def)?; + + let fields: Result, _> = column_defs + .into_iter() + .map(|c| planner.column_to_field(&c)) + .collect(); + + let fields = fields?; + + let schema = Schema::new(fields)?; + Ok(Arc::new(schema)) +} + pub fn sql_expr>(s: S) -> SQLPlannerResult { let mut planner = SQLPlanner::default(); @@ -2138,6 +2215,12 @@ pub fn sql_expr>(s: S) -> SQLPlannerResult { // ---------------- // Helper functions // ---------------- + +/// # Examples +/// ``` +/// // Quoted identifier "MyCol" -> "MyCol" +/// // Unquoted identifier MyCol -> "MyCol" +/// ``` fn ident_to_str(ident: &Ident) -> String { if ident.quote_style == Some('"') { ident.value.to_string() @@ -2190,3 +2273,52 @@ fn unresolve_alias(expr: ExprRef, projection: &[ExprRef]) -> SQLPlannerResult Date: Mon, 16 Dec 2024 14:02:28 -0800 Subject: [PATCH 2/5] add test --- src/daft-connect/src/lib.rs | 1 - src/daft-sql/src/planner.rs | 29 ----------------------------- tests/connect/test_analyze_plan.py | 3 ++- 3 files changed, 2 insertions(+), 31 deletions(-) diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 7b8432652e..369bfe8e47 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -6,7 +6,6 @@ #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] -use daft_sql::error::SQLPlannerResult; use dashmap::DashMap; use eyre::Context; #[cfg(feature = "python")] diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 7d5296aa60..683391b601 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -202,35 +202,6 @@ impl<'a> SQLPlanner<'a> { Ok(()) } - pub fn parse_column_definitions(&self, column_defs: &str) { - let tokens = Tokenizer::new(&GenericDialect, column_defs) - .tokenize() - .unwrap(); - - let mut parser = Parser::new(&GenericDialect) - .with_options(ParserOptions { - trailing_commas: true, - ..Default::default() - }) - .with_tokens(tokens); - - let o = parser.parse_comma_separated(Parser::parse_column_def); - - // Vec<(String, String)> - let outputs = o - .unwrap() - .into_iter() - .map(|cd| { - let name = cd.name.to_string(); - let data_type = cd.data_type.to_string(); - - (name, data_type) - }) - .collect::>(); - - println!("{:?}", outputs); - } - pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult { let tokens = Tokenizer::new(&GenericDialect {}, sql).tokenize()?; diff --git a/tests/connect/test_analyze_plan.py b/tests/connect/test_analyze_plan.py index 7e5d8091eb..954adc33fb 100644 --- a/tests/connect/test_analyze_plan.py +++ b/tests/connect/test_analyze_plan.py @@ -4,4 +4,5 @@ def test_analyze_plan(spark_session): data = [[1000, 99]] df1 = spark_session.createDataFrame(data, schema="Value int, Total int") - df1.collect() + s = df1.schema + assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])" From dc20be849b18e7a116df0f78f7456b97e7cacfeb Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 16 Dec 2024 15:04:56 -0800 Subject: [PATCH 3/5] add disclaimer --- tests/connect/test_analyze_plan.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/connect/test_analyze_plan.py b/tests/connect/test_analyze_plan.py index 954adc33fb..b2f50155f9 100644 --- a/tests/connect/test_analyze_plan.py +++ b/tests/connect/test_analyze_plan.py @@ -5,4 +5,7 @@ def test_analyze_plan(spark_session): data = [[1000, 99]] df1 = spark_session.createDataFrame(data, schema="Value int, Total int") s = df1.schema + + # todo: this is INCORRECT but it is an issue with pyspark client + # ideally should be StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)]) assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])" From 24dde3cf0305198635ce05feea1977557554e46b Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 17 Dec 2024 11:37:32 -0800 Subject: [PATCH 4/5] Update tests/connect/test_analyze_plan.py Co-authored-by: Cory Grinstead --- tests/connect/test_analyze_plan.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/connect/test_analyze_plan.py b/tests/connect/test_analyze_plan.py index b2f50155f9..fae5ab57dc 100644 --- a/tests/connect/test_analyze_plan.py +++ b/tests/connect/test_analyze_plan.py @@ -1,7 +1,11 @@ from __future__ import annotations +@pytest.mark.skip( + reason="Currently an issue in the spark connect code. It always passes the inferred schema instead of the supplied schema." +) def test_analyze_plan(spark_session): + data = [[1000, 99]] df1 = spark_session.createDataFrame(data, schema="Value int, Total int") s = df1.schema From eba4b4f38375e29039fb8d201fe389b6e31a0be9 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 17 Dec 2024 11:39:09 -0800 Subject: [PATCH 5/5] modify test --- tests/connect/test_analyze_plan.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/connect/test_analyze_plan.py b/tests/connect/test_analyze_plan.py index fae5ab57dc..492de7e53c 100644 --- a/tests/connect/test_analyze_plan.py +++ b/tests/connect/test_analyze_plan.py @@ -1,15 +1,18 @@ from __future__ import annotations +import pytest + @pytest.mark.skip( reason="Currently an issue in the spark connect code. It always passes the inferred schema instead of the supplied schema." ) def test_analyze_plan(spark_session): - data = [[1000, 99]] df1 = spark_session.createDataFrame(data, schema="Value int, Total int") s = df1.schema # todo: this is INCORRECT but it is an issue with pyspark client - # ideally should be StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)]) - assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])" + # right now it is assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])" + assert ( + str(s) == "StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)])" + )