diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 66a29ccdb..6294df033 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -3,7 +3,7 @@ use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::views::HugrView; use crate::hugr::{NodeMetadata, ValidationError}; use crate::ops::{self, LeafOp, OpTrait, OpType}; -use crate::{IncomingPort, Node, OutgoingPort, Port}; +use crate::{IncomingPort, Node, OutgoingPort}; use std::iter; @@ -369,7 +369,7 @@ pub trait Dataflow: Container { input_extensions, ), // Constant wire from the constant value node - vec![Wire::new(const_node, Port::new_outgoing(0))], + vec![Wire::new(const_node, OutgoingPort::from(0))], )?; Ok(load_n.out_wire(0)) @@ -659,12 +659,12 @@ fn wire_up_inputs( fn wire_up( data_builder: &mut T, src: Node, - src_port: impl TryInto, + src_port: impl Into, dst: Node, - dst_port: impl TryInto, + dst_port: impl Into, ) -> Result { - let src_port = Port::try_new_outgoing(src_port)?; - let dst_port = Port::try_new_incoming(dst_port)?; + let src_port = src_port.into(); + let dst_port = dst_port.into(); let base = data_builder.hugr_mut(); let src_parent = base.get_parent(src); @@ -676,9 +676,9 @@ fn wire_up( if !typ.copyable() { let val_err: ValidationError = InterGraphEdgeError::NonCopyableData { from: src, - from_offset: src_port, + from_offset: src_port.into(), to: dst, - to_offset: dst_port, + to_offset: dst_port.into(), ty: EdgeKind::Value(typ), } .into(); @@ -694,9 +694,9 @@ fn wire_up( else { let val_err: ValidationError = InterGraphEdgeError::NoRelation { from: src, - from_offset: src_port, + from_offset: src_port.into(), to: dst, - to_offset: dst_port, + to_offset: dst_port.into(), } .into(); return Err(val_err.into()); diff --git a/src/builder/handle.rs b/src/builder/handle.rs index eca836d9e..1235659da 100644 --- a/src/builder/handle.rs +++ b/src/builder/handle.rs @@ -1,13 +1,8 @@ //! Handles to nodes in HUGR used during the building phase. //! -use crate::{ - ops::{ - handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID}, - OpTag, - }, - Port, -}; -use crate::{Node, Wire}; +use crate::ops::handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID}; +use crate::ops::OpTag; +use crate::{Node, OutgoingPort, Wire}; use itertools::Itertools; use std::iter::FusedIterator; @@ -64,7 +59,7 @@ impl BuildHandle { /// Retrieve a [`Wire`] corresponding to the given offset. /// Does not check whether such a wire is valid for this node. pub fn out_wire(&self, offset: usize) -> Wire { - Wire::new(self.node(), Port::new_outgoing(offset)) + Wire::new(self.node(), OutgoingPort::from(offset)) } #[inline] @@ -124,14 +119,12 @@ impl Iterator for Outputs { fn next(&mut self) -> Option { self.range .next() - .map(|offset| Wire::new(self.node, Port::new_outgoing(offset))) + .map(|offset| Wire::new(self.node, OutgoingPort::from(offset))) } #[inline] fn nth(&mut self, n: usize) -> Option { - self.range - .nth(n) - .map(|offset| Wire::new(self.node, Port::new_outgoing(offset))) + self.range.nth(n).map(|offset| Wire::new(self.node, offset)) } #[inline] @@ -157,7 +150,7 @@ impl DoubleEndedIterator for Outputs { fn next_back(&mut self) -> Option { self.range .next_back() - .map(|offset| Wire::new(self.node, Port::new_outgoing(offset))) + .map(|offset| Wire::new(self.node, offset)) } } diff --git a/src/core.rs b/src/core.rs index 21145729a..83bba3ab5 100644 --- a/src/core.rs +++ b/src/core.rs @@ -69,7 +69,7 @@ pub type Direction = portgraph::Direction; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] /// A DataFlow wire, defined by a Value-kind output port of a node // Stores node and offset to output port -pub struct Wire(Node, usize); +pub struct Wire(Node, OutgoingPort); impl Node { /// Returns the node as a portgraph `NodeIndex`. @@ -88,29 +88,29 @@ impl Port { } } - /// Creates a new incoming port. + /// Converts to an [IncomingPort] if this port is one; else fails with + /// [HugrError::InvalidPortDirection] #[inline] - pub fn new_incoming(port: impl Into) -> Self { - Self::try_new_incoming(port).unwrap() - } - - /// Creates a new outgoing port. - #[inline] - pub fn new_outgoing(port: impl Into) -> Self { - Self::try_new_outgoing(port).unwrap() + pub fn as_incoming(&self) -> Result { + match self.direction() { + Direction::Incoming => Ok(IncomingPort { + index: self.index() as u16, + }), + dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)), + } } - /// Creates a new incoming port. + /// Converts to an [OutgoingPort] if this port is one; else fails with + /// [HugrError::InvalidPortDirection] #[inline] - pub fn try_new_incoming(port: impl TryInto) -> Result { - let Ok(port) = port.try_into() else { - return Err(HugrError::InvalidPortDirection(Direction::Outgoing)); - }; - Ok(Self { - offset: portgraph::PortOffset::new_incoming(port.index()), - }) + pub fn as_outgoing(&self) -> Result { + match self.direction() { + Direction::Outgoing => Ok(OutgoingPort { + index: self.index() as u16, + }), + dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)), + } } - /// Creates a new outgoing port. #[inline] pub fn try_new_outgoing(port: impl TryInto) -> Result { @@ -181,28 +181,18 @@ impl From for OutgoingPort { } } -impl TryFrom for IncomingPort { - type Error = HugrError; - #[inline(always)] - fn try_from(port: Port) -> Result { - match port.direction() { - Direction::Incoming => Ok(Self { - index: port.index() as u16, - }), - dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)), +impl From for Port { + fn from(value: IncomingPort) -> Self { + Self { + offset: portgraph::PortOffset::new_incoming(value.index()), } } } -impl TryFrom for OutgoingPort { - type Error = HugrError; - #[inline(always)] - fn try_from(port: Port) -> Result { - match port.direction() { - Direction::Outgoing => Ok(Self { - index: port.index() as u16, - }), - dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)), +impl From for Port { + fn from(value: OutgoingPort) -> Self { + Self { + offset: portgraph::PortOffset::new_outgoing(value.index()), } } } @@ -216,8 +206,8 @@ impl NodeIndex for Node { impl Wire { /// Create a new wire from a node and a port. #[inline] - pub fn new(node: Node, port: impl TryInto) -> Self { - Self(node, Port::try_new_outgoing(port).unwrap().index()) + pub fn new(node: Node, port: impl Into) -> Self { + Self(node, port.into()) } /// The node that this wire is connected to. @@ -228,8 +218,8 @@ impl Wire { /// The output port that this wire is connected to. #[inline] - pub fn source(&self) -> Port { - Port::new_outgoing(self.1) + pub fn source(&self) -> OutgoingPort { + self.1 } } diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index a9c9357ea..ac8bb53ce 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -108,9 +108,9 @@ pub trait HugrMut: HugrMutInternals { fn connect( &mut self, src: Node, - src_port: impl TryInto, + src_port: impl Into, dst: Node, - dst_port: impl TryInto, + dst_port: impl Into, ) -> Result<(), HugrError> { self.valid_node(src)?; self.valid_node(dst)?; @@ -121,7 +121,7 @@ pub trait HugrMut: HugrMutInternals { /// /// The port is left in place. #[inline] - fn disconnect(&mut self, node: Node, port: Port) -> Result<(), HugrError> { + fn disconnect(&mut self, node: Node, port: impl Into) -> Result<(), HugrError> { self.valid_node(node)?; self.hugr_mut().disconnect(node, port) } @@ -134,7 +134,11 @@ pub trait HugrMut: HugrMutInternals { /// /// [`OpTrait::other_input`]: crate::ops::OpTrait::other_input /// [`OpTrait::other_output`]: crate::ops::OpTrait::other_output - fn add_other_edge(&mut self, src: Node, dst: Node) -> Result<(Port, Port), HugrError> { + fn add_other_edge( + &mut self, + src: Node, + dst: Node, + ) -> Result<(OutgoingPort, IncomingPort), HugrError> { self.valid_node(src)?; self.valid_node(dst)?; self.hugr_mut().add_other_edge(src, dst) @@ -247,20 +251,21 @@ impl + AsMut> HugrMut for T { fn connect( &mut self, src: Node, - src_port: impl TryInto, + src_port: impl Into, dst: Node, - dst_port: impl TryInto, + dst_port: impl Into, ) -> Result<(), HugrError> { self.as_mut().graph.link_nodes( src.pg_index(), - Port::try_new_outgoing(src_port)?.index(), + src_port.into().index(), dst.pg_index(), - Port::try_new_incoming(dst_port)?.index(), + dst_port.into().index(), )?; Ok(()) } - fn disconnect(&mut self, node: Node, port: Port) -> Result<(), HugrError> { + fn disconnect(&mut self, node: Node, port: impl Into) -> Result<(), HugrError> { + let port = port.into(); let offset = port.pg_offset(); let port = self .as_mut() @@ -274,15 +279,21 @@ impl + AsMut> HugrMut for T { Ok(()) } - fn add_other_edge(&mut self, src: Node, dst: Node) -> Result<(Port, Port), HugrError> { - let src_port: Port = self + fn add_other_edge( + &mut self, + src: Node, + dst: Node, + ) -> Result<(OutgoingPort, IncomingPort), HugrError> { + let src_port = self .get_optype(src) .other_port_index(Direction::Outgoing) - .expect("Source operation has no non-dataflow outgoing edges"); - let dst_port: Port = self + .expect("Source operation has no non-dataflow outgoing edges") + .as_outgoing()?; + let dst_port = self .get_optype(dst) .other_port_index(Direction::Incoming) - .expect("Destination operation has no non-dataflow incoming edges"); + .expect("Destination operation has no non-dataflow incoming edges") + .as_incoming()?; self.connect(src, src_port, dst, dst_port)?; Ok((src_port, dst_port)) } diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index cf4e12a70..4ac3331e8 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -5,7 +5,7 @@ use std::iter; use crate::hugr::{HugrMut, Node}; use crate::ops::{LeafOp, OpTag, OpTrait}; use crate::types::EdgeKind; -use crate::{Direction, HugrView, Port}; +use crate::{HugrView, IncomingPort}; use super::Rewrite; @@ -18,12 +18,12 @@ pub struct IdentityInsertion { /// The node following the identity to be inserted. pub post_node: Node, /// The port following the identity to be inserted. - pub post_port: Port, + pub post_port: IncomingPort, } impl IdentityInsertion { /// Create a new [`IdentityInsertion`] specification. - pub fn new(post_node: Node, post_port: Port) -> Self { + pub fn new(post_node: Node, post_port: IncomingPort) -> Self { Self { post_node, post_port, @@ -43,10 +43,6 @@ pub enum IdentityInsertionError { /// Invalid port kind. #[error("post_port has invalid kind {0:?}. Must be Value.")] InvalidPortKind(Option), - - /// Must be input port. - #[error("post_port is an output port, must be input.")] - PortIsOutput, } impl Rewrite for IdentityInsertion { @@ -71,17 +67,13 @@ impl Rewrite for IdentityInsertion { unimplemented!() } fn apply(self, h: &mut impl HugrMut) -> Result { - if self.post_port.direction() != Direction::Incoming { - return Err(IdentityInsertionError::PortIsOutput); - } - let kind = h.get_optype(self.post_node).port_kind(self.post_port); let Some(EdgeKind::Value(ty)) = kind else { return Err(IdentityInsertionError::InvalidPortKind(kind)); }; let (pre_node, pre_port) = h - .linked_ports(self.post_node, self.post_port) + .linked_outputs(self.post_node, self.post_port) .exactly_one() .ok() .expect("Value kind input can only have one connection."); @@ -155,11 +147,6 @@ mod tests { let final_node = tail.node(); - let final_node_output = h.node_outputs(final_node).next().unwrap(); - let rw = IdentityInsertion::new(final_node, final_node_output); - let apply_result = h.apply_rewrite(rw); - assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput)); - let final_node_input = h.node_inputs(final_node).next().unwrap(); let rw = IdentityInsertion::new(final_node, final_node_input); diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 1cd720ef0..589a85643 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -158,7 +158,7 @@ impl Rewrite for OutlineCfg { // 3. Entry edges. Change any edges into entry_block from outside, to target new_block let preds: Vec<_> = h - .linked_ports(entry, h.node_inputs(entry).exactly_one().ok().unwrap()) + .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap()) .collect(); for (pred, br) in preds { if !self.blocks.contains(&pred) { diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 53e3f68a7..9b661c8bd 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -7,12 +7,9 @@ use std::slice; use itertools::Itertools; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, NodeMetadata}; -use crate::{ - hugr::{Node, Rewrite}, - ops::{OpTag, OpTrait, OpType}, - Hugr, Port, -}; +use crate::hugr::{HugrMut, HugrView, NodeMetadata, Rewrite}; +use crate::ops::{OpTag, OpTrait, OpType}; +use crate::{Hugr, IncomingPort, Node}; use thiserror::Error; /// Specification of a simple replacement operation. @@ -24,10 +21,10 @@ pub struct SimpleReplacement { replacement: Hugr, /// A map from (target ports of edges from the Input node of `replacement`) to (target ports of /// edges from nodes not in `removal` to nodes in `removal`). - nu_inp: HashMap<(Node, Port), (Node, Port)>, + nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)>, /// A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to /// (input ports of the Output node of `replacement`). - nu_out: HashMap<(Node, Port), Port>, + nu_out: HashMap<(Node, IncomingPort), IncomingPort>, } impl SimpleReplacement { @@ -36,8 +33,8 @@ impl SimpleReplacement { pub fn new( subgraph: SiblingSubgraph, replacement: Hugr, - nu_inp: HashMap<(Node, Port), (Node, Port)>, - nu_out: HashMap<(Node, Port), Port>, + nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)>, + nu_out: HashMap<(Node, IncomingPort), IncomingPort>, ) -> Self { Self { subgraph, @@ -61,8 +58,10 @@ impl SimpleReplacement { } type SubgraphNodesIter<'a> = Copied>; -type NuOutNodesIter<'a> = - iter::Map, fn(&'a (Node, Port)) -> Node>; +type NuOutNodesIter<'a> = iter::Map< + hash_map::Keys<'a, (Node, IncomingPort), IncomingPort>, + fn(&'a (Node, IncomingPort)) -> Node, +>; impl Rewrite for SimpleReplacement { type Error = SimpleReplacementError; @@ -116,7 +115,7 @@ impl Rewrite for SimpleReplacement { for &node in replacement_inner_nodes { let new_node = index_map.get(&node).unwrap(); for outport in self.replacement.node_outputs(node) { - for target in self.replacement.linked_ports(node, outport) { + for target in self.replacement.linked_inputs(node, outport) { if self.replacement.get_optype(target.0).tag() != OpTag::Output { let new_target = index_map.get(&target.0).unwrap(); h.connect(*new_node, outport, *new_target, target.1) @@ -131,7 +130,7 @@ impl Rewrite for SimpleReplacement { if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) let (rem_inp_pred_node, rem_inp_pred_port) = h - .linked_ports(*rem_inp_node, *rem_inp_port) + .linked_outputs(*rem_inp_node, *rem_inp_port) .exactly_one() .ok() // PortLinks does not implement Debug .unwrap(); @@ -151,7 +150,7 @@ impl Rewrite for SimpleReplacement { for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { let (rep_out_pred_node, rep_out_pred_port) = self .replacement - .linked_ports(replacement_output_node, *rep_out_port) + .linked_outputs(replacement_output_node, *rep_out_port) .exactly_one() .unwrap(); if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input { @@ -173,7 +172,7 @@ impl Rewrite for SimpleReplacement { if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): let (rem_inp_pred_node, rem_inp_pred_port) = h - .linked_ports(*rem_inp_node, *rem_inp_port) + .linked_outputs(*rem_inp_node, *rem_inp_port) .exactly_one() .ok() // PortLinks does not implement Debug .unwrap(); @@ -198,7 +197,7 @@ impl Rewrite for SimpleReplacement { #[inline] fn invalidation_set(&self) -> Self::InvalidationSet<'_> { let subcirc = self.subgraph.nodes().iter().copied(); - let get_node: fn(&(Node, Port)) -> Node = |key: &(Node, Port)| key.0; + let get_node: fn(&(Node, IncomingPort)) -> Node = |key| key.0; let out_neighs = self.nu_out.keys().map(get_node); subcirc.chain(out_neighs) } @@ -221,7 +220,6 @@ pub enum SimpleReplacementError { #[cfg(test)] pub(in crate::hugr::rewrite) mod test { use itertools::Itertools; - use portgraph::Direction; use rstest::{fixture, rstest}; use std::collections::{HashMap, HashSet}; @@ -232,13 +230,14 @@ pub(in crate::hugr::rewrite) mod test { use crate::extension::prelude::BOOL_T; use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; - use crate::hugr::{Hugr, HugrMut, Node, Rewrite}; + use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::OpTag; use crate::ops::{OpTrait, OpType}; use crate::std_extensions::logic::test::and_op; use crate::std_extensions::quantum::test::{cx_gate, h_gate}; + use crate::type_row; use crate::types::{FunctionType, Type}; - use crate::{type_row, Port}; + use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -371,28 +370,20 @@ pub(in crate::hugr::rewrite) mod test { .unwrap(); let (n_node_h0, n_node_h1) = n.input_neighbours(n_node_cx).collect_tuple().unwrap(); // 3.2. Locate the ports we need to specify as "glue" in n - let n_port_0 = n.node_ports(n_node_h0, Direction::Incoming).next().unwrap(); - let n_port_1 = n.node_ports(n_node_h1, Direction::Incoming).next().unwrap(); - let (n_cx_out_0, n_cx_out_1) = n - .node_ports(n_node_cx, Direction::Outgoing) - .take(2) - .collect_tuple() - .unwrap(); - let n_port_2 = n.linked_ports(n_node_cx, n_cx_out_0).next().unwrap().1; - let n_port_3 = n.linked_ports(n_node_cx, n_cx_out_1).next().unwrap().1; + let n_port_0 = n.node_inputs(n_node_h0).next().unwrap(); + let n_port_1 = n.node_inputs(n_node_h1).next().unwrap(); + let (n_cx_out_0, n_cx_out_1) = n.node_outputs(n_node_cx).take(2).collect_tuple().unwrap(); + let n_port_2 = n.linked_inputs(n_node_cx, n_cx_out_0).next().unwrap().1; + let n_port_3 = n.linked_inputs(n_node_cx, n_cx_out_1).next().unwrap().1; // 3.3. Locate the ports we need to specify as "glue" in h - let (h_port_0, h_port_1) = h - .node_ports(h_node_cx, Direction::Incoming) - .take(2) - .collect_tuple() - .unwrap(); - let h_h0_out = h.node_ports(h_node_h0, Direction::Outgoing).next().unwrap(); - let h_h1_out = h.node_ports(h_node_h1, Direction::Outgoing).next().unwrap(); - let (h_outp_node, h_port_2) = h.linked_ports(h_node_h0, h_h0_out).next().unwrap(); - let h_port_3 = h.linked_ports(h_node_h1, h_h1_out).next().unwrap().1; + let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap(); + let h_h0_out = h.node_outputs(h_node_h0).next().unwrap(); + let h_h1_out = h.node_outputs(h_node_h1).next().unwrap(); + let (h_outp_node, h_port_2) = h.linked_inputs(h_node_h0, h_h0_out).next().unwrap(); + let h_port_3 = h.linked_inputs(h_node_h1, h_h1_out).next().unwrap().1; // 3.4. Construct the maps - let mut nu_inp: HashMap<(Node, Port), (Node, Port)> = HashMap::new(); - let mut nu_out: HashMap<(Node, Port), Port> = HashMap::new(); + let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new(); + let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new(); nu_inp.insert((n_node_h0, n_port_0), (h_node_cx, h_port_0)); nu_inp.insert((n_node_h1, n_port_1), (h_node_cx, h_port_1)); nu_out.insert((h_outp_node, h_port_2), n_port_2); @@ -465,11 +456,11 @@ pub(in crate::hugr::rewrite) mod test { // 3.3. Locate the ports we need to specify as "glue" in h let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap(); let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap(); - let h_port_2 = h.node_ports(h_node_h0, Direction::Incoming).next().unwrap(); - let h_port_3 = h.node_ports(h_node_h1, Direction::Incoming).next().unwrap(); + let h_port_2 = h.node_inputs(h_node_h0).next().unwrap(); + let h_port_3 = h.node_inputs(h_node_h1).next().unwrap(); // 3.4. Construct the maps - let mut nu_inp: HashMap<(Node, Port), (Node, Port)> = HashMap::new(); - let mut nu_out: HashMap<(Node, Port), Port> = HashMap::new(); + let mut nu_inp: HashMap<(Node, IncomingPort), (Node, IncomingPort)> = HashMap::new(); + let mut nu_out: HashMap<(Node, IncomingPort), IncomingPort> = HashMap::new(); nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0)); nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1)); nu_out.insert((h_node_h0, h_port_2), n_port_0); @@ -512,7 +503,7 @@ pub(in crate::hugr::rewrite) mod test { .node_outputs(input) .filter(|&p| h.get_optype(input).signature().get(p).is_some()) .map(|p| { - let link = h.linked_ports(input, p).next().unwrap(); + let link = h.linked_inputs(input, p).next().unwrap(); (link, link) }) .collect(); @@ -562,10 +553,10 @@ pub(in crate::hugr::rewrite) mod test { .collect_vec(); let first_out_p = h.node_outputs(input).next().unwrap(); - let embedded_inputs = h.linked_ports(input, first_out_p); + let embedded_inputs = h.linked_inputs(input, first_out_p); let repl_inputs = repl .node_outputs(repl_input) - .map(|p| repl.linked_ports(repl_input, p).next().unwrap()); + .map(|p| repl.linked_inputs(repl_input, p).next().unwrap()); let inputs = embedded_inputs.zip(repl_inputs).collect(); let outputs = repl diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 87eba59b3..74e380367 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -268,19 +268,17 @@ impl TryFrom for Hugr { pub mod test { use super::*; + use crate::builder::{ + test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, ModuleBuilder, + }; + use crate::extension::prelude::BOOL_T; use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; - use crate::{ - builder::{ - test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr, - DataflowSubContainer, HugrBuilder, ModuleBuilder, - }, - extension::prelude::BOOL_T, - hugr::NodeType, - ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG}, - types::{FunctionType, Type}, - Port, - }; + use crate::hugr::NodeType; + use crate::ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG}; + use crate::types::{FunctionType, Type}; + use crate::OutgoingPort; use itertools::Itertools; use portgraph::{ multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap, @@ -462,7 +460,7 @@ pub mod test { // Now add a new input let new_in = hugr.add_op(Input::new([QB].to_vec())); - hugr.disconnect(old_in, Port::new_outgoing(0)).unwrap(); + hugr.disconnect(old_in, OutgoingPort::from(0)).unwrap(); hugr.connect(new_in, 0, out, 0).unwrap(); hugr.move_before_sibling(new_in, old_in).unwrap(); hugr.remove_node(old_in).unwrap(); diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 9d40c52fa..1b56f874c 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -756,7 +756,7 @@ mod test { use crate::std_extensions::logic::test::{and_op, not_op}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow}; - use crate::{type_row, Direction, Node}; + use crate::{type_row, Direction, IncomingPort, Node}; const NAT: Type = crate::extension::prelude::USIZE_T; const Q: Type = crate::extension::prelude::QB_T; @@ -1145,7 +1145,7 @@ mod test { h.update_validate(&EMPTY_REG), Err(ValidationError::UnconnectedPort { node: and, - port: Port::new_incoming(1), + port: IncomingPort::from(1).into(), port_kind: EdgeKind::Value(BOOL_T) }) ); diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 2aadbdc82..5cfc5698c 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -9,6 +9,8 @@ pub mod sibling_subgraph; #[cfg(test)] mod tests; +use std::iter::Map; + pub use self::petgraph::PetgraphWrapper; pub use descendants::DescendantsGraph; pub use root_checked::RootChecked; @@ -24,7 +26,7 @@ use super::{Hugr, HugrError, NodeMetadata, NodeType, DEFAULT_NODETYPE}; use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; use crate::types::{EdgeKind, FunctionType}; -use crate::{Direction, Node, Port}; +use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; /// A trait for inspecting HUGRs. /// For end users we intend this to be superseded by region-specific APIs. @@ -148,30 +150,58 @@ pub trait HugrView: sealed::HugrInternals { fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_>; /// Iterator over output ports of node. - /// Shorthand for [`node_ports`][HugrView::node_ports]`(node, Direction::Outgoing)`. + /// Like [`node_ports`][HugrView::node_ports]`(node, Direction::Outgoing)` + /// but preserves knowledge that the ports are [OutgoingPort]s. #[inline] - fn node_outputs(&self, node: Node) -> Self::NodePorts<'_> { + fn node_outputs(&self, node: Node) -> OutgoingPorts> { self.node_ports(node, Direction::Outgoing) + .map(|p| p.as_outgoing().unwrap()) } /// Iterator over inputs ports of node. - /// Shorthand for [`node_ports`][HugrView::node_ports]`(node, Direction::Incoming)`. + /// Like [`node_ports`][HugrView::node_ports]`(node, Direction::Incoming)` + /// but preserves knowledge that the ports are [IncomingPort]s. #[inline] - fn node_inputs(&self, node: Node) -> Self::NodePorts<'_> { + fn node_inputs(&self, node: Node) -> IncomingPorts> { self.node_ports(node, Direction::Incoming) + .map(|p| p.as_incoming().unwrap()) } /// Iterator over both the input and output ports of node. fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_>; /// Iterator over the nodes and ports connected to a port. - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_>; + fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; + + /// Iterator over the nodes and output ports connected to a given *input* port. + /// Like [`linked_ports`][HugrView::linked_ports] but preserves knowledge + /// that the linked ports are [OutgoingPort]s. + fn linked_outputs( + &self, + node: Node, + port: impl Into, + ) -> OutgoingNodePorts> { + self.linked_ports(node, port.into()) + .map(|(n, p)| (n, p.as_outgoing().unwrap())) + } + + /// Iterator over the nodes and input ports connected to a given *output* port + /// Like [`linked_ports`][HugrView::linked_ports] but preserves knowledge + /// that the linked ports are [IncomingPort]s. + fn linked_inputs( + &self, + node: Node, + port: impl Into, + ) -> IncomingNodePorts> { + self.linked_ports(node, port.into()) + .map(|(n, p)| (n, p.as_incoming().unwrap())) + } /// Iterator the links between two nodes. fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_>; /// Returns whether a port is connected. - fn is_linked(&self, node: Node, port: Port) -> bool { + fn is_linked(&self, node: Node, port: impl Into) -> bool { self.linked_ports(node, port).next().is_some() } @@ -300,6 +330,18 @@ pub trait HugrView: sealed::HugrInternals { } } +/// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s +pub type OutgoingPorts = Map OutgoingPort>; + +/// Wraps an iterator over [Port]s that are known to be [IncomingPort]s +pub type IncomingPorts = Map IncomingPort>; + +/// Wraps an iterator over `(`[`Node`],[`Port`]`)` when the ports are known to be [OutgoingPort]s +pub type OutgoingNodePorts = Map (Node, OutgoingPort)>; + +/// Wraps an iterator over `(`[`Node`],[`Port`]`)` when the ports are known to be [IncomingPort]s +pub type IncomingNodePorts = Map (Node, IncomingPort)>; + /// Trait for views that provides a guaranteed bound on the type of the root node. pub trait RootTagged: HugrView { /// The kind of handle that can be used to refer to the root node. @@ -397,7 +439,8 @@ impl> HugrView for T { } #[inline] - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { + fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + let port = port.into(); let hugr = self.as_ref(); let port = hugr .graph diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index 0f39e0180..bbbfd81ca 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -100,10 +100,10 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { self.graph.all_port_offsets(node.pg_index()).map_into() } - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { + fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { let port = self .graph - .port_index(node.pg_index(), port.pg_offset()) + .port_index(node.pg_index(), port.into().pg_offset()) .unwrap(); self.graph .port_links(port) diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index d5ed45865..69a2da4f5 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -131,10 +131,10 @@ impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { self.graph.all_port_offsets(node.pg_index()).map_into() } - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { + fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { let port = self .graph - .port_index(node.pg_index(), port.pg_offset()) + .port_index(node.pg_index(), port.into().pg_offset()) .unwrap(); self.graph .port_links(port) @@ -316,7 +316,7 @@ impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> { } } - fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { + fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { // Need to filter only to links inside the sibling graph SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) .linked_ports(node, port) diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 33e249dd3..2e1c4aaa0 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -18,18 +18,12 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::extension::ExtensionSet; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::{HugrError, HugrMut, HugrView, RootTagged}; +use crate::ops::handle::{ContainerHandle, DataflowOpID}; +use crate::ops::{OpTag, OpTrait}; use crate::types::Signature; -use crate::{ - ops::{ - handle::{ContainerHandle, DataflowOpID}, - OpTag, OpTrait, - }, - types::{FunctionType, Type}, - Hugr, Node, Port, SimpleReplacement, -}; - -use super::{HugrView, RootTagged}; +use crate::types::{FunctionType, Type}; +use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement}; #[cfg(feature = "pyo3")] use pyo3::{create_exception, exceptions::PyException, PyErr}; @@ -68,12 +62,12 @@ pub struct SiblingSubgraph { /// /// Grouped by input parameter. Each port must be unique and belong to a /// node in `nodes`. - inputs: Vec>, + inputs: Vec>, /// The output ports of the subgraph. /// /// Repeated ports are allowed and correspond to copying the output. Every /// port must belong to a node in `nodes`. - outputs: Vec<(Node, Port)>, + outputs: Vec<(Node, OutgoingPort)>, } /// The type of the incoming boundary of [`SiblingSubgraph`]. @@ -82,9 +76,9 @@ pub struct SiblingSubgraph { /// input parameter. A set in the partition that has more than one element /// corresponds to an input parameter that is copied and useful multiple times /// in the subgraph. -pub type IncomingPorts = Vec>; +pub type IncomingPorts = Vec>; /// The type of the outgoing boundary of [`SiblingSubgraph`]. -pub type OutgoingPorts = Vec<(Node, Port)>; +pub type OutgoingPorts = Vec<(Node, OutgoingPort)>; impl SiblingSubgraph { /// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted @@ -192,15 +186,8 @@ impl SiblingSubgraph { }; // Ordering of the edges here is preserved and becomes ordering of the signature. - let subpg = Subgraph::new_subgraph( - pg.clone(), - inputs - .iter() - .flatten() - .copied() - .chain(outputs.iter().copied()) - .map(to_pg), - ); + let subpg = + Subgraph::new_subgraph(pg.clone(), combine_in_out(&inputs, &outputs).map(to_pg)); let nodes = subpg.nodes_iter().map_into().collect_vec(); validate_subgraph(hugr, &nodes, &inputs, &outputs)?; @@ -373,8 +360,10 @@ impl SiblingSubgraph { rep_inputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); - let mut order_ports = in_order_ports.into_iter().chain(out_order_ports); - if order_ports.any(|(n, p)| is_order_edge(&replacement, n, p)) { + + if combine_in_out(&vec![out_order_ports], &in_order_ports) + .any(|(n, p)| is_order_edge(&replacement, n, p)) + { unimplemented!("Found state order edges in replacement graph"); } @@ -383,7 +372,7 @@ impl SiblingSubgraph { .zip_eq(&self.inputs) .flat_map(|((rep_source_n, rep_source_p), self_targets)| { replacement - .linked_ports(rep_source_n, rep_source_p) + .linked_inputs(rep_source_n, rep_source_p) .flat_map(move |rep_target| { self_targets .iter() @@ -396,7 +385,7 @@ impl SiblingSubgraph { .iter() .zip_eq(rep_outputs) .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| { - hugr.linked_ports(self_source_n, self_source_p) + hugr.linked_inputs(self_source_n, self_source_p) .map(move |self_target| (self_target, rep_target_p)) }) .collect(); @@ -431,17 +420,13 @@ impl SiblingSubgraph { // Connect the inserted nodes in-between the input and output nodes. let [inp, out] = extracted.get_io(extracted.root()).unwrap(); - for (inp_port, repl_ports) in extracted - .node_ports(inp, Direction::Outgoing) - .zip(self.inputs.iter()) - { + for (inp_port, repl_ports) in extracted.node_outputs(inp).zip(self.inputs.iter()) { for (repl_node, repl_port) in repl_ports { extracted.connect(inp, inp_port, node_map[repl_node], *repl_port)?; } } - for (out_port, (repl_node, repl_port)) in extracted - .node_ports(out, Direction::Incoming) - .zip(self.outputs.iter()) + for (out_port, (repl_node, repl_port)) in + extracted.node_inputs(out).zip(self.outputs.iter()) { extracted.connect(node_map[repl_node], *repl_port, out, out_port)?; } @@ -450,6 +435,17 @@ impl SiblingSubgraph { } } +fn combine_in_out<'a>( + inputs: &'a IncomingPorts, + outputs: &'a OutgoingPorts, +) -> impl Iterator + 'a { + inputs + .iter() + .flatten() + .map(|(n, p)| (*n, (*p).into())) + .chain(outputs.iter().map(|(n, p)| (*n, (*p).into()))) +} + /// Precompute convexity information for a HUGR. /// /// This can be used when constructing multiple sibling subgraphs to speed up @@ -469,7 +465,7 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> { /// The type of all ports in the iterator. /// /// If the array is empty or a port does not exist, returns `None`. -fn get_edge_type(hugr: &H, ports: &[(Node, Port)]) -> Option { +fn get_edge_type + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option { let &(n, p) = ports.first()?; let edge_t = hugr.get_optype(n).signature().get(p)?.clone(); ports @@ -503,36 +499,11 @@ fn validate_subgraph( } // Check there are no linked "other" ports - if inputs - .iter() - .flatten() - .chain(outputs) - .any(|&(n, p)| is_order_edge(hugr, n, p)) - { + if combine_in_out(inputs, outputs).any(|(n, p)| is_order_edge(hugr, n, p)) { unimplemented!("Connected order edges not supported at the boundary") } - // Check inputs are incoming ports and outputs are outgoing ports - if let Some(&(n, p)) = inputs - .iter() - .flatten() - .find(|(_, p)| p.direction() == Direction::Outgoing) - { - Err(InvalidSubgraphBoundary::InputPortDirection(n, p))?; - }; - if let Some(&(n, p)) = outputs - .iter() - .find(|(_, p)| p.direction() == Direction::Incoming) - { - Err(InvalidSubgraphBoundary::OutputPortDirection(n, p))?; - }; - - let boundary_ports = inputs - .iter() - .flatten() - .chain(outputs) - .copied() - .collect_vec(); + let boundary_ports = combine_in_out(inputs, outputs).collect_vec(); // Check that the boundary ports are all in the subgraph. if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) { Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?; @@ -607,7 +578,7 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort let inputs = dfg_inputs .into_iter() .map(|p| { - hugr.linked_ports(inp, p) + hugr.linked_inputs(inp, p) .filter(|&(n, _)| n != out) .collect_vec() }) @@ -617,7 +588,7 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort // direct wires to the input. let outputs = dfg_outputs .into_iter() - .filter_map(|p| hugr.linked_ports(out, p).find(|&(n, _)| n != inp)) + .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp)) .collect(); (inputs, outputs) } @@ -686,12 +657,6 @@ pub enum InvalidSubgraph { /// Errors that can occur while constructing a [`SiblingSubgraph`]. #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum InvalidSubgraphBoundary { - /// A node in the input boundary is not Incoming. - #[error("Expected (node {0:?}, port {1:?}) in the input boundary to be an incoming port.")] - InputPortDirection(Node, Port), - /// A node in the output boundary is not Outgoing. - #[error("Expected (node {0:?}, port {1:?}) in the input boundary to be an outgoing port.")] - OutputPortDirection(Node, Port), /// A boundary port's node is not in the set of nodes. #[error("(node {0:?}, port {1:?}) is in the boundary, but node {0:?} is not in the set.")] PortNodeNotInSet(Node, Port), @@ -914,12 +879,12 @@ mod tests { SiblingSubgraph::try_new( hugr.node_outputs(inp) .take(2) - .map(|p| hugr.linked_ports(inp, p).collect_vec()) + .map(|p| hugr.linked_inputs(inp, p).collect_vec()) .filter(|ps| !ps.is_empty()) .collect(), hugr.node_inputs(out) .take(2) - .filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok()) + .filter_map(|p| hugr.linked_outputs(out, p).exactly_one().ok()) .collect(), &func, ) @@ -935,7 +900,10 @@ mod tests { // All graph but one edge assert_matches!( SiblingSubgraph::try_new( - vec![hugr.linked_ports(inp, first_cx_edge).collect()], + vec![hugr + .linked_ports(inp, first_cx_edge) + .map(|(n, p)| (n, p.as_incoming().unwrap())) + .collect()], vec![(inp, first_cx_edge)], &func, ), diff --git a/src/types/signature.rs b/src/types/signature.rs index 7b3216a36..ce13b5e6e 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -3,18 +3,14 @@ #[cfg(feature = "pyo3")] use pyo3::{pyclass, pymethods}; -use std::ops::Index; - +use delegate::delegate; use smol_str::SmolStr; - use std::fmt::{self, Display, Write}; - -use crate::{Direction, Port, PortIndex}; - -use super::{Type, TypeRow}; +use std::ops::Index; use crate::extension::ExtensionSet; -use delegate::delegate; +use crate::types::{Type, TypeRow}; +use crate::{Direction, IncomingPort, OutgoingPort, Port, PortIndex}; #[cfg_attr(feature = "pyo3", pyclass)] #[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -101,7 +97,8 @@ impl FunctionType { /// Returns the type of a value [`Port`]. Returns `None` if the port is out /// of bounds. #[inline] - pub fn get(&self, port: Port) -> Option<&Type> { + pub fn get(&self, port: impl Into) -> Option<&Type> { + let port = port.into(); match port.direction() { Direction::Incoming => self.input.get(port), Direction::Outgoing => self.output.get(port), @@ -199,14 +196,16 @@ impl FunctionType { /// Returns the incoming `Port`s in the signature. #[inline] - pub fn input_ports(&self) -> impl Iterator { + pub fn input_ports(&self) -> impl Iterator { self.ports(Direction::Incoming) + .map(|p| p.as_incoming().unwrap()) } /// Returns the outgoing `Port`s in the signature. #[inline] - pub fn output_ports(&self) -> impl Iterator { + pub fn output_ports(&self) -> impl Iterator { self.ports(Direction::Outgoing) + .map(|p| p.as_outgoing().unwrap()) } }