Skip to content

Commit

Permalink
feat: single_source/target for when you know only one connected port
Browse files Browse the repository at this point in the history
Closes Shorthand HugrView method for "connected port" #499
  • Loading branch information
ss2165 committed Nov 13, 2023
1 parent b191108 commit 6094663
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 30 deletions.
5 changes: 1 addition & 4 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::{HugrView, IncomingPort};

use super::Rewrite;

use itertools::Itertools;
use thiserror::Error;

/// Specification of a identity-insertion operation.
Expand Down Expand Up @@ -73,9 +72,7 @@ impl Rewrite for IdentityInsertion {
};

let (pre_node, pre_port) = h
.linked_outputs(self.post_node, self.post_port)
.exactly_one()
.ok()
.single_source(self.post_node, self.post_port)
.expect("Value kind input can only have one connection.");

h.disconnect(self.post_node, self.post_port).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl Rewrite for OutlineCfg {
let exit_port = h
.node_outputs(exit)
.filter(|p| {
let (t, p2) = h.linked_ports(exit, *p).exactly_one().ok().unwrap();
let (t, p2) = h.single_target(exit, *p).unwrap();
assert!(p2.index() == 0);
t == outside
})
Expand Down
5 changes: 2 additions & 3 deletions src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ impl NewEdgeSpec {
true
};
let found_incoming = h
.linked_ports(self.tgt, tgt_pos)
.exactly_one()
.is_ok_and(|(src_n, _)| descends_from_legal(src_n));
.single_source(self.tgt, tgt_pos)
.is_some_and(|(src_n, _)| descends_from_legal(src_n));
if !found_incoming {
return Err(ReplaceError::NoRemovedEdge(err_edge()));
};
Expand Down
27 changes: 7 additions & 20 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use std::collections::{hash_map, HashMap};
use std::iter::{self, Copied};
use std::slice;

use itertools::Itertools;

use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite};
use crate::ops::{OpTag, OpTrait, OpType};
Expand Down Expand Up @@ -129,11 +127,8 @@ impl Rewrite for SimpleReplacement {
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp {
if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.linked_outputs(*rem_inp_node, *rem_inp_port)
.exactly_one()
.ok() // PortLinks does not implement Debug
.unwrap();
let (rem_inp_pred_node, rem_inp_pred_port) =
h.single_source(*rem_inp_node, *rem_inp_port).unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port).unwrap();
let new_inp_node = index_map.get(rep_inp_node).unwrap();
h.connect(
Expand All @@ -150,8 +145,7 @@ impl Rewrite for SimpleReplacement {
for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
let (rep_out_pred_node, rep_out_pred_port) = self
.replacement
.linked_outputs(replacement_output_node, *rep_out_port)
.exactly_one()
.single_source(replacement_output_node, *rep_out_port)
.unwrap();
if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input {
let new_out_node = index_map.get(&rep_out_pred_node).unwrap();
Expand All @@ -171,11 +165,8 @@ impl Rewrite for SimpleReplacement {
let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let (rem_inp_pred_node, rem_inp_pred_port) = h
.linked_outputs(*rem_inp_node, *rem_inp_port)
.exactly_one()
.ok() // PortLinks does not implement Debug
.unwrap();
let (rem_inp_pred_node, rem_inp_pred_port) =
h.single_source(*rem_inp_node, *rem_inp_port).unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port).unwrap();
h.disconnect(*rem_out_node, *rem_out_port).unwrap();
h.connect(
Expand Down Expand Up @@ -603,7 +594,7 @@ pub(in crate::hugr::rewrite) mod test {
if *tgt == out {
unimplemented!()
};
let (src, src_port) = h.linked_outputs(*r_n, *r_p).exactly_one().ok().unwrap();
let (src, src_port) = h.single_source(*r_n, *r_p).unwrap();
NewEdgeSpec {
src,
tgt: *tgt,
Expand All @@ -618,11 +609,7 @@ pub(in crate::hugr::rewrite) mod test {
.nu_out
.iter()
.map(|((tgt, tgt_port), out_port)| {
let (src, src_port) = replacement
.linked_outputs(out, *out_port)
.exactly_one()
.ok()
.unwrap();
let (src, src_port) = replacement.single_source(out, *out_port).unwrap();
if src == in_ {
unimplemented!()
};
Expand Down
25 changes: 25 additions & 0 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,31 @@ pub trait HugrView: sealed::HugrInternals {
.flat_map(move |port| self.linked_inputs(node, port))
}

/// If there is exactly one OutgoingPort connected to this IncomingPort, return
/// it and its node.
fn single_source(
&self,
node: Node,
port: impl Into<IncomingPort>,
) -> Option<(Node, OutgoingPort)> {
self.linked_ports(node, port.into())
.exactly_one()
.ok()
.map(|(n, p)| (n, p.as_outgoing().unwrap()))
}

/// If there is exactly one IncomingPort connected to this OutgoingPort, return
/// it and its node.
fn single_target(
&self,
node: Node,
port: impl Into<OutgoingPort>,
) -> Option<(Node, IncomingPort)> {
self.linked_ports(node, port.into())
.exactly_one()
.ok()
.map(|(n, p)| (n, p.as_incoming().unwrap()))
}
/// Iterator over the nodes and output ports connected to a given *input* port.
/// Like [`linked_ports`][HugrView::linked_ports] but preserves knowledge
/// that the linked ports are [OutgoingPort]s.
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ impl SiblingSubgraph {
if !hugr.is_linked(n, p) {
return false;
}
let (out_n, _) = hugr.linked_ports(n, p).exactly_one().ok().unwrap();
let (out_n, _) = hugr.single_source(n, p).unwrap();
!nodes_set.contains(&out_n)
})
// Every incoming edge is its own input.
Expand Down Expand Up @@ -882,7 +882,7 @@ mod tests {
.collect(),
hugr.node_inputs(out)
.take(2)
.filter_map(|p| hugr.linked_outputs(out, p).exactly_one().ok())
.filter_map(|p| hugr.single_source(out, p))
.collect(),
&func,
)
Expand Down

0 comments on commit 6094663

Please sign in to comment.