Skip to content

Commit

Permalink
feat: Static checking of Port direction (#614)
Browse files Browse the repository at this point in the history
* Add fallible methods Port::as_{incoming,outgoing} to convert to
{Incoming/Outgoing}Port
* IncomingPort/OutgoingPort convert to Port (infallibly i.e. Into/From}
* Parameterize many omni-directional methods by `impl Into<Port>`;
require caller to do fallible conversions first + explicitly
* Add HugrView::node_{inputs,outputs} and
HugrView::linked_{inputs,outputs} preserving directionality
* Unidirectional methods/structs take (Into)IncomingPort/OutgoingPort
e.g. Wire, SimpleReplacement, InsertIdentity

BREAKING CHANGE: Port::(try_)new_{incom,outgo}ing gone - use
Port::as_{incom,outgo}ing and {Incom,Outgo}ingPort::{into,from}

---------

Co-authored-by: Agustin Borgna <[email protected]>
  • Loading branch information
acl-cqc and aborgna-q authored Oct 27, 2023
1 parent 0c580eb commit cb4bf6a
Show file tree
Hide file tree
Showing 14 changed files with 236 additions and 256 deletions.
20 changes: 10 additions & 10 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};
use crate::{IncomingPort, Node, OutgoingPort, Port};
use crate::{IncomingPort, Node, OutgoingPort};

use std::iter;

Expand Down Expand Up @@ -369,7 +369,7 @@ pub trait Dataflow: Container {
input_extensions,
),
// Constant wire from the constant value node
vec![Wire::new(const_node, Port::new_outgoing(0))],
vec![Wire::new(const_node, OutgoingPort::from(0))],
)?;

