Skip to content

Commit

Permalink
fix: cfg not validating entry/exit types
Browse files Browse the repository at this point in the history
Closes #1189

Failing example generated using hugr-py
  • Loading branch information
ss2165 committed Jun 26, 2024
1 parent 70e0b87 commit 9048c8b
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 4 deletions.
29 changes: 28 additions & 1 deletion hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::fs::File;
use std::io::BufReader;

use cool_asserts::assert_matches;
use rstest::rstest;

Expand All @@ -19,7 +22,7 @@ use crate::std_extensions::logic::test::{and_op, or_op};
use crate::std_extensions::logic::{self, NotOp};
use crate::types::type_param::{TypeArg, TypeArgError};
use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow};
use crate::{const_extension_ids, type_row, Direction, IncomingPort, Node};
use crate::{const_extension_ids, test_file, type_row, Direction, IncomingPort, Node};

const NAT: Type = crate::extension::prelude::USIZE_T;

Expand Down Expand Up @@ -926,6 +929,13 @@ fn cfg_children_restrictions() {
b.remove_node(exit2);

// Change the types in the BasicBlock node to work on qubits instead of bits
b.replace_op(
cfg,
ops::CFG {
signature: FunctionType::new(type_row![QB_T], type_row![BOOL_T]),
},
)
.unwrap();
b.replace_op(
block,
ops::DataflowBlock {
Expand Down Expand Up @@ -989,6 +999,23 @@ fn cfg_connections() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn cfg_entry_io_bug() -> Result<(), Box<dyn std::error::Error>> {
// load test file where input node of entry block has types in reversed
// order compared to parent CFG node.
let mut hugr: Hugr = serde_json::from_reader(BufReader::new(
File::open(test_file!("issue-1189.json")).unwrap(),
))
.unwrap();
assert_matches!(
hugr.update_validate(&PRELUDE_REGISTRY),
Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch{..}, .. })
=> assert_eq!(parent, hugr.root())
);

Ok(())
}

#[cfg(feature = "extension_inference")]
mod extension_tests {
use self::ops::handle::{BasicBlockID, TailLoopID};
Expand Down
34 changes: 31 additions & 3 deletions hugr-core/src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use thiserror::Error;

use crate::types::TypeRow;

use super::dataflow::DataflowParent;
use super::dataflow::{DataflowOpTrait, DataflowParent};
use super::{impl_validate_op, BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp};

/// A set of property flags required for an operation.
Expand Down Expand Up @@ -121,9 +121,37 @@ impl ValidateOp for super::CFG {

fn validate_op_children<'a>(
&self,
children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
for (child, optype) in children.dropping(2) {
let (entry, entry_op) = children.next().unwrap();
let (exit, exit_op) = children.next().unwrap();
let entry_op = entry_op
.as_dataflow_block()
.expect("Child check should have already checked valid ops.");
let exit_op = exit_op
.as_exit_block()
.expect("Child check should have already checked valid ops.");

let sig = self.signature();
if entry_op.inner_signature().input() != sig.input() {
return Err(ChildrenValidationError::IOSignatureMismatch {
child: entry,
actual: entry_op.inner_signature().input().clone(),
expected: sig.input().clone(),
node_desc: "BasicBlock Input",
container_desc: "CFG",
});
}
if &exit_op.cfg_outputs != sig.output() {
return Err(ChildrenValidationError::IOSignatureMismatch {
child: exit,
actual: exit_op.cfg_outputs.clone(),
expected: sig.output().clone(),
node_desc: "BasicBlockExit Output",
container_desc: "CFG",
});
}
for (child, optype) in children {
if optype.tag() == OpTag::BasicBlockExit {
return Err(ChildrenValidationError::InternalExitChildren { child });
}
Expand Down
136 changes: 136 additions & 0 deletions resources/test/issue-1189.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
{
"version": "v1",
"nodes": [
{
"parent": 0,
"op": "CFG",
"signature": {
"t": "G",
"input": [
{
"t": "Sum",
"s": "Unit",
"size": 1
},
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"output": [
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"extension_reqs": []
}
},
{
"parent": 0,
"op": "DataflowBlock",
"inputs": [
{
"t": "Sum",
"s": "Unit",
"size": 2
},
{
"t": "Sum",
"s": "Unit",
"size": 1
}
],
"other_outputs": [
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"sum_rows": [
[]
],
"extension_delta": []
},
{
"parent": 1,
"op": "Input",
"types": [
{
"t": "Sum",
"s": "Unit",
"size": 2
},
{
"t": "Sum",
"s": "Unit",
"size": 1
}
]
},
{
"parent": 1,
"op": "Output",
"types": [
{
"t": "Sum",
"s": "Unit",
"size": 1
},
{
"t": "Sum",
"s": "Unit",
"size": 2
}
]
},
{
"parent": 0,
"op": "ExitBlock",
"cfg_outputs": [
{
"t": "Sum",
"s": "Unit",
"size": 2
}
]
}
],
"edges": [
[
[
2,
1
],
[
3,
0
]
],
[
[
2,
0
],
[
3,
1
]
],
[
[
1,
0
],
[
4,
0
]
]
],
"metadata": null,
"encoder": "hugr-py v0.2.1"
}

0 comments on commit 9048c8b

Please sign in to comment.