From 554d658f7afd38f0eb567ceff0599e65e74541c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 29 Sep 2023 18:06:58 +0200 Subject: [PATCH] feat: Optional direction check when querying a port index (#566) So `connect` throws an error when passed invalid ports, but we still support `usize`s. Closes #565. I choose not to allow backwards edges as that introduces a new failure mode when the directions are the same (and weird behaviour when passed one `Port` and one `usize`). --- src/builder/build_traits.rs | 25 ++++---- src/hugr.rs | 114 ++++++++++++++++++++++++++++++++---- src/hugr/hugrmut.rs | 19 +++--- 3 files changed, 124 insertions(+), 34 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 7ef288e2c..a1411dcd2 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -1,7 +1,7 @@ use crate::hugr::hugrmut::InsertionResult; use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::views::HugrView; -use crate::hugr::{Node, NodeMetadata, Port, PortIndex, ValidationError}; +use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError}; use crate::ops::{self, LeafOp, OpTrait, OpType}; use std::iter; @@ -658,27 +658,26 @@ fn wire_up_inputs( fn wire_up( data_builder: &mut T, src: Node, - src_port: impl PortIndex, + src_port: impl TryInto, dst: Node, - dst_port: impl PortIndex, + dst_port: impl TryInto, ) -> Result { - let src_port = src_port.index(); - let dst_port = dst_port.index(); + let src_port = Port::try_new_outgoing(src_port)?; + let dst_port = Port::try_new_incoming(dst_port)?; let base = data_builder.hugr_mut(); - let src_offset = Port::new_outgoing(src_port); let src_parent = base.get_parent(src); let dst_parent = base.get_parent(dst); let local_source = src_parent == dst_parent; - if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_offset).unwrap() { + if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() { if !local_source { // Non-local value sources require a state edge to an ancestor of dst if !typ.copyable() { let val_err: ValidationError = InterGraphEdgeError::NonCopyableData { from: src, - from_offset: Port::new_outgoing(src_port), + from_offset: src_port, to: dst, - to_offset: Port::new_incoming(dst_port), + to_offset: dst_port, ty: EdgeKind::Value(typ), } .into(); @@ -694,9 +693,9 @@ fn wire_up( else { let val_err: ValidationError = InterGraphEdgeError::NoRelation { from: src, - from_offset: Port::new_outgoing(src_port), + from_offset: src_port, to: dst, - to_offset: Port::new_incoming(dst_port), + to_offset: dst_port, } .into(); return Err(val_err.into()); @@ -705,7 +704,7 @@ fn wire_up( // TODO: Avoid adding duplicate edges // This should be easy with https://github.com/CQCL-DEV/hugr/issues/130 base.add_other_edge(src, src_sibling)?; - } else if !typ.copyable() & base.linked_ports(src, src_offset).next().is_some() { + } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() { // Don't copy linear edges. return Err(BuildError::NoCopyLinear(typ)); } @@ -719,7 +718,7 @@ fn wire_up( data_builder .hugr_mut() .get_optype(dst) - .port_kind(Port::new_incoming(dst_port)) + .port_kind(dst_port) .unwrap(), EdgeKind::Value(_) )) diff --git a/src/hugr.rs b/src/hugr.rs index 6f78fe43f..294f99fea 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -208,14 +208,23 @@ pub struct Port { } /// A trait for getting the undirected index of a port. -/// -/// This allows functions to admit both [`Port`]s and explicit `usize`s for -/// identifying port offsets. pub trait PortIndex { /// Returns the offset of the port. fn index(self) -> usize; } +/// A port in the incoming direction. +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)] +pub struct IncomingPort { + index: u16, +} + +/// A port in the outgoing direction. +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)] +pub struct OutgoingPort { + index: u16, +} + /// The direction of a port. pub type Direction = portgraph::Direction; @@ -372,18 +381,36 @@ impl Port { /// Creates a new incoming port. #[inline] - pub fn new_incoming(port: usize) -> Self { - Self { - offset: portgraph::PortOffset::new_incoming(port), - } + pub fn new_incoming(port: impl Into) -> Self { + Self::try_new_incoming(port).unwrap() } /// Creates a new outgoing port. #[inline] - pub fn new_outgoing(port: usize) -> Self { - Self { - offset: portgraph::PortOffset::new_outgoing(port), - } + pub fn new_outgoing(port: impl Into) -> Self { + Self::try_new_outgoing(port).unwrap() + } + + /// Creates a new incoming port. + #[inline] + pub fn try_new_incoming(port: impl TryInto) -> Result { + let Ok(port) = port.try_into() else { + return Err(HugrError::InvalidPortDirection(Direction::Outgoing)); + }; + Ok(Self { + offset: portgraph::PortOffset::new_incoming(port.index()), + }) + } + + /// Creates a new outgoing port. + #[inline] + pub fn try_new_outgoing(port: impl TryInto) -> Result { + let Ok(port) = port.try_into() else { + return Err(HugrError::InvalidPortDirection(Direction::Incoming)); + }; + Ok(Self { + offset: portgraph::PortOffset::new_outgoing(port.index()), + }) } /// Returns the direction of the port. @@ -407,6 +434,64 @@ impl PortIndex for usize { } } +impl PortIndex for IncomingPort { + #[inline(always)] + fn index(self) -> usize { + self.index as usize + } +} + +impl PortIndex for OutgoingPort { + #[inline(always)] + fn index(self) -> usize { + self.index as usize + } +} + +impl From for IncomingPort { + #[inline(always)] + fn from(index: usize) -> Self { + Self { + index: index as u16, + } + } +} + +impl From for OutgoingPort { + #[inline(always)] + fn from(index: usize) -> Self { + Self { + index: index as u16, + } + } +} + +impl TryFrom for IncomingPort { + type Error = HugrError; + #[inline(always)] + fn try_from(port: Port) -> Result { + match port.direction() { + Direction::Incoming => Ok(Self { + index: port.index() as u16, + }), + dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)), + } + } +} + +impl TryFrom for OutgoingPort { + type Error = HugrError; + #[inline(always)] + fn try_from(port: Port) -> Result { + match port.direction() { + Direction::Outgoing => Ok(Self { + index: port.index() as u16, + }), + dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)), + } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] /// A DataFlow wire, defined by a Value-kind output port of a node // Stores node and offset to output port @@ -415,8 +500,8 @@ pub struct Wire(Node, usize); impl Wire { /// Create a new wire from a node and a port. #[inline] - pub fn new(node: Node, port: Port) -> Self { - Self(node, port.index()) + pub fn new(node: Node, port: impl TryInto) -> Self { + Self(node, Port::try_new_outgoing(port).unwrap().index()) } /// The node that this wire is connected to. @@ -484,6 +569,9 @@ pub enum HugrError { /// The node doesn't exist. #[error("Invalid node {0:?}.")] InvalidNode(Node), + /// An invalid port was specified. + #[error("Invalid port direction {0:?}.")] + InvalidPortDirection(Direction), } #[cfg(feature = "pyo3")] diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index a02d822c9..94f6a1305 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -14,7 +14,7 @@ use crate::{Hugr, Port}; use self::sealed::HugrMutInternals; use super::views::SiblingSubgraph; -use super::{NodeMetadata, PortIndex, Rewrite}; +use super::{IncomingPort, NodeMetadata, OutgoingPort, PortIndex, Rewrite}; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrView + HugrMutInternals { @@ -110,9 +110,9 @@ pub trait HugrMut: HugrView + HugrMutInternals { fn connect( &mut self, src: Node, - src_port: impl PortIndex, + src_port: impl TryInto, dst: Node, - dst_port: impl PortIndex, + dst_port: impl TryInto, ) -> Result<(), HugrError> { self.valid_node(src)?; self.valid_node(dst)?; @@ -262,13 +262,16 @@ where fn connect( &mut self, src: Node, - src_port: impl PortIndex, + src_port: impl TryInto, dst: Node, - dst_port: impl PortIndex, + dst_port: impl TryInto, ) -> Result<(), HugrError> { - self.as_mut() - .graph - .link_nodes(src.index, src_port.index(), dst.index, dst_port.index())?; + self.as_mut().graph.link_nodes( + src.index, + Port::try_new_outgoing(src_port)?.index(), + dst.index, + Port::try_new_incoming(dst_port)?.index(), + )?; Ok(()) }