Ok(load_n.out_wire(0))
Expand Down Expand Up @@ -659,12 +659,12 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
fn wire_up<T: Dataflow + ?Sized>(
data_builder: &mut T,
src: Node,
src_port: impl TryInto<OutgoingPort>,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl TryInto<IncomingPort>,
dst_port: impl Into<IncomingPort>,
) -> Result<bool, BuildError> {
let src_port = Port::try_new_outgoing(src_port)?;
let dst_port = Port::try_new_incoming(dst_port)?;
let src_port = src_port.into();
let dst_port = dst_port.into();
let base = data_builder.hugr_mut();

let src_parent = base.get_parent(src);
Expand All @@ -676,9 +676,9 @@ fn wire_up<T: Dataflow + ?Sized>(
if !typ.copyable() {
let val_err: ValidationError = InterGraphEdgeError::NonCopyableData {
from: src,
from_offset: src_port,
from_offset: src_port.into(),
to: dst,
to_offset: dst_port,
to_offset: dst_port.into(),
ty: EdgeKind::Value(typ),
}
.into();
Expand All @@ -694,9 +694,9 @@ fn wire_up<T: Dataflow + ?Sized>(
else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
from: src,
from_offset: src_port,
from_offset: src_port.into(),
to: dst,
to_offset: dst_port,
to_offset: dst_port.into(),
}
.into();
return Err(val_err.into());
Expand Down
21 changes: 7 additions & 14 deletions src/builder/handle.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
//! Handles to nodes in HUGR used during the building phase.
//!
use crate::{
ops::{
handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID},
OpTag,
},
Port,
};
use crate::{Node, Wire};
use crate::ops::handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID};
use crate::ops::OpTag;
use crate::{Node, OutgoingPort, Wire};

use itertools::Itertools;
use std::iter::FusedIterator;
Expand Down Expand Up @@ -64,7 +59,7 @@ impl<T: NodeHandle> BuildHandle<T> {
/// Retrieve a [`Wire`] corresponding to the given offset.
/// Does not check whether such a wire is valid for this node.
pub fn out_wire(&self, offset: usize) -> Wire {
Wire::new(self.node(), Port::new_outgoing(offset))
Wire::new(self.node(), OutgoingPort::from(offset))
}

#[inline]
Expand Down Expand Up @@ -124,14 +119,12 @@ impl Iterator for Outputs {
fn next(&mut self) -> Option<Self::Item> {
self.range
.next()
.map(|offset| Wire::new(self.node, Port::new_outgoing(offset)))
.map(|offset| Wire::new(self.node, OutgoingPort::from(offset)))
}

#[inline]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.range
.nth(n)
.map(|offset| Wire::new(self.node, Port::new_outgoing(offset)))
self.range.nth(n).map(|offset| Wire::new(self.node, offset))
}

#[inline]
Expand All @@ -157,7 +150,7 @@ impl DoubleEndedIterator for Outputs {
fn next_back(&mut self) -> Option<Self::Item> {
self.range
.next_back()
.map(|offset| Wire::new(self.node, Port::new_outgoing(offset)))
.map(|offset| Wire::new(self.node, offset))
}
}

Expand Down
72 changes: 31 additions & 41 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub type Direction = portgraph::Direction;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
pub struct Wire(Node, usize);
pub struct Wire(Node, OutgoingPort);

impl Node {
/// Returns the node as a portgraph `NodeIndex`.
Expand All @@ -88,29 +88,29 @@ impl Port {
}
}

/// Creates a new incoming port.
/// Converts to an [IncomingPort] if this port is one; else fails with
/// [HugrError::InvalidPortDirection]
#[inline]
pub fn new_incoming(port: impl Into<IncomingPort>) -> Self {
Self::try_new_incoming(port).unwrap()
}

/// Creates a new outgoing port.
#[inline]
pub fn new_outgoing(port: impl Into<OutgoingPort>) -> Self {
Self::try_new_outgoing(port).unwrap()
pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
match self.direction() {
Direction::Incoming => Ok(IncomingPort {
index: self.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
}

/// Creates a new incoming port.
/// Converts to an [OutgoingPort] if this port is one; else fails with
/// [HugrError::InvalidPortDirection]
#[inline]
pub fn try_new_incoming(port: impl TryInto<IncomingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Outgoing));
};
Ok(Self {
offset: portgraph::PortOffset::new_incoming(port.index()),
})
pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
match self.direction() {
Direction::Outgoing => Ok(OutgoingPort {
index: self.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
}

/// Creates a new outgoing port.
#[inline]
pub fn try_new_outgoing(port: impl TryInto<OutgoingPort>) -> Result<Self, HugrError> {
Expand Down Expand Up @@ -181,28 +181,18 @@ impl From<usize> for OutgoingPort {
}
}

impl TryFrom<Port> for IncomingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Incoming => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
impl From<IncomingPort> for Port {
fn from(value: IncomingPort) -> Self {
Self {
offset: portgraph::PortOffset::new_incoming(value.index()),
}
}
}

impl TryFrom<Port> for OutgoingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Outgoing => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
impl From<OutgoingPort> for Port {
fn from(value: OutgoingPort) -> Self {
Self {
offset: portgraph::PortOffset::new_outgoing(value.index()),
}
}
}
Expand All @@ -216,8 +206,8 @@ impl NodeIndex for Node {
impl Wire {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: impl TryInto<OutgoingPort>) -> Self {
Self(node, Port::try_new_outgoing(port).unwrap().index())
pub fn new(node: Node, port: impl Into<OutgoingPort>) -> Self {
Self(node, port.into())
}

/// The node that this wire is connected to.
Expand All @@ -228,8 +218,8 @@ impl Wire {

/// The output port that this wire is connected to.
#[inline]
pub fn source(&self) -> Port {
Port::new_outgoing(self.1)
pub fn source(&self) -> OutgoingPort {
self.1
}
}

Expand Down
39 changes: 25 additions & 14 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ pub trait HugrMut: HugrMutInternals {
fn connect(
&mut self,
src: Node,
src_port: impl TryInto<OutgoingPort>,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl TryInto<IncomingPort>,
dst_port: impl Into<IncomingPort>,
) -> Result<(), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
Expand All @@ -121,7 +121,7 @@ pub trait HugrMut: HugrMutInternals {
///
/// The port is left in place.
#[inline]
fn disconnect(&mut self, node: Node, port: Port) -> Result<(), HugrError> {
fn disconnect(&mut self, node: Node, port: impl Into<Port>) -> Result<(), HugrError> {
self.valid_node(node)?;
self.hugr_mut().disconnect(node, port)
}
Expand All @@ -134,7 +134,11 @@ pub trait HugrMut: HugrMutInternals {
///
/// [`OpTrait::other_input`]: crate::ops::OpTrait::other_input
/// [`OpTrait::other_output`]: crate::ops::OpTrait::other_output
fn add_other_edge(&mut self, src: Node, dst: Node) -> Result<(Port, Port), HugrError> {
fn add_other_edge(
&mut self,
src: Node,
dst: Node,
) -> Result<(OutgoingPort, IncomingPort), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
self.hugr_mut().add_other_edge(src, dst)
Expand Down Expand Up @@ -247,20 +251,21 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
fn connect(
&mut self,
src: Node,
src_port: impl TryInto<OutgoingPort>,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl TryInto<IncomingPort>,
dst_port: impl Into<IncomingPort>,
) -> Result<(), HugrError> {
self.as_mut().graph.link_nodes(
src.pg_index(),
Port::try_new_outgoing(src_port)?.index(),
src_port.into().index(),
dst.pg_index(),
Port::try_new_incoming(dst_port)?.index(),
dst_port.into().index(),
)?;
Ok(())
}

fn disconnect(&mut self, node: Node, port: Port) -> Result<(), HugrError> {
fn disconnect(&mut self, node: Node, port: impl Into<Port>) -> Result<(), HugrError> {
let port = port.into();
let offset = port.pg_offset();
let port = self
.as_mut()
Expand All @@ -274,15 +279,21 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
Ok(())
}

fn add_other_edge(&mut self, src: Node, dst: Node) -> Result<(Port, Port), HugrError> {
let src_port: Port = self
fn add_other_edge(
&mut self,
src: Node,
dst: Node,
) -> Result<(OutgoingPort, IncomingPort), HugrError> {
let src_port = self
.get_optype(src)
.other_port_index(Direction::Outgoing)
.expect("Source operation has no non-dataflow outgoing edges");
let dst_port: Port = self
.expect("Source operation has no non-dataflow outgoing edges")
.as_outgoing()?;
let dst_port = self
.get_optype(dst)
.other_port_index(Direction::Incoming)
.expect("Destination operation has no non-dataflow incoming edges");
.expect("Destination operation has no non-dataflow incoming edges")
.as_incoming()?;
self.connect(src, src_port, dst, dst_port)?;
Ok((src_port, dst_port))
}
Expand Down
21 changes: 4 additions & 17 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::iter;
use crate::hugr::{HugrMut, Node};
use crate::ops::{LeafOp, OpTag, OpTrait};
use crate::types::EdgeKind;
use crate::{Direction, HugrView, Port};
use crate::{HugrView, IncomingPort};

use super::Rewrite;

Expand All @@ -18,12 +18,12 @@ pub struct IdentityInsertion {
/// The node following the identity to be inserted.
pub post_node: Node,
/// The port following the identity to be inserted.
pub post_port: Port,
pub post_port: IncomingPort,
}

impl IdentityInsertion {
/// Create a new [`IdentityInsertion`] specification.
pub fn new(post_node: Node, post_port: Port) -> Self {
pub fn new(post_node: Node, post_port: IncomingPort) -> Self {
Self {
post_node,
post_port,
Expand All @@ -43,10 +43,6 @@ pub enum IdentityInsertionError {
/// Invalid port kind.
#[error("post_port has invalid kind {0:?}. Must be Value.")]
InvalidPortKind(Option<EdgeKind>),

/// Must be input port.
#[error("post_port is an output port, must be input.")]
PortIsOutput,
}

impl Rewrite for IdentityInsertion {
Expand All @@ -71,17 +67,13 @@ impl Rewrite for IdentityInsertion {
unimplemented!()
}
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, IdentityInsertionError> {
if self.post_port.direction() != Direction::Incoming {
return Err(IdentityInsertionError::PortIsOutput);
}

let kind = h.get_optype(self.post_node).port_kind(self.post_port);
let Some(EdgeKind::Value(ty)) = kind else {
return Err(IdentityInsertionError::InvalidPortKind(kind));
};

let (pre_node, pre_port) = h
.linked_ports(self.post_node, self.post_port)
.linked_outputs(self.post_node, self.post_port)
.exactly_one()
.ok()
.expect("Value kind input can only have one connection.");
Expand Down Expand Up @@ -155,11 +147,6 @@ mod tests {

let final_node = tail.node();

let final_node_output = h.node_outputs(final_node).next().unwrap();
let rw = IdentityInsertion::new(final_node, final_node_output);
let apply_result = h.apply_rewrite(rw);
assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput));

let final_node_input = h.node_inputs(final_node).next().unwrap();

let rw = IdentityInsertion::new(final_node, final_node_input);
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl Rewrite for OutlineCfg {

// 3. Entry edges. Change any edges into entry_block from outside, to target new_block
let preds: Vec<_> = h
.linked_ports(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
.linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
.collect();
for (pred, br) in preds {
if !self.blocks.contains(&pred) {
Expand Down
Loading

0 comments on commit cb4bf6a

Please sign in to comment.