diff --git a/src/walker.rs b/src/walker.rs index 316e65ed7..009203d80 100644 --- a/src/walker.rs +++ b/src/walker.rs @@ -1,5 +1,7 @@ +#![allow(missing_docs)] use std::rc::Rc; +use itertools::Itertools; use lazy_static::__Deref; use crate::{ops::OpType, HugrView, Node}; @@ -16,31 +18,31 @@ pub enum WalkOrder { Postorder, } -struct WalkerCallback<'a, T>(Box WalkResult>); +struct WalkerCallback<'a, T, E>(Box Result>); -impl<'a, T, F: 'a + Fn(Node, OpType, &mut T) -> WalkResult> From for WalkerCallback<'a, T> { +impl<'a, T, E, F: 'a + Fn(Node, OpType, T) -> Result> From for WalkerCallback<'a, T, E> { fn from(f: F) -> Self { Self(Box::new(f)) } } -pub struct Walker<'a, T> { - pre_callbacks: Vec>, - post_callbacks: Vec>, +pub struct Walker<'a, T, E> { + pre_callbacks: Vec>, + post_callbacks: Vec>, } -fn call_back( +fn call_back( n: Node, o: OpType, - t: &mut T, - f: impl Fn(Node, O, &mut T) -> WalkResult, -) -> WalkResult + t: T, + f: &impl Fn(Node, O, T) -> Result, +) -> Result where OpType: TryInto, { match o.try_into() { Ok(x) => f(n, x, t), - _ => WalkResult::Advance, + _ => Ok(t), } } @@ -49,7 +51,7 @@ enum WorkItem { Callback(WalkOrder, Node), } -impl<'a, T> Walker<'a, T> { +impl<'a, T, E> Walker<'a, T, E> { pub fn new() -> Self { Self { pre_callbacks: Vec::new(), @@ -57,7 +59,7 @@ impl<'a, T> Walker<'a, T> { } } - pub fn visit WalkResult>( + pub fn visit Result>( &mut self, walk_order: WalkOrder, f: F, @@ -70,11 +72,11 @@ impl<'a, T> Walker<'a, T> { WalkOrder::Preorder => &mut self.pre_callbacks, WalkOrder::Postorder => &mut self.post_callbacks, }; - callbacks.push((move |n, o, t: &'_ mut _| call_back(n, o, t, g.as_ref())).into()); + callbacks.push((move |n, o, t| call_back(n, o, t, g.as_ref())).into()); self } - pub fn walk(&self, hugr: impl HugrView, t: &mut T) { + pub fn walk(&self, hugr: impl HugrView, mut t: T) -> Result { // We intentionally avoid recursion so that we can robustly accept very deep hugrs let mut worklist = vec![WorkItem::Visit(hugr.root())]; @@ -83,7 +85,9 @@ impl<'a, T> Walker<'a, T> { WorkItem::Visit(n) => { worklist.push(WorkItem::Callback(WalkOrder::Postorder, n)); // TODO we should add children in topological order - worklist.extend(hugr.children(n).map(WorkItem::Visit)); + let mut children = hugr.children(n).collect_vec(); + children.reverse(); + worklist.extend(children.into_iter().map(WorkItem::Visit)); worklist.push(WorkItem::Callback(WalkOrder::Preorder, n)); } WorkItem::Callback(order, n) => { @@ -91,27 +95,26 @@ impl<'a, T> Walker<'a, T> { WalkOrder::Preorder => &self.pre_callbacks, WalkOrder::Postorder => &self.post_callbacks, }; + let optype = hugr.get_optype(n); for cb in callbacks.iter() { // this clone is unfortunate, to avoid this we would need a TryInto variant: // try_into(&O) -> Option<&T> - if cb.0.as_ref()(n, hugr.get_optype(n).clone(), t) == WalkResult::Interrupt - { - return; - } + t = cb.0.as_ref()(n, optype.clone(), t)?; } } } } + Ok(t) } } #[cfg(test)] mod test { - use std::{error::Error, iter::empty}; + use std::error::Error; use crate::types::Signature; use crate::{ - builder::{Container, HugrBuilder, ModuleBuilder}, + builder::{Container, HugrBuilder, ModuleBuilder, SubContainer}, extension::{ExtensionRegistry, ExtensionSet}, type_row, types::FunctionType, @@ -119,48 +122,50 @@ mod test { use super::*; + #[test] fn test1() -> Result<(), Box> { let mut module_builder = ModuleBuilder::new(); let sig = Signature { signature: FunctionType::new(type_row![], type_row![]), input_extensions: ExtensionSet::new(), }; - module_builder.define_function("f1", sig.clone()); - module_builder.define_function("f2", sig.clone()); + module_builder + .define_function("f1", sig.clone())? + .finish_sub_container()?; + module_builder + .define_function("f2", sig.clone())? + .finish_sub_container()?; let hugr = module_builder.finish_hugr(&ExtensionRegistry::new())?; - let mut s = String::new(); - Walker::::new() - .visit(WalkOrder::Preorder, |_, crate::ops::Module, r| { - r.extend("pre".chars()); - r.extend(['m']); - WalkResult::Advance + let s = Walker::>::new() + .visit(WalkOrder::Preorder, |_, crate::ops::Module, mut r| { + r += "prem"; + Ok(r) }) - .visit(WalkOrder::Postorder, |_, crate::ops::Module, r| { - r.extend("post".chars()); - r.extend(['n']); - WalkResult::Advance + .visit(WalkOrder::Postorder, |_, crate::ops::Module, mut r| { + r += "postm"; + Ok(r) }) .visit( WalkOrder::Preorder, - |_, crate::ops::FuncDecl { ref name, .. }, r| { - r.extend("pre".chars()); - r.extend(name.chars()); - WalkResult::Advance + |_, crate::ops::FuncDefn { ref name, .. }, mut r| { + r += "pre"; + r += name.as_ref(); + Ok(r) }, ) .visit( WalkOrder::Postorder, - |_, crate::ops::FuncDecl { ref name, .. }, r| { - r.extend("post".chars()); - r.extend(name.chars()); - WalkResult::Advance + |_, crate::ops::FuncDefn { ref name, .. }, mut r| { + r += "post"; + r += name.as_ref(); + Ok(r) }, ) - .walk(&hugr, &mut s); + .walk(&hugr, String::new())?; - assert_eq!(s, "prempref1pref2postf2postf1postn"); - Ok(()) + assert_eq!(s, "prempref1postf1pref2postf2postm"); + Ok::<(), Box>(()) } }