diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 6371e7772..f467b4b7f 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -24,7 +24,7 @@ use itertools::{Itertools, MapInto}; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{multiportgraph, LinkView, PortView}; -use super::internal::HugrInternals; +use super::internal::{HugrInternals, HugrMutInternals}; use super::{ Hugr, HugrError, HugrMut, NodeMetadata, NodeMetadataMap, ValidationError, DEFAULT_OPTYPE, }; @@ -512,6 +512,7 @@ pub trait ExtractHugr: HugrView + Sized { let old_root = hugr.root(); let new_root = hugr.insert_from_view(old_root, &self).new_root; hugr.set_root(new_root); + hugr.set_num_ports(new_root, 0, 0); hugr.remove_node(old_root); hugr } diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 35a6e46c6..d2b66b250 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -201,6 +201,9 @@ where #[cfg(test)] pub(super) mod test { + use rstest::rstest; + + use crate::extension::PRELUDE_REGISTRY; use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, type_row, @@ -269,4 +272,20 @@ pub(super) mod test { Ok(()) } + + #[rstest] + fn extract_hugr() -> Result<(), Box> { + let (hugr, def, _inner) = make_module_hgr()?; + + let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; + let extracted = region.extract_hugr(); + extracted.validate(&PRELUDE_REGISTRY)?; + + let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; + + assert_eq!(region.node_count(), extracted.node_count()); + assert_eq!(region.root_type(), extracted.root_type()); + + Ok(()) + } } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 787cad1a2..7131c2451 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -491,12 +491,13 @@ mod test { #[rstest] fn extract_hugr() -> Result<(), Box> { - let (hugr, def, _inner) = make_module_hgr()?; + let (hugr, _def, inner) = make_module_hgr()?; - let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; + let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; let extracted = region.extract_hugr(); + extracted.validate(&PRELUDE_REGISTRY)?; - let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; + let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; assert_eq!(region.node_count(), extracted.node_count()); assert_eq!(region.root_type(), extracted.root_type());