Skip to content

Commit

Permalink
refactor: Put extension inference behind a feature gate
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Jan 5, 2024
1 parent b680662 commit 3e83207
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Continuous integration
on:
push:
branches:
- main
- main
pull_request:
branches:
- main
Expand Down Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Check formatting
run: cargo fmt -- --check
- name: Run clippy
run: cargo clippy --all-targets -- -D warnings
run: cargo clippy --all-targets --all-features -- -D warnings
- name: Build docs
run: cargo doc --no-deps --all-features
env:
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ name = "hugr"
bench = false
path = "src/lib.rs"

[features]
extension_inference = []

[dependencies]
thiserror = "1.0.28"
portgraph = { version = "0.11.0", features = ["serde", "petgraph"] }
Expand Down
5 changes: 4 additions & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ use crate::types::type_param::{check_type_args, TypeArgError};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound};

#[allow(dead_code)]
mod infer;
pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};
#[cfg(feature = "extension_inference")]
pub use infer::infer_extensions;
pub use infer::{ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
Expand Down
14 changes: 12 additions & 2 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use std::error::Error;

use super::*;
#[cfg(feature = "extension_inference")]
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
#[cfg(feature = "extension_inference")]
use crate::hugr::validate::ValidationError;
use crate::hugr::{Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
#[cfg(feature = "extension_inference")]
use crate::ops::handle::NodeHandle;
use crate::ops::{self, dataflow::IOTrait};
use crate::ops::{LeafOp, OpType};

use crate::type_row;
Expand Down Expand Up @@ -153,6 +158,7 @@ fn plus() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// This generates a solution that causes validation to fail
// because of a missing lift node
Expand Down Expand Up @@ -214,6 +220,7 @@ fn open_variables() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Infer the extensions on a child node with no inputs
fn dangling_src() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -305,6 +312,7 @@ fn create_with_io(
Ok([node, input, output])
}

#[cfg(feature = "extension_inference")]
#[test]
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
fn build_case(
Expand Down Expand Up @@ -967,6 +975,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
Expand Down Expand Up @@ -997,6 +1006,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Test that the difference between a FuncDefn's input and output nodes is being
// constrained to be the same as the extension delta in the FuncDefn signature.
Expand Down
24 changes: 21 additions & 3 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension-inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand All @@ -23,9 +25,9 @@ use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
#[cfg(feature = "extension_inference")]
use crate::extension::infer_extensions;
use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::FunctionType;
Expand Down Expand Up @@ -197,12 +199,20 @@ impl Hugr {
/// Infer extension requirements and add new information to `op_types` field
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
}

#[cfg(not(feature = "extension_inference"))]
#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
Expand Down Expand Up @@ -345,13 +355,20 @@ pub enum HugrError {
#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
#[cfg(feature = "extension_inference")]
use crate::builder::test::closed_dfg_root_hugr;
#[cfg(feature = "extension_inference")]
use crate::extension::ExtensionSet;
#[cfg(feature = "extension_inference")]
use crate::hugr::HugrMut;
#[cfg(feature = "extension_inference")]
use crate::ops;
#[cfg(feature = "extension_inference")]
use crate::type_row;
#[cfg(feature = "extension_inference")]
use crate::types::{FunctionType, Type};

#[cfg(feature = "extension_inference")]
use std::error::Error;

#[test]
Expand All @@ -371,6 +388,7 @@ mod test {
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[cfg(feature = "extension_inference")]
#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
const BIT: Type = crate::extension::prelude::USIZE_T;
Expand Down
11 changes: 9 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

#[cfg(feature = "extension_inference")]
use crate::extension::ExtensionValidator;
use crate::extension::SignatureError;
use crate::extension::{
validate::{ExtensionError, ExtensionValidator},
ExtensionRegistry, ExtensionSolution, InferExtensionError,
validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError,
};

use crate::ops::custom::CustomOpError;
Expand All @@ -36,6 +37,7 @@ struct ValidationContext<'a, 'b> {
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
#[cfg(feature = "extension_inference")]
extension_validator: ExtensionValidator,
/// Registry of available Extensions
extension_registry: &'b ExtensionRegistry,
Expand Down Expand Up @@ -64,6 +66,8 @@ impl Hugr {

impl<'a, 'b> ValidationContext<'a, 'b> {
/// Create a new validation context.
#[cfg(not(feature = "extension_inference"))]
#[allow(unused_variables)]
pub fn new(
hugr: &'a Hugr,
extension_closure: ExtensionSolution,
Expand All @@ -72,6 +76,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
Self {
hugr,
dominators: HashMap::new(),
#[cfg(feature = "extension_inference")]
extension_validator: ExtensionValidator::new(hugr, extension_closure),
extension_registry,
}
Expand Down Expand Up @@ -163,6 +168,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
#[cfg(feature = "extension_inference")]
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
Expand Down Expand Up @@ -240,6 +246,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();

#[cfg(feature = "extension_inference")]
self.extension_validator
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;

Expand Down
10 changes: 9 additions & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use cool_asserts::assert_matches;

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
#[cfg(feature = "extension_inference")]
use crate::builder::ModuleBuilder;
use crate::builder::{
BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
use crate::extension::{
Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY,
};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut, NodeType};
#[cfg(feature = "extension_inference")]
use crate::macros::const_extension_ids;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, Const, LeafOp, OpType};
Expand Down Expand Up @@ -404,6 +406,7 @@ fn test_ext_edge() -> Result<(), HugrError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
const_extension_ids! {
const XA: ExtensionId = "A";
const XB: ExtensionId = "BOOL_EXT";
Expand Down Expand Up @@ -441,6 +444,7 @@ fn test_local_const() -> Result<(), HugrError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
/// A wire with no extension requirements is wired into a node which has
/// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed
Expand Down Expand Up @@ -474,6 +478,7 @@ fn missing_lift_node() -> Result<(), BuildError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
/// A wire with extension requirement `[A]` is wired into a an output with no
/// extension req. In the validation extension typechecking, we don't do any
Expand Down Expand Up @@ -505,6 +510,7 @@ fn too_many_extension() -> Result<(), BuildError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
/// A wire with extension requirements `[A]` and another with requirements
/// `[BOOL_T]` are both wired into a node which requires its inputs to have
Expand Down Expand Up @@ -558,6 +564,7 @@ fn extensions_mismatch() -> Result<(), BuildError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
fn parent_signature_mismatch() -> Result<(), BuildError> {
let rs = ExtensionSet::singleton(&XA);
Expand Down Expand Up @@ -740,6 +747,7 @@ fn invalid_types() {
);
}

#[cfg(feature = "extension_inference")]
#[test]
fn parent_io_mismatch() {
// The DFG node declares that it has an empty extension delta,
Expand Down

0 comments on commit 3e83207

Please sign in to comment.