Skip to content

Commit

Permalink
feat: Optional direction check when querying a port index (#566)
Browse files Browse the repository at this point in the history
So `connect` throws an error when passed invalid ports, but we still
support `usize`s.

Closes #565. I choose not to allow backwards edges as that introduces a
new failure mode when the directions are the same (and weird behaviour
when passed one `Port` and one `usize`).
  • Loading branch information
aborgna-q authored Sep 29, 2023
1 parent 0ce711b commit 554d658
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 34 deletions.
25 changes: 12 additions & 13 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{Node, NodeMetadata, Port, PortIndex, ValidationError};
use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};

use std::iter;
Expand Down Expand Up @@ -658,27 +658,26 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
fn wire_up<T: Dataflow + ?Sized>(
data_builder: &mut T,
src: Node,
src_port: impl PortIndex,
src_port: impl TryInto<OutgoingPort>,
dst: Node,
dst_port: impl PortIndex,
dst_port: impl TryInto<IncomingPort>,
) -> Result<bool, BuildError> {
let src_port = src_port.index();
let dst_port = dst_port.index();
let src_port = Port::try_new_outgoing(src_port)?;
let dst_port = Port::try_new_incoming(dst_port)?;
let base = data_builder.hugr_mut();
let src_offset = Port::new_outgoing(src_port);

let src_parent = base.get_parent(src);
let dst_parent = base.get_parent(dst);
let local_source = src_parent == dst_parent;
if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_offset).unwrap() {
if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
if !local_source {
// Non-local value sources require a state edge to an ancestor of dst
if !typ.copyable() {
let val_err: ValidationError = InterGraphEdgeError::NonCopyableData {
from: src,
from_offset: Port::new_outgoing(src_port),
from_offset: src_port,
to: dst,
to_offset: Port::new_incoming(dst_port),
to_offset: dst_port,
ty: EdgeKind::Value(typ),
}
.into();
Expand All @@ -694,9 +693,9 @@ fn wire_up<T: Dataflow + ?Sized>(
else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
from: src,
from_offset: Port::new_outgoing(src_port),
from_offset: src_port,
to: dst,
to_offset: Port::new_incoming(dst_port),
to_offset: dst_port,
}
.into();
return Err(val_err.into());
Expand All @@ -705,7 +704,7 @@ fn wire_up<T: Dataflow + ?Sized>(
// TODO: Avoid adding duplicate edges
// This should be easy with https://github.com/CQCL-DEV/hugr/issues/130
base.add_other_edge(src, src_sibling)?;
} else if !typ.copyable() & base.linked_ports(src, src_offset).next().is_some() {
} else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
// Don't copy linear edges.
return Err(BuildError::NoCopyLinear(typ));
}
Expand All @@ -719,7 +718,7 @@ fn wire_up<T: Dataflow + ?Sized>(
data_builder
.hugr_mut()
.get_optype(dst)
.port_kind(Port::new_incoming(dst_port))
.port_kind(dst_port)
.unwrap(),
EdgeKind::Value(_)
))
Expand Down
114 changes: 101 additions & 13 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,23 @@ pub struct Port {
}

/// A trait for getting the undirected index of a port.
///
/// This allows functions to admit both [`Port`]s and explicit `usize`s for
/// identifying port offsets.
pub trait PortIndex {
/// Returns the offset of the port.
fn index(self) -> usize;
}

/// A port in the incoming direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)]
pub struct IncomingPort {
index: u16,
}

/// A port in the outgoing direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)]
pub struct OutgoingPort {
index: u16,
}

/// The direction of a port.
pub type Direction = portgraph::Direction;

Expand Down Expand Up @@ -372,18 +381,36 @@ impl Port {

/// Creates a new incoming port.
#[inline]
pub fn new_incoming(port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new_incoming(port),
}
pub fn new_incoming(port: impl Into<IncomingPort>) -> Self {
Self::try_new_incoming(port).unwrap()
}

/// Creates a new outgoing port.
#[inline]
pub fn new_outgoing(port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new_outgoing(port),
}
pub fn new_outgoing(port: impl Into<OutgoingPort>) -> Self {
Self::try_new_outgoing(port).unwrap()
}

/// Creates a new incoming port.
#[inline]
pub fn try_new_incoming(port: impl TryInto<IncomingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Outgoing));
};
Ok(Self {
offset: portgraph::PortOffset::new_incoming(port.index()),
})
}

/// Creates a new outgoing port.
#[inline]
pub fn try_new_outgoing(port: impl TryInto<OutgoingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Incoming));
};
Ok(Self {
offset: portgraph::PortOffset::new_outgoing(port.index()),
})
}

/// Returns the direction of the port.
Expand All @@ -407,6 +434,64 @@ impl PortIndex for usize {
}
}

impl PortIndex for IncomingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl PortIndex for OutgoingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl From<usize> for IncomingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl From<usize> for OutgoingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl TryFrom<Port> for IncomingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Incoming => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

impl TryFrom<Port> for OutgoingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Outgoing => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
Expand All @@ -415,8 +500,8 @@ pub struct Wire(Node, usize);
impl Wire {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: Port) -> Self {
Self(node, port.index())
pub fn new(node: Node, port: impl TryInto<OutgoingPort>) -> Self {
Self(node, Port::try_new_outgoing(port).unwrap().index())
}

/// The node that this wire is connected to.
Expand Down Expand Up @@ -484,6 +569,9 @@ pub enum HugrError {
/// The node doesn't exist.
#[error("Invalid node {0:?}.")]
InvalidNode(Node),
/// An invalid port was specified.
#[error("Invalid port direction {0:?}.")]
InvalidPortDirection(Direction),
}

#[cfg(feature = "pyo3")]
Expand Down
19 changes: 11 additions & 8 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{Hugr, Port};
use self::sealed::HugrMutInternals;

use super::views::SiblingSubgraph;
use super::{NodeMetadata, PortIndex, Rewrite};
use super::{IncomingPort, NodeMetadata, OutgoingPort, PortIndex, Rewrite};

/// Functions for low-level building of a HUGR.
pub trait HugrMut: HugrView + HugrMutInternals {
Expand Down Expand Up @@ -110,9 +110,9 @@ pub trait HugrMut: HugrView + HugrMutInternals {
fn connect(
&mut self,
src: Node,
src_port: impl PortIndex,
src_port: impl TryInto<OutgoingPort>,
dst: Node,
dst_port: impl PortIndex,
dst_port: impl TryInto<IncomingPort>,
) -> Result<(), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
Expand Down Expand Up @@ -262,13 +262,16 @@ where
fn connect(
&mut self,
src: Node,
src_port: impl PortIndex,
src_port: impl TryInto<OutgoingPort>,
dst: Node,
dst_port: impl PortIndex,
dst_port: impl TryInto<IncomingPort>,
) -> Result<(), HugrError> {
self.as_mut()
.graph
.link_nodes(src.index, src_port.index(), dst.index, dst_port.index())?;
self.as_mut().graph.link_nodes(
src.index,
Port::try_new_outgoing(src_port)?.index(),
dst.index,
Port::try_new_incoming(dst_port)?.index(),
)?;
Ok(())
}

Expand Down

0 comments on commit 554d658

Please sign in to comment.