Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optional direction check when querying a port index #566

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{Node, NodeMetadata, Port, PortIndex, ValidationError};
use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};

use std::iter;
Expand Down Expand Up @@ -658,27 +658,26 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
fn wire_up<T: Dataflow + ?Sized>(
data_builder: &mut T,
src: Node,
src_port: impl PortIndex,
src_port: impl TryInto<OutgoingPort>,
dst: Node,
dst_port: impl PortIndex,
dst_port: impl TryInto<IncomingPort>,
) -> Result<bool, BuildError> {
let src_port = src_port.index();
let dst_port = dst_port.index();
let src_port = Port::try_new_outgoing(src_port)?;
let dst_port = Port::try_new_incoming(dst_port)?;
let base = data_builder.hugr_mut();
let src_offset = Port::new_outgoing(src_port);

let src_parent = base.get_parent(src);
let dst_parent = base.get_parent(dst);
let local_source = src_parent == dst_parent;
if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_offset).unwrap() {
if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
if !local_source {
// Non-local value sources require a state edge to an ancestor of dst
if !typ.copyable() {
let val_err: ValidationError = InterGraphEdgeError::NonCopyableData {
from: src,
from_offset: Port::new_outgoing(src_port),
from_offset: src_port,
to: dst,
to_offset: Port::new_incoming(dst_port),
to_offset: dst_port,
ty: EdgeKind::Value(typ),
}
.into();
Expand All @@ -694,9 +693,9 @@ fn wire_up<T: Dataflow + ?Sized>(
else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
from: src,
from_offset: Port::new_outgoing(src_port),
from_offset: src_port,
to: dst,
to_offset: Port::new_incoming(dst_port),
to_offset: dst_port,
}
.into();
return Err(val_err.into());
Expand All @@ -705,7 +704,7 @@ fn wire_up<T: Dataflow + ?Sized>(
// TODO: Avoid adding duplicate edges
// This should be easy with https://github.com/CQCL-DEV/hugr/issues/130
base.add_other_edge(src, src_sibling)?;
} else if !typ.copyable() & base.linked_ports(src, src_offset).next().is_some() {
} else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
// Don't copy linear edges.
return Err(BuildError::NoCopyLinear(typ));
}
Expand All @@ -719,7 +718,7 @@ fn wire_up<T: Dataflow + ?Sized>(
data_builder
.hugr_mut()
.get_optype(dst)
.port_kind(Port::new_incoming(dst_port))
.port_kind(dst_port)
.unwrap(),
EdgeKind::Value(_)
))
Expand Down
114 changes: 101 additions & 13 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,23 @@ pub struct Port {
}

/// A trait for getting the undirected index of a port.
///
/// This allows functions to admit both [`Port`]s and explicit `usize`s for
/// identifying port offsets.
pub trait PortIndex {
/// Returns the offset of the port.
fn index(self) -> usize;
}

/// A port in the incoming direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)]
pub struct IncomingPort {
index: u16,
}

/// A port in the outgoing direction.
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)]
pub struct OutgoingPort {
index: u16,
}

/// The direction of a port.
pub type Direction = portgraph::Direction;

Expand Down Expand Up @@ -372,18 +381,36 @@ impl Port {

/// Creates a new incoming port.
#[inline]
pub fn new_incoming(port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new_incoming(port),
}
pub fn new_incoming(port: impl Into<IncomingPort>) -> Self {
Self::try_new_incoming(port).unwrap()
}

/// Creates a new outgoing port.
#[inline]
pub fn new_outgoing(port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new_outgoing(port),
}
pub fn new_outgoing(port: impl Into<OutgoingPort>) -> Self {
Self::try_new_outgoing(port).unwrap()
}

/// Creates a new incoming port.
#[inline]
pub fn try_new_incoming(port: impl TryInto<IncomingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Outgoing));
};
Ok(Self {
offset: portgraph::PortOffset::new_incoming(port.index()),
})
}

/// Creates a new outgoing port.
#[inline]
pub fn try_new_outgoing(port: impl TryInto<OutgoingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Incoming));
};
Ok(Self {
offset: portgraph::PortOffset::new_outgoing(port.index()),
})
}

/// Returns the direction of the port.
Expand All @@ -407,6 +434,64 @@ impl PortIndex for usize {
}
}

impl PortIndex for IncomingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl PortIndex for OutgoingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}

impl From<usize> for IncomingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl From<usize> for OutgoingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}

impl TryFrom<Port> for IncomingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Incoming => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

impl TryFrom<Port> for OutgoingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Outgoing => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
Expand All @@ -415,8 +500,8 @@ pub struct Wire(Node, usize);
impl Wire {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: Port) -> Self {
Self(node, port.index())
pub fn new(node: Node, port: impl TryInto<OutgoingPort>) -> Self {
Self(node, Port::try_new_outgoing(port).unwrap().index())
}

/// The node that this wire is connected to.
Expand Down Expand Up @@ -484,6 +569,9 @@ pub enum HugrError {
/// The node doesn't exist.
#[error("Invalid node {0:?}.")]
InvalidNode(Node),
/// An invalid port was specified.
#[error("Invalid port direction {0:?}.")]
InvalidPortDirection(Direction),
}

#[cfg(feature = "pyo3")]
Expand Down
19 changes: 11 additions & 8 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{Hugr, Port};
use self::sealed::HugrMutInternals;

use super::views::SiblingSubgraph;
use super::{NodeMetadata, PortIndex, Rewrite};
use super::{IncomingPort, NodeMetadata, OutgoingPort, PortIndex, Rewrite};

/// Functions for low-level building of a HUGR.
pub trait HugrMut: HugrView + HugrMutInternals {
Expand Down Expand Up @@ -110,9 +110,9 @@ pub trait HugrMut: HugrView + HugrMutInternals {
fn connect(
&mut self,
src: Node,
src_port: impl PortIndex,
src_port: impl TryInto<OutgoingPort>,
dst: Node,
dst_port: impl PortIndex,
dst_port: impl TryInto<IncomingPort>,
) -> Result<(), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
Expand Down Expand Up @@ -262,13 +262,16 @@ where
fn connect(
&mut self,
src: Node,
src_port: impl PortIndex,
src_port: impl TryInto<OutgoingPort>,
dst: Node,
dst_port: impl PortIndex,
dst_port: impl TryInto<IncomingPort>,
) -> Result<(), HugrError> {
self.as_mut()
.graph
.link_nodes(src.index, src_port.index(), dst.index, dst_port.index())?;
self.as_mut().graph.link_nodes(
src.index,
Port::try_new_outgoing(src_port)?.index(),
dst.index,
Port::try_new_incoming(dst_port)?.index(),
)?;
Ok(())
}

Expand Down