Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 24, 2024
1 parent c6fe2f0 commit 9315779
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 5 deletions.
201 changes: 196 additions & 5 deletions src/algorithms/boundary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ impl Boundary {
})
}

/// Returns an iterator over the input ports in the boundary.
#[inline]
pub fn inputs(&self) -> impl Iterator<Item = BoundaryPort> {
self.ports(Direction::Incoming)
}

/// Returns an iterator over the output ports in the boundary.
#[inline]
pub fn outputs(&self) -> impl Iterator<Item = BoundaryPort> {
self.ports(Direction::Outgoing)
}

/// Returns the [`PortIndex`] corresponding to a [`BoundaryPort`] in this boundary.
pub fn port_index(&self, port: &BoundaryPort) -> PortIndex {
match port.direction {
Expand Down Expand Up @@ -156,13 +168,16 @@ impl Boundary {
// Toposort the subgraph, and collect the reaching input ports for each node.
// We keep track of how many output neighbors remain to be visited, so we can
// trim the `reaching` set when we reach the last one.

// We have to start the toposort from the input nodes that do not have predecessors.
let source_nodes = input_nodes
.keys()
.copied()
.filter(|&node| graph.input_neighbours(node).count() == 0);

let mut reaching: HashMap<NodeIndex, (usize, HashSet<PortIndex>)> =
HashMap::with_capacity(self.num_ports());
for node in toposort::<_, HashSet<PortIndex>>(
graph,
input_nodes.keys().copied(),
Direction::Outgoing,
) {
for node in toposort::<_, HashSet<PortIndex>>(graph, source_nodes, Direction::Outgoing) {
// Collect the reaching ports, plus any ports in the node itself.
let mut reaching_ports: HashSet<PortIndex> = input_nodes
.get(&node)
Expand Down Expand Up @@ -322,3 +337,179 @@ impl PortOrdering {
self.reaching[to_index].insert(from);
}
}

#[cfg(test)]
mod test {
use crate::view::Subgraph;
use crate::{LinkMut, MultiPortGraph, PortMut};

use super::*;
use itertools::Itertools;
use rstest::{fixture, rstest};

/// A complete bipartite graph with `N` input nodes and `M` output nodes.
///
/// Each input node is connected to all output nodes.
///
/// Returns the graph and the input and output nodes arrays.
#[fixture]
fn graph_kn<const N: usize, const M: usize>() -> (MultiPortGraph, [NodeIndex; N], [NodeIndex; M])
{
let mut graph = MultiPortGraph::new();
let inputs: [NodeIndex; N] = (0..N)
.map(|_| graph.add_node(1, M))
.collect_vec()
.try_into()
.unwrap();
let outputs: [NodeIndex; M] = (0..M)
.map(|_| graph.add_node(N, 1))
.collect_vec()
.try_into()
.unwrap();

for (i, &input) in inputs.iter().enumerate() {
for (j, &output) in outputs.iter().enumerate() {
graph.link_nodes(input, j, output, i).unwrap();
}
}

(graph, inputs, outputs)
}

/// Test DAG
///
/// ```text
/// 0 -> 1 -> 2 -> 3
/// | \ |
/// v \ v
/// 4 -> 5 -> 6 -> 7
/// ```
#[fixture]
fn graph() -> MultiPortGraph {
let mut graph = MultiPortGraph::new();
let nodes: Vec<NodeIndex> = (0..8).map(|_| graph.add_node(4, 4)).collect();
// Horizontal links between from port 0 to port 0
graph.link_nodes(nodes[0], 0, nodes[1], 0).unwrap();
graph.link_nodes(nodes[1], 0, nodes[2], 0).unwrap();
graph.link_nodes(nodes[2], 0, nodes[3], 0).unwrap();
graph.link_nodes(nodes[4], 0, nodes[5], 0).unwrap();
graph.link_nodes(nodes[5], 0, nodes[6], 0).unwrap();
graph.link_nodes(nodes[6], 0, nodes[7], 0).unwrap();
// Other ports
graph.link_nodes(nodes[1], 1, nodes[5], 1).unwrap();
graph.link_nodes(nodes[2], 1, nodes[6], 0).unwrap();
graph.link_nodes(nodes[1], 2, nodes[6], 2).unwrap();

graph
}

#[rstest]
fn test_boundary_new(graph: MultiPortGraph) {
let nodes = graph.nodes_iter().collect_vec();

// Create a boundary containing the {1,2,5,6} subgraph.
let boundary = Boundary::new(
&graph,
[
graph.input(nodes[1], 0).unwrap(),
graph.input(nodes[5], 0).unwrap(),
],
[
graph.output(nodes[2], 0).unwrap(),
graph.output(nodes[6], 0).unwrap(),
],
);
let subgraph = Subgraph::with_nodes(&graph, [nodes[1], nodes[2], nodes[5], nodes[6]]);
assert_eq!(boundary, subgraph.port_boundary());
assert_eq!(boundary.num_ports(), 4);
assert_eq!(
boundary.port(graph.input(nodes[5], 0).unwrap(), Direction::Incoming),
BoundaryPort {
index: 1,
direction: Direction::Incoming
}
);
assert_eq!(boundary.ports(Direction::Incoming).count(), 2);
}

#[rstest]
fn test_port_ordering(graph: MultiPortGraph) {
let nodes = graph.nodes_iter().collect_vec();
let subgraph = Subgraph::with_nodes(&graph, [nodes[1], nodes[2], nodes[5], nodes[6]]);
let boundary = subgraph.port_boundary();

let (in_0, in_1) = boundary.inputs().collect_tuple().unwrap();
let (out_0, out_1) = boundary.outputs().collect_tuple().unwrap();

let ordering = boundary.port_ordering(&subgraph);
assert_eq!(
ordering
.reachable_ports(in_0)
.iter()
.copied()
.sorted()
.collect_vec()
.as_slice(),
[out_0, out_1]
);
assert_eq!(
ordering
.reachable_ports(in_1)
.iter()
.copied()
.collect_vec()
.as_slice(),
[out_1]
);
assert_eq!(
ordering
.reaching_ports(out_0)
.iter()
.copied()
.collect_vec()
.as_slice(),
[in_0]
);
assert_eq!(
ordering
.reaching_ports(out_1)
.iter()
.copied()
.sorted()
.collect_vec()
.as_slice(),
[in_0, in_1]
);
}

#[rstest]
fn test_order_comparison(graph: MultiPortGraph) {
let nodes = graph.nodes_iter().collect_vec();
let subgraph = Subgraph::with_nodes(&graph, [nodes[1], nodes[2], nodes[5], nodes[6]]);
let boundary = subgraph.port_boundary();

let (graph_22, ins_22, outs_22) = graph_kn::<2, 2>();
let boundary_22 = Boundary::new(
&graph_22,
[
graph_22.input(ins_22[0], 0).unwrap(),
graph_22.input(ins_22[1], 0).unwrap(),
],
[
graph_22.output(outs_22[0], 0).unwrap(),
graph_22.output(outs_22[1], 0).unwrap(),
],
);

assert!(boundary.is_compatible(&boundary_22));
assert!(boundary_22.is_compatible(&boundary));

let ordering = boundary.port_ordering(&subgraph);
let ordering_22 = boundary_22.port_ordering(&graph_22);

assert!(!ordering.is_stronger_than(&ordering_22));
assert!(ordering_22.is_stronger_than(&ordering));
assert!(!boundary.is_stronger_than(&boundary_22, &subgraph, &graph_22));
assert!(boundary_22.is_stronger_than(&boundary, &graph_22, &subgraph));
}
}
7 changes: 7 additions & 0 deletions src/view/subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::BTreeSet;
use delegate::delegate;
use itertools::{Either, Itertools};

use crate::algorithms::boundary::{Boundary, HasBoundary};
use crate::PortOffset;
use crate::{
algorithms::{ConvexChecker, TopoConvexChecker},
Expand Down Expand Up @@ -383,6 +384,12 @@ where
}
}

impl<G> HasBoundary for Subgraph<G> {
fn port_boundary(&self) -> crate::algorithms::boundary::Boundary {
Boundary::new_unchecked(self.inputs.clone(), self.outputs.clone())
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
Expand Down

0 comments on commit 9315779

Please sign in to comment.