diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml new file mode 100644 index 000000000000..b8c0e56d2566 --- /dev/null +++ b/datafusion/substrait/Cargo.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-substrait" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-recursion = "1.0" +datafusion = "13.0" +prost = "0.9" +prost-types = "0.9" +substrait = "0.2" +tokio = "1.17" + +[build-dependencies] +prost-build = { version = "0.9" } diff --git a/datafusion/substrait/README.md b/datafusion/substrait/README.md new file mode 100644 index 000000000000..9f21d514ab82 --- /dev/null +++ b/datafusion/substrait/README.md @@ -0,0 +1,34 @@ + + +# DataFusion + Substrait + +[Substrait](https://substrait.io/) provides a cross-language serialization format for relational algebra, based on +protocol buffers. + +This repository provides a Substrait producer and consumer for DataFusion: + +- The producer converts a DataFusion logical plan into a Substrait protobuf. +- The consumer converts a Substrait protobuf into a DataFusion logical plan. + +Potential uses of this crate: + +- Replace the current [DataFusion protobuf definition](https://github.com/apache/arrow-datafusion/blob/master/datafusion-proto/proto/datafusion.proto) used in Ballista for passing query plan fragments to executors +- Make it easier to pass query plans over FFI boundaries, such as from Python to Rust +- Allow Apache Calcite query plans to be executed in DataFusion diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs new file mode 100644 index 000000000000..c747a30a6bec --- /dev/null +++ b/datafusion/substrait/src/consumer.rs @@ -0,0 +1,544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_recursion::async_recursion; +use datafusion::common::{DFField, DFSchema, DFSchemaRef}; +use datafusion::logical_expr::{LogicalPlan, aggregate_function}; +use datafusion::logical_plan::build_join_schema; +use datafusion::prelude::JoinType; +use datafusion::{ + error::{DataFusionError, Result}, + logical_plan::{Expr, Operator}, + optimizer::utils::split_conjunction, + prelude::{Column, DataFrame, SessionContext}, + scalar::ScalarValue, +}; + +use datafusion::sql::TableReference; +use substrait::protobuf::{ + aggregate_function::AggregationInvocation, + expression::{ + field_reference::ReferenceType::DirectReference, + literal::LiteralType, + MaskExpression, + reference_segment::ReferenceType::StructField, + RexType, + }, + extensions::simple_extension_declaration::MappingType, + function_argument::ArgType, + read_rel::ReadType, + rel::RelType, + sort_field::{SortKind::*, SortDirection}, + AggregateFunction, Expression, Plan, Rel, +}; + +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; + +pub fn name_to_op(name: &str) -> Result { + match name { + "equal" => Ok(Operator::Eq), + "not_equal" => Ok(Operator::NotEq), + "lt" => Ok(Operator::Lt), + "lte" => Ok(Operator::LtEq), + "gt" => Ok(Operator::Gt), + "gte" => Ok(Operator::GtEq), + "add" => Ok(Operator::Plus), + "subtract" => Ok(Operator::Minus), + "multiply" => Ok(Operator::Multiply), + "divide" => Ok(Operator::Divide), + "mod" => Ok(Operator::Modulo), + "and" => Ok(Operator::And), + "or" => Ok(Operator::Or), + "like" => Ok(Operator::Like), + "not_like" => Ok(Operator::NotLike), + "is_distinct_from" => Ok(Operator::IsDistinctFrom), + "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), + "regex_match" => Ok(Operator::RegexMatch), + "regex_imatch" => Ok(Operator::RegexIMatch), + "regex_not_match" => Ok(Operator::RegexNotMatch), + "regex_not_imatch" => Ok(Operator::RegexNotIMatch), + "bitwise_and" => Ok(Operator::BitwiseAnd), + "bitwise_or" => Ok(Operator::BitwiseOr), + "str_concat" => Ok(Operator::StringConcat), + "bitwise_xor" => Ok(Operator::BitwiseXor), + "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported function name: {:?}", + name + ))), + } +} + +/// Convert Substrait Plan to DataFusion DataFrame +pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result> { + // Register function extension + let function_extension = plan.extensions + .iter() + .map(|e| match &e.mapping_type { + Some(ext) => { + match ext { + MappingType::ExtensionFunction(ext_f) => Ok((ext_f.function_anchor, &ext_f.name)), + _ => Err(DataFusionError::NotImplemented(format!("Extension type not supported: {:?}", ext))) + } + } + None => Err(DataFusionError::NotImplemented("Cannot parse empty extension".to_string())) + }) + .collect::>>()?; + // Parse relations + match plan.relations.len() { + 1 => { + match plan.relations[0].rel_type.as_ref() { + Some(rt) => match rt { + substrait::protobuf::plan_rel::RelType::Rel(rel) => { + Ok(from_substrait_rel(ctx, &rel, &function_extension).await?) + }, + substrait::protobuf::plan_rel::RelType::Root(root) => { + Ok(from_substrait_rel(ctx, &root.input.as_ref().unwrap(), &function_extension).await?) + } + }, + None => Err(DataFusionError::Internal("Cannot parse plan relation: None".to_string())) + } + + }, + _ => Err(DataFusionError::NotImplemented(format!( + "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", + plan.relations.len() + ))) + } +} + +/// Convert Substrait Rel to DataFusion DataFrame +#[async_recursion] +pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: &HashMap) -> Result> { + match &rel.rel_type { + Some(RelType::Project(p)) => { + if let Some(input) = p.input.as_ref() { + let input = from_substrait_rel(ctx, input, extensions).await?; + let mut exprs: Vec = vec![]; + for e in &p.expressions { + let x = from_substrait_rex(e, &input.schema(), extensions).await?; + exprs.push(x.as_ref().clone()); + } + input.select(exprs) + } else { + Err(DataFusionError::NotImplemented( + "Projection without an input is not supported".to_string(), + )) + } + } + Some(RelType::Filter(filter)) => { + if let Some(input) = filter.input.as_ref() { + let input = from_substrait_rel(ctx, input, extensions).await?; + if let Some(condition) = filter.condition.as_ref() { + let expr = from_substrait_rex(condition, &input.schema(), extensions).await?; + input.filter(expr.as_ref().clone()) + } else { + Err(DataFusionError::NotImplemented( + "Filter without an condition is not valid".to_string(), + )) + } + } else { + Err(DataFusionError::NotImplemented( + "Filter without an input is not valid".to_string(), + )) + } + } + Some(RelType::Fetch(fetch)) => { + if let Some(input) = fetch.input.as_ref() { + let input = from_substrait_rel(ctx, input, extensions).await?; + let offset = fetch.offset as usize; + let count = fetch.count as usize; + input.limit(offset, Some(count)) + } else { + Err(DataFusionError::NotImplemented( + "Fetch without an input is not valid".to_string(), + )) + } + } + Some(RelType::Sort(sort)) => { + if let Some(input) = sort.input.as_ref() { + let input = from_substrait_rel(ctx, input, extensions).await?; + let mut sorts: Vec = vec![]; + for s in &sort.sorts { + let expr = from_substrait_rex(&s.expr.as_ref().unwrap(), &input.schema(), extensions).await?; + let asc_nullfirst = match &s.sort_kind { + Some(k) => match k { + Direction(d) => { + let direction : SortDirection = unsafe { + ::std::mem::transmute(*d) + }; + match direction { + SortDirection::AscNullsFirst => Ok((true, true)), + SortDirection::AscNullsLast => Ok((true, false)), + SortDirection::DescNullsFirst => Ok((false, true)), + SortDirection::DescNullsLast => Ok((false, false)), + SortDirection::Clustered => { + Err(DataFusionError::NotImplemented( + "Sort with direction clustered is not yet supported".to_string(), + )) + }, + SortDirection::Unspecified => { + Err(DataFusionError::NotImplemented( + "Unspecified sort direction is invalid".to_string(), + )) + } + } + } + ComparisonFunctionReference(_) => { + Err(DataFusionError::NotImplemented( + "Sort using comparison function reference is not supported".to_string(), + )) + }, + }, + None => { + Err(DataFusionError::NotImplemented( + "Sort without sort kind is invalid".to_string(), + )) + }, + }; + let (asc, nulls_first) = asc_nullfirst.unwrap(); + sorts.push(Expr::Sort { expr: Box::new(expr.as_ref().clone()), asc: asc, nulls_first: nulls_first }); + } + input.sort(sorts) + } else { + Err(DataFusionError::NotImplemented( + "Sort without an input is not valid".to_string(), + )) + } + } + Some(RelType::Aggregate(agg)) => { + if let Some(input) = agg.input.as_ref() { + let input = from_substrait_rel(ctx, input, extensions).await?; + let mut group_expr = vec![]; + let mut aggr_expr = vec![]; + + let groupings = match agg.groupings.len() { + 1 => { Ok(&agg.groupings[0]) }, + _ => { + Err(DataFusionError::NotImplemented( + "Aggregate with multiple grouping sets is not supported".to_string(), + )) + } + }; + + for e in &groupings?.grouping_expressions { + let x = from_substrait_rex(&e, &input.schema(), extensions).await?; + group_expr.push(x.as_ref().clone()); + } + + for m in &agg.measures { + let filter = match &m.filter { + Some(fil) => Some(Box::new(from_substrait_rex(fil, &input.schema(), extensions).await?.as_ref().clone())), + None => None + }; + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => true, + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false + }; + from_substrait_agg_func(&f, &input.schema(), extensions, filter, distinct).await + }, + None => Err(DataFusionError::NotImplemented( + "Aggregate without aggregate function is not supported".to_string(), + )), + }; + aggr_expr.push(agg_func?.as_ref().clone()); + } + + input.aggregate(group_expr, aggr_expr) + } else { + Err(DataFusionError::NotImplemented( + "Aggregate without an input is not valid".to_string(), + )) + } + } + Some(RelType::Join(join)) => { + let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?; + let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(), extensions).await?; + let join_type = match join.r#type { + 1 => JoinType::Inner, + 2 => JoinType::Left, + 3 => JoinType::Right, + 4 => JoinType::Full, + 5 => JoinType::Anti, + 6 => JoinType::Semi, + _ => return Err(DataFusionError::Internal("invalid join type".to_string())), + }; + let mut predicates = vec![]; + let schema = build_join_schema(&left.schema(), &right.schema(), &JoinType::Inner)?; + let on = from_substrait_rex(&join.expression.as_ref().unwrap(), &schema, extensions).await?; + split_conjunction(&on, &mut predicates); + let pairs = predicates + .iter() + .map(|p| match p { + Expr::BinaryExpr { + left, + op: Operator::Eq, + right, + } => match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => Ok((l.flat_name(), r.flat_name())), + _ => { + return Err(DataFusionError::Internal( + "invalid join condition".to_string(), + )) + } + }, + _ => { + return Err(DataFusionError::Internal( + "invalid join condition".to_string(), + )) + } + }) + .collect::>>()?; + let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| l.as_str()).collect(); + let right_cols: Vec<&str> = pairs.iter().map(|(_, r)| r.as_str()).collect(); + left.join(right, join_type, &left_cols, &right_cols, None) + } + Some(RelType::Read(read)) => match &read.as_ref().read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return Err(DataFusionError::Internal( + "No table name found in NamedTable".to_string(), + )); + } + 1 => TableReference::Bare { + table: &nt.names[0], + }, + 2 => TableReference::Partial { + schema: &nt.names[0], + table: &nt.names[1], + }, + _ => TableReference::Full { + catalog: &nt.names[0], + schema: &nt.names[1], + table: &nt.names[2], + }, + }; + let t = ctx.table(table_reference)?; + match &read.projection { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + match t.to_logical_plan()? { + LogicalPlan::TableScan(scan) => { + let mut scan = scan.clone(); + let fields: Vec = column_indices + .iter() + .map(|i| scan.projected_schema.field(*i).clone()) + .collect(); + scan.projection = Some(column_indices); + scan.projected_schema = DFSchemaRef::new( + DFSchema::new_with_metadata(fields, HashMap::new())?, + ); + let plan = LogicalPlan::TableScan(scan); + Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) + } + _ => Err(DataFusionError::Internal( + "unexpected plan for table".to_string(), + )), + } + } + _ => Ok(t), + }, + _ => Ok(t), + } + } + _ => Err(DataFusionError::NotImplemented( + "Only NamedTable reads are supported".to_string(), + )), + }, + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported RelType: {:?}", + rel.rel_type + ))), + } +} + +/// Convert Substrait AggregateFunction to DataFusion Expr +pub async fn from_substrait_agg_func( + f: &AggregateFunction, + input_schema: &DFSchema, + extensions: &HashMap, + filter: Option>, + distinct: bool +) -> Result> { + let mut args: Vec = vec![]; + for arg in &f.arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => from_substrait_rex(e, input_schema, extensions).await, + _ => Err(DataFusionError::NotImplemented( + "Aggregated function argument non-Value type not supported".to_string(), + )) + }; + args.push(arg_expr?.as_ref().clone()); + } + + let fun = match extensions.get(&f.function_reference) { + Some(function_name) => aggregate_function::AggregateFunction::from_str(function_name), + None => Err(DataFusionError::NotImplemented(format!( + "Aggregated function not found: function anchor = {:?}", + f.function_reference + ) + )) + }; + + Ok( + Arc::new( + Expr::AggregateFunction { + fun: fun.unwrap(), + args: args, + distinct: distinct, + filter: filter + } + ) + ) +} + +/// Convert Substrait Rex to DataFusion Expr +#[async_recursion] +pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensions: &HashMap) -> Result> { + match &e.rex_type { + Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => Err(DataFusionError::NotImplemented( + "Direct reference StructField with child is not supported".to_string(), + )), + None => Ok(Arc::new(Expr::Column(Column { + relation: None, + name: input_schema + .field(x.field as usize) + .name() + .to_string(), + }))), + }, + _ => Err(DataFusionError::NotImplemented( + "Direct reference with types other than StructField is not supported".to_string(), + )), + }, + _ => Err(DataFusionError::NotImplemented( + "unsupported field ref type".to_string(), + )), + }, + Some(RexType::IfThen(if_then)) => { + // Parse `ifs` + // If the first element does not have a `then` part, then we can assume it's a base expression + let mut when_then_expr: Vec<(Box, Box)> = vec![]; + let mut expr = None; + for (i, if_expr) in if_then.ifs.iter().enumerate() { + if i == 0 { + // Check if the first element is type base expression + if if_expr.then.is_none() { + expr = Some(Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone())); + continue; + } + } + when_then_expr.push( + ( + Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone()), + Box::new(from_substrait_rex(&if_expr.then.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone()) + ), + ); + } + // Parse `else` + let else_expr = match &if_then.r#else { + Some(e) => Some(Box::new( + from_substrait_rex(&e, input_schema, extensions).await?.as_ref().clone(), + )), + None => None + }; + Ok(Arc::new(Expr::Case { expr: expr, when_then_expr: when_then_expr, else_expr: else_expr })) + }, + Some(RexType::ScalarFunction(f)) => { + assert!(f.arguments.len() == 2); + let op = match extensions.get(&f.function_reference) { + Some(fname) => name_to_op(fname), + None => Err(DataFusionError::NotImplemented(format!( + "Aggregated function not found: function reference = {:?}", + f.function_reference + ) + )) + }; + match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) { + (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { + Ok(Arc::new(Expr::BinaryExpr { + left: Box::new(from_substrait_rex(l, input_schema, extensions).await?.as_ref().clone()), + op: op?, + right: Box::new( + from_substrait_rex(r, input_schema, extensions).await?.as_ref().clone(), + ), + })) + } + (l, r) => Err(DataFusionError::NotImplemented(format!( + "Invalid arguments for binary expression: {:?} and {:?}", + l, r + ))), + } + } + Some(RexType::Literal(lit)) => match &lit.literal_type { + Some(LiteralType::I8(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8))))) + } + Some(LiteralType::I16(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16))))) + } + Some(LiteralType::I32(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n as i32))))) + } + Some(LiteralType::I64(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n as i64))))) + } + Some(LiteralType::Boolean(b)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b))))) + } + Some(LiteralType::Date(d)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d))))) + } + Some(LiteralType::Fp32(f)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f))))) + } + Some(LiteralType::Fp64(f)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f))))) + } + Some(LiteralType::String(s)) => Ok(Arc::new(Expr::Literal(ScalarValue::Utf8( + Some(s.clone()), + )))), + Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some( + b.clone(), + ))))), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported literal_type: {:?}", + lit.literal_type + ))) + } + }, + _ => Err(DataFusionError::NotImplemented( + "unsupported rex_type".to_string(), + )), + } +} diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs new file mode 100644 index 000000000000..07b8e7addeac --- /dev/null +++ b/datafusion/substrait/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod consumer; +pub mod producer; +pub mod serializer; diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs new file mode 100644 index 000000000000..78532046bced --- /dev/null +++ b/datafusion/substrait/src/producer.rs @@ -0,0 +1,557 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use datafusion::{ + error::{DataFusionError, Result}, + logical_plan::{DFSchemaRef, Expr, JoinConstraint, LogicalPlan, Operator}, + prelude::JoinType, + scalar::ScalarValue, +}; + +use substrait::protobuf::{ + aggregate_function::AggregationInvocation, + aggregate_rel::{Grouping, Measure}, + expression::{ + field_reference::ReferenceType, + if_then::IfClause, + literal::LiteralType, + mask_expression::{StructItem, StructSelect}, + reference_segment, + FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, ScalarFunction, + }, + extensions::{self, simple_extension_declaration::{MappingType, ExtensionFunction}}, + function_argument::ArgType, + plan_rel, + read_rel::{NamedTable, ReadType}, + rel::RelType, + sort_field::{ + SortDirection, + SortKind, + }, + AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, SortField, SortRel, + PlanRel, + Plan, Rel, RelRoot, AggregateFunction, +}; + +/// Convert DataFusion LogicalPlan to Substrait Plan +pub fn to_substrait_plan(plan: &LogicalPlan) -> Result> { + // Parse relation nodes + let mut extension_info: (Vec, HashMap) = (vec![], HashMap::new()); + // Generate PlanRel(s) + // Note: Only 1 relation tree is currently supported + let plan_rels = vec![PlanRel { + rel_type: Some(plan_rel::RelType::Root( + RelRoot { + input: Some(*to_substrait_rel(plan, &mut extension_info)?), + names: plan.schema().field_names(), + } + )) + }]; + + let (function_extensions, _) = extension_info; + + // Return parsed plan + Ok(Box::new(Plan { + extension_uris: vec![], + extensions: function_extensions, + relations: plan_rels, + advanced_extensions: None, + expected_type_urls: vec![], + })) + +} + +/// Convert DataFusion LogicalPlan to Substrait Rel +pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec, HashMap)) -> Result> { + match plan { + LogicalPlan::TableScan(scan) => { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + if let Some(struct_items) = projection { + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(NamedStruct { + names: scan + .projected_schema + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect(), + r#struct: None, + }), + filter: None, + projection: Some(MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }), + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: vec![scan.table_name.clone()], + advanced_extension: None, + })), + }))), + })) + } else { + Err(DataFusionError::NotImplemented( + "TableScan without projection is not supported".to_string(), + )) + } + } + LogicalPlan::Projection(p) => { + let expressions = p + .expr + .iter() + .map(|e| to_substrait_rex(e, p.input.schema(), extension_info)) + .collect::>>()?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: None, + input: Some(to_substrait_rel(p.input.as_ref(), extension_info)?), + expressions, + advanced_extension: None, + }))), + })) + } + LogicalPlan::Filter(filter) => { + let input = to_substrait_rel(filter.input.as_ref(), extension_info)?; + let filter_expr = to_substrait_rex(&filter.predicate, filter.input.schema(), extension_info)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + common: None, + input: Some(input), + condition: Some(Box::new(filter_expr)), + advanced_extension: None, + }))), + })) + } + LogicalPlan::Limit(limit) => { + let input = to_substrait_rel(limit.input.as_ref(), extension_info)?; + let limit_fetch = match limit.fetch { + Some(count) => count, + None => 0, + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(input), + offset: limit.skip as i64, + count: limit_fetch as i64, + advanced_extension: None, + }))), + })) + } + LogicalPlan::Sort(sort) => { + let input = to_substrait_rel(sort.input.as_ref(), extension_info)?; + let sort_fields = sort + .expr + .iter() + .map(|e| substrait_sort_field(e, sort.input.schema(), extension_info)) + .collect::>>()?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + common: None, + input: Some(input), + sorts: sort_fields, + advanced_extension: None, + }))), + })) + } + LogicalPlan::Aggregate(agg) => { + let input = to_substrait_rel(agg.input.as_ref(), extension_info)?; + // Translate aggregate expression to Substrait's groupings (repeated repeated Expression) + let grouping = agg + .group_expr + .iter() + .map(|e| to_substrait_rex(e, agg.input.schema(), extension_info)) + .collect::>>()?; + let measures = agg + .aggr_expr + .iter() + .map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info)) + .collect::>>()?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + groupings: vec![Grouping { grouping_expressions: grouping }], //groupings, + measures: measures, + advanced_extension: None, + }))), + })) + } + LogicalPlan::Distinct(distinct) => { + // Use Substrait's AggregateRel with empty measures to represent `select distinct` + let input = to_substrait_rel(distinct.input.as_ref(), extension_info)?; + // Get grouping keys from the input relation's number of output fields + let grouping = (0..distinct.input.schema().fields().len()) + .map(|x: usize| substrait_field_ref(x)) + .collect::>>()?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + groupings: vec![Grouping { grouping_expressions: grouping }], + measures: vec![], + advanced_extension: None, + }))), + })) + } + LogicalPlan::Join(join) => { + let left = to_substrait_rel(join.left.as_ref(), extension_info)?; + let right = to_substrait_rel(join.right.as_ref(), extension_info)?; + let join_type = match join.join_type { + JoinType::Inner => 1, + JoinType::Left => 2, + JoinType::Right => 3, + JoinType::Full => 4, + JoinType::Anti => 5, + JoinType::Semi => 6, + }; + // we only support basic joins so return an error for anything not yet supported + if join.null_equals_null { + return Err(DataFusionError::NotImplemented( + "join null_equals_null".to_string(), + )); + } + if join.filter.is_some() { + return Err(DataFusionError::NotImplemented("join filter".to_string())); + } + match join.join_constraint { + JoinConstraint::On => {} + _ => { + return Err(DataFusionError::NotImplemented( + "join constraint".to_string(), + )) + } + } + // map the left and right columns to binary expressions in the form `l = r` + let join_expression: Vec = join + .on + .iter() + .map(|(l, r)| Expr::Column(l.clone()).eq(Expr::Column(r.clone()))) + .collect(); + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let join_expression = join_expression + .into_iter() + .reduce(|acc: Expr, expr: Expr| acc.and(expr)); + if let Some(e) = join_expression { + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type, + expression: Some(Box::new(to_substrait_rex(&e, &join.schema, extension_info)?)), + post_join_filter: None, + advanced_extension: None, + }))), + })) + } else { + Err(DataFusionError::NotImplemented( + "Empty join condition".to_string(), + )) + } + } + LogicalPlan::SubqueryAlias(alias) => { + // Do nothing if encounters SubqueryAlias + // since there is no corresponding relation type in Substrait + to_substrait_rel(alias.input.as_ref(), extension_info) + } + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported operator: {:?}", + plan + ))), + } +} + +pub fn operator_to_name(op: Operator) -> &'static str { + match op { + Operator::Eq => "equal", + Operator::NotEq => "not_equal", + Operator::Lt => "lt", + Operator::LtEq => "lte", + Operator::Gt => "gt", + Operator::GtEq => "gte", + Operator::Plus => "add", + Operator::Minus => "substract", + Operator::Multiply => "multiply", + Operator::Divide => "divide", + Operator::Modulo => "mod", + Operator::And => "and", + Operator::Or => "or", + Operator::Like => "like", + Operator::NotLike => "not_like", + Operator::IsDistinctFrom => "is_distinct_from", + Operator::IsNotDistinctFrom => "is_not_distinct_from", + Operator::RegexMatch => "regex_match", + Operator::RegexIMatch => "regex_imatch", + Operator::RegexNotMatch => "regex_not_match", + Operator::RegexNotIMatch => "regex_not_imatch", + Operator::BitwiseAnd => "bitwise_and", + Operator::BitwiseOr => "bitwise_or", + Operator::StringConcat => "str_concat", + Operator::BitwiseXor => "bitwise_xor", + Operator::BitwiseShiftRight => "bitwise_shift_right", + Operator::BitwiseShiftLeft => "bitwise_shift_left", + } +} + +pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { + match expr { + Expr::AggregateFunction { fun, args, distinct, filter } => { + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); + } + let function_name = fun.to_string().to_lowercase(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments: arguments, + sorts: vec![], + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: substrait::protobuf::AggregationPhase::Unspecified as i32, + args: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, extension_info)?), + None => None + } + }) + }, + _ => Err(DataFusionError::Internal(format!( + "Expression must be compatible with aggregation. Unsupported expression: {:?}", + expr + ))), + } +} + +fn _register_function(function_name: String, extension_info: &mut (Vec, HashMap)) -> u32 { + let (function_extensions, function_set) = extension_info; + let function_name = function_name.to_lowercase(); + // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, + // a plan-relative identifier starting from 0 is used as the function_anchor. + // The consumer is responsible for correctly registering + // mapping info stored in the extensions by the producer. + let function_anchor = match function_set.get(&function_name) { + Some(function_anchor) => { + // Function has been registered + *function_anchor + }, + None => { + // Function has NOT been registered + let function_anchor = function_set.len() as u32; + function_set.insert(function_name.clone(), function_anchor); + + let function_extension = ExtensionFunction { + extension_uri_reference: u32::MAX, + function_anchor: function_anchor, + name: function_name, + }; + let simple_extension = extensions::SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(function_extension)), + }; + function_extensions.push(simple_extension); + function_anchor + } + }; + + // Return function anchor + function_anchor + +} + +/// Return Substrait scalar function with two arguments +pub fn make_binary_op_scalar_func(lhs: &Expression, rhs: &Expression, op: Operator, extension_info: &mut (Vec, HashMap)) -> Expression { + let function_name = operator_to_name(op).to_string().to_lowercase(); + let function_anchor = _register_function(function_name, extension_info); + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(lhs.clone())), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(rhs.clone())), + }, + ], + output_type: None, + args: vec![], + })), + } +} + +/// Convert DataFusion Expr to Substrait Rex +pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { + match expr { + Expr::Between { expr, negated, low, high } => { + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; + let substrait_low = to_substrait_rex(low, schema, extension_info)?; + let substrait_high = to_substrait_rex(high, schema, extension_info)?; + + let l_expr = make_binary_op_scalar_func(&substrait_expr, &substrait_low, Operator::Lt, extension_info); + let r_expr = make_binary_op_scalar_func(&substrait_high, &substrait_expr, Operator::Lt, extension_info); + + Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::Or, extension_info)) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; + let substrait_low = to_substrait_rex(low, schema, extension_info)?; + let substrait_high = to_substrait_rex(high, schema, extension_info)?; + + let l_expr = make_binary_op_scalar_func(&substrait_low, &substrait_expr, Operator::LtEq, extension_info); + let r_expr = make_binary_op_scalar_func(&substrait_expr, &substrait_high, Operator::LtEq, extension_info); + + Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::And, extension_info)) + } + } + Expr::Column(col) => { + let index = schema.index_of_column(&col)?; + substrait_field_ref(index) + } + Expr::BinaryExpr { left, op, right } => { + let l = to_substrait_rex(left, schema, extension_info)?; + let r = to_substrait_rex(right, schema, extension_info)?; + + Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) + } + Expr::Case { expr, when_then_expr, else_expr } => { + let mut ifs: Vec = vec![]; + // Parse base + if let Some(e) = expr { // Base expression exists + ifs.push(IfClause { + r#if: Some(to_substrait_rex(e, schema, extension_info)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(to_substrait_rex(r#if, schema, extension_info)?), + then: Some(to_substrait_rex(then, schema, extension_info)?), + }); + } + + // Parse outer `else` + let r#else: Option> = match else_expr { + Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), + None => None, + }; + + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { + ifs: ifs, + r#else: r#else + }))), + }) + } + Expr::Literal(value) => { + let literal_type = match value { + ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)), + ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)), + ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)), + ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)), + ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)), + ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)), + ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)), + ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())), + ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())), + ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())), + ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())), + ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported literal: {:?}", + value + ))) + } + }; + Ok(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: true, + type_variation_reference: 0, + literal_type, + })), + }) + } + Expr::Alias(expr, _alias) => { + to_substrait_rex(expr, schema, extension_info) + } + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported expression: {:?}", + expr + ))), + } +} + +fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { + match expr { + Expr::Sort { expr, asc, nulls_first } => { + let e = to_substrait_rex(expr, schema, extension_info)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) + }, + _ => Err(DataFusionError::NotImplemented(format!( + "Expecting sort expression but got {:?}", + expr + ))), + } +} + +fn substrait_field_ref(index: usize) -> Result { + Ok(Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: None, + }))), + }) +} diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs new file mode 100644 index 000000000000..7f52077f1be9 --- /dev/null +++ b/datafusion/substrait/src/serializer.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::producer; + +use datafusion::error::Result; +use datafusion::prelude::*; + +use prost::Message; +use substrait::protobuf::Plan; + +use std::fs::OpenOptions; +use std::io::{Write, Read}; + +pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { + let df = ctx.sql(sql).await?; + let plan = df.to_logical_plan()?; + let proto = producer::to_substrait_plan(&plan)?; + + let mut protobuf_out = Vec::::new(); + proto.encode(&mut protobuf_out).unwrap(); + let mut file = OpenOptions::new() + .create(true) + .write(true) + .open(path)?; + file.write_all(&protobuf_out)?; + Ok(()) +} + +pub async fn deserialize(path: &str) -> Result> { + let mut protobuf_in = Vec::::new(); + + let mut file = OpenOptions::new() + .read(true) + .open(path)?; + + file.read_to_end(&mut protobuf_in)?; + let proto = Message::decode(&*protobuf_in).unwrap(); + + Ok(Box::new(proto)) +} + + diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs new file mode 100644 index 000000000000..21a3a5f291a3 --- /dev/null +++ b/datafusion/substrait/tests/roundtrip.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_substrait::consumer; +use datafusion_substrait::producer; + +#[cfg(test)] +mod tests { + + use crate::{consumer::from_substrait_plan, producer::to_substrait_plan}; + use datafusion::error::Result; + use datafusion::prelude::*; + use substrait::protobuf::extensions::simple_extension_declaration::MappingType; + + #[tokio::test] + async fn simple_select() -> Result<()> { + roundtrip("SELECT a, b FROM data").await + } + + #[tokio::test] + async fn wildcard_select() -> Result<()> { + roundtrip("SELECT * FROM data").await + } + + #[tokio::test] + async fn select_with_filter() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1").await + } + + #[tokio::test] + async fn select_with_reused_functions() -> Result<()> { + let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; + roundtrip(sql).await?; + let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; + function_names.sort(); + function_anchors.sort(); + + assert_eq!(function_names, ["and", "gt", "lt"]); + assert_eq!(function_anchors, [0, 1, 2]); + + Ok(()) + } + + #[tokio::test] + async fn select_with_filter_date() -> Result<()> { + roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await + } + + #[tokio::test] + async fn select_with_filter_bool_expr() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d AND a > 1").await + } + + #[tokio::test] + async fn select_with_limit() -> Result<()> { + roundtrip_fill_na("SELECT * FROM data LIMIT 100").await + } + + #[tokio::test] + async fn select_with_limit_offset() -> Result<()> { + roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await + } + + #[tokio::test] + async fn simple_aggregate() -> Result<()> { + roundtrip("SELECT a, sum(b) FROM data GROUP BY a").await + } + + #[tokio::test] + async fn aggregate_distinct_with_having() -> Result<()> { + roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100").await + } + + #[tokio::test] + async fn aggregate_multiple_keys() -> Result<()> { + roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await + } + + #[tokio::test] + async fn simple_distinct() -> Result<()> { + test_alias( + "SELECT * FROM (SELECT distinct a FROM data)", // `SELECT *` is used to add `projection` at the root + "SELECT a FROM data GROUP BY a", + ).await + } + + #[tokio::test] + async fn select_distinct_two_fields() -> Result<()> { + test_alias( + "SELECT * FROM (SELECT distinct a, b FROM data)", // `SELECT *` is used to add `projection` at the root + "SELECT a, b FROM data GROUP BY a, b", + ).await + } + + #[tokio::test] + async fn simple_alias() -> Result<()> { + test_alias( + "SELECT d1.a, d1.b FROM data d1", + "SELECT a, b FROM data", + ).await + } + + #[tokio::test] + async fn two_table_alias() -> Result<()> { + test_alias( + "SELECT d1.a FROM data d1 JOIN data2 d2 ON d1.a = d2.a", + "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", + ) + .await + } + + #[tokio::test] + async fn between_integers() -> Result<()> { + test_alias( + "SELECT * FROM data WHERE a BETWEEN 2 AND 6", + "SELECT * FROM data WHERE a >= 2 AND a <= 6" + ) + .await + } + + #[tokio::test] + async fn not_between_integers() -> Result<()> { + test_alias( + "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6", + "SELECT * FROM data WHERE a < 2 OR a > 6" + ) + .await + } + + #[tokio::test] + async fn case_without_base_expression() -> Result<()> { + roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data").await + } + + #[tokio::test] + async fn case_with_base_expression() -> Result<()> { + roundtrip("SELECT (CASE a + WHEN 0 THEN 'zero' + WHEN 1 THEN 'one' + ELSE 'other' + END) FROM data").await + } + + #[tokio::test] + async fn roundtrip_inner_join() -> Result<()> { + roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await + } + + #[tokio::test] + async fn inner_join() -> Result<()> { + assert_expected_plan( + "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", + "Projection: data.a\ + \n Inner Join: data.a = data2.a\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]", + ) + .await + } + + async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { + let mut ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.to_logical_plan()?; + let proto = to_substrait_plan(&plan)?; + let df = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = df.to_logical_plan()?; + let plan2str = format!("{:?}", plan2); + assert_eq!(expected_plan_str, &plan2str); + Ok(()) + } + + async fn roundtrip_fill_na(sql: &str) -> Result<()> { + let mut ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan1 = df.to_logical_plan()?; + let proto = to_substrait_plan(&plan1)?; + + let df = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = df.to_logical_plan()?; + + // Format plan string and replace all None's with 0 + let plan1str = format!("{:?}", plan1).replace("None", "0"); + let plan2str = format!("{:?}", plan2).replace("None", "0"); + + assert_eq!(plan1str, plan2str); + Ok(()) + } + + async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { + // Since we ignore the SubqueryAlias in the producer, the result should be + // the same as producing a Substrait plan from the same query without aliases + // sql_with_alias -> substrait -> logical plan = sql_no_alias -> substrait -> logical plan + let mut ctx = create_context().await?; + + let df_a = ctx.sql(sql_with_alias).await?; + let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?; + let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?.to_logical_plan()?; + + let df = ctx.sql(sql_no_alias).await?; + let proto = to_substrait_plan(&df.to_logical_plan()?)?; + let plan = from_substrait_plan(&mut ctx, &proto).await?.to_logical_plan()?; + + println!("{:#?}", plan_with_alias); + println!("{:#?}", plan); + + let plan1str = format!("{:?}", plan_with_alias); + let plan2str = format!("{:?}", plan); + assert_eq!(plan1str, plan2str); + Ok(()) + } + + async fn roundtrip(sql: &str) -> Result<()> { + let mut ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.to_logical_plan()?; + let proto = to_substrait_plan(&plan)?; + + let df = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = df.to_logical_plan()?; + + println!("{:#?}", plan); + println!("{:#?}", plan2); + + let plan1str = format!("{:?}", plan); + let plan2str = format!("{:?}", plan2); + assert_eq!(plan1str, plan2str); + Ok(()) + } + + async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.to_logical_plan()?; + let proto = to_substrait_plan(&plan)?; + + let mut function_names: Vec = vec![]; + let mut function_anchors: Vec = vec![]; + for e in &proto.extensions { + let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), + _ => unreachable!("Producer does not generate a non-function extension") + }; + function_names.push(function_name.to_string()); + function_anchors.push(function_anchor); + } + + Ok((function_names, function_anchors)) + } + + async fn create_context() -> Result { + let ctx = SessionContext::new(); + ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + Ok(ctx) + } +} diff --git a/datafusion/substrait/tests/serialize.rs b/datafusion/substrait/tests/serialize.rs new file mode 100644 index 000000000000..505c4f5f4ec4 --- /dev/null +++ b/datafusion/substrait/tests/serialize.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod tests { + + use datafusion_substrait::consumer::from_substrait_plan; + use datafusion_substrait::serializer; + + use datafusion::error::Result; + use datafusion::prelude::*; + + use std::fs; + + #[tokio::test] + async fn serialize_simple_select() -> Result<()> { + let mut ctx = create_context().await?; + let path = "tests/simple_select.bin"; + let sql = "SELECT a, b FROM data"; + // Test reference + let df_ref = ctx.sql(sql).await?; + let plan_ref = df_ref.to_logical_plan()?; + // Test + // Write substrait plan to file + serializer::serialize(sql, &ctx, &path).await?; + // Read substrait plan from file + let proto = serializer::deserialize(path).await?; + // Check plan equality + let df = from_substrait_plan(&mut ctx, &proto).await?; + let plan = df.to_logical_plan()?; + let plan_str_ref = format!("{:?}", plan_ref); + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str_ref, plan_str); + // Delete test binary file + fs::remove_file(path)?; + + Ok(()) + } + + async fn create_context() -> Result { + let ctx = SessionContext::new(); + ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + Ok(ctx) + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv new file mode 100644 index 000000000000..4394789bcda6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/data.csv @@ -0,0 +1,3 @@ +a,b,c,d +1,2,2020-01-01,false +3,4,2020-01-01,true \ No newline at end of file