Skip to content

Commit

Permalink
fix walker, rustfmt
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Oct 26, 2023
1 parent bb76660 commit adba209
Showing 1 changed file with 49 additions and 44 deletions.
93 changes: 49 additions & 44 deletions src/walker.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(missing_docs)]
use std::rc::Rc;

use itertools::Itertools;
use lazy_static::__Deref;

use crate::{ops::OpType, HugrView, Node};
Expand All @@ -16,31 +18,31 @@ pub enum WalkOrder {
Postorder,
}

struct WalkerCallback<'a, T>(Box<dyn 'a + Fn(Node, OpType, &mut T) -> WalkResult>);
struct WalkerCallback<'a, T, E>(Box<dyn 'a + Fn(Node, OpType, T) -> Result<T, E>>);

impl<'a, T, F: 'a + Fn(Node, OpType, &mut T) -> WalkResult> From<F> for WalkerCallback<'a, T> {
impl<'a, T, E, F: 'a + Fn(Node, OpType, T) -> Result<T, E>> From<F> for WalkerCallback<'a, T, E> {
fn from(f: F) -> Self {
Self(Box::new(f))
}
}

pub struct Walker<'a, T> {
pre_callbacks: Vec<WalkerCallback<'a, T>>,
post_callbacks: Vec<WalkerCallback<'a, T>>,
pub struct Walker<'a, T, E> {
pre_callbacks: Vec<WalkerCallback<'a, T, E>>,
post_callbacks: Vec<WalkerCallback<'a, T, E>>,
}

fn call_back<O, T>(
fn call_back<O, T, E>(
n: Node,
o: OpType,
t: &mut T,
f: impl Fn(Node, O, &mut T) -> WalkResult,
) -> WalkResult
t: T,
f: &impl Fn(Node, O, T) -> Result<T, E>,
) -> Result<T, E>
where
OpType: TryInto<O>,
{
match o.try_into() {
Ok(x) => f(n, x, t),
_ => WalkResult::Advance,
_ => Ok(t),
}
}

Expand All @@ -49,15 +51,15 @@ enum WorkItem {
Callback(WalkOrder, Node),
}

impl<'a, T> Walker<'a, T> {
impl<'a, T, E> Walker<'a, T, E> {
pub fn new() -> Self {
Self {
pre_callbacks: Vec::new(),
post_callbacks: Vec::new(),
}
}

pub fn visit<O, F: 'a + Fn(Node, O, &mut T) -> WalkResult>(
pub fn visit<O, F: 'a + Fn(Node, O, T) -> Result<T, E>>(
&mut self,
walk_order: WalkOrder,
f: F,
Expand All @@ -70,11 +72,11 @@ impl<'a, T> Walker<'a, T> {
WalkOrder::Preorder => &mut self.pre_callbacks,
WalkOrder::Postorder => &mut self.post_callbacks,
};
callbacks.push((move |n, o, t: &'_ mut _| call_back(n, o, t, g.as_ref())).into());
callbacks.push((move |n, o, t| call_back(n, o, t, g.as_ref())).into());
self
}

pub fn walk(&self, hugr: impl HugrView, t: &mut T) {
pub fn walk(&self, hugr: impl HugrView, mut t: T) -> Result<T, E> {
// We intentionally avoid recursion so that we can robustly accept very deep hugrs
let mut worklist = vec![WorkItem::Visit(hugr.root())];

Expand All @@ -83,84 +85,87 @@ impl<'a, T> Walker<'a, T> {
WorkItem::Visit(n) => {
worklist.push(WorkItem::Callback(WalkOrder::Postorder, n));
// TODO we should add children in topological order
worklist.extend(hugr.children(n).map(WorkItem::Visit));
let mut children = hugr.children(n).collect_vec();
children.reverse();
worklist.extend(children.into_iter().map(WorkItem::Visit));
worklist.push(WorkItem::Callback(WalkOrder::Preorder, n));
}
WorkItem::Callback(order, n) => {
let callbacks = match order {
WalkOrder::Preorder => &self.pre_callbacks,
WalkOrder::Postorder => &self.post_callbacks,
};
let optype = hugr.get_optype(n);
for cb in callbacks.iter() {
// this clone is unfortunate, to avoid this we would need a TryInto variant:
// try_into(&O) -> Option<&T>
if cb.0.as_ref()(n, hugr.get_optype(n).clone(), t) == WalkResult::Interrupt
{
return;
}
t = cb.0.as_ref()(n, optype.clone(), t)?;
}
}
}
}
Ok(t)
}
}

#[cfg(test)]
mod test {
use std::{error::Error, iter::empty};
use std::error::Error;

use crate::types::Signature;
use crate::{
builder::{Container, HugrBuilder, ModuleBuilder},
builder::{Container, HugrBuilder, ModuleBuilder, SubContainer},
extension::{ExtensionRegistry, ExtensionSet},
type_row,
types::FunctionType,
};

use super::*;

#[test]
fn test1() -> Result<(), Box<dyn Error>> {
let mut module_builder = ModuleBuilder::new();
let sig = Signature {
signature: FunctionType::new(type_row![], type_row![]),
input_extensions: ExtensionSet::new(),
};
module_builder.define_function("f1", sig.clone());
module_builder.define_function("f2", sig.clone());
module_builder
.define_function("f1", sig.clone())?
.finish_sub_container()?;
module_builder
.define_function("f2", sig.clone())?
.finish_sub_container()?;

let hugr = module_builder.finish_hugr(&ExtensionRegistry::new())?;

let mut s = String::new();
Walker::<String>::new()
.visit(WalkOrder::Preorder, |_, crate::ops::Module, r| {
r.extend("pre".chars());
r.extend(['m']);
WalkResult::Advance
let s = Walker::<String, Box<dyn Error>>::new()
.visit(WalkOrder::Preorder, |_, crate::ops::Module, mut r| {
r += "prem";
Ok(r)
})
.visit(WalkOrder::Postorder, |_, crate::ops::Module, r| {
r.extend("post".chars());
r.extend(['n']);
WalkResult::Advance
.visit(WalkOrder::Postorder, |_, crate::ops::Module, mut r| {
r += "postm";
Ok(r)
})
.visit(
WalkOrder::Preorder,
|_, crate::ops::FuncDecl { ref name, .. }, r| {
r.extend("pre".chars());
r.extend(name.chars());
WalkResult::Advance
|_, crate::ops::FuncDefn { ref name, .. }, mut r| {
r += "pre";
r += name.as_ref();
Ok(r)
},
)
.visit(
WalkOrder::Postorder,
|_, crate::ops::FuncDecl { ref name, .. }, r| {
r.extend("post".chars());
r.extend(name.chars());
WalkResult::Advance
|_, crate::ops::FuncDefn { ref name, .. }, mut r| {
r += "post";
r += name.as_ref();
Ok(r)
},
)
.walk(&hugr, &mut s);
.walk(&hugr, String::new())?;

assert_eq!(s, "prempref1pref2postf2postf1postn");
Ok(())
assert_eq!(s, "prempref1postf1pref2postf2postm");
Ok::<(), Box<dyn Error>>(())
}
}

0 comments on commit adba209

Please sign in to comment.