From dcdbe88bc7c6796fcb11d189ce3a6a8c33b3c511 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 4 Apr 2024 07:12:40 -0400 Subject: [PATCH] user defined nodes --- datafusion/expr/src/logical_plan/plan.rs | 35 ++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6f4f8bcdd31e9..73c59f9200b90 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1191,6 +1191,38 @@ where rewrite_arc(node, f).map(|res| res.discard_data()) } +/// Rewrites all inputs for an Extension node "in place" +/// (it currently has to copy values because there are no APIs for in place modification) +/// +/// Should be removed when we have an API for in place modifications of the +/// extension to avoid these copies +fn rewrite_extension_inputs( + node: &mut Arc, + mut f: F, +) -> Result> +where + F: FnMut(LogicalPlan) -> Result>, +{ + let Transformed { + data: new_inputs, + transformed, + tnr, + } = node + .inputs() + .into_iter() + .cloned() + .map_until_stop_and_collect(|input| f(input))?; + + let exprs = node.expressions(); + let mut new_node = node.from_template(&exprs, &new_inputs); + std::mem::swap(node, &mut new_node); + Ok(Transformed { + data: (), + transformed, + tnr, + }) +} + impl LogicalPlan { /// applies `f` to each input of this plan node, rewriting them *in place.* /// @@ -1241,8 +1273,7 @@ impl LogicalPlan { rewrite_arc_no_data(input, &mut f) } LogicalPlan::Extension(extension) => { - todo!(); - //rewrite_extension_inputs(&mut extension.node, &mut f) + rewrite_extension_inputs(&mut extension.node, &mut f) } LogicalPlan::Union(Union { inputs, .. }) => { let results = inputs