Skip to content

Commit

Permalink
feat: HugrView::signature(Node) and port+type iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 13, 2023
1 parent 0437efc commit 47d4646
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 10 deletions.
3 changes: 1 addition & 2 deletions src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use pyo3::{create_exception, exceptions::PyException, PyErr};
use crate::core::NodeIndex;
use crate::extension::ExtensionSet;
use crate::hugr::{Hugr, NodeType};
use crate::ops::OpTrait;
use crate::ops::OpType;
use crate::{Node, PortIndex};
use portgraph::hierarchy::AttachError;
Expand Down Expand Up @@ -167,7 +166,7 @@ impl TryFrom<&Hugr> for SerHugrV0 {
.expect("Could not reach one of the nodes");

let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| {
let sig = hugr.get_optype(node).signature();
let sig = hugr.signature(node);
let offset = match offset < sig.port_count(dir) {
true => Some(offset as u16),
false => None,
Expand Down
33 changes: 33 additions & 0 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};
use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE};
use crate::ops::handle::NodeHandle;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG};
#[rustversion::since(1.75)] // uses impl in return position
use crate::types::Type;
use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, IncomingPort, Node, OutgoingPort, Port};

Expand Down Expand Up @@ -355,6 +357,37 @@ pub trait HugrView: sealed::HugrInternals {
.next()
.map(|(n, _)| n)
}

/// Get the "signature" (incoming and outgoing types) of a node, non-Value
/// kind edges will be missing.
fn signature(&self, node: Node) -> FunctionType {
self.get_optype(node).signature()
}

#[rustversion::since(1.75)] // uses impl in return position
/// Iterator over all ports in a given direction that have Value type, along
/// with corresponding types.
fn value_types(&self, node: Node, dir: Direction) -> impl Iterator<Item = (Port, Type)> {
let sig = self.signature(node);
self.node_ports(node, dir)
.flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone())))
}

#[rustversion::since(1.75)] // uses impl in return position
/// Iterator over all incoming ports that have Value type, along
/// with corresponding types.
fn in_value_types(&self, node: Node) -> impl Iterator<Item = (IncomingPort, Type)> {
self.value_types(node, Direction::Incoming)
.map(|(p, t)| (p.as_incoming().unwrap(), t))
}

#[rustversion::since(1.75)] // uses impl in return position
/// Iterator over all incoming ports that have Value type, along
/// with corresponding types.
fn out_value_types(&self, node: Node) -> impl Iterator<Item = (OutgoingPort, Type)> {
self.value_types(node, Direction::Outgoing)
.map(|(p, t)| (p.as_outgoing().unwrap(), t))
}
}

/// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s
Expand Down
16 changes: 8 additions & 8 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,15 @@ impl SiblingSubgraph {
.iter()
.map(|part| {
let &(n, p) = part.iter().next().expect("is non-empty");
let sig = hugr.get_optype(n).signature();
let sig = hugr.signature(n);
sig.port_type(p).cloned().expect("must be dataflow edge")
})
.collect_vec();
let output = self
.outputs
.iter()
.map(|&(n, p)| {
let sig = hugr.get_optype(n).signature();
let sig = hugr.signature(n);
sig.port_type(p).cloned().expect("must be dataflow edge")
})
.collect_vec();
Expand Down Expand Up @@ -356,10 +356,10 @@ impl SiblingSubgraph {
// See https://github.com/CQCL-DEV/hugr/discussions/432
let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p));
let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p));
let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs
.partition(|&(n, p)| replacement.get_optype(n).signature().port_type(p).is_some());
let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs
.partition(|&(n, p)| replacement.get_optype(n).signature().port_type(p).is_some());
let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) =
rep_inputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some());
let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) =
rep_outputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some());

if combine_in_out(&vec![out_order_ports], &in_order_ports)
.any(|(n, p)| is_order_edge(&replacement, n, p))
Expand Down Expand Up @@ -467,10 +467,10 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> {
/// If the array is empty or a port does not exist, returns `None`.
fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option<Type> {
let &(n, p) = ports.first()?;
let edge_t = hugr.get_optype(n).signature().port_type(p)?.clone();
let edge_t = hugr.signature(n).port_type(p)?.clone();
ports
.iter()
.all(|&(n, p)| hugr.get_optype(n).signature().port_type(p) == Some(&edge_t))
.all(|&(n, p)| hugr.signature(n).port_type(p) == Some(&edge_t))
.then_some(edge_t)
}

Expand Down
31 changes: 31 additions & 0 deletions src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,34 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<Dataflow
]
);
}

#[rustversion::since(1.75)] // uses impl in return position
#[test]
fn value_types() {
use crate::builder::Container;
use crate::extension::prelude::BOOL_T;
use crate::std_extensions::logic::test::not_op;
use crate::utils::test_quantum_extension::h_gate;
use itertools::Itertools;
let mut dfg = DFGBuilder::new(FunctionType::new(
type_row![QB_T, BOOL_T],
type_row![BOOL_T, QB_T],
))
.unwrap();

let [q, b] = dfg.input_wires_arr();
let n1 = dfg.add_dataflow_op(h_gate(), [q]).unwrap();
let n2 = dfg.add_dataflow_op(not_op(), [b]).unwrap();
dfg.add_other_wire(n1.node(), n2.node()).unwrap();
let h = dfg
.finish_prelude_hugr_with_outputs([n2.out_wire(0), n1.out_wire(0)])
.unwrap();

let [_, o] = h.get_io(h.root()).unwrap();
let n1_out_types = h.out_value_types(n1.node()).collect_vec();

assert_eq!(&n1_out_types[..], &[(0.into(), QB_T)]);
let out_types = h.in_value_types(o).collect_vec();

assert_eq!(&out_types[..], &[(0.into(), BOOL_T), (1.into(), QB_T)]);
}

0 comments on commit 47d4646

Please sign in to comment.