Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connect): support DdlParse #3580

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ daft-logical-plan = {path = "src/daft-logical-plan"}
daft-micropartition = {path = "src/daft-micropartition"}
daft-scan = {path = "src/daft-scan"}
daft-schema = {path = "src/daft-schema"}
daft-sql = {path = "src/daft-sql"}
daft-table = {path = "src/daft-table"}
derivative = "2.2.0"
derive_builder = "0.20.2"
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ daft-logical-plan = {workspace = true}
daft-micropartition = {workspace = true}
daft-scan = {workspace = true}
daft-schema = {workspace = true}
daft-sql = {workspace = true}
daft-table = {workspace = true}
dashmap = "6.1.0"
eyre = "0.6.12"
Expand Down
24 changes: 23 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,29 @@

Ok(Response::new(response))
}
_ => unimplemented_err!("Analyze plan operation is not yet implemented"),
Analyze::DdlParse(DdlParse { ddl_string }) => {
let daft_schema = match daft_sql::sql_schema(&ddl_string) {
Ok(daft_schema) => daft_schema,
Err(e) => return invalid_argument_err!("{e}"),

Check warning on line 329 in src/daft-connect/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/lib.rs#L329

Added line #L329 was not covered by tests
};

let daft_schema = daft_schema.to_struct();

let schema = translation::to_spark_datatype(&daft_schema);

let schema = analyze_plan_response::Schema {
schema: Some(schema),
};

let response = AnalyzePlanResponse {
session_id,
server_side_session_id: String::new(),
result: Some(analyze_plan_response::Result::Schema(schema)),
};

Ok(Response::new(response))
}
other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"),

Check warning on line 348 in src/daft-connect/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/lib.rs#L348

Added line #L348 was not covered by tests
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/daft-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use derive_more::Display;
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};

use crate::field::Field;
use crate::{field::Field, prelude::DataType};

pub type SchemaRef = Arc<Schema>;

Expand Down Expand Up @@ -48,6 +48,11 @@ impl Schema {
Ok(Self { fields: map })
}

pub fn to_struct(&self) -> DataType {
let fields = self.fields.values().cloned().collect();
DataType::Struct(fields)
}

pub fn exclude<S: AsRef<str>>(&self, names: &[S]) -> DaftResult<Self> {
let mut fields = IndexMap::new();
let names = names.iter().map(|s| s.as_ref()).collect::<HashSet<&str>>();
Expand Down
3 changes: 3 additions & 0 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ pub mod catalog;
pub mod error;
pub mod functions;
mod modules;

mod planner;
pub use planner::*;

#[cfg(feature = "python")]
pub mod python;
mod table_provider;
Expand Down
111 changes: 107 additions & 4 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField,
Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions, With,
ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct,
ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr,
Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator,
Value, WildcardAdditionalOptions, With,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -1262,6 +1262,28 @@
}
}

fn column_to_field(&self, column_def: &ColumnDef) -> SQLPlannerResult<Field> {
let ColumnDef {
name,
data_type,
collation,
options,
} = column_def;

if let Some(collation) = collation {
unsupported_sql_err!("collation operation ({collation:?}) is not supported")

Check warning on line 1274 in src/daft-sql/src/planner.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/planner.rs#L1274

Added line #L1274 was not covered by tests
}

if !options.is_empty() {
unsupported_sql_err!("unsupported options: {options:?}")

Check warning on line 1278 in src/daft-sql/src/planner.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/planner.rs#L1278

Added line #L1278 was not covered by tests
}

let name = ident_to_str(name);
let data_type = self.sql_dtype_to_dtype(data_type)?;

Ok(Field::new(name, data_type))
}

fn value_to_lit(&self, value: &Value) -> SQLPlannerResult<LiteralValue> {
Ok(match value {
Value::SingleQuotedString(s) => LiteralValue::Utf8(s.clone()),
Expand Down Expand Up @@ -2114,6 +2136,32 @@

Ok(())
}

pub fn sql_schema<S: AsRef<str>>(s: S) -> SQLPlannerResult<SchemaRef> {
let planner = SQLPlanner::default();

let tokens = Tokenizer::new(&GenericDialect, s.as_ref()).tokenize()?;

let mut parser = Parser::new(&GenericDialect)
.with_options(ParserOptions {
trailing_commas: true,
..Default::default()
})
.with_tokens(tokens);

let column_defs = parser.parse_comma_separated(Parser::parse_column_def)?;

let fields: Result<Vec<_>, _> = column_defs
.into_iter()
.map(|c| planner.column_to_field(&c))
.collect();

let fields = fields?;

let schema = Schema::new(fields)?;
Ok(Arc::new(schema))
}

pub fn sql_expr<S: AsRef<str>>(s: S) -> SQLPlannerResult<ExprRef> {
let mut planner = SQLPlanner::default();

Expand All @@ -2138,6 +2186,12 @@
// ----------------
// Helper functions
// ----------------

/// # Examples
/// ```
/// // Quoted identifier "MyCol" -> "MyCol"
/// // Unquoted identifier MyCol -> "MyCol"
/// ```
fn ident_to_str(ident: &Ident) -> String {
if ident.quote_style == Some('"') {
ident.value.to_string()
Expand Down Expand Up @@ -2190,3 +2244,52 @@
})
.ok_or_else(|| PlannerError::column_not_found(expr.name(), "projection"))
}

#[cfg(test)]
mod tests {
use daft_core::prelude::*;

use crate::sql_schema;

#[test]
fn test_sql_schema_creates_expected_schema() {
let result =
sql_schema("Year int, First_Name STRING, County STRING, Sex STRING, Count int")
.unwrap();

let expected = Schema::new(vec![
Field::new("Year", DataType::Int32),
Field::new("First_Name", DataType::Utf8),
Field::new("County", DataType::Utf8),
Field::new("Sex", DataType::Utf8),
Field::new("Count", DataType::Int32),
])
.unwrap();

assert_eq!(&*result, &expected);
}

#[test]
fn test_duplicate_column_names_in_schema() {
// This test checks that sql_schema fails or handles duplicates gracefully.
// The planner currently returns errors if schema construction fails, so we expect an Err here.
let result = sql_schema("col1 INT, col1 STRING");

assert_eq!(
result.unwrap_err().to_string(),
"Daft error: DaftError::ValueError Attempting to make a Schema with duplicate field names: col1"
);
}

#[test]
fn test_degenerate_empty_schema() {
assert!(sql_schema("").is_err());
}

#[test]
fn test_single_field_schema() {
let result = sql_schema("col1 INT").unwrap();
let expected = Schema::new(vec![Field::new("col1", DataType::Int32)]).unwrap();
assert_eq!(&*result, &expected);
}
}
11 changes: 11 additions & 0 deletions tests/connect/test_analyze_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations


def test_analyze_plan(spark_session):
andrewgazelka marked this conversation as resolved.
Show resolved Hide resolved
data = [[1000, 99]]
df1 = spark_session.createDataFrame(data, schema="Value int, Total int")
s = df1.schema

# todo: this is INCORRECT but it is an issue with pyspark client
# ideally should be StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)])
assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])"
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
Loading