diff --git a/src/builder.rs b/src/builder.rs index 52d2b7ca4..b2b8d7371 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -146,7 +146,7 @@ pub(crate) mod test { /// inference. Using DFGBuilder will default to a root node with an open /// extension variable pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr { - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: signature.clone(), })); hugr.add_op_with_parent( diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 6294df033..950a85903 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -200,7 +200,7 @@ pub trait Dataflow: Container { op: impl Into, input_wires: impl IntoIterator, ) -> Result, BuildError> { - self.add_dataflow_node(NodeType::open_extensions(op), input_wires) + self.add_dataflow_node(NodeType::new_auto(op), input_wires) } /// Add a dataflow [`NodeType`] to the sibling graph, wiring up the `input_wires` to the @@ -628,7 +628,7 @@ fn add_op_with_wires( optype: impl Into, inputs: Vec, ) -> Result<(Node, usize), BuildError> { - add_node_with_wires(data_builder, NodeType::open_extensions(optype), inputs) + add_node_with_wires(data_builder, NodeType::new_auto(optype), inputs) } fn add_node_with_wires( diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index eb168082e..809093652 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -62,7 +62,7 @@ impl CFGBuilder { signature: signature.clone(), }; - let base = Hugr::new(NodeType::open_extensions(cfg_op)); + let base = Hugr::new(NodeType::new_open(cfg_op)); let cfg_node = base.root(); CFGBuilder::create(base, cfg_node, signature.input, signature.output) } diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index da8808eea..6a46b5e55 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -176,7 +176,7 @@ impl ConditionalBuilder { extension_delta, }; // TODO: Allow input extensions to be specified - let base = Hugr::new(NodeType::open_extensions(op)); + let base = Hugr::new(NodeType::new_open(op)); let conditional_node = base.root(); Ok(ConditionalBuilder { @@ -194,7 +194,7 @@ impl CaseBuilder { let op = ops::Case { signature: signature.clone(), }; - let base = Hugr::new(NodeType::open_extensions(op)); + let base = Hugr::new(NodeType::new_open(op)); let root = base.root(); let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?; diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 9088b3c5c..03480d13b 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -79,7 +79,7 @@ impl DFGBuilder { let dfg_op = ops::DFG { signature: signature.clone(), }; - let base = Hugr::new(NodeType::open_extensions(dfg_op)); + let base = Hugr::new(NodeType::new_open(dfg_op)); let root = base.root(); DFGBuilder::create_with_io(base, root, signature, None) } diff --git a/src/builder/module.rs b/src/builder/module.rs index 83b08a32c..a78c047d7 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -90,7 +90,7 @@ impl + AsRef> ModuleBuilder { }; self.hugr_mut().replace_op( f_node, - NodeType::pure(ops::FuncDefn { + NodeType::new_pure(ops::FuncDefn { name, signature: signature.clone(), }), diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index a8a07eb4a..5eb8286e1 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -82,7 +82,7 @@ impl TailLoopBuilder { rest: inputs_outputs.into(), }; // TODO: Allow input extensions to be specified - let base = Hugr::new(NodeType::open_extensions(tail_loop.clone())); + let base = Hugr::new(NodeType::new_open(tail_loop.clone())); let root = base.root(); Self::create_with_io(base, root, &tail_loop) } diff --git a/src/extension.rs b/src/extension.rs index aec4f6020..dd7401cfe 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -4,7 +4,7 @@ //! system (outside the `types` module), which also parses nested [`OpDef`]s. use std::collections::hash_map::Entry; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; @@ -301,18 +301,13 @@ pub enum ExtensionBuildError { } /// A set of extensions identified by their unique [`ExtensionId`]. -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct ExtensionSet(HashSet); +#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct ExtensionSet(BTreeSet); impl ExtensionSet { /// Creates a new empty extension set. - pub fn new() -> Self { - Self(HashSet::new()) - } - - /// Creates a new extension set from some extensions. - pub fn new_from_extensions(extensions: impl Into>) -> Self { - Self(extensions.into()) + pub const fn new() -> Self { + Self(BTreeSet::new()) } /// Adds a extension to the set. @@ -350,13 +345,18 @@ impl ExtensionSet { /// The things in other which are in not in self pub fn missing_from(&self, other: &Self) -> Self { - ExtensionSet(HashSet::from_iter(other.0.difference(&self.0).cloned())) + ExtensionSet::from_iter(other.0.difference(&self.0).cloned()) } /// Iterate over the contained ExtensionIds pub fn iter(&self) -> impl Iterator { self.0.iter() } + + /// True if this set contains no [ExtensionId]s + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } impl Display for ExtensionSet { @@ -367,6 +367,6 @@ impl Display for ExtensionSet { impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { - Self(HashSet::from_iter(iter)) + Self(BTreeSet::from_iter(iter)) } } diff --git a/src/extension/infer.rs b/src/extension/infer.rs index c86710e76..5ebb0b8b7 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -10,7 +10,7 @@ //! depend on these open variables, then the validation check for extensions //! will succeed regardless of what the variable is instantiated to. -use super::{ExtensionId, ExtensionSet}; +use super::ExtensionSet; use crate::{ hugr::views::HugrView, ops::{OpTag, OpTrait}, @@ -65,8 +65,8 @@ impl Meta { enum Constraint { /// A variable has the same value as another variable Equal(Meta), - /// Variable extends the value of another by one extension - Plus(ExtensionId, Meta), + /// Variable extends the value of another by a set of extensions + Plus(ExtensionSet, Meta), } #[derive(Debug, Clone, PartialEq, Error)] @@ -230,26 +230,6 @@ impl UnificationContext { self.solved.get(&self.resolve(*m)) } - /// Convert an extension *set* difference in terms of a sequence of fresh - /// metas with `Plus` constraints which each add only one extension req. - fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) { - let mut last_meta = input; - // Create fresh metavariables with `Plus` constraints for - // each extension that should be added by the node - // Hence a extension delta [A, B] would lead to - // > ma = fresh_meta() - // > add_constraint(ma, Plus(a, input) - // > mb = fresh_meta() - // > add_constraint(mb, Plus(b, ma) - // > add_constraint(output, Equal(mb)) - for r in delta.0.into_iter() { - let curr_meta = self.fresh_meta(); - self.add_constraint(curr_meta, Constraint::Plus(r, last_meta)); - last_meta = curr_meta; - } - self.add_constraint(output, Constraint::Equal(last_meta)); - } - /// Return the metavariable corresponding to the given location on the /// graph, either by making a new meta, or looking it up fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta { @@ -311,17 +291,13 @@ impl UnificationContext { match node_type.signature() { // Input extensions are open None => { - self.gen_union_constraint( - m_input, - m_output, - node_type.op_signature().extension_reqs, - ); - if matches!( - node_type.tag(), - OpTag::Alias | OpTag::Function | OpTag::FuncDefn - ) { - self.add_solution(m_input, ExtensionSet::new()); - } + let delta = node_type.op_signature().extension_reqs; + let c = if delta.is_empty() { + Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) + }; + self.add_constraint(m_output, c); } // We have a solution for everything! Some(sig) => { @@ -510,8 +486,7 @@ impl UnificationContext { // to a set which already contained it. Constraint::Plus(r, other_meta) => { if let Some(rs) = self.get_solution(other_meta) { - let mut rrs = rs.clone(); - rrs.insert(r); + let rrs = rs.clone().union(r); match self.get_solution(&meta) { // Let's check that this is right? Some(rs) => { @@ -657,19 +632,19 @@ impl UnificationContext { // Handle the case where the constraints for `m` contain a self // reference, i.e. "m = Plus(E, m)", in which case the variable // should be instantiated to E rather than the empty set. - let solution = - ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map( - |c| match c { - // If `m` has been merged, [`self.variables`] entry - // will have already been updated to the merged - // value by [`self.merge_equal_metas`] so we don't - // need to worry about resolving it. - Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => { - Some(x.clone()) - } - _ => None, - }, - )); + let solution = self + .get_constraints(&m) + .unwrap() + .iter() + .filter_map(|c| match c { + // If `m` has been merged, [`self.variables`] entry + // will have already been updated to the merged + // value by [`self.merge_equal_metas`] so we don't + // need to worry about resolving it. + Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x), + _ => None, + }) + .fold(ExtensionSet::new(), ExtensionSet::union); self.add_solution(m, solution); } } @@ -685,6 +660,7 @@ mod test { use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::prelude::QB_T; + use crate::extension::ExtensionId; use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; @@ -720,7 +696,7 @@ mod test { signature: main_sig, }; - let root_node = NodeType::open_extensions(op); + let root_node = NodeType::new_open(op); let mut hugr = Hugr::new(root_node); let input = ops::Input::new(type_row![NAT, NAT]); @@ -804,8 +780,14 @@ mod test { ctx.solved.insert(metas[2], ExtensionSet::singleton(&A)); ctx.add_constraint(metas[1], Constraint::Equal(metas[2])); - ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2])); - ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0])); + ctx.add_constraint( + metas[0], + Constraint::Plus(ExtensionSet::singleton(&B), metas[2]), + ); + ctx.add_constraint( + metas[4], + Constraint::Plus(ExtensionSet::singleton(&C), metas[0]), + ); ctx.add_constraint(metas[3], Constraint::Equal(metas[4])); ctx.add_constraint(metas[5], Constraint::Equal(metas[0])); ctx.main_loop()?; @@ -830,21 +812,21 @@ mod test { // This generates a solution that causes validation to fail // because of a missing lift node fn missing_lift_node() -> Result<(), Box> { - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]) .with_extension_delta(&ExtensionSet::singleton(&A)), })); let input = hugr.add_node_with_parent( hugr.root(), - NodeType::pure(ops::Input { + NodeType::new_pure(ops::Input { types: type_row![NAT], }), )?; let output = hugr.add_node_with_parent( hugr.root(), - NodeType::pure(ops::Output { + NodeType::new_pure(ops::Output { types: type_row![NAT], }), )?; @@ -878,8 +860,8 @@ mod test { .insert((NodeIndex::new(4).into(), Direction::Incoming), ab); ctx.variables.insert(a); ctx.variables.insert(b); - ctx.add_constraint(ab, Constraint::Plus(A, b)); - ctx.add_constraint(ab, Constraint::Plus(B, a)); + ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b)); + ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a)); let solution = ctx.main_loop()?; // We'll only find concrete solutions for the Incoming extension reqs of // the main node created by `Hugr::default` @@ -1046,7 +1028,7 @@ mod test { extension_delta: rs.clone(), }; - let mut hugr = Hugr::new(NodeType::pure(op)); + let mut hugr = Hugr::new(NodeType::new_pure(op)); let conditional_node = hugr.root(); let case_op = ops::Case { @@ -1081,7 +1063,7 @@ mod test { fn extension_adding_sequence() -> Result<(), Box> { let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]); - let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::DFG { signature: df_sig .clone() .with_extension_delta(&ExtensionSet::from_iter([A, B])), @@ -1252,7 +1234,7 @@ mod test { let b = ExtensionSet::singleton(&B); let c = ExtensionSet::singleton(&C); - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc), })); @@ -1350,7 +1332,7 @@ mod test { /// +--------------------+ #[test] fn multi_entry() -> Result<(), Box> { - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions? })); let cfg = hugr.root(); @@ -1433,7 +1415,7 @@ mod test { ) -> Result> { let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext); - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]) .with_extension_delta(&hugr_delta), })); diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 817adcec1..8fe92e4fc 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -8,8 +8,6 @@ use super::{ TypeParametrised, }; -use crate::types::SignatureDescription; - use crate::types::FunctionType; use crate::types::type_param::TypeArg; @@ -34,18 +32,6 @@ pub trait CustomSignatureFunc: Send + Sync { misc: &HashMap, extension_registry: &ExtensionRegistry, ) -> Result; - - /// Describe the signature of a node, given the operation name, - /// values for the type parameters, - /// and 'misc' data from the extension definition YAML. - fn describe_signature( - &self, - _name: &SmolStr, - _arg_values: &[TypeArg], - _misc: &HashMap, - ) -> SignatureDescription { - SignatureDescription::default() - } } // Note this is very much a utility, rather than definitive; @@ -208,16 +194,6 @@ impl OpDef { Ok(res) } - /// Optional description of the ports in the signature. - pub fn signature_desc(&self, args: &[TypeArg]) -> SignatureDescription { - match &self.signature_func { - SignatureFunc::FromDecl { .. } => { - todo!() - } - SignatureFunc::CustomFunc(bf) => bf.describe_signature(&self.name, args, &self.misc), - } - } - pub(crate) fn should_serialize_signature(&self) -> bool { match self.signature_func { SignatureFunc::CustomFunc(_) => true, diff --git a/src/hugr.rs b/src/hugr.rs index d6dcd5ec6..128eee522 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -82,7 +82,7 @@ impl NodeType { } /// Instantiate an OpType with no input extensions - pub fn pure(op: impl Into) -> Self { + pub fn new_pure(op: impl Into) -> Self { NodeType { op: op.into(), input_extensions: Some(ExtensionSet::new()), @@ -91,13 +91,24 @@ impl NodeType { /// Instantiate an OpType with an unknown set of input extensions /// (to be inferred later) - pub fn open_extensions(op: impl Into) -> Self { + pub fn new_open(op: impl Into) -> Self { NodeType { op: op.into(), input_extensions: None, } } + /// Instantiate an [OpType] with the default set of input extensions + /// for that OpType. + pub fn new_auto(op: impl Into) -> Self { + let op = op.into(); + if OpTag::ModuleOp.is_superset(op.tag()) { + Self::new_pure(op) + } else { + Self::new_open(op) + } + } + /// Use the input extensions to calculate the concrete signature of the node pub fn signature(&self) -> Option { self.input_extensions @@ -119,9 +130,7 @@ impl NodeType { pub fn input_extensions(&self) -> Option<&ExtensionSet> { self.input_extensions.as_ref() } -} -impl NodeType { /// Gets the underlying [OpType] i.e. without any [input_extensions] /// /// [input_extensions]: NodeType::input_extensions @@ -153,7 +162,7 @@ impl OpType { impl Default for Hugr { fn default() -> Self { - Self::new(NodeType::pure(crate::ops::Module)) + Self::new(NodeType::new_pure(crate::ops::Module)) } } @@ -239,7 +248,7 @@ impl Hugr { /// Add a node to the graph, with the default conversion from OpType to NodeType pub(crate) fn add_op(&mut self, op: impl Into) -> Node { - self.add_node(NodeType::open_extensions(op)) + self.add_node(NodeType::new_auto(op)) } /// Add a node to the graph. diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index fca006b5d..3e1ef81dc 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -37,7 +37,7 @@ pub trait HugrMut: HugrMutInternals { parent: Node, op: impl Into, ) -> Result { - self.add_node_with_parent(parent, NodeType::open_extensions(op)) + self.add_node_with_parent(parent, NodeType::new_auto(op)) } /// Add a node to the graph with a parent in the hierarchy. @@ -217,7 +217,7 @@ impl + AsMut> HugrMut for T { } fn add_op_before(&mut self, sibling: Node, op: impl Into) -> Result { - self.add_node_before(sibling, NodeType::open_extensions(op)) + self.add_node_before(sibling, NodeType::new_auto(op)) } fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result { @@ -620,7 +620,7 @@ mod test { { let f_in = hugr - .add_node_with_parent(f, NodeType::pure(ops::Input::new(type_row![NAT]))) + .add_node_with_parent(f, NodeType::new_pure(ops::Input::new(type_row![NAT]))) .unwrap(); let f_out = hugr .add_op_with_parent(f, ops::Output::new(type_row![NAT, NAT])) diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 74e380367..edc517b0c 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -222,10 +222,7 @@ impl TryFrom for Hugr { for node_ser in nodes { hugr.add_node_with_parent( node_ser.parent, - match node_ser.input_extensions { - None => NodeType::open_extensions(node_ser.op), - Some(rs) => NodeType::new(node_ser.op, rs), - }, + NodeType::new(node_ser.op, node_ser.input_extensions), )?; } @@ -332,11 +329,11 @@ pub mod test { let mut h = Hierarchy::new(); let mut op_types = UnmanagedDenseMap::new(); - op_types[root] = NodeType::open_extensions(gen_optype(&g, root)); + op_types[root] = NodeType::new_open(gen_optype(&g, root)); for n in [a, b, c] { h.push_child(n, root).unwrap(); - op_types[n] = NodeType::pure(gen_optype(&g, n)); + op_types[n] = NodeType::new_pure(gen_optype(&g, n)); } let hg = Hugr { diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index be8d5062d..9cca30c54 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -853,7 +853,7 @@ mod test { Err(ValidationError::NoParent { node }) => assert_eq!(node, other) ); b.set_parent(other, root).unwrap(); - b.replace_op(other, NodeType::pure(declare_op)).unwrap(); + b.replace_op(other, NodeType::new_pure(declare_op)).unwrap(); b.add_ports(other, Direction::Outgoing, 1); assert_eq!(b.validate(&EMPTY_REG), Ok(())); @@ -872,7 +872,7 @@ mod test { fn leaf_root() { let leaf_op: OpType = LeafOp::Noop { ty: USIZE_T }.into(); - let b = Hugr::new(NodeType::pure(leaf_op)); + let b = Hugr::new(NodeType::new_pure(leaf_op)); assert_eq!(b.validate(&EMPTY_REG), Ok(())); } @@ -883,7 +883,7 @@ mod test { } .into(); - let mut b = Hugr::new(NodeType::pure(dfg_op)); + let mut b = Hugr::new(NodeType::new_pure(dfg_op)); let root = b.root(); add_df_children(&mut b, root, 1); assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); @@ -956,7 +956,7 @@ mod test { .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, NodeType::pure(LeafOp::Noop { ty: NAT })) + b.replace_op(output, NodeType::new_pure(LeafOp::Noop { ty: NAT })) .unwrap(); assert_matches!( b.validate(&EMPTY_REG), @@ -964,8 +964,11 @@ mod test { ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, NodeType::pure(ops::Output::new(type_row![BOOL_T]))) - .unwrap(); + b.replace_op( + output, + NodeType::new_pure(ops::Output::new(type_row![BOOL_T])), + ) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) @@ -973,14 +976,14 @@ mod test { ); b.replace_op( output, - NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), + NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), ) .unwrap(); // After fixing the output back, replace the copy with an output op b.replace_op( copy, - NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), + NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), ) .unwrap(); assert_matches!( @@ -1007,7 +1010,7 @@ mod test { b.validate(&EMPTY_REG).unwrap(); b.replace_op( copy, - NodeType::pure(ops::CFG { + NodeType::new_pure(ops::CFG { signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), }), ) @@ -1063,7 +1066,7 @@ mod test { // Change the types in the BasicBlock node to work on qubits instead of bits b.replace_op( block, - NodeType::pure(ops::BasicBlock::DFB { + NodeType::new_pure(ops::BasicBlock::DFB { inputs: type_row![Q], tuple_sum_rows: vec![type_row![]], other_outputs: type_row![Q], @@ -1074,11 +1077,14 @@ mod test { let mut block_children = b.hierarchy.children(block.pg_index()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q]))) - .unwrap(); + b.replace_op( + block_input, + NodeType::new_pure(ops::Input::new(type_row![Q])), + ) + .unwrap(); b.replace_op( block_output, - NodeType::pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), + NodeType::new_pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), ) .unwrap(); assert_matches!( @@ -1310,12 +1316,12 @@ mod test { let main_signature = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: main_signature, })); let input = hugr.add_node_with_parent( hugr.root(), - NodeType::pure(ops::Input { + NodeType::new_pure(ops::Input { types: type_row![NAT], }), )?; diff --git a/src/hugr/views/root_checked.rs b/src/hugr/views/root_checked.rs index 26815e8ed..6b3f7aba3 100644 --- a/src/hugr/views/root_checked.rs +++ b/src/hugr/views/root_checked.rs @@ -79,7 +79,7 @@ mod test { #[test] fn root_checked() { - let root_type = NodeType::pure(ops::DFG { + let root_type = NodeType::new_pure(ops::DFG { signature: FunctionType::new(vec![], vec![]), }); let mut h = Hugr::new(root_type.clone()); @@ -94,7 +94,7 @@ mod test { let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); // That is a HugrMutInternal, so we can try: let root = dfg_v.root(); - let bb = NodeType::pure(BasicBlock::DFB { + let bb = NodeType::new_pure(BasicBlock::DFB { inputs: type_row![], other_outputs: type_row![], tuple_sum_rows: vec![type_row![]], @@ -129,7 +129,7 @@ mod test { let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap(); // And it's a HugrMut: - let nodetype = NodeType::pure(LeafOp::MakeTuple { tys: type_row![] }); + let nodetype = NodeType::new_pure(LeafOp::MakeTuple { tys: type_row![] }); bb_v.add_node_with_parent(bb_v.root(), nodetype).unwrap(); } } diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 69a2da4f5..77f74f39b 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -454,7 +454,7 @@ mod test { ); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); - let bad_nodetype = NodeType::open_extensions(crate::ops::CFG { signature }); + let bad_nodetype = NodeType::new_open(crate::ops::CFG { signature }); assert_eq!( sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), Err(HugrError::InvalidTag { @@ -471,7 +471,7 @@ mod test { #[rstest] fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { let root = simple_dfg_hugr.root(); - let case_nodetype = NodeType::open_extensions(crate::ops::Case { + let case_nodetype = NodeType::new_open(crate::ops::Case { signature: simple_dfg_hugr.root_type().op_signature(), }); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); diff --git a/src/ops.rs b/src/ops.rs index 4926e60b6..f6ef25004 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -9,7 +9,7 @@ pub mod leaf; pub mod module; pub mod tag; pub mod validate; -use crate::types::{EdgeKind, FunctionType, SignatureDescription, Type}; +use crate::types::{EdgeKind, FunctionType, Type}; use crate::PortIndex; use crate::{Direction, Port}; @@ -189,12 +189,6 @@ pub trait OpTrait { fn signature(&self) -> FunctionType { Default::default() } - /// Optional description of the ports in the signature. - /// - /// Only dataflow operations have a non-empty signature. - fn signature_desc(&self) -> SignatureDescription { - Default::default() - } /// Get the static input type of this operation if it has one (only Some for /// [`LoadConstant`] and [`Call`]) diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 77ef60399..b1c5a39b3 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; -use crate::types::{type_param::TypeArg, FunctionType, SignatureDescription}; +use crate::types::{type_param::TypeArg, FunctionType}; use crate::{Hugr, Node}; use super::tag::OpTag; @@ -76,13 +76,6 @@ impl OpTrait for ExternalOp { } } - fn signature_desc(&self) -> SignatureDescription { - match self { - Self::Opaque(_) => SignatureDescription::default(), - Self::Extension(ExtensionOp { def, args, .. }) => def.signature_desc(args), - } - } - fn tag(&self) -> OpTag { OpTag::Leaf } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index f09cd2328..768d35fcb 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -7,7 +7,7 @@ use super::{OpName, OpTag, OpTrait, StaticTag}; use crate::{ extension::{ExtensionId, ExtensionSet}, - types::{EdgeKind, FunctionType, SignatureDescription, Type, TypeRow}, + types::{EdgeKind, FunctionType, Type, TypeRow}, }; /// Dataflow operations with no children. @@ -118,15 +118,6 @@ impl OpTrait for LeafOp { } } - /// Optional description of the ports in the signature. - fn signature_desc(&self) -> SignatureDescription { - match self { - LeafOp::CustomOp(ext) => ext.signature_desc(), - // TODO: More port descriptions - _ => Default::default(), - } - } - fn other_input(&self) -> Option { Some(EdgeKind::StateOrder) } diff --git a/src/std_extensions.rs b/src/std_extensions.rs index 2ee80378f..bc4af73cb 100644 --- a/src/std_extensions.rs +++ b/src/std_extensions.rs @@ -6,4 +6,3 @@ pub mod arithmetic; pub mod collections; pub mod logic; pub mod quantum; -pub mod rotation; diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 3918fd8ae..d07b97f62 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,7 +1,5 @@ //! Conversions between integer and floating-point values. -use std::collections::HashSet; - use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, type_row, @@ -39,10 +37,10 @@ fn itof_sig(arg_values: &[TypeArg]) -> Result { pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( EXTENSION_ID, - ExtensionSet::new_from_extensions(HashSet::from_iter(vec![ + ExtensionSet::from_iter(vec![ super::int_types::EXTENSION_ID, super::float_types::EXTENSION_ID, - ])), + ]), ); extension diff --git a/src/std_extensions/rotation.rs b/src/std_extensions/rotation.rs deleted file mode 100644 index 35b2102b9..000000000 --- a/src/std_extensions/rotation.rs +++ /dev/null @@ -1,400 +0,0 @@ -#![allow(missing_docs)] -//! This is an experiment, it is probably already outdated. - -use std::ops::{Add, Div, Mul, Neg, Sub}; - -use cgmath::num_traits::ToPrimitive; -use num_rational::Rational64; -use smol_str::SmolStr; - -#[cfg(feature = "pyo3")] -use pyo3::{pyclass, FromPyObject}; - -use crate::extension::ExtensionId; -use crate::types::type_param::TypeArg; -use crate::types::{CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow}; -use crate::values::CustomConst; -use crate::{ops, Extension}; - -pub const PI_NAME: &str = "PI"; -pub const ANGLE_T_NAME: &str = "angle"; -pub const QUAT_T_NAME: &str = "quat"; -pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("rotations"); - -pub const ANGLE_T: Type = Type::new_extension(CustomType::new_simple( - SmolStr::new_inline(ANGLE_T_NAME), - EXTENSION_ID, - TypeBound::Copyable, -)); - -pub const QUAT_T: Type = Type::new_extension(CustomType::new_simple( - SmolStr::new_inline(QUAT_T_NAME), - EXTENSION_ID, - TypeBound::Copyable, -)); -/// The extension with all the operations and types defined in this extension. -pub fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_ID); - - RotationType::Angle.add_to_extension(&mut extension); - RotationType::Quaternion.add_to_extension(&mut extension); - - extension - .add_op_custom_sig_simple( - "AngleAdd".into(), - "".into(), - vec![], - |_arg_values: &[TypeArg]| { - let t: TypeRow = - vec![Type::new_extension(RotationType::Angle.custom_type())].into(); - Ok(FunctionType::new(t.clone(), t)) - }, - ) - .unwrap(); - - let pi_val = RotationValue::PI; - - extension - .add_value(PI_NAME, ops::Const::new(pi_val.into(), ANGLE_T).unwrap()) - .unwrap(); - extension -} - -/// Custom types defined by this extension. -#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum RotationType { - Angle, - Quaternion, -} - -impl RotationType { - pub const fn name(&self) -> SmolStr { - match self { - RotationType::Angle => SmolStr::new_inline(ANGLE_T_NAME), - RotationType::Quaternion => SmolStr::new_inline(QUAT_T_NAME), - } - } - - pub const fn description(&self) -> &str { - match self { - RotationType::Angle => "Floating point angle", - RotationType::Quaternion => "Quaternion specifying rotation.", - } - } - - pub fn custom_type(self) -> CustomType { - CustomType::new(self.name(), [], EXTENSION_ID, TypeBound::Copyable) - } - - fn add_to_extension(self, extension: &mut Extension) { - extension - .add_type( - self.name(), - vec![], - self.description().to_string(), - TypeBound::Copyable.into(), - ) - .unwrap(); - } -} - -impl From for CustomType { - fn from(ty: RotationType) -> Self { - ty.custom_type() - } -} - -/// Constant values for [`RotationType`]. -#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum RotationValue { - Angle(AngleValue), - Quaternion(cgmath::Quaternion), -} - -impl RotationValue { - const PI: Self = Self::Angle(AngleValue::PI); - fn rotation_type(&self) -> RotationType { - match self { - RotationValue::Angle(_) => RotationType::Angle, - RotationValue::Quaternion(_) => RotationType::Quaternion, - } - } -} - -#[typetag::serde] -impl CustomConst for RotationValue { - fn name(&self) -> SmolStr { - match self { - RotationValue::Angle(val) => format!("AngleConstant({})", val.radians()), - RotationValue::Quaternion(val) => format!("QuatConstant({:?})", val), - } - .into() - } - - fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - let self_typ = self.rotation_type(); - - if &self_typ.custom_type() == typ { - Ok(()) - } else { - Err(CustomCheckFailure::Message( - "Rotation constant type mismatch.".into(), - )) - } - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::values::downcast_equal_consts(self, other) - } -} - -// -// TODO: -// -// operations: -// -// AngleAdd, -// AngleMul, -// AngleNeg, -// QuatMul, -// RxF64, -// RzF64, -// TK1, -// Rotation, -// ToRotation, -// -// -// -// signatures: -// -// LeafOp::AngleAdd | LeafOp::AngleMul => FunctionType::new_linear([Type::Angle]), -// LeafOp::QuatMul => FunctionType::new_linear([Type::Quat64]), -// LeafOp::AngleNeg => FunctionType::new_linear([Type::Angle]), -// LeafOp::RxF64 | LeafOp::RzF64 => { -// FunctionType::new_df([Type::Qubit], [Type::Angle]) -// } -// LeafOp::TK1 => FunctionType::new_df(vec![Type::Qubit], vec![Type::Angle; 3]), -// LeafOp::Rotation => FunctionType::new_df([Type::Qubit], [Type::Quat64]), -// LeafOp::ToRotation => FunctionType::new_df( -// [ -// Type::Angle, -// Type::F64, -// Type::F64, -// Type::F64, -// ], -// [Type::Quat64], -// ), - -#[derive(Clone, Copy, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "pyo3", pyclass(name = "Rational"))] -pub struct Rational(pub Rational64); - -impl From for Rational { - fn from(r: Rational64) -> Self { - Self(r) - } -} - -// angle is contained value * pi in radians -#[derive(Clone, PartialEq, Debug, Copy, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "pyo3", derive(FromPyObject))] -pub enum AngleValue { - F64(f64), - Rational(Rational), -} - -impl AngleValue { - const PI: Self = AngleValue::Rational(Rational(Rational64::new_raw(1, 1))); - fn binary_op f64, G: FnOnce(Rational64, Rational64) -> Rational64>( - self, - rhs: Self, - opf: F, - opr: G, - ) -> Self { - match (self, rhs) { - (AngleValue::F64(x), AngleValue::F64(y)) => AngleValue::F64(opf(x, y)), - (AngleValue::F64(x), AngleValue::Rational(y)) - | (AngleValue::Rational(y), AngleValue::F64(x)) => { - AngleValue::F64(opf(x, y.0.to_f64().unwrap())) - } - (AngleValue::Rational(x), AngleValue::Rational(y)) => { - AngleValue::Rational(Rational(opr(x.0, y.0))) - } - } - } - - fn unary_op f64, G: FnOnce(Rational64) -> Rational64>( - self, - opf: F, - opr: G, - ) -> Self { - match self { - AngleValue::F64(x) => AngleValue::F64(opf(x)), - AngleValue::Rational(x) => AngleValue::Rational(Rational(opr(x.0))), - } - } - - pub fn to_f64(&self) -> f64 { - match self { - AngleValue::F64(x) => *x, - AngleValue::Rational(x) => x.0.to_f64().expect("Floating point conversion error."), - } - } - - pub fn radians(&self) -> f64 { - self.to_f64() * std::f64::consts::PI - } -} - -impl Add for AngleValue { - type Output = AngleValue; - - fn add(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x + y, |x, y| x + y) - } -} - -impl Sub for AngleValue { - type Output = AngleValue; - - fn sub(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x - y, |x, y| x - y) - } -} - -impl Mul for AngleValue { - type Output = AngleValue; - - fn mul(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x * y, |x, y| x * y) - } -} - -impl Div for AngleValue { - type Output = AngleValue; - - fn div(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x / y, |x, y| x / y) - } -} - -impl Neg for AngleValue { - type Output = AngleValue; - - fn neg(self) -> Self::Output { - self.unary_op(|x| -x, |x| -x) - } -} - -impl Add for &AngleValue { - type Output = AngleValue; - - fn add(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x + y, |x, y| x + y) - } -} - -impl Sub for &AngleValue { - type Output = AngleValue; - - fn sub(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x - y, |x, y| x - y) - } -} - -impl Mul for &AngleValue { - type Output = AngleValue; - - fn mul(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x * y, |x, y| x * y) - } -} - -impl Div for &AngleValue { - type Output = AngleValue; - - fn div(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x / y, |x, y| x / y) - } -} - -impl Neg for &AngleValue { - type Output = AngleValue; - - fn neg(self) -> Self::Output { - self.unary_op(|x| -x, |x| -x) - } -} - -#[cfg(test)] -mod test { - - use rstest::{fixture, rstest}; - - use super::{AngleValue, RotationValue, ANGLE_T, ANGLE_T_NAME, EXTENSION_ID, PI_NAME}; - use crate::{ - extension::ExtensionId, - extension::SignatureError, - types::{CustomType, Type, TypeBound}, - values::CustomConst, - Extension, - }; - - #[fixture] - fn extension() -> Extension { - super::extension() - } - - #[rstest] - fn test_types(extension: Extension) { - let angle = extension.get_type(ANGLE_T_NAME).unwrap(); - - let custom = angle.instantiate_concrete([]).unwrap(); - - angle.check_custom(&custom).unwrap(); - - let wrong_ext = ExtensionId::new("wrong_extensions").unwrap(); - - let false_custom = CustomType::new( - custom.name().clone(), - vec![], - wrong_ext.clone(), - TypeBound::Copyable, - ); - assert_eq!( - angle.check_custom(&false_custom), - Err(SignatureError::ExtensionMismatch(EXTENSION_ID, wrong_ext,)) - ); - - assert_eq!(Type::new_extension(custom), ANGLE_T); - } - - #[rstest] - fn test_type_check(extension: Extension) { - let custom_type = extension - .get_type(ANGLE_T_NAME) - .unwrap() - .instantiate_concrete([]) - .unwrap(); - - let custom_value = RotationValue::Angle(AngleValue::F64(0.0)); - - // correct type - custom_value.check_custom_type(&custom_type).unwrap(); - - let wrong_custom_type = extension - .get_type("quat") - .unwrap() - .instantiate_concrete([]) - .unwrap(); - let res = custom_value.check_custom_type(&wrong_custom_type); - assert!(res.is_err()); - } - - #[rstest] - fn test_constant(extension: Extension) { - let pi_val = extension.get_value(PI_NAME).unwrap(); - - ANGLE_T.check_type(pi_val.typed_value().value()).unwrap(); - } -} diff --git a/src/types.rs b/src/types.rs index cabdcf39d..8ea81504e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -10,7 +10,7 @@ pub mod type_row; pub use check::{ConstTypeError, CustomCheckFailure}; pub use custom::CustomType; -pub use signature::{FunctionType, Signature, SignatureDescription}; +pub use signature::{FunctionType, Signature}; pub use type_param::TypeArg; pub use type_row::TypeRow; diff --git a/src/types/signature.rs b/src/types/signature.rs index ce13b5e6e..b2995dec0 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -4,13 +4,11 @@ use pyo3::{pyclass, pymethods}; use delegate::delegate; -use smol_str::SmolStr; use std::fmt::{self, Display, Write}; -use std::ops::Index; use crate::extension::ExtensionSet; use crate::types::{Type, TypeRow}; -use crate::{Direction, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Direction, IncomingPort, OutgoingPort, Port}; #[cfg_attr(feature = "pyo3", pyclass)] #[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -249,97 +247,3 @@ impl Display for Signature { } } } - -/// Descriptive names for the ports in a [`Signature`]. -/// -/// This is a separate type from [`Signature`] as it is not normally used during the compiler operations. -#[cfg_attr(feature = "pyo3", pyclass)] -#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct SignatureDescription { - /// Input of the function. - pub input: Vec, - /// Output of the function. - pub output: Vec, -} - -#[cfg_attr(feature = "pyo3", pymethods)] -impl SignatureDescription { - /// The number of wires in the signature. - #[inline(always)] - pub fn is_empty(&self) -> bool { - self.input.is_empty() && self.output.is_empty() - } -} - -impl SignatureDescription { - /// Create a new signature. - pub fn new(input: impl Into>, output: impl Into>) -> Self { - Self { - input: input.into(), - output: output.into(), - } - } - - /// Create a new signature with only linear inputs and outputs. - pub fn new_linear(linear: impl Into>) -> Self { - let linear = linear.into(); - SignatureDescription::new(linear.clone(), linear) - } - - pub(crate) fn row_zip<'a>( - type_row: &'a TypeRow, - name_row: &'a [SmolStr], - ) -> impl Iterator { - name_row - .iter() - .chain(&EmptyStringIterator) - .zip(type_row.iter()) - } - - /// Iterate over the input wires of the signature and their names. - /// - /// Unnamed wires are given an empty string name. - /// - /// TODO: Return Option<&String> instead of &String for the description. - pub fn input_zip<'a>( - &'a self, - signature: &'a Signature, - ) -> impl Iterator { - Self::row_zip(signature.input(), &self.input) - } - - /// Iterate over the output wires of the signature and their names. - /// - /// Unnamed wires are given an empty string name. - pub fn output_zip<'a>( - &'a self, - signature: &'a Signature, - ) -> impl Iterator { - Self::row_zip(signature.output(), &self.output) - } -} - -impl Index for SignatureDescription { - type Output = SmolStr; - - fn index(&self, index: Port) -> &Self::Output { - match index.direction() { - Direction::Incoming => self.input.get(index.index()).unwrap_or(EMPTY_STRING_REF), - Direction::Outgoing => self.output.get(index.index()).unwrap_or(EMPTY_STRING_REF), - } - } -} - -/// An iterator that always returns the an empty string. -pub(crate) struct EmptyStringIterator; - -/// A reference to an empty string. Used by [`EmptyStringIterator`]. -pub(crate) const EMPTY_STRING_REF: &SmolStr = &SmolStr::new_inline(""); - -impl<'a> Iterator for &'a EmptyStringIterator { - type Item = &'a SmolStr; - - fn next(&mut self) -> Option { - Some(EMPTY_STRING_REF) - } -}