Skip to content

Commit

Permalink
Add subtrait support for IS NULL and IS NOT NULL (#8093)
Browse files Browse the repository at this point in the history
* added match arms and tests for is null

* fixed formatting

---------

Co-authored-by: Tanmay Gujar <[email protected]>
  • Loading branch information
tgujar and tgujar authored Nov 12, 2023
1 parent 9e012a6 commit e18c709
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 1 deletion.
42 changes: 42 additions & 0 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ enum ScalarFunctionType {
Like,
/// [Expr::Like] Case insensitive operator counterpart of `Like`
ILike,
/// [Expr::IsNull]
IsNull,
/// [Expr::IsNotNull]
IsNotNull,
}

pub fn name_to_op(name: &str) -> Result<Operator> {
Expand Down Expand Up @@ -126,6 +130,8 @@ fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
"not" => Ok(ScalarFunctionType::Not),
"like" => Ok(ScalarFunctionType::Like),
"ilike" => Ok(ScalarFunctionType::ILike),
"is_null" => Ok(ScalarFunctionType::IsNull),
"is_not_null" => Ok(ScalarFunctionType::IsNotNull),
others => not_impl_err!("Unsupported function name: {others:?}"),
}
}
Expand Down Expand Up @@ -880,6 +886,42 @@ pub async fn from_substrait_rex(
ScalarFunctionType::ILike => {
make_datafusion_like(true, f, input_schema, extensions).await
}
ScalarFunctionType::IsNull => {
let arg = f.arguments.first().ok_or_else(|| {
DataFusionError::Substrait(
"expect one argument for `IS NULL` expr".to_string(),
)
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
Ok(Arc::new(Expr::IsNull(Box::new(expr))))
}
_ => not_impl_err!("Invalid arguments for IS NULL expression"),
}
}
ScalarFunctionType::IsNotNull => {
let arg = f.arguments.first().ok_or_else(|| {
DataFusionError::Substrait(
"expect one argument for `IS NOT NULL` expr".to_string(),
)
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
Ok(Arc::new(Expr::IsNotNull(Box::new(expr))))
}
_ => {
not_impl_err!("Invalid arguments for IS NOT NULL expression")
}
}
}
}
}
Some(RexType::Literal(lit)) => {
Expand Down
48 changes: 47 additions & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,53 @@ pub fn to_substrait_rex(
col_ref_offset,
extension_info,
),
_ => not_impl_err!("Unsupported expression: {expr:?}"),
Expr::IsNull(arg) => {
let arguments: Vec<FunctionArgument> = vec![FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
}];

let function_name = "is_null".to_string();
let function_anchor = _register_function(function_name, extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
args: vec![],
options: vec![],
})),
})
}
Expr::IsNotNull(arg) => {
let arguments: Vec<FunctionArgument> = vec![FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
}];

let function_name = "is_not_null".to_string();
let function_anchor = _register_function(function_name, extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
args: vec![],
options: vec![],
})),
})
}
_ => {
not_impl_err!("Unsupported expression: {expr:?}")
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,16 @@ async fn simple_scalar_function_substr() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await
}

#[tokio::test]
async fn simple_scalar_function_is_null() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a IS NULL").await
}

#[tokio::test]
async fn simple_scalar_function_is_not_null() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a IS NOT NULL").await
}

#[tokio::test]
async fn case_without_base_expression() -> Result<()> {
roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data")
Expand Down

0 comments on commit e18c709

Please sign in to comment.