Skip to content

Commit

Permalink
chore: update hugr dependency (#219)
Browse files Browse the repository at this point in the history
This required a partial move from using `Port` everywhere to the more
specific `OutgoingPort` and `IncomingPort`.

I'll open a separate issue for the more pervasive changes like updating
the `Units` iterator and the `portmatching` module.
For now, we just cast and `unwrap` in those cases.

Requires CQCL/hugr#647
  • Loading branch information
aborgna-q authored Nov 8, 2023
1 parent 1f2b0b6 commit 15a36fb
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ license-file = "LICENCE"
[workspace.dependencies]

tket2 = { path = "./tket2" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "0beb165" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "1fce927" }
portgraph = { version = "0.10" }
pyo3 = { version = "0.20" }
itertools = { version = "0.11.0" }
Expand Down
16 changes: 11 additions & 5 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod units;

pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
use itertools::Either::{Left, Right};

use derive_more::From;
use hugr::hugr::hugrmut::HugrMut;
Expand Down Expand Up @@ -174,7 +175,7 @@ pub(crate) fn remove_empty_wire(
if input_port >= circ.num_outputs(inp) {
return Err(CircuitMutError::InvalidPortOffset(input_port));
}
let input_port = Port::new_outgoing(input_port);
let input_port = Port::new(Direction::Outgoing, input_port);
let link = circ
.linked_ports(inp, input_port)
.at_most_one()
Expand Down Expand Up @@ -233,10 +234,15 @@ fn shift_ports<C: HugrMut + ?Sized>(
circ.disconnect(node, port)?;
}
for (other_n, other_p) in links {
// TODO: simplify when CQCL-DEV/hugr#565 is resolved
match dir {
Direction::Incoming => circ.connect(other_n, other_p, node, free_port),
Direction::Outgoing => circ.connect(node, free_port, other_n, other_p),
match other_p.as_directed() {
Right(other_p) => {
let dst_port = free_port.as_incoming().unwrap();
circ.connect(other_n, other_p, node, dst_port)
}
Left(other_p) => {
let src_port = free_port.as_outgoing().unwrap();
circ.connect(node, src_port, other_n, other_p)
}
}?;
}
free_port = port;
Expand Down
20 changes: 14 additions & 6 deletions tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::iter::FusedIterator;

use hugr::hugr::NodeType;
use hugr::ops::{OpTag, OpTrait};
use itertools::Either::{Left, Right};
use petgraph::visit as pv;

use super::units::filter::FilteredUnits;
Expand Down Expand Up @@ -169,12 +170,12 @@ impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> {

#[inline]
fn assign_wire(&self, node: Node, port: Port) -> Option<Wire> {
match port.direction() {
Direction::Incoming => {
let (from, from_port) = self.circ.linked_ports(node, port).next()?;
match port.as_directed() {
Left(to_port) => {
let (from, from_port) = self.circ.linked_outputs(node, to_port).next()?;
Some(Wire::new(from, from_port))
}
Direction::Outgoing => Some(Wire::new(node, port)),
Right(from_port) => Some(Wire::new(node, from_port)),
}
}
}
Expand Down Expand Up @@ -274,7 +275,10 @@ where
// TODO: `with_wires` combinator for `Units`?
let wire_unit = circ
.linear_units()
.map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index()))
.map(|(linear_unit, port, _)| {
let port = port.as_outgoing().unwrap();
(Wire::new(circ.input(), port), linear_unit.index())
})
.collect();

let nodes = pv::Topo::new(&circ.as_petgraph());
Expand Down Expand Up @@ -378,7 +382,10 @@ where
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| self.circ.linked_ports(node, input_port).next())
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
Expand All @@ -388,6 +395,7 @@ where
self.wire_unit.len()
});
// Update the map tracking the linear units
let port = port.as_outgoing().unwrap();
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
Expand Down
6 changes: 2 additions & 4 deletions tket2/src/circuit/units.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,7 @@ impl UnitLabeller for DefaultUnitLabeller {

#[inline]
fn assign_wire(&self, node: Node, port: Port) -> Option<Wire> {
match port.direction() {
Direction::Incoming => None,
Direction::Outgoing => Some(Wire::new(node, port)),
}
let port = port.as_outgoing().ok()?;
Some(Wire::new(node, port))
}
}
33 changes: 18 additions & 15 deletions tket2/src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use hugr::hugr::{HugrError, NodeMetadata};
use hugr::ops::handle::DataflowParentID;
use hugr::ops::OpType;
use hugr::types::{FunctionType, Signature};
use hugr::{Hugr, HugrView, Node, Port, PortIndex, Wire};
use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire};
use itertools::Itertools;

