From 9fe73ed7587c03f1c849ba9afac33e18bd1eb018 Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:07:03 +0100 Subject: [PATCH] feat(hugr-passes): Add `force_order` pass. (#1285) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit closes #1282 --------- Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/force_order.rs | 307 +++++++++++++++++++++++++++++++++ hugr-passes/src/lib.rs | 3 + 3 files changed, 311 insertions(+) create mode 100644 hugr-passes/src/force_order.rs diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 467c92505..161566cfc 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -18,6 +18,7 @@ itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } +petgraph = { workspace = true } [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs new file mode 100644 index 000000000..4278b2879 --- /dev/null +++ b/hugr-passes/src/force_order.rs @@ -0,0 +1,307 @@ +//! Provides [force_order], a tool for fixing the order of nodes in a Hugr. +use std::{cmp::Reverse, collections::BinaryHeap}; + +use hugr_core::{ + hugr::{ + hugrmut::HugrMut, + views::{DescendantsGraph, HierarchyView, SiblingGraph}, + HugrError, + }, + ops::{OpTag, OpTrait}, + types::EdgeKind, + Direction, HugrView as _, Node, +}; +use itertools::Itertools as _; +use petgraph::{ + visit::{ + GraphBase, GraphRef, IntoNeighbors as _, IntoNeighborsDirected, IntoNodeIdentifiers, + NodeFiltered, VisitMap, Visitable, Walker, + }, + Direction::Incoming, +}; + +/// Insert order edges into a Hugr according to a rank function. +/// +/// All dataflow parents which are transitive children of `root`, including +/// `root` itself, will have their dataflow regions ordered. +/// +/// Dataflow regions are ordered by inserting order edges between their +/// immediate children. A dataflow parent with `C` children will have at most +/// `C-1` edges added. Any node than can be ordered will be. +/// +/// Nodes are ordered according to the `rank` function. Nodes of lower rank will +/// be ordered earlier in their parent. Note that if `rank(n1) < rank(n2)` it +/// is not guaranteed that `n1` will be ordered before `n2`. If `n1` dominates +/// `n2` it cannot be ordered after `n1` without invalidating `hugr`. Nodes of +/// equal rank will be ordered arbitrarily, although that arbitrary order is +/// deterministic. +pub fn force_order( + hugr: &mut impl HugrMut, + root: Node, + rank: impl Fn(Node) -> i64, +) -> Result<(), HugrError> { + force_order_by_key(hugr, root, rank) +} + +/// As [force_order], but allows a generic [Ord] choice for the result of the +/// `rank` function. +pub fn force_order_by_key( + hugr: &mut impl HugrMut, + root: Node, + rank: impl Fn(Node) -> K, +) -> Result<(), HugrError> { + let dataflow_parents = DescendantsGraph::::try_new(hugr, root)? + .nodes() + .filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent) + .collect_vec(); + for dp in dataflow_parents { + let sg = SiblingGraph::::try_new(hugr, dp)?; + let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp); + let ordered_nodes = ForceOrder::new(&petgraph, &rank) + .iter(&petgraph) + .filter(|&x| hugr.get_optype(x).tag() <= OpTag::DataflowChild) + .collect_vec(); + + for (&n1, &n2) in ordered_nodes.iter().tuple_windows() { + let (n1_ot, n2_ot) = (hugr.get_optype(n1), hugr.get_optype(n2)); + assert_eq!( + Some(EdgeKind::StateOrder), + n1_ot.other_port_kind(Direction::Outgoing), + "Node {n1} does not support state order edges" + ); + assert_eq!( + Some(EdgeKind::StateOrder), + n2_ot.other_port_kind(Direction::Incoming), + "Node {n2} does not support state order edges" + ); + if !hugr.output_neighbours(n1).contains(&n2) { + hugr.connect( + n1, + n1_ot.other_output_port().unwrap(), + n2, + n2_ot.other_input_port().unwrap(), + ); + } + } + } + + Ok(()) +} + +/// An adaption of [petgraph::visit::Topo]. We differ only in that we sort nodes +/// by the rank function before adding them to the internal work stack. This +/// ensures we visit lower ranked nodes before higher ranked nodes whenever the +/// topology of the graph allows. +#[derive(Clone)] +struct ForceOrder { + tovisit: BinaryHeap<(Reverse, N)>, + ordered: VM, + rank: F, +} + +impl K> ForceOrder +where + N: Copy + PartialEq, + VM: VisitMap, +{ + pub fn new(graph: G, rank: F) -> Self + where + G: IntoNodeIdentifiers + IntoNeighborsDirected + Visitable, + { + let mut topo = Self::empty(graph, rank); + topo.extend_with_initials(graph); + topo + } + + fn empty(graph: G, rank: F) -> Self + where + G: GraphRef + Visitable, + { + Self { + ordered: graph.visit_map(), + tovisit: Default::default(), + rank, + } + } + + fn extend_with_initials(&mut self, g: G) + where + G: IntoNodeIdentifiers + IntoNeighborsDirected, + { + // find all initial nodes (nodes without incoming edges) + self.extend( + g.node_identifiers() + .filter(move |&a| g.neighbors_directed(a, Incoming).next().is_none()), + ); + } + + fn extend(&mut self, new_nodes: impl IntoIterator) { + self.tovisit + .extend(new_nodes.into_iter().map(|x| (Reverse((self.rank)(x)), x))); + } + + /// Return the next node in the current topological order traversal, or + /// `None` if the traversal is at the end. + /// + /// *Note:* The graph may not have a complete topological order, and the only + /// way to know is to run the whole traversal and make sure it visits every node. + pub fn next(&mut self, g: G) -> Option + where + G: IntoNeighborsDirected + Visitable, + { + // Take an unvisited element and find which of its neighbors are next + while let Some((_, nix)) = self.tovisit.pop() { + if self.ordered.is_visited(&nix) { + continue; + } + self.ordered.visit(nix); + // Look at each neighbor, and those that only have incoming edges + // from the already ordered list, they are the next to visit. + let new_nodes = g + .neighbors(nix) + .filter(|&n| { + petgraph::visit::Reversed(g) + .neighbors(n) + .all(|b| self.ordered.is_visited(&b)) + }) + .collect_vec(); + + self.extend(new_nodes); + return Some(nix); + } + None + } +} + +impl K> Walker + for ForceOrder +where + G::NodeId: Ord, +{ + type Item = ::NodeId; + + fn walk_next(&mut self, g: G) -> Option { + self.next(g) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use super::*; + use hugr_core::builder::{BuildHandle, Dataflow, DataflowHugr}; + use hugr_core::ops::handle::{DataflowOpID, NodeHandle}; + + use hugr_core::std_extensions::arithmetic::int_ops::{self, IntOpDef}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::FunctionType; + use hugr_core::{builder::DFGBuilder, hugr::Hugr}; + use hugr_core::{HugrView, Wire}; + + use petgraph::visit::Topo; + + const I: u8 = 3; + + fn build_neg(builder: &mut impl Dataflow, wire: Wire) -> BuildHandle { + builder + .add_dataflow_op(IntOpDef::ineg.with_log_width(I), [wire]) + .unwrap() + } + + fn build_add(builder: &mut impl Dataflow, w1: Wire, w2: Wire) -> BuildHandle { + builder + .add_dataflow_op(IntOpDef::iadd.with_log_width(I), [w1, w2]) + .unwrap() + } + + /// Our tests use the following hugr: + /// + /// a DFG with sig: [i8,i8] -> [i8,i8] + /// + /// Input + /// | | + /// | iw1 | iw2 + /// v0(neg) v1(neg) + /// | \ | + /// | \ | + /// | \ | + /// v2(neg) v3(add) + /// | | + /// Output + fn test_hugr() -> (Hugr, [Node; 4]) { + let t = INT_TYPES[I as usize].clone(); + let mut builder = + DFGBuilder::new(FunctionType::new_endo(vec![t.clone(), t.clone()])).unwrap(); + let [iw1, iw2] = builder.input_wires_arr(); + let v0 = build_neg(&mut builder, iw1); + let v1 = build_neg(&mut builder, iw2); + let v2 = build_neg(&mut builder, v0.out_wire(0)); + let v3 = build_add(&mut builder, v0.out_wire(0), v1.out_wire(0)); + let nodes = [v0, v1, v2, v3] + .into_iter() + .map(|x| x.handle().node()) + .collect_vec() + .try_into() + .unwrap(); + ( + builder + .finish_hugr_with_outputs( + [v2.out_wire(0), v3.out_wire(0)], + &int_ops::INT_OPS_REGISTRY, + ) + .unwrap(), + nodes, + ) + } + + type RankMap = HashMap; + + fn force_order_test_impl(hugr: &mut Hugr, rank_map: RankMap) -> Vec { + force_order(hugr, hugr.root(), |n| *rank_map.get(&n).unwrap_or(&0)).unwrap(); + + let topo_sorted = Topo::new(&hugr.as_petgraph()) + .iter(&hugr.as_petgraph()) + .filter(|n| rank_map.contains_key(n)) + .collect_vec(); + hugr.validate_no_extensions(&int_ops::INT_OPS_REGISTRY) + .unwrap(); + + topo_sorted + } + + #[test] + fn test_force_order_1() { + let (mut hugr, [v0, v1, v2, v3]) = test_hugr(); + + // v0 has a higher rank than v2, but v0 dominates v2, so cannot be + // ordered before it. + // + // v1 and v3 are pushed to the bottom of the graph with high weights. + let rank_map = [(v0, 2), (v2, 1), (v1, 10), (v3, 9)].into_iter().collect(); + + let topo_sort = force_order_test_impl(&mut hugr, rank_map); + assert_eq!(vec![v0, v2, v1, v3], topo_sort); + } + + #[test] + fn test_force_order_2() { + let (mut hugr, [v0, v1, v2, v3]) = test_hugr(); + + // v1 and v3 are pulled to the top of the graph with low weights. + // v3 cannot ascend past v0 because it is dominated by v0 + let rank_map = [(v0, 2), (v2, 1), (v1, -10), (v3, -9)] + .into_iter() + .collect(); + let topo_sort = force_order_test_impl(&mut hugr, rank_map); + assert_eq!(vec![v1, v0, v3, v2], topo_sort); + } + + #[test] + fn test_force_order_3() { + let (mut hugr, [v0, v1, v2, v3]) = test_hugr(); + let rank_map = [(v0, 0), (v1, 1), (v2, 2), (v3, 3)].into_iter().collect(); + let topo_sort = force_order_test_impl(&mut hugr, rank_map); + assert_eq!(vec![v0, v1, v2, v3], topo_sort); + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index f6e09b71b..f477d27cb 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,7 +1,10 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod force_order; mod half_node; pub mod merge_bbs; pub mod nest_cfgs; pub mod validation; + +pub use force_order::{force_order, force_order_by_key};