Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve convexity checking and fix test #585

Merged
merged 8 commits into from
Oct 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 101 additions & 14 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,10 @@ fn get_edge_type<H: HugrView>(hugr: &H, ports: &[(Node, Port)]) -> Option<Type>

/// Whether a subgraph is valid.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Whether a subgraph is valid.
/// Whether a subgraph is valid.
///
/// Verifies that input and output ports are valid subgraph boundaries, i.e. they belong
/// to nodes within the subgraph and are linked to at least one node outside of the subgraph.
/// This does NOT check convexity proper, i.e. whether the set of nodes form a convex
/// induced graph.

///
/// Does NOT check for convexity.
/// Verifies that input and output ports are valid subgraph boundaries, i.e. they belong
/// to nodes within the subgraph and are linked to at least one node outside of the subgraph.
/// This does NOT check convexity proper, i.e. whether the set of nodes form a convex
/// induced graph.
fn validate_subgraph<H: HugrView>(
hugr: &H,
nodes: &[Node],
Expand Down Expand Up @@ -517,16 +520,34 @@ fn validate_subgraph<H: HugrView>(
}

let mut ports_inside = inputs.iter().flatten().chain(outputs).copied();
let mut ports_outside = ports_inside
.clone()
.flat_map(|(n, p)| hugr.linked_ports(n, p));
// Check incoming & outgoing ports have target resp. source inside
let nodes = nodes.iter().copied().collect::<HashSet<_>>();
if ports_inside.any(|(n, _)| !nodes.contains(&n)) {
return Err(InvalidSubgraph::InvalidBoundary);
}
// Check incoming & outgoing ports have source resp. target outside
if ports_outside.any(|(n, _)| nodes.contains(&n)) {
// Check that every inside port has at least one linked port outside.
if ports_inside.any(|(n, p)| hugr.linked_ports(n, p).all(|(n1, _)| nodes.contains(&n1))) {
return Err(InvalidSubgraph::InvalidBoundary);
}
// Check that every incoming port of a node in the subgraph whose source is not in the subgraph
// belongs to inputs.
if nodes.clone().into_iter().any(|n| {
hugr.node_inputs(n).any(|p| {
hugr.linked_ports(n, p).any(|(n1, _)| {
!nodes.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
})
})
}) {
return Err(InvalidSubgraph::NotConvex);
}
// Check that every outgoing port of a node in the subgraph whose target is not in the subgraph
// belongs to outputs.
if nodes.clone().into_iter().any(|n| {
hugr.node_outputs(n).any(|p| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !nodes.contains(&n1) && !outputs.contains(&(n, p)))
})
}) {
return Err(InvalidSubgraph::NotConvex);
}

Expand Down Expand Up @@ -662,10 +683,13 @@ mod tests {
hugr::views::{HierarchyView, SiblingGraph},
hugr::HugrMut,
ops::{
handle::{FuncID, NodeHandle},
handle::{DfgID, FuncID, NodeHandle},
OpType,
},
std_extensions::{logic::test::and_op, quantum::test::cx_gate},
std_extensions::{
logic::test::{and_op, not_op},
quantum::test::cx_gate,
},
type_row,
};

Expand Down Expand Up @@ -712,6 +736,23 @@ mod tests {
Ok((hugr, func_id.node()))
}

fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
let mut mod_builder = ModuleBuilder::new();
let func =
mod_builder.declare("test", FunctionType::new_linear(type_row![BOOL_T]).pure())?;
let func_id = {
let mut dfg = mod_builder.define_declaration(&func)?;
let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?;
let outs2 = dfg.add_dataflow_op(not_op(), outs1.outputs())?;
let outs3 = dfg.add_dataflow_op(not_op(), outs2.outputs())?;
dfg.finish_with_outputs(outs3.outputs())?
};
let hugr = mod_builder
.finish_prelude_hugr()
.map_err(|e| -> BuildError { e.into() })?;
Ok((hugr, func_id.node()))
}

/// A HUGR with a copy
fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
let mut mod_builder = ModuleBuilder::new();
Expand Down Expand Up @@ -854,19 +895,41 @@ mod tests {

#[test]
fn non_convex_subgraph() {
let (hugr, func_root) = build_3not_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, _out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let not1 = hugr.output_neighbours(inp).exactly_one().unwrap();
let not2 = hugr.output_neighbours(not1).exactly_one().unwrap();
let not3 = hugr.output_neighbours(not2).exactly_one().unwrap();
let not1_inp = hugr.node_inputs(not1).next().unwrap();
let not1_out = hugr.node_outputs(not1).next().unwrap();
let not3_inp = hugr.node_inputs(not3).next().unwrap();
let not3_out = hugr.node_outputs(not3).next().unwrap();
assert!(matches!(
SiblingSubgraph::try_new(
vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
vec![(not1, not1_out), (not3, not3_out)],
&func
),
Err(InvalidSubgraph::NotConvex)
));
}

#[test]
fn invalid_boundary() {
let (hugr, func_root) = build_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
let snd_cx_edge = hugr.node_inputs(out).next().unwrap();
// All graph but one edge
let cx_edges_in = hugr.node_outputs(inp);
let cx_edges_out = hugr.node_inputs(out);
// All graph but the CX
assert!(matches!(
SiblingSubgraph::try_new(
vec![vec![(out, snd_cx_edge)]],
vec![(inp, first_cx_edge)],
cx_edges_out.map(|p| vec![(out, p)]).collect(),
cx_edges_in.map(|p| (inp, p)).collect(),
&func,
),
Err(InvalidSubgraph::NotConvex)
Err(InvalidSubgraph::InvalidBoundary)
));
}

Expand Down Expand Up @@ -894,4 +957,28 @@ mod tests {

Ok(())
}

#[test]
fn edge_both_output_and_copy() {
// https://github.com/CQCL-DEV/hugr/issues/518
let one_bit = type_row![BOOL_T];
let two_bit = type_row![BOOL_T, BOOL_T];

let mut builder =
DFGBuilder::new(FunctionType::new(one_bit.clone(), two_bit.clone())).unwrap();
let inw = builder.input_wires().exactly_one().unwrap();
let outw1 = builder
.add_dataflow_op(not_op(), [inw])
.unwrap()
.out_wire(0);
let outw2 = builder
.add_dataflow_op(and_op(), [inw, outw1])
.unwrap()
.outputs();
let outw = [outw1].into_iter().chain(outw2);
let h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap();
let view = SiblingGraph::<DfgID>::try_new(&h, h.root()).unwrap();
let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap();
assert_eq!(subg.nodes().len(), 2);
}
}