Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix/infer_stack_overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Nov 7, 2023
2 parents 7780c5c + 57270d5 commit 97b265f
Show file tree
Hide file tree
Showing 24 changed files with 119 additions and 670 deletions.
2 changes: 1 addition & 1 deletion src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pub(crate) mod test {
/// inference. Using DFGBuilder will default to a root node with an open
/// extension variable
pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_op_with_parent(
Expand Down
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub trait Dataflow: Container {
op: impl Into<OpType>,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
self.add_dataflow_node(NodeType::open_extensions(op), input_wires)
self.add_dataflow_node(NodeType::new_auto(op), input_wires)
}

/// Add a dataflow [`NodeType`] to the sibling graph, wiring up the `input_wires` to the
Expand Down Expand Up @@ -628,7 +628,7 @@ fn add_op_with_wires<T: Dataflow + ?Sized>(
optype: impl Into<OpType>,
inputs: Vec<Wire>,
) -> Result<(Node, usize), BuildError> {
add_node_with_wires(data_builder, NodeType::open_extensions(optype), inputs)
add_node_with_wires(data_builder, NodeType::new_auto(optype), inputs)
}

fn add_node_with_wires<T: Dataflow + ?Sized>(
Expand Down
2 changes: 1 addition & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl CFGBuilder<Hugr> {
signature: signature.clone(),
};

let base = Hugr::new(NodeType::open_extensions(cfg_op));
let base = Hugr::new(NodeType::new_open(cfg_op));
let cfg_node = base.root();
CFGBuilder::create(base, cfg_node, signature.input, signature.output)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl ConditionalBuilder<Hugr> {
extension_delta,
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new_open(op));
let conditional_node = base.root();

Ok(ConditionalBuilder {
Expand All @@ -194,7 +194,7 @@ impl CaseBuilder<Hugr> {
let op = ops::Case {
signature: signature.clone(),
};
let base = Hugr::new(NodeType::open_extensions(op));
let base = Hugr::new(NodeType::new_open(op));
let root = base.root();
let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?;

Expand Down
2 changes: 1 addition & 1 deletion src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl DFGBuilder<Hugr> {
let dfg_op = ops::DFG {
signature: signature.clone(),
};
let base = Hugr::new(NodeType::open_extensions(dfg_op));
let base = Hugr::new(NodeType::new_open(dfg_op));
let root = base.root();
DFGBuilder::create_with_io(base, root, signature, None)
}
Expand Down
2 changes: 1 addition & 1 deletion src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
};
self.hugr_mut().replace_op(
f_node,
NodeType::pure(ops::FuncDefn {
NodeType::new_pure(ops::FuncDefn {
name,
signature: signature.clone(),
}),
Expand Down
2 changes: 1 addition & 1 deletion src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl TailLoopBuilder<Hugr> {
rest: inputs_outputs.into(),
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(tail_loop.clone()));
let base = Hugr::new(NodeType::new_open(tail_loop.clone()));
let root = base.root();
Self::create_with_io(base, root, &tail_loop)
}
Expand Down
24 changes: 12 additions & 12 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! system (outside the `types` module), which also parses nested [`OpDef`]s.
use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

Expand Down Expand Up @@ -301,18 +301,13 @@ pub enum ExtensionBuildError {
}

/// A set of extensions identified by their unique [`ExtensionId`].
#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(HashSet<ExtensionId>);
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

impl ExtensionSet {
/// Creates a new empty extension set.
pub fn new() -> Self {
Self(HashSet::new())
}

/// Creates a new extension set from some extensions.
pub fn new_from_extensions(extensions: impl Into<HashSet<ExtensionId>>) -> Self {
Self(extensions.into())
pub const fn new() -> Self {
Self(BTreeSet::new())
}

/// Adds a extension to the set.
Expand Down Expand Up @@ -350,13 +345,18 @@ impl ExtensionSet {

/// The things in other which are in not in self
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet(HashSet::from_iter(other.0.difference(&self.0).cloned()))
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
}

/// Iterate over the contained ExtensionIds
pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.iter()
}

/// True if this set contains no [ExtensionId]s
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

impl Display for ExtensionSet {
Expand All @@ -367,6 +367,6 @@ impl Display for ExtensionSet {

impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(HashSet::from_iter(iter))
Self(BTreeSet::from_iter(iter))
}
}
106 changes: 44 additions & 62 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
//! depend on these open variables, then the validation check for extensions
//! will succeed regardless of what the variable is instantiated to.
use super::{ExtensionId, ExtensionSet};
use super::ExtensionSet;
use crate::{
hugr::views::HugrView,
ops::{OpTag, OpTrait},
Expand Down Expand Up @@ -65,8 +65,8 @@ impl Meta {
enum Constraint {
/// A variable has the same value as another variable
Equal(Meta),
/// Variable extends the value of another by one extension
Plus(ExtensionId, Meta),
/// Variable extends the value of another by a set of extensions
Plus(ExtensionSet, Meta),
}

#[derive(Debug, Clone, PartialEq, Error)]
Expand Down Expand Up @@ -230,26 +230,6 @@ impl UnificationContext {
self.solved.get(&self.resolve(*m))
}

/// Convert an extension *set* difference in terms of a sequence of fresh
/// metas with `Plus` constraints which each add only one extension req.
fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) {
let mut last_meta = input;
// Create fresh metavariables with `Plus` constraints for
// each extension that should be added by the node
// Hence a extension delta [A, B] would lead to
// > ma = fresh_meta()
// > add_constraint(ma, Plus(a, input)
// > mb = fresh_meta()
// > add_constraint(mb, Plus(b, ma)
// > add_constraint(output, Equal(mb))
for r in delta.0.into_iter() {
let curr_meta = self.fresh_meta();
self.add_constraint(curr_meta, Constraint::Plus(r, last_meta));
last_meta = curr_meta;
}
self.add_constraint(output, Constraint::Equal(last_meta));
}

/// Return the metavariable corresponding to the given location on the
/// graph, either by making a new meta, or looking it up
fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta {
Expand Down Expand Up @@ -311,17 +291,13 @@ impl UnificationContext {
match node_type.signature() {
// Input extensions are open
None => {
self.gen_union_constraint(
m_input,
m_output,
node_type.op_signature().extension_reqs,
);
if matches!(
node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
) {
self.add_solution(m_input, ExtensionSet::new());
}
let delta = node_type.op_signature().extension_reqs;
let c = if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
}
// We have a solution for everything!
Some(sig) => {
Expand Down Expand Up @@ -510,8 +486,7 @@ impl UnificationContext {
// to a set which already contained it.
Constraint::Plus(r, other_meta) => {
if let Some(rs) = self.get_solution(other_meta) {
let mut rrs = rs.clone();
rrs.insert(r);
let rrs = rs.clone().union(r);
match self.get_solution(&meta) {
// Let's check that this is right?
Some(rs) => {
Expand Down Expand Up @@ -657,19 +632,19 @@ impl UnificationContext {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
let solution =
ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map(
|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => {
Some(x.clone())
}
_ => None,
},
));
let solution = self
.get_constraints(&m)
.unwrap()
.iter()
.filter_map(|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x),
_ => None,
})
.fold(ExtensionSet::new(), ExtensionSet::union);
self.add_solution(m, solution);
}
}
Expand All @@ -685,6 +660,7 @@ mod test {
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
Expand Down Expand Up @@ -720,7 +696,7 @@ mod test {
signature: main_sig,
};

let root_node = NodeType::open_extensions(op);
let root_node = NodeType::new_open(op);
let mut hugr = Hugr::new(root_node);

let input = ops::Input::new(type_row![NAT, NAT]);
Expand Down Expand Up @@ -804,8 +780,14 @@ mod test {

ctx.solved.insert(metas[2], ExtensionSet::singleton(&A));
ctx.add_constraint(metas[1], Constraint::Equal(metas[2]));
ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2]));
ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0]));
ctx.add_constraint(
metas[0],
Constraint::Plus(ExtensionSet::singleton(&B), metas[2]),
);
ctx.add_constraint(
metas[4],
Constraint::Plus(ExtensionSet::singleton(&C), metas[0]),
);
ctx.add_constraint(metas[3], Constraint::Equal(metas[4]));
ctx.add_constraint(metas[5], Constraint::Equal(metas[0]));
ctx.main_loop()?;
Expand All @@ -830,21 +812,21 @@ mod test {
// This generates a solution that causes validation to fail
// because of a missing lift node
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A)),
}));

let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
NodeType::new_pure(ops::Input {
types: type_row![NAT],
}),
)?;

let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Output {
NodeType::new_pure(ops::Output {
types: type_row![NAT],
}),
)?;
Expand Down Expand Up @@ -878,8 +860,8 @@ mod test {
.insert((NodeIndex::new(4).into(), Direction::Incoming), ab);
ctx.variables.insert(a);
ctx.variables.insert(b);
ctx.add_constraint(ab, Constraint::Plus(A, b));
ctx.add_constraint(ab, Constraint::Plus(B, a));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
Expand Down Expand Up @@ -1046,7 +1028,7 @@ mod test {
extension_delta: rs.clone(),
};

let mut hugr = Hugr::new(NodeType::pure(op));
let mut hugr = Hugr::new(NodeType::new_pure(op));
let conditional_node = hugr.root();

let case_op = ops::Case {
Expand Down Expand Up @@ -1081,7 +1063,7 @@ mod test {
fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::DFG {
signature: df_sig
.clone()
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
Expand Down Expand Up @@ -1252,7 +1234,7 @@ mod test {
let b = ExtensionSet::singleton(&B);
let c = ExtensionSet::singleton(&C);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc),
}));

Expand Down Expand Up @@ -1350,7 +1332,7 @@ mod test {
/// +--------------------+
#[test]
fn multi_entry() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions?
}));
let cfg = hugr.root();
Expand Down Expand Up @@ -1433,7 +1415,7 @@ mod test {
) -> Result<Hugr, Box<dyn Error>> {
let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext);

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&hugr_delta),
}));
Expand Down
Loading

0 comments on commit 97b265f

Please sign in to comment.