Skip to content

Commit

Permalink
feat(connect): printSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent ea8f8bd commit 56e872c
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ chrono = "0.4.38"
chrono-tz = "0.10.0"
comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-display = {path = "src/common/display", default-features = false}
common-error = {path = "src/common/error", default-features = false}
common-file-formats = {path = "src/common/file-formats"}
common-runtime = {path = "src/common/runtime", default-features = false}
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
arrow2 = {workspace = true, features = ["io_json_integration"]}
async-stream = "0.3.6"
common-daft-config = {workspace = true}
common-display = {workspace = true}
common-file-formats = {workspace = true}
daft-core = {workspace = true}
daft-dsl = {workspace = true}
Expand Down
47 changes: 44 additions & 3 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#![feature(stmt_expr_attributes)]
#![feature(try_trait_v2_residual)]

use common_display::{tree::TreeDisplay, DisplayLevel};
use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use dashmap::DashMap;
use eyre::Context;
#[cfg(feature = "python")]
Expand All @@ -22,10 +24,10 @@ use spark_connect::{
ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse,
};
use tonic::{transport::Server, Request, Response, Status};
use tracing::info;
use tracing::{info, warn};
use uuid::Uuid;

use crate::session::Session;
use crate::{session::Session, translation::SparkAnalyzer};

mod config;
mod err;
Expand Down Expand Up @@ -303,7 +305,7 @@ impl SparkConnectService for DaftSparkConnectService {
return Err(Status::invalid_argument("op_type is required to be root"));
};

let result = match translation::relation_to_schema(relation).await {
let result = match translation::relation_to_spark_schema(relation).await {
Ok(schema) => schema,
Err(e) => {
return invalid_argument_err!(
Expand Down Expand Up @@ -346,6 +348,45 @@ impl SparkConnectService for DaftSparkConnectService {

Ok(Response::new(response))
}
Analyze::TreeString(TreeString { plan, level }) => {
let Some(plan) = plan else {
return invalid_argument_err!("plan is required");
};

let Some(op_type) = plan.op_type else {
return invalid_argument_err!("op_type is required");
};

let OpType::Root(input) = op_type else {
return invalid_argument_err!("op_type must be Root");
};

if let Some(common) = &input.common {
if common.origin.is_some() {
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented");
}
}

// We're just checking the schema here, so we don't need to use a persistent cache as it won't be used
let pset = InMemoryPartitionSetCache::empty();
let translator = SparkAnalyzer::new(&pset);
let plan = Box::pin(translator.to_logical_plan(input))
.await
.unwrap()
.build();

let s = plan.display_as(DisplayLevel::Default);

let response = AnalyzePlanResponse {
session_id,
server_side_session_id: String::new(),
result: Some(spark_connect::analyze_plan_response::Result::TreeString(
analyze_plan_response::TreeString { tree_string: s },
)),
};

Ok(Response::new(response))
}
other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ pub use datatype::{to_daft_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::SparkAnalyzer;
pub use schema::relation_to_schema;
pub use schema::{relation_to_daft_schema, relation_to_spark_schema};
34 changes: 21 additions & 13 deletions src/daft-connect/src/translation/schema.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use daft_schema::schema::{Schema, SchemaRef};
use spark_connect::{
data_type::{Kind, Struct, StructField},
DataType, Relation,
Expand All @@ -9,19 +10,8 @@ use super::SparkAnalyzer;
use crate::translation::to_spark_datatype;

#[tracing::instrument(skip_all)]
pub async fn relation_to_schema(input: Relation) -> eyre::Result<DataType> {
if let Some(common) = &input.common {
if common.origin.is_some() {
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented");
}
}

// We're just checking the schema here, so we don't need to use a persistent cache as it won't be used
let pset = InMemoryPartitionSetCache::empty();
let translator = SparkAnalyzer::new(&pset);
let plan = Box::pin(translator.to_logical_plan(input)).await?;

let result = plan.schema();
pub async fn relation_to_spark_schema(input: Relation) -> eyre::Result<DataType> {
let result = relation_to_daft_schema(input).await?;

let fields: eyre::Result<Vec<StructField>> = result
.fields
Expand All @@ -44,3 +34,21 @@ pub async fn relation_to_schema(input: Relation) -> eyre::Result<DataType> {
})),
})
}

#[tracing::instrument(skip_all)]
pub async fn relation_to_daft_schema(input: Relation) -> eyre::Result<SchemaRef> {
if let Some(common) = &input.common {
if common.origin.is_some() {
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented");
}
}

// We're just checking the schema here, so we don't need to use a persistent cache as it won't be used
let pset = InMemoryPartitionSetCache::empty();
let translator = SparkAnalyzer::new(&pset);
let plan = Box::pin(translator.to_logical_plan(input)).await?;

let result = plan.schema();

Ok(result)
}
13 changes: 13 additions & 0 deletions tests/connect/test_print_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations


def test_print_schema(spark_session: object, capsys: object) -> None:
df = spark_session.range(10)
df.printSchema()

captured = capsys.readouterr()
expected = (
"root\n"
" |-- id: long (nullable = true)\n"
)
assert captured.out == expected

0 comments on commit 56e872c

Please sign in to comment.