diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a15121652452..c6bcbb479e80 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -78,6 +78,10 @@ enum ScalarFunctionType { Like, /// [Expr::Like] Case insensitive operator counterpart of `Like` ILike, + /// [Expr::IsNull] + IsNull, + /// [Expr::IsNotNull] + IsNotNull, } pub fn name_to_op(name: &str) -> Result { @@ -126,6 +130,8 @@ fn scalar_function_type_from_str(name: &str) -> Result { "not" => Ok(ScalarFunctionType::Not), "like" => Ok(ScalarFunctionType::Like), "ilike" => Ok(ScalarFunctionType::ILike), + "is_null" => Ok(ScalarFunctionType::IsNull), + "is_not_null" => Ok(ScalarFunctionType::IsNotNull), others => not_impl_err!("Unsupported function name: {others:?}"), } } @@ -880,6 +886,42 @@ pub async fn from_substrait_rex( ScalarFunctionType::ILike => { make_datafusion_like(true, f, input_schema, extensions).await } + ScalarFunctionType::IsNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `IS NULL` expr".to_string(), + ) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::IsNull(Box::new(expr)))) + } + _ => not_impl_err!("Invalid arguments for IS NULL expression"), + } + } + ScalarFunctionType::IsNotNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `IS NOT NULL` expr".to_string(), + ) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) + } + _ => { + not_impl_err!("Invalid arguments for IS NOT NULL expression") + } + } + } } } Some(RexType::Literal(lit)) => { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e3c6f94d43d5..142b6c3628bb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1025,7 +1025,53 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - _ => not_impl_err!("Unsupported expression: {expr:?}"), + Expr::IsNull(arg) => { + let arguments: Vec = vec![FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + col_ref_offset, + extension_info, + )?)), + }]; + + let function_name = "is_null".to_string(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }) + } + Expr::IsNotNull(arg) => { + let arguments: Vec = vec![FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + col_ref_offset, + extension_info, + )?)), + }]; + + let function_name = "is_not_null".to_string(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }) + } + _ => { + not_impl_err!("Unsupported expression: {expr:?}") + } } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ca2b4d48c460..582e5a5d7c8e 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -314,6 +314,16 @@ async fn simple_scalar_function_substr() -> Result<()> { roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await } +#[tokio::test] +async fn simple_scalar_function_is_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NULL").await +} + +#[tokio::test] +async fn simple_scalar_function_is_not_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NOT NULL").await +} + #[tokio::test] async fn case_without_base_expression() -> Result<()> { roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data")