use crate::Circuit;
Expand Down Expand Up @@ -79,7 +79,7 @@ impl Chunk {
.map(|wires| {
let (inp_node, inp_port) = wires[0];
let (out_node, out_port) = circ
.linked_ports(inp_node, inp_port)
.linked_outputs(inp_node, inp_port)
.exactly_one()
.ok()
.unwrap();
Expand Down Expand Up @@ -127,15 +127,15 @@ impl Chunk {
{
let connection_targets: Vec<ConnectionTarget> = self
.circ
.linked_ports(chunk_inp, chunk_inp_port)
.linked_inputs(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
if node == chunk_out {
// This was a direct wire from the chunk input to the output. Use the output's [`ChunkConnection`].
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
} else {
// Translate the original chunk node into the inserted node.
(*node_map.get(&node).unwrap(), port).into()
ConnectionTarget::InsertedInput(*node_map.get(&node).unwrap(), port)
}
})
.collect();
Expand All @@ -145,7 +145,7 @@ impl Chunk {
for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) {
let (node, port) = self
.circ
.linked_ports(chunk_out, chunk_out_port)
.linked_outputs(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
Expand All @@ -155,7 +155,7 @@ impl Chunk {
ConnectionTarget::TransitiveConnection(input_connection)
} else {
// Translate the original chunk node into the inserted node.
(*node_map.get(&node).unwrap(), port).into()
ConnectionTarget::InsertedOutput(*node_map.get(&node).unwrap(), port)
};
output_map.insert(wire, target);
}
Expand Down Expand Up @@ -223,11 +223,12 @@ struct ChunkInsertResult {
}

/// The target of a chunk connection in a reassembled circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum ConnectionTarget {
/// The target is a single node and port.
#[from]
InsertedNode(Node, Port),
/// The target is a chunk node's input.
InsertedInput(Node, IncomingPort),
/// The target is a chunk node's output.
InsertedOutput(Node, OutgoingPort),
/// The link goes directly to the opposite boundary, without an intermediary
/// node.
TransitiveConnection(ChunkConnection),
Expand Down Expand Up @@ -284,7 +285,7 @@ impl CircuitChunks {
.collect();
let output_connections = circ
.node_inputs(circ_output)
.flat_map(|p| circ.linked_ports(circ_output, p))
.flat_map(|p| circ.linked_outputs(circ_output, p))
.map(|(n, p)| Wire::new(n, p).into())
.collect();

Expand Down Expand Up @@ -336,8 +337,8 @@ impl CircuitChunks {
// The chunks input and outputs are each identified with a
// [`ChunkConnection`]. We collect both sides first, and rewire them
// after the chunks have been inserted.
let mut sources: HashMap<ChunkConnection, (Node, Port)> = HashMap::new();
let mut targets: HashMap<ChunkConnection, Vec<(Node, Port)>> = HashMap::new();
let mut sources: HashMap<ChunkConnection, (Node, OutgoingPort)> = HashMap::new();
let mut targets: HashMap<ChunkConnection, Vec<(Node, IncomingPort)>> = HashMap::new();

// A map for `ChunkConnection`s that have been merged into another (due
// to identity wires in the updated chunks).
Expand Down Expand Up @@ -380,7 +381,7 @@ impl CircuitChunks {
// (due to an identity wire).
for (connection, conn_target) in outgoing_connections {
match conn_target {
ConnectionTarget::InsertedNode(node, port) => {
ConnectionTarget::InsertedOutput(node, port) => {
// The output of a chunk always has fresh `ChunkConnection`s.
sources.insert(connection, (node, port));
}
Expand All @@ -390,21 +391,23 @@ impl CircuitChunks {
get_merged_connection(&transitive_connections, merged_connection);
transitive_connections.insert(connection, merged_connection);
}
_ => panic!("Unexpected connection target"),
}
}
for (connection, conn_targets) in incoming_connections {
// The connection in the chunk's input may have been merged into a earlier one.
let connection = get_merged_connection(&transitive_connections, connection);
for tgt in conn_targets {
match tgt {
ConnectionTarget::InsertedNode(node, port) => {
ConnectionTarget::InsertedInput(node, port) => {
targets.entry(connection).or_default().push((node, port));
}
ConnectionTarget::TransitiveConnection(_merged_connection) => {
// The merge has been registered when scanning the
// outgoing_connections, so we don't need to do
// anything here.
}
_ => panic!("Unexpected connection target"),
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions tket2/src/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod pattern;
#[cfg(feature = "pyo3")]
pub mod pyo3;

use hugr::OutgoingPort;
use itertools::Itertools;
pub use matcher::{PatternMatch, PatternMatcher};
pub use pattern::CircuitPattern;
Expand Down Expand Up @@ -134,6 +135,14 @@ pub(super) enum NodeID {
CopyNode(Node, Port),
}

impl NodeID {
/// Create a new copy NodeID.
pub fn new_copy(node: Node, port: impl Into<OutgoingPort>) -> Self {
let port: OutgoingPort = port.into();
Self::CopyNode(node, port.into())
}
}

impl From<Node> for NodeID {
fn from(node: Node) -> Self {
Self::HugrNode(node)
Expand Down
18 changes: 11 additions & 7 deletions tket2/src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use hugr::hugr::views::sibling_subgraph::{
};
use hugr::hugr::views::SiblingSubgraph;
use hugr::ops::OpType;
use hugr::{Hugr, Node, Port, PortIndex};
use hugr::{Hugr, IncomingPort, Node, OutgoingPort, Port, PortIndex};
use itertools::Itertools;
use portmatching::{
automaton::{LineBuilder, ScopeAutomaton},
Expand Down Expand Up @@ -129,12 +129,16 @@ impl PatternMatch {
let inputs = pattern_ref
.inputs
.iter()
.map(|p| p.iter().map(|(n, p)| (map[n], *p)).collect_vec())
.map(|ps| {
ps.iter()
.map(|(n, p)| (map[n], p.as_incoming().unwrap()))
.collect_vec()
})
.collect_vec();
let outputs = pattern_ref
.outputs
.iter()
.map(|(n, p)| (map[n], *p))
.map(|(n, p)| (map[n], p.as_outgoing().unwrap()))
.collect_vec();
Self::try_from_io_with_checker(root, pattern, circ, inputs, outputs, checker)
}
Expand All @@ -154,8 +158,8 @@ impl PatternMatch {
root: Node,
pattern: PatternID,
circ: &impl Circuit,
inputs: Vec<Vec<(Node, Port)>>,
outputs: Vec<(Node, Port)>,
inputs: Vec<Vec<(Node, IncomingPort)>>,
outputs: Vec<(Node, OutgoingPort)>,
) -> Result<Self, InvalidPatternMatch> {
let checker = ConvexChecker::new(circ);
Self::try_from_io_with_checker(root, pattern, circ, inputs, outputs, &checker)
Expand All @@ -173,8 +177,8 @@ impl PatternMatch {
root: Node,
pattern: PatternID,
circ: &'c C,
inputs: Vec<Vec<(Node, Port)>>,
outputs: Vec<(Node, Port)>,
inputs: Vec<Vec<(Node, IncomingPort)>>,
outputs: Vec<(Node, OutgoingPort)>,
checker: &ConvexChecker<'c, C>,
) -> Result<Self, InvalidPatternMatch> {
let subgraph = SiblingSubgraph::try_new_with_checker(inputs, outputs, circ, checker)?;
Expand Down
17 changes: 9 additions & 8 deletions tket2/src/portmatching/pattern.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Circuit Patterns for pattern matching
use hugr::IncomingPort;
use hugr::{ops::OpTrait, Node, Port};
use itertools::Itertools;
use portmatching::{patterns::NoRootFound, HashMap, Pattern, SinglePatternMatcher};
Expand Down Expand Up @@ -42,17 +43,17 @@ impl CircuitPattern {
let op = cmd.optype().clone();
pattern.require(cmd.node().into(), op.try_into().unwrap());
for in_offset in 0..cmd.input_count() {
let in_offset = Port::new_incoming(in_offset);
let edge_prop =
PEdge::try_from_port(cmd.node(), in_offset, circuit).expect("Invalid HUGR");
let in_offset: IncomingPort = in_offset.into();
let edge_prop = PEdge::try_from_port(cmd.node(), in_offset.into(), circuit)
.expect("Invalid HUGR");
let (prev_node, prev_port) = circuit
.linked_ports(cmd.node(), in_offset)
.linked_outputs(cmd.node(), in_offset)
.exactly_one()
.ok()
.expect("invalid HUGR");
let prev_node = match edge_prop {
PEdge::InternalEdge { .. } => NodeID::HugrNode(prev_node),
PEdge::InputEdge { .. } => NodeID::CopyNode(prev_node, prev_port),
PEdge::InputEdge { .. } => NodeID::new_copy(prev_node, prev_port),
};
pattern.add_edge(cmd.node().into(), prev_node, edge_prop);
}
Expand Down Expand Up @@ -235,8 +236,8 @@ mod tests {
edges,
[
(cx_gate, h_gate),
(cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(0))),
(cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(1))),
(cx_gate, NodeID::new_copy(inp, 0)),
(cx_gate, NodeID::new_copy(inp, 1)),
]
.into_iter()
.collect()
Expand Down Expand Up @@ -292,7 +293,7 @@ mod tests {
assert!(edges.iter().any(|e| {
e.reverse().is_none()
&& e.source.unwrap() == rx_n.into()
&& e.target.unwrap() == NodeID::CopyNode(inp, Port::new_outgoing(1))
&& e.target.unwrap() == NodeID::new_copy(inp, 1)
}));
}
}
Expand Down
Loading

0 comments on commit 15a36fb

Please sign in to comment.