Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Put extension inference behind a feature gate #786

Merged
merged 8 commits into from
Jan 8, 2024
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Continuous integration
on:
push:
branches:
- main
- main
pull_request:
branches:
- main
Expand Down Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Check formatting
run: cargo fmt -- --check
- name: Run clippy
run: cargo clippy --all-targets -- -D warnings
run: cargo clippy --all-targets --all-features -- -D warnings
- name: Build docs
run: cargo doc --no-deps --all-features
env:
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ name = "hugr"
bench = false
path = "src/lib.rs"

[features]
extension_inference = []

[dependencies]
thiserror = "1.0.28"
portgraph = { version = "0.11.0", features = ["serde", "petgraph"] }
Expand Down
5 changes: 4 additions & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ use crate::types::type_param::{check_type_args, TypeArgError};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound};

#[allow(dead_code)]
mod infer;
pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};
#[cfg(feature = "extension_inference")]
pub use infer::infer_extensions;
pub use infer::{ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
Expand Down
14 changes: 12 additions & 2 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use std::error::Error;

use super::*;
#[cfg(feature = "extension_inference")]
croyzor marked this conversation as resolved.
Show resolved Hide resolved
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
#[cfg(feature = "extension_inference")]
use crate::hugr::validate::ValidationError;
use crate::hugr::{Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
#[cfg(feature = "extension_inference")]
use crate::ops::handle::NodeHandle;
use crate::ops::{self, dataflow::IOTrait};
use crate::ops::{LeafOp, OpType};

use crate::type_row;
Expand Down Expand Up @@ -153,6 +158,7 @@ fn plus() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// This generates a solution that causes validation to fail
// because of a missing lift node
Expand Down Expand Up @@ -214,6 +220,7 @@ fn open_variables() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
croyzor marked this conversation as resolved.
Show resolved Hide resolved
#[test]
// Infer the extensions on a child node with no inputs
fn dangling_src() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -305,6 +312,7 @@ fn create_with_io(
Ok([node, input, output])
}

#[cfg(feature = "extension_inference")]
#[test]
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
fn build_case(
Expand Down Expand Up @@ -967,6 +975,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
Expand Down Expand Up @@ -997,6 +1006,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Test that the difference between a FuncDefn's input and output nodes is being
// constrained to be the same as the extension delta in the FuncDefn signature.
Expand Down
23 changes: 20 additions & 3 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension_inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand All @@ -23,9 +25,9 @@ use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
#[cfg(feature = "extension_inference")]
use crate::extension::infer_extensions;
use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::FunctionType;
Expand Down Expand Up @@ -197,12 +199,19 @@ impl Hugr {
/// Infer extension requirements and add new information to `op_types` field
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
}

#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
Expand Down Expand Up @@ -345,13 +354,20 @@ pub enum HugrError {
#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
#[cfg(feature = "extension_inference")]
use crate::builder::test::closed_dfg_root_hugr;
#[cfg(feature = "extension_inference")]
use crate::extension::ExtensionSet;
#[cfg(feature = "extension_inference")]
use crate::hugr::HugrMut;
#[cfg(feature = "extension_inference")]
use crate::ops;
#[cfg(feature = "extension_inference")]
use crate::type_row;
#[cfg(feature = "extension_inference")]
use crate::types::{FunctionType, Type};

#[cfg(feature = "extension_inference")]
use std::error::Error;
croyzor marked this conversation as resolved.
Show resolved Hide resolved

#[test]
Expand All @@ -371,6 +387,7 @@ mod test {
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[cfg(feature = "extension_inference")]
#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
const BIT: Type = crate::extension::prelude::USIZE_T;
Expand Down
12 changes: 10 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

#[cfg(feature = "extension_inference")]
use crate::extension::validate::ExtensionValidator;
use crate::extension::SignatureError;
use crate::extension::{
validate::{ExtensionError, ExtensionValidator},
ExtensionRegistry, ExtensionSolution, InferExtensionError,
validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError,
};

use crate::ops::custom::CustomOpError;
Expand All @@ -36,6 +37,7 @@ struct ValidationContext<'a, 'b> {
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
#[cfg(feature = "extension_inference")]
extension_validator: ExtensionValidator,
/// Registry of available Extensions
extension_registry: &'b ExtensionRegistry,
Expand Down Expand Up @@ -64,6 +66,9 @@ impl Hugr {

impl<'a, 'b> ValidationContext<'a, 'b> {
/// Create a new validation context.
// Allow unused "extension_closure" variable for when
// the "extension_inference" feature is disabled.
#[allow(unused_variables)]
pub fn new(
hugr: &'a Hugr,
extension_closure: ExtensionSolution,
Expand All @@ -72,6 +77,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
Self {
hugr,
dominators: HashMap::new(),
#[cfg(feature = "extension_inference")]
croyzor marked this conversation as resolved.
Show resolved Hide resolved
extension_validator: ExtensionValidator::new(hugr, extension_closure),
extension_registry,
}
Expand Down Expand Up @@ -163,6 +169,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
#[cfg(feature = "extension_inference")]
croyzor marked this conversation as resolved.
Show resolved Hide resolved
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
Expand Down Expand Up @@ -240,6 +247,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();

#[cfg(feature = "extension_inference")]
self.extension_validator
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;

Expand Down
13 changes: 12 additions & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use cool_asserts::assert_matches;

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
#[cfg(feature = "extension_inference")]
use crate::builder::ModuleBuilder;
use crate::builder::{
BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
use crate::extension::{
Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY,
};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut, NodeType};
#[cfg(feature = "extension_inference")]
use crate::macros::const_extension_ids;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, Const, LeafOp, OpType};
Expand All @@ -23,6 +25,7 @@ use crate::values::Value;
use crate::{type_row, Direction, IncomingPort, Node};

const NAT: Type = crate::extension::prelude::USIZE_T;
#[cfg(feature = "extension_inference")]
const Q: Type = crate::extension::prelude::QB_T;

/// Creates a hugr with a single function definition that copies a bit `copies` times.
Expand Down Expand Up @@ -71,6 +74,7 @@ fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, No
/// Intended to be used to populate a BasicBlock node in a CFG.
///
/// Returns the node indices of each of the operations.
#[cfg(feature = "extension_inference")]
croyzor marked this conversation as resolved.
Show resolved Hide resolved
fn add_block_children(
b: &mut Hugr,
parent: Node,
Expand Down Expand Up @@ -257,6 +261,7 @@ fn df_children_restrictions() {
);
}

#[cfg(feature = "extension_inference")]
#[test]
/// Validation errors in a dataflow subgraph.
fn cfg_children_restrictions() {
Expand Down Expand Up @@ -404,6 +409,7 @@ fn test_ext_edge() -> Result<(), HugrError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
croyzor marked this conversation as resolved.
Show resolved Hide resolved
const_extension_ids! {
const XA: ExtensionId = "A";
const XB: ExtensionId = "BOOL_EXT";
Expand Down Expand Up @@ -441,6 +447,7 @@ fn test_local_const() -> Result<(), HugrError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
/// A wire with no extension requirements is wired into a node which has
/// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed
Expand Down Expand Up @@ -474,6 +481,7 @@ fn missing_lift_node() -> Result<(), BuildError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
/// A wire with extension requirement `[A]` is wired into a an output with no
/// extension req. In the validation extension typechecking, we don't do any
Expand Down Expand Up @@ -505,6 +513,7 @@ fn too_many_extension() -> Result<(), BuildError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
/// A wire with extension requirements `[A]` and another with requirements
/// `[BOOL_T]` are both wired into a node which requires its inputs to have
Expand Down Expand Up @@ -558,6 +567,7 @@ fn extensions_mismatch() -> Result<(), BuildError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
fn parent_signature_mismatch() -> Result<(), BuildError> {
let rs = ExtensionSet::singleton(&XA);
Expand Down Expand Up @@ -740,6 +750,7 @@ fn invalid_types() {
);
}

#[cfg(feature = "extension_inference")]
#[test]
fn parent_io_mismatch() {
// The DFG node declares that it has an empty extension delta,
Expand Down