Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use SiblingSubgraph in SimpleReplacement #517

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 55 additions & 60 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Implementation of the `SimpleReplace` operation.

use std::collections::{HashMap, HashSet};
use std::collections::HashMap;

use itertools::Itertools;

use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrMut, HugrView, NodeMetadata};
use crate::{
hugr::{Node, Rewrite},
Expand All @@ -13,40 +14,45 @@ use crate::{
use thiserror::Error;

/// Specification of a simple replacement operation.
// TODO: use `SiblingSubgraph` to define the replacement.
#[derive(Debug, Clone)]
pub struct SimpleReplacement {
/// The common DataflowParent of all nodes to be replaced.
pub parent: Node,
/// The set of nodes to remove (a convex set of leaf children of `parent`).
pub removal: HashSet<Node>,
/// The subgraph of the hugr to be replaced.
subgraph: SiblingSubgraph,
/// A hugr with DFG root (consisting of replacement nodes).
pub replacement: Hugr,
replacement: Hugr,
/// A map from (target ports of edges from the Input node of `replacement`) to (target ports of
/// edges from nodes not in `removal` to nodes in `removal`).
pub nu_inp: HashMap<(Node, Port), (Node, Port)>,
nu_inp: HashMap<(Node, Port), (Node, Port)>,
/// A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
/// (input ports of the Output node of `replacement`).
pub nu_out: HashMap<(Node, Port), Port>,
nu_out: HashMap<(Node, Port), Port>,
}

impl SimpleReplacement {
/// Create a new [`SimpleReplacement`] specification.
pub fn new(
parent: Node,
removal: HashSet<Node>,
subgraph: SiblingSubgraph,
replacement: Hugr,
nu_inp: HashMap<(Node, Port), (Node, Port)>,
nu_out: HashMap<(Node, Port), Port>,
) -> Self {
Self {
parent,
removal,
subgraph,
replacement,
nu_inp,
nu_out,
}
}

/// The replacement hugr.
pub fn replacement(&self) -> &Hugr {
&self.replacement
}

/// Subgraph to be replaced.
pub fn subgraph(&self) -> &SiblingSubgraph {
&self.subgraph
}
}

impl Rewrite for SimpleReplacement {
Expand All @@ -60,13 +66,14 @@ impl Rewrite for SimpleReplacement {
}

fn apply(self, h: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
let parent = self.subgraph.get_parent(h);
// 1. Check the parent node exists and is a DataflowParent.
if !OpTag::DataflowParent.is_superset(h.get_optype(self.parent).tag()) {
if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
return Err(SimpleReplacementError::InvalidParentNode());
}
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
for node in &self.removal {
if h.get_parent(*node) != Some(self.parent) || h.children(*node).next().is_some() {
for node in self.subgraph.nodes() {
if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
return Err(SimpleReplacementError::InvalidRemovedNode());
}
}
Expand All @@ -80,7 +87,7 @@ impl Rewrite for SimpleReplacement {
.collect::<Vec<Node>>();
// slice of nodes omitting Input and Output:
let replacement_inner_nodes = &replacement_nodes[2..];
let self_output_node = h.children(self.parent).nth(1).unwrap();
let self_output_node = h.children(parent).nth(1).unwrap();
let replacement_output_node = *replacement_nodes.get(1).unwrap();
for &node in replacement_inner_nodes {
// Add the nodes.
Expand Down Expand Up @@ -170,8 +177,8 @@ impl Rewrite for SimpleReplacement {
}
}
// 3.5. Remove all nodes in self.removal and edges between them.
for node in &self.removal {
h.remove_node(*node).unwrap();
for &node in self.subgraph.nodes() {
h.remove_node(node).unwrap();
}
Ok(())
}
Expand All @@ -196,15 +203,15 @@ pub(in crate::hugr::rewrite) mod test {
use itertools::Itertools;
use portgraph::Direction;
use rstest::{fixture, rstest};
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;

use crate::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::BOOL_T;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::views::HugrView;
use crate::hugr::views::{HugrView, SiblingSubgraph};
use crate::hugr::{Hugr, HugrMut, Node};
use crate::ops::OpTag;
use crate::ops::{OpTrait, OpType};
Expand Down Expand Up @@ -327,28 +334,23 @@ pub(in crate::hugr::rewrite) mod test {
/// └───┘└───┘
fn test_simple_replacement(simple_hugr: Hugr, dfg_hugr: Hugr) {
let mut h: Hugr = simple_hugr;
// 1. Find the DFG node for the inner circuit
let p: Node = h
.nodes()
.find(|node: &Node| h.get_optype(*node).tag() == OpTag::Dfg)
.unwrap();
// 2. Locate the CX and its successor H's in h
// 1. Locate the CX and its successor H's in h
let h_node_cx: Node = h
.nodes()
.find(|node: &Node| *h.get_optype(*node) == OpType::LeafOp(cx_gate()))
.unwrap();
let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
let s: HashSet<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
// 3. Construct a new DFG-rooted hugr for the replacement
let s: Vec<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
// 2. Construct a new DFG-rooted hugr for the replacement
let n: Hugr = dfg_hugr;
// 4. Construct the input and output matchings
// 4.1. Locate the CX and its predecessor H's in n
// 3. Construct the input and output matchings
// 3.1. Locate the CX and its predecessor H's in n
let n_node_cx = n
.nodes()
.find(|node: &Node| *n.get_optype(*node) == OpType::LeafOp(cx_gate()))
.unwrap();
let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap();
// 4.2. Locate the ports we need to specify as "glue" in n
// 3.2. Locate the ports we need to specify as "glue" in n
let n_port_0 = n.node_ports(n_node_h0, Direction::Incoming).next().unwrap();
let n_port_1 = n.node_ports(n_node_h1, Direction::Incoming).next().unwrap();
let (n_cx_out_0, n_cx_out_1) = n
Expand All @@ -358,7 +360,7 @@ pub(in crate::hugr::rewrite) mod test {
.unwrap();
let n_port_2 = n.linked_ports(n_node_cx, n_cx_out_0).next().unwrap().1;
let n_port_3 = n.linked_ports(n_node_cx, n_cx_out_1).next().unwrap().1;
// 4.3. Locate the ports we need to specify as "glue" in h
// 3.3. Locate the ports we need to specify as "glue" in h
let (h_port_0, h_port_1) = h
.node_ports(h_node_cx, Direction::Incoming)
.take(2)
Expand All @@ -368,17 +370,16 @@ pub(in crate::hugr::rewrite) mod test {
let h_h1_out = h.node_ports(h_node_h1, Direction::Outgoing).next().unwrap();
let (h_outp_node, h_port_2) = h.linked_ports(h_node_h0, h_h0_out).next().unwrap();
let h_port_3 = h.linked_ports(h_node_h1, h_h1_out).next().unwrap().1;
// 4.4. Construct the maps
// 3.4. Construct the maps
let mut nu_inp: HashMap<(Node, Port), (Node, Port)> = HashMap::new();
let mut nu_out: HashMap<(Node, Port), Port> = HashMap::new();
nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0));
nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1));
nu_out.insert((h_outp_node, h_port_2), n_port_2);
nu_out.insert((h_outp_node, h_port_3), n_port_3);
// 5. Define the replacement
// 4. Define the replacement
let r = SimpleReplacement {
parent: p,
removal: s,
subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
replacement: n,
nu_inp,
nu_out,
Expand Down Expand Up @@ -414,49 +415,43 @@ pub(in crate::hugr::rewrite) mod test {
fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
let mut h: Hugr = simple_hugr;

// 1. Find the DFG node for the inner circuit
let p: Node = h
.nodes()
.find(|node: &Node| h.get_optype(*node).tag() == OpTag::Dfg)
.unwrap();
// 2. Locate the CX in h
// 1. Locate the CX in h
let h_node_cx: Node = h
.nodes()
.find(|node: &Node| *h.get_optype(*node) == OpType::LeafOp(cx_gate()))
.unwrap();
let s: HashSet<Node> = vec![h_node_cx].into_iter().collect();
// 3. Construct a new DFG-rooted hugr for the replacement
let s: Vec<Node> = vec![h_node_cx].into_iter().collect();
// 2. Construct a new DFG-rooted hugr for the replacement
let n: Hugr = dfg_hugr2;
// 4. Construct the input and output matchings
// 4.1. Locate the Output and its predecessor H in n
// 3. Construct the input and output matchings
// 3.1. Locate the Output and its predecessor H in n
let n_node_output = n
.nodes()
.find(|node: &Node| n.get_optype(*node).tag() == OpTag::Output)
.unwrap();
let (_n_node_input, n_node_h) = n.input_neighbours(n_node_output).collect_tuple().unwrap();
// 4.2. Locate the ports we need to specify as "glue" in n
// 3.2. Locate the ports we need to specify as "glue" in n
let (n_port_0, n_port_1) = n
.node_inputs(n_node_output)
.take(2)
.collect_tuple()
.unwrap();
let n_port_2 = n.node_inputs(n_node_h).next().unwrap();
// 4.3. Locate the ports we need to specify as "glue" in h
// 3.3. Locate the ports we need to specify as "glue" in h
let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap();
let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
let h_port_2 = h.node_ports(h_node_h0, Direction::Incoming).next().unwrap();
let h_port_3 = h.node_ports(h_node_h1, Direction::Incoming).next().unwrap();
// 4.4. Construct the maps
// 3.4. Construct the maps
let mut nu_inp: HashMap<(Node, Port), (Node, Port)> = HashMap::new();
let mut nu_out: HashMap<(Node, Port), Port> = HashMap::new();
nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0));
nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1));
nu_out.insert((h_node_h0, h_port_2), n_port_0);
nu_out.insert((h_node_h1, h_port_3), n_port_1);
// 5. Define the replacement
// 4. Define the replacement
let r = SimpleReplacement {
parent: p,
removal: s,
subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(),
replacement: n,
nu_inp,
nu_out,
Expand Down Expand Up @@ -484,11 +479,10 @@ pub(in crate::hugr::rewrite) mod test {
let replacement = h.clone();
let orig = h.clone();

let parent = h.root();
let removal = h
.nodes()
.filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
.collect();
.collect_vec();
let inputs = h
.node_outputs(input)
.filter(|&p| h.get_optype(input).signature().get(p).is_some())
Expand All @@ -503,8 +497,7 @@ pub(in crate::hugr::rewrite) mod test {
.map(|p| ((output, p), p))
.collect();
h.apply_rewrite(SimpleReplacement::new(
parent,
removal,
SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
replacement,
inputs,
outputs,
Expand Down Expand Up @@ -538,11 +531,10 @@ pub(in crate::hugr::rewrite) mod test {

let orig = h.clone();

let parent = h.root();
let removal = h
.nodes()
.filter(|&n| h.get_optype(n).tag() == OpTag::Leaf)
.collect();
.collect_vec();

let first_out_p = h.node_outputs(input).next().unwrap();
let embedded_inputs = h.linked_ports(input, first_out_p);
Expand All @@ -558,7 +550,10 @@ pub(in crate::hugr::rewrite) mod test {
.collect();

h.apply_rewrite(SimpleReplacement::new(
parent, removal, repl, inputs, outputs,
SiblingSubgraph::try_from_nodes(removal, &h).unwrap(),
repl,
inputs,
outputs,
))
.unwrap();

Expand Down
64 changes: 59 additions & 5 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ pub struct SiblingSubgraph {
}

/// The type of the incoming boundary of [`SiblingSubgraph`].
///
/// The nested vec represents a partition of the incoming boundary ports by
/// input parameter. A set in the partition that has more than one element
/// corresponds to an input parameter that is copied and useful multiple times
/// in the subgraph.
pub type IncomingPorts = Vec<Vec<(Node, Port)>>;
/// The type of the outgoing boundary of [`SiblingSubgraph`].
pub type OutgoingPorts = Vec<(Node, Port)>;
Expand Down Expand Up @@ -194,6 +199,58 @@ impl SiblingSubgraph {
})
}

/// Create a subgraph from a set of nodes.
///
/// The incoming boundary is given by the set of edges with a source
/// not in nodes and a target in nodes. Conversely, the outgoing boundary
/// is given by the set of edges with a source in nodes and a target not
/// in nodes.
///
/// The subgraph signature will be given by the types of the incoming and
/// outgoing edges ordered by the node order in `nodes` and within each node
/// by the port order.

/// The in- and out-arity of the signature will match the
/// number of incoming and outgoing edges respectively. In particular, the
/// assumption is made that no two incoming edges have the same source
/// (no copy nodes at the input bounary).
pub fn try_from_nodes(
nodes: impl Into<Vec<Node>>,
hugr: &impl HugrView,
) -> Result<Self, InvalidSubgraph> {
let nodes = nodes.into();
let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
let incoming_edges = nodes
.iter()
.flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
let outgoing_edges = nodes
.iter()
.flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
let inputs = incoming_edges
.filter(|&(n, p)| {
if !hugr.is_linked(n, p) {
return false;
}
let (out_n, _) = hugr.linked_ports(n, p).exactly_one().ok().unwrap();
!nodes_set.contains(&out_n)
})
// Every incoming edge is its own input.
.map(|p| vec![p])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm puzzled why we convert p to a singleton vector here, i.e. why is IncomingPorts a vector of vectors (not really pertinent to this PR though).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've better documented this peculiarity. Let me know if this clarifies the matter.

.collect_vec();
let outputs = outgoing_edges
.filter(|&(n, p)| {
if !hugr.is_linked(n, p) {
return false;
}
// TODO: what if there are multiple outgoing edges?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference a github issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

// See https://github.com/CQCL-DEV/hugr/issues/518
let (in_n, _) = hugr.linked_ports(n, p).next().unwrap();
!nodes_set.contains(&in_n)
})
.collect_vec();
Self::try_new(inputs, outputs, hugr)
}

/// An iterator over the nodes in the subgraph.
pub fn nodes(&self) -> &[Node] {
&self.nodes
Expand Down Expand Up @@ -248,8 +305,6 @@ impl SiblingSubgraph {
hugr: &impl HugrView,
replacement: Hugr,
) -> Result<SimpleReplacement, InvalidReplacement> {
let removal = self.nodes().iter().copied().collect();

let rep_root = replacement.root();
let dfg_optype = replacement.get_optype(rep_root);
if !OpTag::Dfg.is_superset(dfg_optype.tag()) {
Expand Down Expand Up @@ -300,8 +355,7 @@ impl SiblingSubgraph {
.collect();

Ok(SimpleReplacement::new(
self.get_parent(hugr),
removal,
self.clone(),
replacement,
nu_inp,
nu_out,
Expand Down Expand Up @@ -609,7 +663,7 @@ mod tests {

let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();

assert_eq!(rep.removal.len(), 1);
assert_eq!(rep.subgraph().nodes().len(), 1);

assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out
hugr.apply_rewrite(rep).unwrap();
Expand Down