Skip to content

Commit

Permalink
Rm extension/validate.rs (ExtensionValidator), just check child.is_su…
Browse files Browse the repository at this point in the history
…bset(parent)
  • Loading branch information
acl-cqc committed May 30, 2024
1 parent 16c72e5 commit 759af86
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 407 deletions.
1 change: 0 additions & 1 deletion hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult, Folder};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

Expand Down
33 changes: 22 additions & 11 deletions hugr-core/src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ use crate::{
Direction, Node,
};

use super::validate::ExtensionError;

use petgraph::graph as pg;
use petgraph::{Directed, EdgeType, Undirected};

Expand Down Expand Up @@ -88,11 +86,24 @@ pub enum InferExtensionError {
/// The location on the hugr that's associated to the unsolved meta
location: (Node, Direction),
},
/// An extension mismatch between two nodes which are connected by an edge.
/// This should mirror (or reuse) `ValidationError`'s SrcExceedsTgtExtensions
/// and TgtExceedsSrcExtensions
#[error("Edge mismatch: {0}")]
EdgeMismatch(#[from] ExtensionError),
/// Too many extension requirements coming from src
#[error("Extensions at source node {from:?} ({from_extensions}) exceed those at target {to:?} ({to_extensions})")]
#[allow(missing_docs)]
SrcExceedsTgtExtensions {
from: Node,
from_extensions: ExtensionSet,
to: Node,
to_extensions: ExtensionSet,
},
/// Missing lift node
#[error("Extensions at target node {to:?} ({to_extensions}) exceed those at source {from:?} ({from_extensions})")]
#[allow(missing_docs)]
TgtExceedsSrcExtensions {
from: Node,
from_extensions: ExtensionSet,
to: Node,
to_extensions: ExtensionSet,
},
}

/// A graph of metavariables connected by constraints.
Expand Down Expand Up @@ -384,21 +395,21 @@ impl UnificationContext {
[(node2, rs2.clone()), (node1, rs1.clone())]
};

return InferExtensionError::EdgeMismatch(if src_rs.is_subset(&tgt_rs) {
ExtensionError::TgtExceedsSrcExtensions {
return if src_rs.is_subset(&tgt_rs) {
InferExtensionError::TgtExceedsSrcExtensions {
from: *src,
from_extensions: src_rs,
to: *tgt,
to_extensions: tgt_rs,
}
} else {
ExtensionError::SrcExceedsTgtExtensions {
InferExtensionError::SrcExceedsTgtExtensions {
from: *src,
from_extensions: src_rs,
to: *tgt,
to_extensions: tgt_rs,
}
});
};
}
}
if let (Some(loc1), Some(loc2)) = (loc1, loc2) {
Expand Down
209 changes: 0 additions & 209 deletions hugr-core/src/extension/validate.rs

This file was deleted.

65 changes: 35 additions & 30 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

use crate::extension::validate::ExtensionValidator;
use crate::extension::SignatureError;
use crate::extension::{validate::ExtensionError, ExtensionRegistry, InferExtensionError};
use crate::extension::{ExtensionRegistry, ExtensionSet, InferExtensionError, SignatureError};

use crate::ops::custom::CustomOpError;
use crate::ops::custom::{resolve_opaque_op, CustomOp};
use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{FuncDefn, OpTag, OpTrait, OpType, ValidateOp};
use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp};
use crate::types::type_param::TypeParam;
use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, Hugr, Node, Port};
Expand Down Expand Up @@ -59,30 +56,33 @@ impl Hugr {
validator.validate()
}

/// Validate extensions on the input and output edges of nodes. Check that
/// the target ends of edges require the extensions from the sources, and
/// check extension deltas from parent nodes are reflected in their children.
/// Validate extensions, i.e. that extension deltas from parent nodes are reflected in their children.
pub fn validate_extensions(&self) -> Result<(), ValidationError> {
let validator = ExtensionValidator::new(self, HashMap::new());
for src_node in self.nodes() {
let node_type = self.get_nodetype(src_node);

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.get_io(src_node) {
validator.validate_io_extensions(src_node, input, output)?;
for parent in self.nodes() {
let parent_op = self.get_optype(parent);
let parent_extensions = match parent_op.inner_function_type() {
Some(FunctionType { extension_reqs, .. }) => extension_reqs,
None => {
if matches!(parent_op.tag(), OpTag::Cfg | OpTag::Conditional) {
parent_op.extension_delta()
} else {
assert!(
parent_op.tag() == OpTag::ModuleRoot
|| self.children(parent).next().is_none()
);
continue;
}
}
}

for src_port in self.node_outputs(src_node) {
for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) {
validator.check_extensions_compatible(
&(src_node, src_port.into()),
&(tgt_node, tgt_port.into()),
)?;
};
for child in self.children(parent) {
let child_extensions = self.get_optype(child).extension_delta();
if !parent_extensions.is_superset(&child_extensions) {
return Err(ValidationError::ExtensionError {
parent,
parent_extensions,
child,
child_extensions,
});
}
}
}
Expand Down Expand Up @@ -741,9 +741,14 @@ pub enum ValidationError {
/// There are invalid inter-graph edges.
#[error(transparent)]
InterGraphEdgeError(#[from] InterGraphEdgeError),
#[error("Extensions of child node ({child}) {child_extensions} are not a subset of the parent node ({parent}): {parent_extensions}")]
/// There are errors in the extension declarations.
#[error(transparent)]
ExtensionError(#[from] ExtensionError),
ExtensionError {
parent: Node,
parent_extensions: ExtensionSet,
child: Node,
child_extensions: ExtensionSet,
},
#[error(transparent)]
CantInfer(#[from] InferExtensionError),
/// Error in a node signature
Expand Down
Loading

0 comments on commit 759af86

Please sign in to comment.