From 759af8668cc93752f70fc638b9387f0bd5e3d249 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 30 May 2024 13:09:41 +0100 Subject: [PATCH] Rm extension/validate.rs (ExtensionValidator), just check child.is_subset(parent) --- hugr-core/src/extension.rs | 1 - hugr-core/src/extension/infer.rs | 33 +++-- hugr-core/src/extension/validate.rs | 209 ---------------------------- hugr-core/src/hugr/validate.rs | 65 +++++---- hugr-core/src/hugr/validate/test.rs | 164 ++-------------------- 5 files changed, 65 insertions(+), 407 deletions(-) delete mode 100644 hugr-core/src/extension/validate.rs diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 1ef9c50cd..b6d6bab14 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -35,7 +35,6 @@ pub use type_def::{TypeDef, TypeDefBound}; mod const_fold; pub mod prelude; pub mod simple_op; -pub mod validate; pub use const_fold::{ConstFold, ConstFoldResult, Folder}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; diff --git a/hugr-core/src/extension/infer.rs b/hugr-core/src/extension/infer.rs index bf00231e7..3310daa6c 100644 --- a/hugr-core/src/extension/infer.rs +++ b/hugr-core/src/extension/infer.rs @@ -18,8 +18,6 @@ use crate::{ Direction, Node, }; -use super::validate::ExtensionError; - use petgraph::graph as pg; use petgraph::{Directed, EdgeType, Undirected}; @@ -88,11 +86,24 @@ pub enum InferExtensionError { /// The location on the hugr that's associated to the unsolved meta location: (Node, Direction), }, - /// An extension mismatch between two nodes which are connected by an edge. - /// This should mirror (or reuse) `ValidationError`'s SrcExceedsTgtExtensions - /// and TgtExceedsSrcExtensions - #[error("Edge mismatch: {0}")] - EdgeMismatch(#[from] ExtensionError), + /// Too many extension requirements coming from src + #[error("Extensions at source node {from:?} ({from_extensions}) exceed those at target {to:?} ({to_extensions})")] + #[allow(missing_docs)] + SrcExceedsTgtExtensions { + from: Node, + from_extensions: ExtensionSet, + to: Node, + to_extensions: ExtensionSet, + }, + /// Missing lift node + #[error("Extensions at target node {to:?} ({to_extensions}) exceed those at source {from:?} ({from_extensions})")] + #[allow(missing_docs)] + TgtExceedsSrcExtensions { + from: Node, + from_extensions: ExtensionSet, + to: Node, + to_extensions: ExtensionSet, + }, } /// A graph of metavariables connected by constraints. @@ -384,21 +395,21 @@ impl UnificationContext { [(node2, rs2.clone()), (node1, rs1.clone())] }; - return InferExtensionError::EdgeMismatch(if src_rs.is_subset(&tgt_rs) { - ExtensionError::TgtExceedsSrcExtensions { + return if src_rs.is_subset(&tgt_rs) { + InferExtensionError::TgtExceedsSrcExtensions { from: *src, from_extensions: src_rs, to: *tgt, to_extensions: tgt_rs, } } else { - ExtensionError::SrcExceedsTgtExtensions { + InferExtensionError::SrcExceedsTgtExtensions { from: *src, from_extensions: src_rs, to: *tgt, to_extensions: tgt_rs, } - }); + }; } } if let (Some(loc1), Some(loc2)) = (loc1, loc2) { diff --git a/hugr-core/src/extension/validate.rs b/hugr-core/src/extension/validate.rs deleted file mode 100644 index 246e4c3a3..000000000 --- a/hugr-core/src/extension/validate.rs +++ /dev/null @@ -1,209 +0,0 @@ -//! Validation routines for instantiations of a extension ops and types in a -//! Hugr. - -use std::collections::HashMap; - -use thiserror::Error; - -use super::{ExtensionSet, ExtensionSolution}; -use crate::hugr::NodeType; -use crate::{Direction, Hugr, HugrView, Node, Port}; - -/// Context for validating the extension requirements defined in a Hugr. -#[derive(Debug, Clone, Default)] -pub struct ExtensionValidator { - /// Extension requirements associated with each edge - extensions: HashMap<(Node, Direction), ExtensionSet>, -} - -impl ExtensionValidator { - /// Initialise a new extension validator, pre-computing the extension - /// requirements for each node in the Hugr. - /// - /// The `closure` argument is a set of extensions which doesn't actually - /// live on the graph, but is used to close the graph for validation - pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self { - let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new(); - for (node, incoming_sol) in closure.into_iter() { - let extension_reqs = hugr - .get_nodetype(node) - .op_signature() - .map(|s| s.extension_reqs) - .unwrap_or_default(); - - let outgoing_sol = extension_reqs.union(incoming_sol.clone()); - - extensions.insert((node, Direction::Incoming), incoming_sol); - extensions.insert((node, Direction::Outgoing), outgoing_sol); - } - - let mut validator = ExtensionValidator { extensions }; - - for node in hugr.nodes() { - validator.gather_extensions(&node, hugr.get_nodetype(node)); - } - - validator - } - - /// Use the signature supplied by a dataflow node to work out the - /// extension requirements for all of its input and output edges, then put - /// those requirements in the extension validation context. - fn gather_extensions(&mut self, node: &Node, node_type: &NodeType) { - if let Some((input_exts, output_exts)) = node_type.io_extensions() { - let prev_i = self - .extensions - .insert((*node, Direction::Incoming), input_exts.clone()); - assert!(prev_i.is_none()); - let prev_o = self - .extensions - .insert((*node, Direction::Outgoing), output_exts); - assert!(prev_o.is_none()); - } - } - - /// Get the input or output extension requirements for a particular node in the Hugr. - /// - /// # Errors - /// - /// If the node extensions are missing. - fn query_extensions( - &self, - node: Node, - dir: Direction, - ) -> Result<&ExtensionSet, ExtensionError> { - self.extensions - .get(&(node, dir)) - .ok_or(ExtensionError::MissingInputExtensions(node)) - } - - /// Check that two `PortIndex` have compatible extension requirements, - /// according to the information accumulated by `gather_extensions`. - /// - /// This extension checking assumes that free extension variables - /// (e.g. implicit lifting of `A -> B` to `[R]A -> [R]B`) - /// and adding of lift nodes - /// (i.e. those which transform an edge from `A` to `[R]A`) - /// has already been done. - pub fn check_extensions_compatible( - &self, - src: &(Node, Port), - tgt: &(Node, Port), - ) -> Result<(), ExtensionError> { - let rs_src = self.query_extensions(src.0, Direction::Outgoing)?; - let rs_tgt = self.query_extensions(tgt.0, Direction::Incoming)?; - - if rs_src == rs_tgt { - Ok(()) - } else if rs_src.is_subset(rs_tgt) { - // The extra extension requirements reside in the target node. - // If so, we can fix this mismatch with a lift node - Err(ExtensionError::TgtExceedsSrcExtensionsAtPort { - from: src.0, - from_offset: src.1, - from_extensions: rs_src.clone(), - to: tgt.0, - to_offset: tgt.1, - to_extensions: rs_tgt.clone(), - }) - } else { - Err(ExtensionError::SrcExceedsTgtExtensionsAtPort { - from: src.0, - from_offset: src.1, - from_extensions: rs_src.clone(), - to: tgt.0, - to_offset: tgt.1, - to_extensions: rs_tgt.clone(), - }) - } - } - - /// Check that a pair of input and output nodes declare the same extensions - /// as in the signature of their parents. - #[allow(unused_variables)] - pub fn validate_io_extensions( - &self, - parent: Node, - input: Node, - output: Node, - ) -> Result<(), ExtensionError> { - #[cfg(feature = "extension_inference")] - { - let parent_input_extensions = self.query_extensions(parent, Direction::Incoming)?; - let parent_output_extensions = self.query_extensions(parent, Direction::Outgoing)?; - for dir in Direction::BOTH { - let input_extensions = self.query_extensions(input, dir)?; - let output_extensions = self.query_extensions(output, dir)?; - if parent_input_extensions != input_extensions { - return Err(ExtensionError::ParentIOExtensionMismatch { - parent, - parent_extensions: parent_input_extensions.clone(), - child: input, - child_extensions: input_extensions.clone(), - }); - }; - if parent_output_extensions != output_extensions { - return Err(ExtensionError::ParentIOExtensionMismatch { - parent, - parent_extensions: parent_output_extensions.clone(), - child: output, - child_extensions: output_extensions.clone(), - }); - }; - } - } - Ok(()) - } -} - -/// Errors that can occur while validating a Hugr. -#[derive(Debug, Clone, PartialEq, Error)] -#[allow(missing_docs)] -#[non_exhaustive] -pub enum ExtensionError { - /// Missing lift node - #[error("Extensions at target node {to:?} ({to_extensions}) exceed those at source {from:?} ({from_extensions})")] - TgtExceedsSrcExtensions { - from: Node, - from_extensions: ExtensionSet, - to: Node, - to_extensions: ExtensionSet, - }, - /// A version of the above which includes port info - #[error("Extensions at target node {to:?} ({to_offset:?}) ({to_extensions}) exceed those at source {from:?} ({from_offset:?}) ({from_extensions})")] - TgtExceedsSrcExtensionsAtPort { - from: Node, - from_offset: Port, - from_extensions: ExtensionSet, - to: Node, - to_offset: Port, - to_extensions: ExtensionSet, - }, - /// Too many extension requirements coming from src - #[error("Extensions at source node {from:?} ({from_extensions}) exceed those at target {to:?} ({to_extensions})")] - SrcExceedsTgtExtensions { - from: Node, - from_extensions: ExtensionSet, - to: Node, - to_extensions: ExtensionSet, - }, - /// A version of the above which includes port info - #[error("Extensions at source node {from:?} ({from_offset:?}) ({from_extensions}) exceed those at target {to:?} ({to_offset:?}) ({to_extensions})")] - SrcExceedsTgtExtensionsAtPort { - from: Node, - from_offset: Port, - from_extensions: ExtensionSet, - to: Node, - to_offset: Port, - to_extensions: ExtensionSet, - }, - #[error("Missing input extensions for node {0:?}")] - MissingInputExtensions(Node), - #[error("Extensions of I/O node ({child:?}) {child_extensions:?} don't match those expected by parent node ({parent:?}): {parent_extensions:?}")] - ParentIOExtensionMismatch { - parent: Node, - parent_extensions: ExtensionSet, - child: Node, - child_extensions: ExtensionSet, - }, -} diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 7f9b2199f..e82287669 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,14 +9,11 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::validate::ExtensionValidator; -use crate::extension::SignatureError; -use crate::extension::{validate::ExtensionError, ExtensionRegistry, InferExtensionError}; +use crate::extension::{ExtensionRegistry, ExtensionSet, InferExtensionError, SignatureError}; -use crate::ops::custom::CustomOpError; -use crate::ops::custom::{resolve_opaque_op, CustomOp}; +use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, Hugr, Node, Port}; @@ -59,30 +56,33 @@ impl Hugr { validator.validate() } - /// Validate extensions on the input and output edges of nodes. Check that - /// the target ends of edges require the extensions from the sources, and - /// check extension deltas from parent nodes are reflected in their children. + /// Validate extensions, i.e. that extension deltas from parent nodes are reflected in their children. pub fn validate_extensions(&self) -> Result<(), ValidationError> { - let validator = ExtensionValidator::new(self, HashMap::new()); - for src_node in self.nodes() { - let node_type = self.get_nodetype(src_node); - - // FuncDefns have no resources since they're static nodes, but the - // functions they define can have any extension delta. - if node_type.tag() != OpTag::FuncDefn { - // If this is a container with I/O nodes, check that the extension they - // define match the extensions of the container. - if let Some([input, output]) = self.get_io(src_node) { - validator.validate_io_extensions(src_node, input, output)?; + for parent in self.nodes() { + let parent_op = self.get_optype(parent); + let parent_extensions = match parent_op.inner_function_type() { + Some(FunctionType { extension_reqs, .. }) => extension_reqs, + None => { + if matches!(parent_op.tag(), OpTag::Cfg | OpTag::Conditional) { + parent_op.extension_delta() + } else { + assert!( + parent_op.tag() == OpTag::ModuleRoot + || self.children(parent).next().is_none() + ); + continue; + } } - } - - for src_port in self.node_outputs(src_node) { - for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) { - validator.check_extensions_compatible( - &(src_node, src_port.into()), - &(tgt_node, tgt_port.into()), - )?; + }; + for child in self.children(parent) { + let child_extensions = self.get_optype(child).extension_delta(); + if !parent_extensions.is_superset(&child_extensions) { + return Err(ValidationError::ExtensionError { + parent, + parent_extensions, + child, + child_extensions, + }); } } } @@ -741,9 +741,14 @@ pub enum ValidationError { /// There are invalid inter-graph edges. #[error(transparent)] InterGraphEdgeError(#[from] InterGraphEdgeError), + #[error("Extensions of child node ({child}) {child_extensions} are not a subset of the parent node ({parent}): {parent_extensions}")] /// There are errors in the extension declarations. - #[error(transparent)] - ExtensionError(#[from] ExtensionError), + ExtensionError { + parent: Node, + parent_extensions: ExtensionSet, + child: Node, + child_extensions: ExtensionSet, + }, #[error(transparent)] CantInfer(#[from] InferExtensionError), /// Error in a node signature diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 8e8e1f0c5..30692e5ea 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -872,7 +872,6 @@ mod extension_tests { use crate::macros::const_extension_ids; const_extension_ids! { - const XA: ExtensionId = "A"; const XB: ExtensionId = "BOOL_EXT"; } @@ -998,6 +997,7 @@ mod extension_tests { } #[test] + // ALAN TODO parametrize this by some different OpTypes e.g. DFG, Function fn parent_io_mismatch() { // The DFG node declares that it has an empty extension delta, // but it's child graph adds extension "XB", causing a mismatch. @@ -1033,162 +1033,14 @@ mod extension_tests { hugr.connect(lift, 0, output, 0); let result = hugr.validate(&PRELUDE_REGISTRY); - assert_matches!( + assert_eq!( result, - Err(ValidationError::ExtensionError( - ExtensionError::ParentIOExtensionMismatch { .. } - )) - ); - } - - #[test] - /// A wire with no extension requirements is wired into a node which has - /// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed - /// by adding a lift node, but for validation this is an error. - fn missing_lift_node() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - let mut main = module_builder.define_function( - "main", - FunctionType::new(type_row![NAT], type_row![NAT]).into(), - )?; - let [main_input] = main.input_wires_arr(); - - let f_builder = main.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), - // Inner DFG has extension requirements that the wire wont satisfy - Some(ExtensionSet::from_iter([XA, XB])), - [main_input], - )?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); - Ok(()) - } - - #[test] - /// A wire with extension requirement `[A]` is wired into a an output with no - /// extension req. In the validation extension typechecking, we don't do any - /// unification, so don't allow open extension variables on the function - /// signature, so this fails. - fn too_many_extension() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).into(); - - let mut main = module_builder.define_function("main", main_sig)?; - let [main_input] = main.input_wires_arr(); - - let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(XA); - - let f_builder = main.dfg_builder(inner_sig, Some(ExtensionSet::new()), [main_input])?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::SrcExceedsTgtExtensionsAtPort { .. } - )) - ); - Ok(()) - } - - #[test] - /// A wire with extension requirements `[A]` and another with requirements - /// `[BOOL_T]` are both wired into a node which requires its inputs to have - /// requirements `[A,BOOL_T]`. A slightly more complex test of the error from - /// `missing_lift_node`. - fn extensions_mismatch() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let all_rs = ExtensionSet::from_iter([XA, XB]); - - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(all_rs.clone()) - .into(); - - let mut main = module_builder.define_function("main", main_sig)?; - - let [inp_wire] = main.input_wires_arr(); - - let [left_wire] = main - .dfg_builder( - FunctionType::new(type_row![], type_row![NAT]), - Some(XA.into()), - [], - )? - .finish_with_outputs([inp_wire])? - .outputs_arr(); - - let [right_wire] = main - .dfg_builder( - FunctionType::new(type_row![], type_row![NAT]), - Some(XB.into()), - [], - )? - .finish_with_outputs([inp_wire])? - .outputs_arr(); - - let builder = main.dfg_builder( - FunctionType::new(type_row![NAT, NAT], type_row![NAT]), - Some(all_rs), - [left_wire, right_wire], - )?; - let [left, _] = builder.input_wires_arr(); - let [output] = builder.finish_with_outputs([left])?.outputs_arr(); - - main.finish_with_outputs([output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); - Ok(()) - } - - #[test] - fn parent_signature_mismatch() { - let main_signature = - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(XA); - - let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: main_signature, - })); - let input = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Input { - types: type_row![NAT], - }), - ); - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::new( - ops::Output { - types: type_row![NAT], - }, - Some(XA.into()), - ), - ); - hugr.connect(input, 0, output, 0); - - assert_matches!( - hugr.validate(&PRELUDE_REGISTRY), - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) + Err(ValidationError::ExtensionError { + parent: hugr.root(), + parent_extensions: ExtensionSet::new(), + child: lift, + child_extensions: XB.into() + }) ); } }