Skip to content

Commit

Permalink
refactor: Put extension inference behind a feature gate (#786)
Browse files Browse the repository at this point in the history
Add a feature flag "extension_inference", without which
`Hugr::infer_extensions` is replaced by a dummy method, and
`ValidationContext` doesn't use a `ExtensionValidator`, and tests that
require extension inference to be done are disabled.
Our CI runs tests and benchmarks once with default features, then again
with all features, so extension inference is still being tested by CI.
The workflow has only been altered to tell `clippy` to run with all
features as well.

Resolves #784
  • Loading branch information
croyzor authored Jan 8, 2024
1 parent b3ab15a commit 2b7de5d
Show file tree
Hide file tree
Showing 8 changed files with 429 additions and 387 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Continuous integration
on:
push:
branches:
- main
- main
pull_request:
branches:
- main
Expand Down Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Check formatting
run: cargo fmt -- --check
- name: Run clippy
run: cargo clippy --all-targets -- -D warnings
run: cargo clippy --all-targets --all-features -- -D warnings
- name: Build docs
run: cargo doc --no-deps --all-features
env:
Expand Down Expand Up @@ -102,9 +102,9 @@ jobs:
- name: Run tests with coverage instrumentation
run: |
cargo llvm-cov clean --workspace
cargo llvm-cov --doctests
cargo llvm-cov --all-features --doctests
- name: Generate coverage report
run: cargo llvm-cov report --codecov --output-path coverage.json
run: cargo llvm-cov --all-features report --codecov --output-path coverage.json
- name: Upload coverage to codecov.io
uses: codecov/codecov-action@v3
with:
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ name = "hugr"
bench = false
path = "src/lib.rs"

[features]
extension_inference = []

[dependencies]
thiserror = "1.0.28"
portgraph = { version = "0.11.0", features = ["serde", "petgraph"] }
Expand Down
5 changes: 4 additions & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ use crate::types::type_param::{check_type_args, TypeArgError};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound};

#[allow(dead_code)]
mod infer;
pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};
#[cfg(feature = "extension_inference")]
pub use infer::infer_extensions;
pub use infer::{ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
Expand Down
17 changes: 13 additions & 4 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
use std::error::Error;

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::hugr::{Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::dataflow::DataflowParent;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
use crate::ops::{self, dataflow::IOTrait};
use crate::ops::{LeafOp, OpType};
#[cfg(feature = "extension_inference")]
use crate::{
builder::test::closed_dfg_root_hugr,
hugr::validate::ValidationError,
ops::{dataflow::DataflowParent, handle::NodeHandle},
};

use crate::type_row;
use crate::types::{FunctionType, Type, TypeRow};
Expand Down Expand Up @@ -154,6 +158,7 @@ fn plus() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// This generates a solution that causes validation to fail
// because of a missing lift node
Expand Down Expand Up @@ -215,6 +220,7 @@ fn open_variables() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Infer the extensions on a child node with no inputs
fn dangling_src() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -306,6 +312,7 @@ fn create_with_io(
Ok([node, input, output])
}

#[cfg(feature = "extension_inference")]
#[test]
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
fn build_case(
Expand Down Expand Up @@ -968,6 +975,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
Expand Down Expand Up @@ -998,6 +1006,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Test that the difference between a FuncDefn's input and output nodes is being
// constrained to be the same as the extension delta in the FuncDefn signature.
Expand Down
46 changes: 25 additions & 21 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,33 +120,37 @@ impl ExtensionValidator {

/// Check that a pair of input and output nodes declare the same extensions
/// as in the signature of their parents.
#[allow(unused_variables)]
pub fn validate_io_extensions(
&self,
parent: Node,
input: Node,
output: Node,
) -> Result<(), ExtensionError> {
let parent_input_extensions = self.query_extensions(parent, Direction::Incoming)?;
let parent_output_extensions = self.query_extensions(parent, Direction::Outgoing)?;
for dir in Direction::BOTH {
let input_extensions = self.query_extensions(input, dir)?;
let output_extensions = self.query_extensions(output, dir)?;
if parent_input_extensions != input_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_input_extensions.clone(),
child: input,
child_extensions: input_extensions.clone(),
});
};
if parent_output_extensions != output_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_output_extensions.clone(),
child: output,
child_extensions: output_extensions.clone(),
});
};
#[cfg(feature = "extension_inference")]
{
let parent_input_extensions = self.query_extensions(parent, Direction::Incoming)?;
let parent_output_extensions = self.query_extensions(parent, Direction::Outgoing)?;
for dir in Direction::BOTH {
let input_extensions = self.query_extensions(input, dir)?;
let output_extensions = self.query_extensions(output, dir)?;
if parent_input_extensions != input_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_input_extensions.clone(),
child: input,
child_extensions: input_extensions.clone(),
});
};
if parent_output_extensions != output_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_output_extensions.clone(),
child: output,
child_extensions: output_extensions.clone(),
});
};
}
}
Ok(())
}
Expand Down
33 changes: 22 additions & 11 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension_inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand All @@ -23,9 +25,9 @@ use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
#[cfg(feature = "extension_inference")]
use crate::extension::infer_extensions;
use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::FunctionType;
Expand Down Expand Up @@ -197,12 +199,19 @@ impl Hugr {
/// Infer extension requirements and add new information to `op_types` field
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
}

#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
Expand Down Expand Up @@ -345,13 +354,7 @@ pub enum HugrError {
#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops;
use crate::type_row;
use crate::types::{FunctionType, Type};

#[cfg(feature = "extension_inference")]
use std::error::Error;

#[test]
Expand All @@ -371,8 +374,16 @@ mod test {
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[cfg(feature = "extension_inference")]
#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops::LeafOp;
use crate::type_row;
use crate::types::{FunctionType, Type};

const BIT: Type = crate::extension::prelude::USIZE_T;
let r = ExtensionSet::singleton(&"R".try_into().unwrap());

Expand All @@ -382,7 +393,7 @@ mod test {
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let lift = hugr.add_node_with_parent(
hugr.root(),
ops::LeafOp::Lift {
LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "R".try_into().unwrap(),
},
Expand Down
10 changes: 8 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

use crate::extension::validate::ExtensionValidator;
use crate::extension::SignatureError;
use crate::extension::{
validate::{ExtensionError, ExtensionValidator},
ExtensionRegistry, ExtensionSolution, InferExtensionError,
validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError,
};

use crate::ops::custom::CustomOpError;
Expand All @@ -36,6 +36,7 @@ struct ValidationContext<'a, 'b> {
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
#[allow(dead_code)]
extension_validator: ExtensionValidator,
/// Registry of available Extensions
extension_registry: &'b ExtensionRegistry,
Expand Down Expand Up @@ -64,6 +65,9 @@ impl Hugr {

impl<'a, 'b> ValidationContext<'a, 'b> {
/// Create a new validation context.
// Allow unused "extension_closure" variable for when
// the "extension_inference" feature is disabled.
#[allow(unused_variables)]
pub fn new(
hugr: &'a Hugr,
extension_closure: ExtensionSolution,
Expand Down Expand Up @@ -163,6 +167,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
#[cfg(feature = "extension_inference")]
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
Expand Down Expand Up @@ -240,6 +245,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();

#[cfg(feature = "extension_inference")]
self.extension_validator
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;

Expand Down
Loading

0 comments on commit 2b7de5d

Please sign in to comment.