Skip to content

Commit

Permalink
feat: Default DFG builders to open extension sets (#473)
Browse files Browse the repository at this point in the history
This does NOT resolve #424 (even just for DFGs) yet, but is a first step
towards doing so.
- In inference tests, when we want concrete resources we need to avoid
using builder methods.
- We should do inference on every kind of graph (instead of assuming
some kinds will be closed).
- Default _some_ DFG builder methods to instantiating nodes with open
resources (but not enough)
  • Loading branch information
croyzor authored Aug 31, 2023
1 parent 9b66d6d commit fa55478
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 29 deletions.
3 changes: 1 addition & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ pub trait Dataflow: Container {
signature: signature.clone(),
};
let nodetype = match &input_extensions {
// TODO: Make this NodeType::open_extensions
None => NodeType::pure(op),
None => NodeType::open_extensions(op),
Some(rs) => NodeType::new(op, rs.clone()),
};
let (dfg_n, _) = add_node_with_wires(self, nodetype, input_wires.into_iter().collect())?;
Expand Down
4 changes: 2 additions & 2 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ impl CFGBuilder<Hugr> {

impl HugrBuilder for CFGBuilder<Hugr> {
fn finish_hugr(
self,
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.validate(extension_registry)?;
self.base.infer_and_validate(extension_registry)?;
Ok(self.base)
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ impl HugrBuilder for ConditionalBuilder<Hugr> {
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.infer_extensions()?;
self.base.validate(extension_registry)?;
self.base.infer_and_validate(extension_registry)?;
Ok(self.base)
}
}
Expand Down
9 changes: 3 additions & 6 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,15 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
base.as_mut().add_node_with_parent(
parent,
match &input_extensions {
// TODO: Make this NodeType::open_extensions
None => NodeType::pure(input),
None => NodeType::open_extensions(input),
Some(rs) => NodeType::new(input, rs.clone()),
},
)?;
base.as_mut().add_node_with_parent(
parent,
match input_extensions.map(|inp| inp.union(&signature.extension_reqs)) {
// TODO: Make this NodeType::open_extensions
None => NodeType::new(output, signature.extension_reqs),
None => NodeType::open_extensions(output),
Some(rs) => NodeType::new(output, rs),
},
)?;
Expand Down Expand Up @@ -100,9 +99,7 @@ impl HugrBuilder for DFGBuilder<Hugr> {
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, ValidationError> {
let closure = self.base.infer_extensions()?;
self.base
.validate_with_extension_closure(closure, extension_registry)?;
self.base.infer_and_validate(extension_registry)?;
Ok(self.base)
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ impl Default for ModuleBuilder<Hugr> {
}

impl HugrBuilder for ModuleBuilder<Hugr> {
fn finish_hugr(self, extension_registry: &ExtensionRegistry) -> Result<Hugr, ValidationError> {
self.0.validate(extension_registry)?;
fn finish_hugr(
mut self,
extension_registry: &ExtensionRegistry,
) -> Result<Hugr, ValidationError> {
self.0.infer_and_validate(extension_registry)?;
Ok(self.0)
}
}
Expand Down
14 changes: 13 additions & 1 deletion src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use thiserror::Error;
use pyo3::prelude::*;

pub use self::views::HugrView;
use crate::extension::{infer_extensions, ExtensionSet, ExtensionSolution, InferExtensionError};
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
use crate::ops::{OpTag, OpTrait, OpType};
use crate::types::{FunctionType, Signature};

Expand Down Expand Up @@ -196,6 +198,16 @@ impl Hugr {
rw.apply(self)
}

/// Run resource inference and pass the closure into validation
pub fn infer_and_validate(
&mut self,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let closure = self.infer_extensions()?;
self.validate_with_extension_closure(closure, extension_registry)?;
Ok(())
}

/// Infer extension requirements and add new information to `op_types` field
///
/// See [`infer_extensions`] for details on the "closure" value
Expand Down
42 changes: 28 additions & 14 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,7 @@ mod test {
use cool_asserts::assert_matches;

use super::*;
use crate::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder,
ModuleBuilder,
};
use crate::builder::{BuildError, Container, Dataflow, DataflowSubContainer, ModuleBuilder};
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
use crate::extension::{prelude_registry, Extension, ExtensionSet, TypeDefBound, EMPTY_REG};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
Expand Down Expand Up @@ -1134,7 +1131,7 @@ mod test {
let f_handle = f_builder.finish_with_outputs(f_inputs)?;
let [f_output] = f_handle.outputs_arr();
main.finish_with_outputs([f_output])?;
let handle = module_builder.finish_prelude_hugr();
let handle = module_builder.hugr().validate(&prelude_registry());

assert_matches!(
handle,
Expand Down Expand Up @@ -1171,7 +1168,7 @@ mod test {
let f_handle = f_builder.finish_with_outputs(f_inputs)?;
let [f_output] = f_handle.outputs_arr();
main.finish_with_outputs([f_output])?;
let handle = module_builder.finish_prelude_hugr();
let handle = module_builder.hugr().validate(&prelude_registry());
assert_matches!(
handle,
Err(ValidationError::ExtensionError(
Expand Down Expand Up @@ -1233,7 +1230,7 @@ mod test {
let [output] = builder.finish_with_outputs([])?.outputs_arr();

main.finish_with_outputs([output])?;
let handle = module_builder.finish_prelude_hugr();
let handle = module_builder.hugr().validate(&prelude_registry());
assert_matches!(
handle,
Err(ValidationError::ExtensionError(
Expand All @@ -1245,16 +1242,33 @@ mod test {

#[test]
fn parent_signature_mismatch() -> Result<(), BuildError> {
let main_signature = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"R".into()));
let rs = ExtensionSet::singleton(&"R".into());

let mut builder = DFGBuilder::new(main_signature)?;
let [w] = builder.input_wires_arr();
builder.set_outputs([w])?;
let hugr = builder.base.validate(&prelude_registry()); // finish_hugr_with_outputs([w]);
let main_signature =
FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);

let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
signature: main_signature,
}));
let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
types: type_row![NAT],
}),
)?;
let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::new(
ops::Output {
types: type_row![NAT],
},
rs,
),
)?;
hugr.connect(input, 0, output, 0)?;

assert_matches!(
hugr,
hugr.validate(&prelude_registry()),
Err(ValidationError::ExtensionError(
ExtensionError::TgtExceedsSrcExtensionsAtPort { .. }
))
Expand Down

0 comments on commit fa55478

Please sign in to comment.