diff --git a/src/algorithms/boundary.rs b/src/algorithms/boundary.rs index b137b14..cc7cf25 100644 --- a/src/algorithms/boundary.rs +++ b/src/algorithms/boundary.rs @@ -107,6 +107,18 @@ impl Boundary { }) } + /// Returns an iterator over the input ports in the boundary. + #[inline] + pub fn inputs(&self) -> impl Iterator { + self.ports(Direction::Incoming) + } + + /// Returns an iterator over the output ports in the boundary. + #[inline] + pub fn outputs(&self) -> impl Iterator { + self.ports(Direction::Outgoing) + } + /// Returns the [`PortIndex`] corresponding to a [`BoundaryPort`] in this boundary. pub fn port_index(&self, port: &BoundaryPort) -> PortIndex { match port.direction { @@ -156,13 +168,16 @@ impl Boundary { // Toposort the subgraph, and collect the reaching input ports for each node. // We keep track of how many output neighbors remain to be visited, so we can // trim the `reaching` set when we reach the last one. + + // We have to start the toposort from the input nodes that do not have predecessors. + let source_nodes = input_nodes + .keys() + .copied() + .filter(|&node| graph.input_neighbours(node).count() == 0); + let mut reaching: HashMap)> = HashMap::with_capacity(self.num_ports()); - for node in toposort::<_, HashSet>( - graph, - input_nodes.keys().copied(), - Direction::Outgoing, - ) { + for node in toposort::<_, HashSet>(graph, source_nodes, Direction::Outgoing) { // Collect the reaching ports, plus any ports in the node itself. let mut reaching_ports: HashSet = input_nodes .get(&node) @@ -322,3 +337,179 @@ impl PortOrdering { self.reaching[to_index].insert(from); } } + +#[cfg(test)] +mod test { + use crate::view::Subgraph; + use crate::{LinkMut, MultiPortGraph, PortMut}; + + use super::*; + use itertools::Itertools; + use rstest::{fixture, rstest}; + + /// A complete bipartite graph with `N` input nodes and `M` output nodes. + /// + /// Each input node is connected to all output nodes. + /// + /// Returns the graph and the input and output nodes arrays. + #[fixture] + fn graph_kn() -> (MultiPortGraph, [NodeIndex; N], [NodeIndex; M]) + { + let mut graph = MultiPortGraph::new(); + let inputs: [NodeIndex; N] = (0..N) + .map(|_| graph.add_node(1, M)) + .collect_vec() + .try_into() + .unwrap(); + let outputs: [NodeIndex; M] = (0..M) + .map(|_| graph.add_node(N, 1)) + .collect_vec() + .try_into() + .unwrap(); + + for (i, &input) in inputs.iter().enumerate() { + for (j, &output) in outputs.iter().enumerate() { + graph.link_nodes(input, j, output, i).unwrap(); + } + } + + (graph, inputs, outputs) + } + + /// Test DAG + /// + /// ```text + /// 0 -> 1 -> 2 -> 3 + /// | \ | + /// v \ v + /// 4 -> 5 -> 6 -> 7 + /// ``` + #[fixture] + fn graph() -> MultiPortGraph { + let mut graph = MultiPortGraph::new(); + let nodes: Vec = (0..8).map(|_| graph.add_node(4, 4)).collect(); + // Horizontal links between from port 0 to port 0 + graph.link_nodes(nodes[0], 0, nodes[1], 0).unwrap(); + graph.link_nodes(nodes[1], 0, nodes[2], 0).unwrap(); + graph.link_nodes(nodes[2], 0, nodes[3], 0).unwrap(); + graph.link_nodes(nodes[4], 0, nodes[5], 0).unwrap(); + graph.link_nodes(nodes[5], 0, nodes[6], 0).unwrap(); + graph.link_nodes(nodes[6], 0, nodes[7], 0).unwrap(); + // Other ports + graph.link_nodes(nodes[1], 1, nodes[5], 1).unwrap(); + graph.link_nodes(nodes[2], 1, nodes[6], 0).unwrap(); + graph.link_nodes(nodes[1], 2, nodes[6], 2).unwrap(); + + graph + } + + #[rstest] + fn test_boundary_new(graph: MultiPortGraph) { + let nodes = graph.nodes_iter().collect_vec(); + + // Create a boundary containing the {1,2,5,6} subgraph. + let boundary = Boundary::new( + &graph, + [ + graph.input(nodes[1], 0).unwrap(), + graph.input(nodes[5], 0).unwrap(), + ], + [ + graph.output(nodes[2], 0).unwrap(), + graph.output(nodes[6], 0).unwrap(), + ], + ); + let subgraph = Subgraph::with_nodes(&graph, [nodes[1], nodes[2], nodes[5], nodes[6]]); + assert_eq!(boundary, subgraph.port_boundary()); + assert_eq!(boundary.num_ports(), 4); + assert_eq!( + boundary.port(graph.input(nodes[5], 0).unwrap(), Direction::Incoming), + BoundaryPort { + index: 1, + direction: Direction::Incoming + } + ); + assert_eq!(boundary.ports(Direction::Incoming).count(), 2); + } + + #[rstest] + fn test_port_ordering(graph: MultiPortGraph) { + let nodes = graph.nodes_iter().collect_vec(); + let subgraph = Subgraph::with_nodes(&graph, [nodes[1], nodes[2], nodes[5], nodes[6]]); + let boundary = subgraph.port_boundary(); + + let (in_0, in_1) = boundary.inputs().collect_tuple().unwrap(); + let (out_0, out_1) = boundary.outputs().collect_tuple().unwrap(); + + let ordering = boundary.port_ordering(&subgraph); + assert_eq!( + ordering + .reachable_ports(in_0) + .iter() + .copied() + .sorted() + .collect_vec() + .as_slice(), + [out_0, out_1] + ); + assert_eq!( + ordering + .reachable_ports(in_1) + .iter() + .copied() + .collect_vec() + .as_slice(), + [out_1] + ); + assert_eq!( + ordering + .reaching_ports(out_0) + .iter() + .copied() + .collect_vec() + .as_slice(), + [in_0] + ); + assert_eq!( + ordering + .reaching_ports(out_1) + .iter() + .copied() + .sorted() + .collect_vec() + .as_slice(), + [in_0, in_1] + ); + } + + #[rstest] + fn test_order_comparison(graph: MultiPortGraph) { + let nodes = graph.nodes_iter().collect_vec(); + let subgraph = Subgraph::with_nodes(&graph, [nodes[1], nodes[2], nodes[5], nodes[6]]); + let boundary = subgraph.port_boundary(); + + let (graph_22, ins_22, outs_22) = graph_kn::<2, 2>(); + let boundary_22 = Boundary::new( + &graph_22, + [ + graph_22.input(ins_22[0], 0).unwrap(), + graph_22.input(ins_22[1], 0).unwrap(), + ], + [ + graph_22.output(outs_22[0], 0).unwrap(), + graph_22.output(outs_22[1], 0).unwrap(), + ], + ); + + assert!(boundary.is_compatible(&boundary_22)); + assert!(boundary_22.is_compatible(&boundary)); + + let ordering = boundary.port_ordering(&subgraph); + let ordering_22 = boundary_22.port_ordering(&graph_22); + + assert!(!ordering.is_stronger_than(&ordering_22)); + assert!(ordering_22.is_stronger_than(&ordering)); + assert!(!boundary.is_stronger_than(&boundary_22, &subgraph, &graph_22)); + assert!(boundary_22.is_stronger_than(&boundary, &graph_22, &subgraph)); + } +} diff --git a/src/view/subgraph.rs b/src/view/subgraph.rs index d99ff0b..c7c666b 100644 --- a/src/view/subgraph.rs +++ b/src/view/subgraph.rs @@ -5,6 +5,7 @@ use std::collections::BTreeSet; use delegate::delegate; use itertools::{Either, Itertools}; +use crate::algorithms::boundary::{Boundary, HasBoundary}; use crate::PortOffset; use crate::{ algorithms::{ConvexChecker, TopoConvexChecker}, @@ -383,6 +384,12 @@ where } } +impl HasBoundary for Subgraph { + fn port_boundary(&self) -> crate::algorithms::boundary::Boundary { + Boundary::new_unchecked(self.inputs.clone(), self.outputs.clone()) + } +} + #[cfg(test)] mod tests { use itertools::Itertools;