From 2b05270d2cbce436694e0a62d8ecc44cba78badd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 31 Aug 2023 17:45:51 +0100 Subject: [PATCH] fix: in IdentityInsertion add noop to correct parent (#477) --- src/hugr/rewrite/insert_identity.rs | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index e61be1e0c..a12dd045b 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -1,7 +1,7 @@ //! Implementation of the `InsertIdentity` operation. use crate::hugr::{HugrMut, Node}; -use crate::ops::LeafOp; +use crate::ops::{LeafOp, OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{Direction, Hugr, HugrView, Port}; @@ -32,6 +32,9 @@ impl IdentityInsertion { /// Error from an [`IdentityInsertion`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] pub enum IdentityInsertionError { + /// Invalid parent node. + #[error("Parent node is invalid.")] + InvalidParentNode, /// Invalid node. #[error("Node is invalid.")] InvalidNode(), @@ -78,7 +81,15 @@ impl Rewrite for IdentityInsertion { .expect("Value kind input can only have one connection."); h.disconnect(self.post_node, self.post_port).unwrap(); - let new_node = h.add_op(LeafOp::Noop { ty }); + let parent = h + .get_parent(self.post_node) + .ok_or(IdentityInsertionError::InvalidParentNode)?; + if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) { + return Err(IdentityInsertionError::InvalidParentNode); + } + let new_node = h + .add_op_with_parent(parent, LeafOp::Noop { ty }) + .expect("Parent validity already checked."); h.connect(pre_node, pre_port.index(), new_node, 0) .expect("Should only fail if ports don't exist."); @@ -95,8 +106,10 @@ mod tests { use super::super::simple_replace::test::dfg_hugr; use super::*; use crate::{ - algorithm::nest_cfgs::test::build_conditional_in_loop_cfg, extension::prelude::QB_T, - ops::handle::NodeHandle, Hugr, + algorithm::nest_cfgs::test::build_conditional_in_loop_cfg, + extension::{prelude::QB_T, prelude_registry}, + ops::handle::NodeHandle, + Hugr, }; #[rstest] @@ -121,6 +134,8 @@ mod tests { let noop: LeafOp = h.get_optype(noop_node).clone().try_into().unwrap(); assert_eq!(noop, LeafOp::Noop { ty: QB_T }); + + h.validate(&prelude_registry()).unwrap(); } #[test]