diff --git a/Cargo.toml b/Cargo.toml index b6c886f0..29ce91aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ license-file = "LICENCE" [workspace.dependencies] tket2 = { path = "./tket2" } -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "0beb165" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "1fce927" } portgraph = { version = "0.10" } pyo3 = { version = "0.20" } itertools = { version = "0.11.0" } diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 5bf9678b..7e46563d 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -7,6 +7,7 @@ pub mod units; pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; +use itertools::Either::{Left, Right}; use derive_more::From; use hugr::hugr::hugrmut::HugrMut; @@ -174,7 +175,7 @@ pub(crate) fn remove_empty_wire( if input_port >= circ.num_outputs(inp) { return Err(CircuitMutError::InvalidPortOffset(input_port)); } - let input_port = Port::new_outgoing(input_port); + let input_port = Port::new(Direction::Outgoing, input_port); let link = circ .linked_ports(inp, input_port) .at_most_one() @@ -233,10 +234,15 @@ fn shift_ports( circ.disconnect(node, port)?; } for (other_n, other_p) in links { - // TODO: simplify when CQCL-DEV/hugr#565 is resolved - match dir { - Direction::Incoming => circ.connect(other_n, other_p, node, free_port), - Direction::Outgoing => circ.connect(node, free_port, other_n, other_p), + match other_p.as_directed() { + Right(other_p) => { + let dst_port = free_port.as_incoming().unwrap(); + circ.connect(other_n, other_p, node, dst_port) + } + Left(other_p) => { + let src_port = free_port.as_outgoing().unwrap(); + circ.connect(node, src_port, other_n, other_p) + } }?; } free_port = port; diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 4df47321..b049bbda 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -8,6 +8,7 @@ use std::iter::FusedIterator; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; +use itertools::Either::{Left, Right}; use petgraph::visit as pv; use super::units::filter::FilteredUnits; @@ -169,12 +170,12 @@ impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { #[inline] fn assign_wire(&self, node: Node, port: Port) -> Option { - match port.direction() { - Direction::Incoming => { - let (from, from_port) = self.circ.linked_ports(node, port).next()?; + match port.as_directed() { + Left(to_port) => { + let (from, from_port) = self.circ.linked_outputs(node, to_port).next()?; Some(Wire::new(from, from_port)) } - Direction::Outgoing => Some(Wire::new(node, port)), + Right(from_port) => Some(Wire::new(node, from_port)), } } } @@ -274,7 +275,10 @@ where // TODO: `with_wires` combinator for `Units`? let wire_unit = circ .linear_units() - .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index())) + .map(|(linear_unit, port, _)| { + let port = port.as_outgoing().unwrap(); + (Wire::new(circ.input(), port), linear_unit.index()) + }) .collect(); let nodes = pv::Topo::new(&circ.as_petgraph()); @@ -378,7 +382,10 @@ where // Find the linear unit id for this port. let linear_id = self .follow_linear_port(node, port) - .and_then(|input_port| self.circ.linked_ports(node, input_port).next()) + .and_then(|input_port| { + let input_port = input_port.as_incoming().unwrap(); + self.circ.linked_outputs(node, input_port).next() + }) .and_then(|(from, from_port)| { // Remove the old wire from the map (if there was one) self.wire_unit.remove(&Wire::new(from, from_port)) @@ -388,6 +395,7 @@ where self.wire_unit.len() }); // Update the map tracking the linear units + let port = port.as_outgoing().unwrap(); let new_wire = Wire::new(node, port); self.wire_unit.insert(new_wire, linear_id); LinearUnit::new(linear_id) diff --git a/tket2/src/circuit/units.rs b/tket2/src/circuit/units.rs index f96a8f74..9dd17dce 100644 --- a/tket2/src/circuit/units.rs +++ b/tket2/src/circuit/units.rs @@ -239,9 +239,7 @@ impl UnitLabeller for DefaultUnitLabeller { #[inline] fn assign_wire(&self, node: Node, port: Port) -> Option { - match port.direction() { - Direction::Incoming => None, - Direction::Outgoing => Some(Wire::new(node, port)), - } + let port = port.as_outgoing().ok()?; + Some(Wire::new(node, port)) } } diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index 6be458f3..e15a6da5 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -16,7 +16,7 @@ use hugr::hugr::{HugrError, NodeMetadata}; use hugr::ops::handle::DataflowParentID; use hugr::ops::OpType; use hugr::types::{FunctionType, Signature}; -use hugr::{Hugr, HugrView, Node, Port, PortIndex, Wire}; +use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; use itertools::Itertools; use crate::Circuit; @@ -79,7 +79,7 @@ impl Chunk { .map(|wires| { let (inp_node, inp_port) = wires[0]; let (out_node, out_port) = circ - .linked_ports(inp_node, inp_port) + .linked_outputs(inp_node, inp_port) .exactly_one() .ok() .unwrap(); @@ -127,7 +127,7 @@ impl Chunk { { let connection_targets: Vec = self .circ - .linked_ports(chunk_inp, chunk_inp_port) + .linked_inputs(chunk_inp, chunk_inp_port) .map(|(node, port)| { if node == chunk_out { // This was a direct wire from the chunk input to the output. Use the output's [`ChunkConnection`]. @@ -135,7 +135,7 @@ impl Chunk { ConnectionTarget::TransitiveConnection(output_connection) } else { // Translate the original chunk node into the inserted node. - (*node_map.get(&node).unwrap(), port).into() + ConnectionTarget::InsertedInput(*node_map.get(&node).unwrap(), port) } }) .collect(); @@ -145,7 +145,7 @@ impl Chunk { for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { let (node, port) = self .circ - .linked_ports(chunk_out, chunk_out_port) + .linked_outputs(chunk_out, chunk_out_port) .exactly_one() .ok() .unwrap(); @@ -155,7 +155,7 @@ impl Chunk { ConnectionTarget::TransitiveConnection(input_connection) } else { // Translate the original chunk node into the inserted node. - (*node_map.get(&node).unwrap(), port).into() + ConnectionTarget::InsertedOutput(*node_map.get(&node).unwrap(), port) }; output_map.insert(wire, target); } @@ -223,11 +223,12 @@ struct ChunkInsertResult { } /// The target of a chunk connection in a reassembled circuit. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum ConnectionTarget { - /// The target is a single node and port. - #[from] - InsertedNode(Node, Port), + /// The target is a chunk node's input. + InsertedInput(Node, IncomingPort), + /// The target is a chunk node's output. + InsertedOutput(Node, OutgoingPort), /// The link goes directly to the opposite boundary, without an intermediary /// node. TransitiveConnection(ChunkConnection), @@ -284,7 +285,7 @@ impl CircuitChunks { .collect(); let output_connections = circ .node_inputs(circ_output) - .flat_map(|p| circ.linked_ports(circ_output, p)) + .flat_map(|p| circ.linked_outputs(circ_output, p)) .map(|(n, p)| Wire::new(n, p).into()) .collect(); @@ -336,8 +337,8 @@ impl CircuitChunks { // The chunks input and outputs are each identified with a // [`ChunkConnection`]. We collect both sides first, and rewire them // after the chunks have been inserted. - let mut sources: HashMap = HashMap::new(); - let mut targets: HashMap> = HashMap::new(); + let mut sources: HashMap = HashMap::new(); + let mut targets: HashMap> = HashMap::new(); // A map for `ChunkConnection`s that have been merged into another (due // to identity wires in the updated chunks). @@ -380,7 +381,7 @@ impl CircuitChunks { // (due to an identity wire). for (connection, conn_target) in outgoing_connections { match conn_target { - ConnectionTarget::InsertedNode(node, port) => { + ConnectionTarget::InsertedOutput(node, port) => { // The output of a chunk always has fresh `ChunkConnection`s. sources.insert(connection, (node, port)); } @@ -390,6 +391,7 @@ impl CircuitChunks { get_merged_connection(&transitive_connections, merged_connection); transitive_connections.insert(connection, merged_connection); } + _ => panic!("Unexpected connection target"), } } for (connection, conn_targets) in incoming_connections { @@ -397,7 +399,7 @@ impl CircuitChunks { let connection = get_merged_connection(&transitive_connections, connection); for tgt in conn_targets { match tgt { - ConnectionTarget::InsertedNode(node, port) => { + ConnectionTarget::InsertedInput(node, port) => { targets.entry(connection).or_default().push((node, port)); } ConnectionTarget::TransitiveConnection(_merged_connection) => { @@ -405,6 +407,7 @@ impl CircuitChunks { // outgoing_connections, so we don't need to do // anything here. } + _ => panic!("Unexpected connection target"), } } } diff --git a/tket2/src/portmatching.rs b/tket2/src/portmatching.rs index 423c612c..74db117d 100644 --- a/tket2/src/portmatching.rs +++ b/tket2/src/portmatching.rs @@ -5,6 +5,7 @@ pub mod pattern; #[cfg(feature = "pyo3")] pub mod pyo3; +use hugr::OutgoingPort; use itertools::Itertools; pub use matcher::{PatternMatch, PatternMatcher}; pub use pattern::CircuitPattern; @@ -134,6 +135,14 @@ pub(super) enum NodeID { CopyNode(Node, Port), } +impl NodeID { + /// Create a new copy NodeID. + pub fn new_copy(node: Node, port: impl Into) -> Self { + let port: OutgoingPort = port.into(); + Self::CopyNode(node, port.into()) + } +} + impl From for NodeID { fn from(node: Node) -> Self { Self::HugrNode(node) diff --git a/tket2/src/portmatching/matcher.rs b/tket2/src/portmatching/matcher.rs index aea9d8aa..8df5f0e2 100644 --- a/tket2/src/portmatching/matcher.rs +++ b/tket2/src/portmatching/matcher.rs @@ -13,7 +13,7 @@ use hugr::hugr::views::sibling_subgraph::{ }; use hugr::hugr::views::SiblingSubgraph; use hugr::ops::OpType; -use hugr::{Hugr, Node, Port, PortIndex}; +use hugr::{Hugr, IncomingPort, Node, OutgoingPort, Port, PortIndex}; use itertools::Itertools; use portmatching::{ automaton::{LineBuilder, ScopeAutomaton}, @@ -129,12 +129,16 @@ impl PatternMatch { let inputs = pattern_ref .inputs .iter() - .map(|p| p.iter().map(|(n, p)| (map[n], *p)).collect_vec()) + .map(|ps| { + ps.iter() + .map(|(n, p)| (map[n], p.as_incoming().unwrap())) + .collect_vec() + }) .collect_vec(); let outputs = pattern_ref .outputs .iter() - .map(|(n, p)| (map[n], *p)) + .map(|(n, p)| (map[n], p.as_outgoing().unwrap())) .collect_vec(); Self::try_from_io_with_checker(root, pattern, circ, inputs, outputs, checker) } @@ -154,8 +158,8 @@ impl PatternMatch { root: Node, pattern: PatternID, circ: &impl Circuit, - inputs: Vec>, - outputs: Vec<(Node, Port)>, + inputs: Vec>, + outputs: Vec<(Node, OutgoingPort)>, ) -> Result { let checker = ConvexChecker::new(circ); Self::try_from_io_with_checker(root, pattern, circ, inputs, outputs, &checker) @@ -173,8 +177,8 @@ impl PatternMatch { root: Node, pattern: PatternID, circ: &'c C, - inputs: Vec>, - outputs: Vec<(Node, Port)>, + inputs: Vec>, + outputs: Vec<(Node, OutgoingPort)>, checker: &ConvexChecker<'c, C>, ) -> Result { let subgraph = SiblingSubgraph::try_new_with_checker(inputs, outputs, circ, checker)?; diff --git a/tket2/src/portmatching/pattern.rs b/tket2/src/portmatching/pattern.rs index 9154e9bf..9eeddc8f 100644 --- a/tket2/src/portmatching/pattern.rs +++ b/tket2/src/portmatching/pattern.rs @@ -1,5 +1,6 @@ //! Circuit Patterns for pattern matching +use hugr::IncomingPort; use hugr::{ops::OpTrait, Node, Port}; use itertools::Itertools; use portmatching::{patterns::NoRootFound, HashMap, Pattern, SinglePatternMatcher}; @@ -42,17 +43,17 @@ impl CircuitPattern { let op = cmd.optype().clone(); pattern.require(cmd.node().into(), op.try_into().unwrap()); for in_offset in 0..cmd.input_count() { - let in_offset = Port::new_incoming(in_offset); - let edge_prop = - PEdge::try_from_port(cmd.node(), in_offset, circuit).expect("Invalid HUGR"); + let in_offset: IncomingPort = in_offset.into(); + let edge_prop = PEdge::try_from_port(cmd.node(), in_offset.into(), circuit) + .expect("Invalid HUGR"); let (prev_node, prev_port) = circuit - .linked_ports(cmd.node(), in_offset) + .linked_outputs(cmd.node(), in_offset) .exactly_one() .ok() .expect("invalid HUGR"); let prev_node = match edge_prop { PEdge::InternalEdge { .. } => NodeID::HugrNode(prev_node), - PEdge::InputEdge { .. } => NodeID::CopyNode(prev_node, prev_port), + PEdge::InputEdge { .. } => NodeID::new_copy(prev_node, prev_port), }; pattern.add_edge(cmd.node().into(), prev_node, edge_prop); } @@ -235,8 +236,8 @@ mod tests { edges, [ (cx_gate, h_gate), - (cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(0))), - (cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(1))), + (cx_gate, NodeID::new_copy(inp, 0)), + (cx_gate, NodeID::new_copy(inp, 1)), ] .into_iter() .collect() @@ -292,7 +293,7 @@ mod tests { assert!(edges.iter().any(|e| { e.reverse().is_none() && e.source.unwrap() == rx_n.into() - && e.target.unwrap() == NodeID::CopyNode(inp, Port::new_outgoing(1)) + && e.target.unwrap() == NodeID::new_copy(inp, 1) })); } } diff --git a/tket2/src/portmatching/pyo3.rs b/tket2/src/portmatching/pyo3.rs index 4098caa9..410898e2 100644 --- a/tket2/src/portmatching/pyo3.rs +++ b/tket2/src/portmatching/pyo3.rs @@ -4,7 +4,7 @@ use std::fmt; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError; -use hugr::{Hugr, Port}; +use hugr::{Hugr, IncomingPort, OutgoingPort}; use itertools::Itertools; use portmatching::{HashMap, PatternID}; use pyo3::{prelude::*, types::PyIterator}; @@ -94,11 +94,11 @@ pub struct PyPatternMatch { /// This is the incoming boundary of a [`hugr::hugr::views::SiblingSubgraph`]. /// The input ports are grouped together if they are connected to the same /// source. - pub inputs: Vec>, + pub inputs: Vec>, /// The output ports of the subcircuit. /// /// This is the outgoing boundary of a [`hugr::hugr::views::SiblingSubgraph`]. - pub outputs: Vec<(Node, Port)>, + pub outputs: Vec<(Node, OutgoingPort)>, /// The node map from pattern to circuit. pub node_map: HashMap, } @@ -134,16 +134,16 @@ impl PyPatternMatch { let inputs = pattern .inputs .iter() - .map(|p| { - p.iter() - .map(|&(n, p)| (node_map[&Node(n)], p)) + .map(|ps| { + ps.iter() + .map(|&(n, p)| (node_map[&Node(n)], p.as_incoming().unwrap())) .collect_vec() }) .collect_vec(); let outputs = pattern .outputs .iter() - .map(|&(n, p)| (node_map[&Node(n)], p)) + .map(|&(n, p)| (node_map[&Node(n)], p.as_outgoing().unwrap())) .collect_vec(); Ok(Self { pattern_id: pattern_id.0,