diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 5191a620b473e..b73d246e19899 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -361,21 +361,45 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; - let sort_fields = sort - .expr + LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch }) => { + let sort_fields = expr .iter() - .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, input.schema(), extensions)) .collect::>>()?; - Ok(Box::new(Rel { + + let input = to_substrait_rel(input.as_ref(), state, extensions)?; + + let sort_rel = Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { common: None, input: Some(input), sorts: sort_fields, advanced_extension: None, }))), - })) + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(sort_rel), + offset_mode: None, + count_mode, + advanced_extension: None, + }))), + })) + } + None => Ok(sort_rel), + } } LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1ce0eec1b21df..1d1a87015135b 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -199,6 +199,16 @@ async fn select_with_filter() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1").await } +#[tokio::test] +async fn select_with_filter_sort_limit() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2").await +} + +#[tokio::test] +async fn select_with_filter_sort_limit_offset() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2 OFFSET 1").await +} + #[tokio::test] async fn select_with_reused_functions() -> Result<()> { let ctx = create_context().await?;