diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 950a85903..e69125fdc 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -122,17 +122,22 @@ pub trait Container { } /// Add metadata to the container node. - fn set_metadata(&mut self, meta: NodeMetadata) { + fn set_metadata(&mut self, key: impl AsRef, meta: impl Into) { let parent = self.container_node(); // Implementor's container_node() should be a valid node - self.hugr_mut().set_metadata(parent, meta).unwrap(); + self.hugr_mut().set_metadata(parent, key, meta).unwrap(); } /// Add metadata to a child node. /// /// Returns an error if the specified `child` is not a child of this container - fn set_child_metadata(&mut self, child: Node, meta: NodeMetadata) -> Result<(), BuildError> { - self.hugr_mut().set_metadata(child, meta)?; + fn set_child_metadata( + &mut self, + child: Node, + key: impl AsRef, + meta: impl Into, + ) -> Result<(), BuildError> { + self.hugr_mut().set_metadata(child, key, meta)?; Ok(()) } } diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 03480d13b..b5f1960b8 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -398,24 +398,31 @@ pub(crate) mod test { // Create a simple DFG let mut dfg_builder = DFGBuilder::new(FunctionType::new(type_row![BIT], type_row![BIT]))?; let [i1] = dfg_builder.input_wires_arr(); - dfg_builder.set_metadata(json!(42)); + dfg_builder.set_metadata("x", 42); let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1], &EMPTY_REG)?; // Create a module, and insert the DFG into it let mut module_builder = ModuleBuilder::new(); - { + let (dfg_node, f_node) = { let mut f_build = module_builder.define_function( "main", FunctionType::new(type_row![BIT], type_row![BIT]).pure(), )?; let [i1] = f_build.input_wires_arr(); - let id = f_build.add_hugr_with_wires(dfg_hugr, [i1])?; - f_build.finish_with_outputs([id.out_wire(0)])?; - } + let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?; + let f = f_build.finish_with_outputs([dfg.out_wire(0)])?; + module_builder.set_child_metadata(f.node(), "x", "hi")?; + (dfg.node(), f.node()) + }; + + let hugr = module_builder.finish_hugr(&EMPTY_REG)?; + assert_eq!(hugr.node_count(), 7); - assert_eq!(module_builder.finish_hugr(&EMPTY_REG)?.node_count(), 7); + assert_eq!(hugr.get_metadata(hugr.root(), "x"), None); + assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42))); + assert_eq!(hugr.get_metadata(f_node, "x").cloned(), Some(json!("hi"))); Ok(()) } diff --git a/src/hugr.rs b/src/hugr.rs index 128eee522..39e256e70 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -53,7 +53,7 @@ pub struct Hugr { op_types: UnmanagedDenseMap, /// Node metadata - metadata: UnmanagedDenseMap, + metadata: UnmanagedDenseMap>, } #[derive(Clone, Debug, Default, PartialEq, serde::Serialize, serde::Deserialize)] @@ -178,9 +178,14 @@ impl AsMut for Hugr { } } -/// Arbitrary metadata for a node. +/// Arbitrary metadata entry for a node. +/// +/// Each entry is associated to a string key. pub type NodeMetadata = serde_json::Value; +/// The container of all the metadata entries for a node. +pub type NodeMetadataMap = serde_json::Map; + /// Public API for HUGRs. impl Hugr { /// Resolve extension ops, infer extensions used, and pass the closure into validation diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 3e1ef81dc..73ed659e6 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -14,17 +14,53 @@ use crate::{Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use self::sealed::HugrMutInternals; +use super::NodeMetadataMap; + /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrMutInternals { - /// Returns the metadata associated with a node. - fn get_metadata_mut(&mut self, node: Node) -> Result<&mut NodeMetadata, HugrError> { + /// Returns a metadata entry associated with a node. + fn get_metadata_mut( + &mut self, + node: Node, + key: impl AsRef, + ) -> Result<&mut NodeMetadata, HugrError> { self.valid_node(node)?; - Ok(self.hugr_mut().metadata.get_mut(node.pg_index())) + let node_meta = self + .hugr_mut() + .metadata + .get_mut(node.pg_index()) + .get_or_insert_with(Default::default); + Ok(node_meta + .entry(key.as_ref()) + .or_insert(serde_json::Value::Null)) + } + + /// Sets a metadata value associated with a node. + fn set_metadata( + &mut self, + node: Node, + key: impl AsRef, + metadata: impl Into, + ) -> Result<(), HugrError> { + let entry = self.get_metadata_mut(node, key)?; + *entry = metadata.into(); + Ok(()) + } + + /// Retrieve the complete metadata map for a node. + fn take_node_metadata(&mut self, node: Node) -> Option { + self.valid_node(node).ok()?; + self.hugr_mut().metadata.take(node.pg_index()) } - /// Sets the metadata associated with a node. - fn set_metadata(&mut self, node: Node, metadata: NodeMetadata) -> Result<(), HugrError> { - *self.get_metadata_mut(node)? = metadata; + /// Overwrite the complete metadata map for a node. + fn overwrite_node_metadata( + &mut self, + node: Node, + metadata: Option, + ) -> Result<(), HugrError> { + self.valid_node(node)?; + self.hugr_mut().metadata.set(node.pg_index(), metadata); Ok(()) } @@ -304,7 +340,7 @@ impl + AsMut> HugrMut for T { let optype = other.op_types.take(node); self.as_mut().op_types.set(new_node, optype); let meta = other.metadata.take(node); - self.as_mut().set_metadata(new_node.into(), meta).unwrap(); + self.as_mut().metadata.set(new_node, meta); } debug_assert_eq!( Some(&new_root.pg_index()), @@ -326,10 +362,8 @@ impl + AsMut> HugrMut for T { for (&node, &new_node) in node_map.iter() { let nodetype = other.get_nodetype(node.into()); self.as_mut().op_types.set(new_node, nodetype.clone()); - let meta = other.get_metadata(node.into()); - self.as_mut() - .set_metadata(new_node.into(), meta.clone()) - .unwrap(); + let meta = other.base_hugr().metadata.get(node); + self.as_mut().metadata.set(new_node, meta.clone()); } debug_assert_eq!( Some(&new_root.pg_index()), @@ -359,10 +393,8 @@ impl + AsMut> HugrMut for T { for (&node, &new_node) in node_map.iter() { let nodetype = other.get_nodetype(node.into()); self.as_mut().op_types.set(new_node, nodetype.clone()); - let meta = other.get_metadata(node.into()); - self.as_mut() - .set_metadata(new_node.into(), meta.clone()) - .unwrap(); + let meta = other.base_hugr().metadata.get(node); + self.as_mut().metadata.set(new_node, meta.clone()); } Ok(translate_indices(node_map)) } diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 9b661c8bd..9cb08c07f 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -7,7 +7,7 @@ use std::slice; use itertools::Itertools; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, NodeMetadata, Rewrite}; +use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node}; use thiserror::Error; @@ -76,7 +76,7 @@ impl Rewrite for SimpleReplacement { unimplemented!() } - fn apply(self, h: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + fn apply(mut self, h: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { let parent = self.subgraph.get_parent(h); // 1. Check the parent node exists and is a DataflowParent. if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) { @@ -107,8 +107,8 @@ impl Rewrite for SimpleReplacement { index_map.insert(node, new_node); // Move the metadata - let meta: &NodeMetadata = self.replacement.get_metadata(node); - h.set_metadata(new_node, meta.clone()).unwrap(); + let meta: Option = self.replacement.take_node_metadata(node); + h.overwrite_node_metadata(new_node, meta).unwrap(); } // Add edges between all newly added nodes matching those in replacement. // TODO This will probably change when implicit copies are implemented. diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index edc517b0c..8e6f08d52 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -1,7 +1,6 @@ //! Serialization definition for [`Hugr`] //! [`Hugr`]: crate::hugr::Hugr -use serde_json::json; use std::collections::HashMap; use thiserror::Error; @@ -55,6 +54,9 @@ struct SerHugrV0 { /// for each edge: (src, src_offset, tgt, tgt_offset) edges: Vec<[(Node, Option); 2]>, /// for each node: (metadata) + // + // TODO: Update to Vec>> to more closely + // match the internal representation. #[serde(default)] metadata: Vec, } @@ -143,7 +145,7 @@ impl TryFrom<&Hugr> for SerHugrV0 { } let mut nodes = vec![None; hugr.node_count()]; - let mut metadata = vec![json!(null); hugr.node_count()]; + let mut metadata = vec![serde_json::Value::Null; hugr.node_count()]; for n in hugr.nodes() { let parent = node_rekey[&hugr.get_parent(n).unwrap_or(n)]; let opt = hugr.get_nodetype(n); @@ -153,7 +155,11 @@ impl TryFrom<&Hugr> for SerHugrV0 { input_extensions: opt.input_extensions.clone(), op: opt.op.clone(), }); - metadata[new_node] = hugr.get_metadata(n).clone(); + let node_metadata = hugr.metadata.get(n.pg_index()).clone(); + metadata[new_node] = match node_metadata { + Some(m) => serde_json::Value::Object(m.clone()), + None => serde_json::Value::Null, + }; } let nodes = nodes .into_iter() @@ -227,8 +233,8 @@ impl TryFrom for Hugr { } for (node, metadata) in metadata.into_iter().enumerate() { - let node = portgraph::NodeIndex::new(node).into(); - hugr.set_metadata(node, metadata)?; + let node = portgraph::NodeIndex::new(node); + hugr.metadata[node] = metadata.as_object().cloned(); } let unwrap_offset = |node: Node, offset, dir, hugr: &Hugr| -> Result { @@ -354,7 +360,7 @@ pub mod test { fn weighted_hugr_ser() { let hugr = { let mut module_builder = ModuleBuilder::new(); - module_builder.set_metadata(json!({"name": "test"})); + module_builder.set_metadata("name", "test"); let t_row = vec![Type::new_sum(vec![NAT, QB])]; let mut f_build = module_builder @@ -375,7 +381,7 @@ pub mod test { .out_wire(0) }) .collect_vec(); - f_build.set_metadata(json!(42)); + f_build.set_metadata("val", 42); f_build.finish_with_outputs(outputs).unwrap(); module_builder.finish_prelude_hugr().unwrap() diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 5cfc5698c..eafc68986 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -22,7 +22,7 @@ use itertools::{Itertools, MapInto}; use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle}; use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; -use super::{Hugr, HugrError, NodeMetadata, NodeType, DEFAULT_NODETYPE}; +use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE}; use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; use crate::types::{EdgeKind, FunctionType}; @@ -130,13 +130,19 @@ pub trait HugrView: sealed::HugrInternals { /// Returns the metadata associated with a node. #[inline] - fn get_metadata(&self, node: Node) -> &NodeMetadata { + fn get_metadata(&self, node: Node, key: impl AsRef) -> Option<&NodeMetadata> { match self.contains_node(node) { - true => self.base_hugr().metadata.get(node.pg_index()), - false => &NodeMetadata::Null, + true => self.get_node_metadata(node)?.get(key.as_ref()), + false => None, } } + /// Retrieve the complete metadata map for a node. + fn get_node_metadata(&self, node: Node) -> Option<&NodeMetadataMap> { + self.valid_node(node).ok()?; + self.base_hugr().metadata.get(node.pg_index()).as_ref() + } + /// Returns the number of nodes in the hugr. fn node_count(&self) -> usize;