From cd5e544a740db9c7e1a2abfb3fda476d000f128d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Sat, 11 Nov 2023 16:08:31 +0000 Subject: [PATCH 01/31] refactor!: move non_df_count to OpTrait Closes #521 --- src/ops.rs | 11 +++++++++-- src/ops/controlflow.rs | 11 ++++++++++- src/ops/validate.rs | 26 +------------------------- 3 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index f6ef25004..e7162f65a 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -99,7 +99,7 @@ impl OpType { /// /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. pub fn other_port_index(&self, dir: Direction) -> Option { - let non_df_count = self.validity_flags().non_df_port_count(dir).unwrap_or(1); + let non_df_count = self.non_df_port_count(dir).unwrap_or(1); if self.other_port(dir).is_some() && non_df_count == 1 { // if there is a static input it comes before the non_df_ports let static_input = @@ -119,7 +119,6 @@ impl OpType { let signature = self.signature(); let has_other_ports = self.other_port(dir).is_some(); let non_df_count = self - .validity_flags() .non_df_port_count(dir) .unwrap_or(has_other_ports as usize); // if there is a static input it comes before the non_df_ports @@ -214,6 +213,14 @@ pub trait OpTrait { fn other_output(&self) -> Option { None } + + /// Get the number of non-dataflow multiports. + /// + /// If None, the operation must have exactly one non-dataflow port + /// if the operation type has other_edges, or zero otherwise. + fn non_df_port_count(&self, _dir: Direction) -> Option { + None + } } #[enum_dispatch] diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 54adcf803..b6836ea70 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -3,8 +3,8 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; -use crate::type_row; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; +use crate::{type_row, Direction}; use super::dataflow::DataflowOpTrait; use super::OpTag; @@ -177,6 +177,15 @@ impl OpTrait for BasicBlock { BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), } } + + fn non_df_port_count(&self, dir: Direction) -> Option { + match self { + Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => { + Some(tuple_sum_rows.len()) + } + _ => None, + } + } } impl BasicBlock { diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 6c6ff65cd..546d6cec6 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -11,7 +11,6 @@ use portgraph::{NodeIndex, PortOffset}; use thiserror::Error; use crate::types::{Type, TypeRow}; -use crate::Direction; use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp}; @@ -32,30 +31,12 @@ pub struct OpValidityFlags { pub requires_children: bool, /// Whether the children must form a DAG (no cycles). pub requires_dag: bool, - /// A strict requirement on the number of non-dataflow multiports. - /// - /// If not specified, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - pub non_df_ports: (Option, Option), /// A validation check for edges between children /// // Enclosed in an `Option` to avoid iterating over the edges if not needed. pub edge_check: Option Result<(), EdgeValidationError>>, } -impl OpValidityFlags { - /// Get the number of non-dataflow multiports. - /// - /// If None, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - pub fn non_df_port_count(&self, dir: Direction) -> Option { - match dir { - Direction::Incoming => self.non_df_ports.0, - Direction::Outgoing => self.non_df_ports.1, - } - } -} - impl Default for OpValidityFlags { fn default() -> Self { // Defaults to flags valid for non-container operations @@ -65,7 +46,6 @@ impl Default for OpValidityFlags { allowed_second_child: OpTag::Any, requires_children: false, requires_dag: false, - non_df_ports: (None, None), edge_check: None, } } @@ -316,16 +296,12 @@ impl ValidateOp for BasicBlock { /// Returns the set of allowed parent operation types. fn validity_flags(&self) -> OpValidityFlags { match self { - BasicBlock::DFB { - tuple_sum_rows: tuple_sum_variants, - .. - } => OpValidityFlags { + BasicBlock::DFB { .. } => OpValidityFlags { allowed_children: OpTag::DataflowChild, allowed_first_child: OpTag::Input, allowed_second_child: OpTag::Output, requires_children: true, requires_dag: true, - non_df_ports: (None, Some(tuple_sum_variants.len())), ..Default::default() }, // Default flags are valid for non-container operations From 2b1bd372be9d1cbf2fc404691e69d0f591cb9f94 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Sat, 11 Nov 2023 16:27:14 +0000 Subject: [PATCH 02/31] fix: builder `call` port calculation incorrect --- src/builder/build_traits.rs | 2 +- src/builder/module.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index dfdc86488..47d6ec63f 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -612,7 +612,7 @@ pub trait Dataflow: Container { }) } }; - let const_in_port = signature.output.len(); + let const_in_port = signature.input.len(); let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; diff --git a/src/builder/module.rs b/src/builder/module.rs index a78c047d7..fb3e2a1e0 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -233,14 +233,14 @@ mod test { let mut f_build = module_builder.define_function( "main", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(), )?; let local_build = f_build.define_function( "local", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(), )?; let [wire] = local_build.input_wires_arr(); - let f_id = local_build.finish_with_outputs([wire])?; + let f_id = local_build.finish_with_outputs([wire, wire])?; let call = f_build.call(f_id.handle(), f_build.input_wires())?; From b959e16c376dde6bea346da39995c6b0e68a8088 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Sat, 11 Nov 2023 16:57:06 +0000 Subject: [PATCH 03/31] refactor!: remove `static_input` Since it only applies to LoadConst and Call, use special case methods, and a tag to capture these ops. --- src/ops.rs | 25 +++++++++++++------------ src/ops/dataflow.rs | 24 ++++++++++-------------- src/ops/tag.rs | 8 ++++++-- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index e7162f65a..e33359063 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -86,9 +86,9 @@ impl OpType { signature.get(port).cloned().map(EdgeKind::Value) } else if port.index() == port_count && dir == Direction::Incoming - && self.static_input().is_some() + && OpTag::StaticInput.is_superset(self.tag()) { - self.static_input().map(EdgeKind::Static) + Some(EdgeKind::Static(static_in_type(self))) } else { self.other_port(dir) } @@ -103,7 +103,7 @@ impl OpType { if self.other_port(dir).is_some() && non_df_count == 1 { // if there is a static input it comes before the non_df_ports let static_input = - (dir == Direction::Incoming && self.static_input().is_some()) as usize; + (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; Some(Port::new( dir, @@ -122,7 +122,8 @@ impl OpType { .non_df_port_count(dir) .unwrap_or(has_other_ports as usize); // if there is a static input it comes before the non_df_ports - let static_input = (dir == Direction::Incoming && self.static_input().is_some()) as usize; + let static_input = + (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; signature.port_count(dir) + non_df_count + static_input } @@ -142,6 +143,14 @@ impl OpType { } } +fn static_in_type(op: &OpType) -> Type { + match op { + OpType::Call(call) => Type::new_function(call.called_function_type().clone()), + OpType::LoadConstant(load) => load.constant_type().clone(), + _ => panic!("this function should not be called if the optype is not known to be Call or LoadConst.") + } +} + /// Macro used by operations that want their /// name to be the same as their type name macro_rules! impl_op_name { @@ -188,14 +197,6 @@ pub trait OpTrait { fn signature(&self) -> FunctionType { Default::default() } - - /// Get the static input type of this operation if it has one (only Some for - /// [`LoadConstant`] and [`Call`]) - #[inline] - fn static_input(&self) -> Option { - None - } - /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 5b63b706f..4ab06d5cf 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -11,10 +11,6 @@ pub(super) trait DataflowOpTrait { fn description(&self) -> &str; fn signature(&self) -> FunctionType; - /// Get the static input type of this operation if it has one. - fn static_input(&self) -> Option { - None - } /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// @@ -125,10 +121,6 @@ impl OpTrait for T { fn other_output(&self) -> Option { DataflowOpTrait::other_output(self) } - - fn static_input(&self) -> Option { - DataflowOpTrait::static_input(self) - } } impl StaticTag for T { const TAG: OpTag = T::TAG; @@ -156,10 +148,12 @@ impl DataflowOpTrait for Call { fn signature(&self) -> FunctionType { self.signature.clone() } - +} +impl Call { #[inline] - fn static_input(&self) -> Option { - Some(Type::new_function(self.signature.clone())) + /// Return the signature of the function called by this op. + pub fn called_function_type(&self) -> &FunctionType { + &self.signature } } @@ -204,10 +198,12 @@ impl DataflowOpTrait for LoadConstant { fn signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), vec![self.datatype.clone()]) } - +} +impl LoadConstant { #[inline] - fn static_input(&self) -> Option { - Some(self.datatype.clone()) + /// The type of the constant loaded by this op. + pub fn constant_type(&self) -> &Type { + &self.datatype } } diff --git a/src/ops/tag.rs b/src/ops/tag.rs index e45014977..049f383e1 100644 --- a/src/ops/tag.rs +++ b/src/ops/tag.rs @@ -46,6 +46,8 @@ pub enum OpTag { Input, /// A dataflow output. Output, + /// Dataflow node that has a static input + StaticInput, /// A function call. FnCall, /// A constant load operation. @@ -121,8 +123,9 @@ impl OpTag { ], OpTag::TailLoop => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Conditional => &[OpTag::DataflowChild], - OpTag::FnCall => &[OpTag::DataflowChild], - OpTag::LoadConst => &[OpTag::DataflowChild], + OpTag::StaticInput => &[OpTag::DataflowChild], + OpTag::FnCall => &[OpTag::StaticInput], + OpTag::LoadConst => &[OpTag::StaticInput], OpTag::Leaf => &[OpTag::DataflowChild], OpTag::DataflowParent => &[OpTag::Any], } @@ -150,6 +153,7 @@ impl OpTag { OpTag::Cfg => "Nested control-flow operation", OpTag::TailLoop => "Tail-recursive loop", OpTag::Conditional => "Conditional operation", + OpTag::StaticInput => "Dataflow child with static input (LoadConst or FnCall)", OpTag::FnCall => "Function call", OpTag::LoadConst => "Constant load operation", OpTag::Leaf => "Leaf operation", From 29350eec745ead41c6a3b1c86cfaa94862beeeac Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Sat, 11 Nov 2023 17:23:15 +0000 Subject: [PATCH 04/31] refactor!: clearer OpType port/kind method names And addds `static_input_port` method. Closes Clean up OpType methods #495 --- src/hugr/hugrmut.rs | 4 +-- src/hugr/serialize.rs | 2 +- src/hugr/views/sibling_subgraph.rs | 4 +-- src/ops.rs | 43 ++++++++++++++++++++++-------- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 00bef6d18..f549929d0 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -314,12 +314,12 @@ impl + AsMut> HugrMut for T { ) -> Result<(OutgoingPort, IncomingPort), HugrError> { let src_port = self .get_optype(src) - .other_port_index(Direction::Outgoing) + .other_output_port() .expect("Source operation has no non-dataflow outgoing edges") .as_outgoing()?; let dst_port = self .get_optype(dst) - .other_port_index(Direction::Incoming) + .other_input_port() .expect("Destination operation has no non-dataflow incoming edges") .as_incoming()?; self.connect(src, src_port, dst, dst_port)?; diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index a96d1abca..abc985616 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -246,7 +246,7 @@ impl TryFrom for Hugr { None => { let op_type = hugr.get_optype(node); op_type - .other_port_index(dir) + .other_port(dir) .ok_or(HUGRSerializationError::MissingPortOffset { node, op_type: op_type.clone(), diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 515392615..55d1426b5 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -596,13 +596,13 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort /// Whether a port is linked to a state order edge. fn is_order_edge(hugr: &H, node: Node, port: Port) -> bool { let op = hugr.get_optype(node); - op.other_port_index(port.direction()) == Some(port) && hugr.is_linked(node, port) + op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port) } /// Whether node has a non-df linked port in the given direction. fn has_other_edge(hugr: &H, node: Node, dir: Direction) -> bool { let op = hugr.get_optype(node); - op.other_port(dir).is_some() && hugr.is_linked(node, op.other_port_index(dir).unwrap()) + op.other_port_kind(dir).is_some() && hugr.is_linked(node, op.other_port(dir).unwrap()) } /// Errors that can occur while constructing a [`SimpleReplacement`]. diff --git a/src/ops.rs b/src/ops.rs index e33359063..fa16e1803 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -63,12 +63,12 @@ impl Default for OpType { } impl OpType { - /// The edge kind for the non-dataflow or constant-input ports of the + /// The edge kind for the non-dataflow or constant ports of the /// operation, not described by the signature. /// /// If not None, a single extra multiport of that kind will be present on /// the given direction. - pub fn other_port(&self, dir: Direction) -> Option { + pub fn other_port_kind(&self, dir: Direction) -> Option { match dir { Direction::Incoming => self.other_input(), Direction::Outgoing => self.other_output(), @@ -84,23 +84,20 @@ impl OpType { let port_count = signature.port_count(dir); if port.index() < port_count { signature.get(port).cloned().map(EdgeKind::Value) - } else if port.index() == port_count - && dir == Direction::Incoming - && OpTag::StaticInput.is_superset(self.tag()) - { + } else if Some(port) == self.static_input_port() { Some(EdgeKind::Static(static_in_type(self))) } else { - self.other_port(dir) + self.other_port_kind(dir) } } /// The non-dataflow port for the operation, not described by the signature. - /// See `[OpType::other_port]`. + /// See `[OpType::other_port_kind]`. /// /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. - pub fn other_port_index(&self, dir: Direction) -> Option { + pub fn other_port(&self, dir: Direction) -> Option { let non_df_count = self.non_df_port_count(dir).unwrap_or(1); - if self.other_port(dir).is_some() && non_df_count == 1 { + if self.other_port_kind(dir).is_some() && non_df_count == 1 { // if there is a static input it comes before the non_df_ports let static_input = (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; @@ -114,10 +111,34 @@ impl OpType { } } + /// The non-dataflow input port for the operation, not described by the signature. + /// See `[OpType::other_port]`. + pub fn other_input_port(&self) -> Option { + self.other_port(Direction::Incoming) + } + + /// The non-dataflow input port for the operation, not described by the signature. + /// See `[OpType::other_port]`. + pub fn other_output_port(&self) -> Option { + self.other_port(Direction::Outgoing) + } + + /// If the op has a static input (Call and LoadConstant), the port of that input. + pub fn static_input_port(&self) -> Option { + match self { + OpType::Call(call) => Some(Port::new( + Direction::Incoming, + call.called_function_type().input_count(), + )), + OpType::LoadConstant(_) => Some(Port::new(Direction::Incoming, 0)), + _ => None, + } + } + /// Returns the number of ports for the given direction. pub fn port_count(&self, dir: Direction) -> usize { let signature = self.signature(); - let has_other_ports = self.other_port(dir).is_some(); + let has_other_ports = self.other_port_kind(dir).is_some(); let non_df_count = self .non_df_port_count(dir) .unwrap_or(has_other_ports as usize); From 7477cea2ef6106b56b228b1bbc78cf41a1485674 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Sat, 11 Nov 2023 18:28:43 +0000 Subject: [PATCH 05/31] feat: iterator over all ports connected to node Uses return position impl because it will be stable soon and its much nicer. Closes #655 --- Cargo.toml | 1 + src/hugr/views.rs | 14 ++++++++++++ src/hugr/views/tests.rs | 49 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 53a0f81f2..f445091bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" delegate = "0.10.0" +rustversion = "1.0.14" [features] pyo3 = ["dep:pyo3"] diff --git a/src/hugr/views.rs b/src/hugr/views.rs index eafc68986..ead394c0b 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -179,6 +179,20 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over the nodes and ports connected to a port. fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all the nodes and ports connected to a node's inputs.. + fn all_linked_outputs(&self, node: Node) -> impl Iterator { + self.node_inputs(node) + .flat_map(move |port| self.linked_outputs(node, port)) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all the nodes and ports connected to a node's outputs.. + fn all_linked_inputs(&self, node: Node) -> impl Iterator { + self.node_outputs(node) + .flat_map(move |port| self.linked_inputs(node, port)) + } + /// Iterator over the nodes and output ports connected to a given *input* port. /// Like [`linked_ports`][HugrView::linked_ports] but preserves knowledge /// that the linked ports are [OutgoingPort]s. diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 4e592a2de..816a0c41b 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -68,3 +68,52 @@ fn dot_string(sample_hugr: (Hugr, BuildHandle, BuildHandle, BuildHandle)) { + use crate::hugr::Direction; + use crate::Port; + use itertools::Itertools; + let (h, n1, n2) = sample_hugr; + + let all_output_ports = h.all_linked_outputs(n2.node()).collect_vec(); + + assert_eq!( + &all_output_ports[..], + &[ + ( + n1.node(), + Port::new(Direction::Outgoing, 1).as_outgoing().unwrap() + ), + ( + n1.node(), + Port::new(Direction::Outgoing, 0).as_outgoing().unwrap() + ), + ( + n1.node(), + Port::new(Direction::Outgoing, 2).as_outgoing().unwrap() + ), + ] + ); + + let all_linked_inputs = h.all_linked_inputs(n1.node()).collect_vec(); + + assert_eq!( + &all_linked_inputs[..], + &[ + ( + n2.node(), + Port::new(Direction::Incoming, 1).as_incoming().unwrap() + ), + ( + n2.node(), + Port::new(Direction::Incoming, 0).as_incoming().unwrap() + ), + ( + n2.node(), + Port::new(Direction::Incoming, 2).as_incoming().unwrap() + ), + ] + ); +} From e6cc8e1e7ddaa8fe51ec6e8b744b2967ed363cbc Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Sat, 11 Nov 2023 22:24:09 +0000 Subject: [PATCH 06/31] feat: `static_source` for getting static source node --- src/builder/build_traits.rs | 5 +++-- src/hugr/validate/test.rs | 2 ++ src/hugr/views.rs | 7 +++++++ src/ops.rs | 23 +++++++++++++++-------- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 47d6ec63f..2cea5568a 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -612,8 +612,9 @@ pub trait Dataflow: Container { }) } }; - let const_in_port = signature.input.len(); - let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?; + let op: OpType = ops::Call { signature }.into(); + let const_in_port = op.static_input_port().unwrap(); + let op_id = self.add_dataflow_op(op, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; self.hugr_mut() diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 2cf5d1812..5d855ebdf 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -428,8 +428,10 @@ fn test_local_const() -> Result<(), HugrError> { // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op)?; let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: BOOL_T })?; + h.connect(cst, 0, lcst, 0)?; h.connect(lcst, 0, and, 1)?; + assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: h.update_validate(&EMPTY_REG).unwrap(); Ok(()) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index ead394c0b..114c41859 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -348,6 +348,13 @@ pub trait HugrView: sealed::HugrInternals { }) .finish() } + + /// If a node has a static input, return the source node. + fn static_source(&self, node: Node) -> Option { + self.linked_outputs(node, self.get_optype(node).static_input_port()?) + .next() + .map(|(n, _)| n) + } } /// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s diff --git a/src/ops.rs b/src/ops.rs index fa16e1803..89fa8e44b 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -10,8 +10,8 @@ pub mod module; pub mod tag; pub mod validate; use crate::types::{EdgeKind, FunctionType, Type}; -use crate::PortIndex; use crate::{Direction, Port}; +use crate::{IncomingPort, PortIndex}; use portgraph::NodeIndex; use smol_str::SmolStr; @@ -82,9 +82,10 @@ impl OpType { let dir = port.direction(); let port_count = signature.port_count(dir); + let port_as_in = port.as_incoming().ok(); if port.index() < port_count { signature.get(port).cloned().map(EdgeKind::Value) - } else if Some(port) == self.static_input_port() { + } else if port_as_in.is_some() && port_as_in == self.static_input_port() { Some(EdgeKind::Static(static_in_type(self))) } else { self.other_port_kind(dir) @@ -124,13 +125,19 @@ impl OpType { } /// If the op has a static input (Call and LoadConstant), the port of that input. - pub fn static_input_port(&self) -> Option { + pub fn static_input_port(&self) -> Option { match self { - OpType::Call(call) => Some(Port::new( - Direction::Incoming, - call.called_function_type().input_count(), - )), - OpType::LoadConstant(_) => Some(Port::new(Direction::Incoming, 0)), + OpType::Call(call) => Some( + Port::new( + Direction::Incoming, + call.called_function_type().input_count(), + ) + .as_incoming() + .unwrap(), + ), + OpType::LoadConstant(_) => { + Some(Port::new(Direction::Incoming, 0).as_incoming().unwrap()) + } _ => None, } } From 11ef7baf8f2c27ff1c9a9154f132510a315956fe Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 09:37:57 +0000 Subject: [PATCH 07/31] refactor!: rename `FunctionType::get` -> `FunctionType::port_type` --- src/hugr/rewrite/simple_replace.rs | 11 ++++++++--- src/hugr/views/sibling_subgraph.rs | 16 ++++++++-------- src/ops.rs | 2 +- src/types/signature.rs | 10 +++++----- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 4a376604d..44eb2c1ae 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -505,7 +505,7 @@ pub(in crate::hugr::rewrite) mod test { .collect_vec(); let inputs = h .node_outputs(input) - .filter(|&p| h.get_optype(input).signature().get(p).is_some()) + .filter(|&p| h.get_optype(input).signature().port_type(p).is_some()) .map(|p| { let link = h.linked_inputs(input, p).next().unwrap(); (link, link) @@ -513,7 +513,7 @@ pub(in crate::hugr::rewrite) mod test { .collect(); let outputs = h .node_inputs(output) - .filter(|&p| h.get_optype(output).signature().get(p).is_some()) + .filter(|&p| h.get_optype(output).signature().port_type(p).is_some()) .map(|p| ((output, p), p)) .collect(); h.apply_rewrite(SimpleReplacement::new( @@ -565,7 +565,12 @@ pub(in crate::hugr::rewrite) mod test { let outputs = repl .node_inputs(repl_output) - .filter(|&p| repl.get_optype(repl_output).signature().get(p).is_some()) + .filter(|&p| { + repl.get_optype(repl_output) + .signature() + .port_type(p) + .is_some() + }) .map(|p| ((repl_output, p), p)) .collect(); diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 55d1426b5..deb518405 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -299,7 +299,7 @@ impl SiblingSubgraph { .map(|part| { let &(n, p) = part.iter().next().expect("is non-empty"); let sig = hugr.get_optype(n).signature(); - sig.get(p).cloned().expect("must be dataflow edge") + sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); let output = self @@ -307,7 +307,7 @@ impl SiblingSubgraph { .iter() .map(|&(n, p)| { let sig = hugr.get_optype(n).signature(); - sig.get(p).cloned().expect("must be dataflow edge") + sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); FunctionType::new(input, output) @@ -356,10 +356,10 @@ impl SiblingSubgraph { // See https://github.com/CQCL-DEV/hugr/discussions/432 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); - let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = - rep_inputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); - let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = - rep_outputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); + let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs + .partition(|&(n, p)| replacement.get_optype(n).signature().port_type(p).is_some()); + let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs + .partition(|&(n, p)| replacement.get_optype(n).signature().port_type(p).is_some()); if combine_in_out(&vec![out_order_ports], &in_order_ports) .any(|(n, p)| is_order_edge(&replacement, n, p)) @@ -467,10 +467,10 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> { /// If the array is empty or a port does not exist, returns `None`. fn get_edge_type + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option { let &(n, p) = ports.first()?; - let edge_t = hugr.get_optype(n).signature().get(p)?.clone(); + let edge_t = hugr.get_optype(n).signature().port_type(p)?.clone(); ports .iter() - .all(|&(n, p)| hugr.get_optype(n).signature().get(p) == Some(&edge_t)) + .all(|&(n, p)| hugr.get_optype(n).signature().port_type(p) == Some(&edge_t)) .then_some(edge_t) } diff --git a/src/ops.rs b/src/ops.rs index 89fa8e44b..fd794895d 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -84,7 +84,7 @@ impl OpType { let port_count = signature.port_count(dir); let port_as_in = port.as_incoming().ok(); if port.index() < port_count { - signature.get(port).cloned().map(EdgeKind::Value) + signature.port_type(port).cloned().map(EdgeKind::Value) } else if port_as_in.is_some() && port_as_in == self.static_input_port() { Some(EdgeKind::Static(static_in_type(self))) } else { diff --git a/src/types/signature.rs b/src/types/signature.rs index 90df38efd..202df790d 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -117,7 +117,7 @@ impl FunctionType { /// Returns the type of a value [`Port`]. Returns `None` if the port is out /// of bounds. #[inline] - pub fn get(&self, port: impl Into) -> Option<&Type> { + pub fn port_type(&self, port: impl Into) -> Option<&Type> { let port = port.into(); match port.direction() { Direction::Incoming => self.input.get(port), @@ -128,7 +128,7 @@ impl FunctionType { /// Returns a mutable reference to the type of a value [`Port`]. /// Returns `None` if the port is out of bounds. #[inline] - pub fn get_mut(&mut self, port: impl Into) -> Option<&mut Type> { + pub fn port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { let port = port.into(); match port.direction() { Direction::Incoming => self.input.get_mut(port), @@ -271,14 +271,14 @@ mod test { assert_eq!(f_type.input_types(), &[Type::UNIT]); assert_eq!( - f_type.get(Port::new(Direction::Incoming, 0)), + f_type.port_type(Port::new(Direction::Incoming, 0)), Some(&Type::UNIT) ); let out = Port::new(Direction::Outgoing, 0); - *(f_type.get_mut(out).unwrap()) = USIZE_T; + *(f_type.port_type_mut(out).unwrap()) = USIZE_T; - assert_eq!(f_type.get(out), Some(&USIZE_T)); + assert_eq!(f_type.port_type(out), Some(&USIZE_T)); assert_eq!(f_type.input_types(), &[Type::UNIT]); assert_eq!(f_type.output_types(), &[USIZE_T]); From 0437efc480058eeb89caa007468fb2980a3318cd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 09:41:37 +0000 Subject: [PATCH 08/31] feat: add `in/out_port_type` to `FunctionType` --- src/types/signature.rs | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/types/signature.rs b/src/types/signature.rs index 202df790d..78a701196 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -118,13 +118,27 @@ impl FunctionType { /// of bounds. #[inline] pub fn port_type(&self, port: impl Into) -> Option<&Type> { - let port = port.into(); + let port: Port = port.into(); match port.direction() { - Direction::Incoming => self.input.get(port), - Direction::Outgoing => self.output.get(port), + Direction::Incoming => self.in_port_type(port.as_incoming().unwrap()), + Direction::Outgoing => self.out_port_type(port.as_outgoing().unwrap()), } } + /// Returns the type of a value input [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn in_port_type(&self, port: impl Into) -> Option<&Type> { + self.input.get(port.into()) + } + + /// Returns the type of a value output [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn out_port_type(&self, port: impl Into) -> Option<&Type> { + self.output.get(port.into()) + } + /// Returns a mutable reference to the type of a value [`Port`]. /// Returns `None` if the port is out of bounds. #[inline] From 47d46465ca6ceb198bbe2e24130e6dd155d3b5c8 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 11:15:25 +0000 Subject: [PATCH 09/31] feat: `HugrView::signature(Node)` and port+type iterators --- src/hugr/serialize.rs | 3 +-- src/hugr/views.rs | 33 ++++++++++++++++++++++++++++++ src/hugr/views/sibling_subgraph.rs | 16 +++++++-------- src/hugr/views/tests.rs | 31 ++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index abc985616..67e209ac1 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -10,7 +10,6 @@ use pyo3::{create_exception, exceptions::PyException, PyErr}; use crate::core::NodeIndex; use crate::extension::ExtensionSet; use crate::hugr::{Hugr, NodeType}; -use crate::ops::OpTrait; use crate::ops::OpType; use crate::{Node, PortIndex}; use portgraph::hierarchy::AttachError; @@ -167,7 +166,7 @@ impl TryFrom<&Hugr> for SerHugrV0 { .expect("Could not reach one of the nodes"); let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| { - let sig = hugr.get_optype(node).signature(); + let sig = hugr.signature(node); let offset = match offset < sig.port_count(dir) { true => Some(offset as u16), false => None, diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 114c41859..2078abf29 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -25,6 +25,8 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE}; use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; +#[rustversion::since(1.75)] // uses impl in return position +use crate::types::Type; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; @@ -355,6 +357,37 @@ pub trait HugrView: sealed::HugrInternals { .next() .map(|(n, _)| n) } + + /// Get the "signature" (incoming and outgoing types) of a node, non-Value + /// kind edges will be missing. + fn signature(&self, node: Node) -> FunctionType { + self.get_optype(node).signature() + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all ports in a given direction that have Value type, along + /// with corresponding types. + fn value_types(&self, node: Node, dir: Direction) -> impl Iterator { + let sig = self.signature(node); + self.node_ports(node, dir) + .flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone()))) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all incoming ports that have Value type, along + /// with corresponding types. + fn in_value_types(&self, node: Node) -> impl Iterator { + self.value_types(node, Direction::Incoming) + .map(|(p, t)| (p.as_incoming().unwrap(), t)) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all incoming ports that have Value type, along + /// with corresponding types. + fn out_value_types(&self, node: Node) -> impl Iterator { + self.value_types(node, Direction::Outgoing) + .map(|(p, t)| (p.as_outgoing().unwrap(), t)) + } } /// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index deb518405..91a7dcb66 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -298,7 +298,7 @@ impl SiblingSubgraph { .iter() .map(|part| { let &(n, p) = part.iter().next().expect("is non-empty"); - let sig = hugr.get_optype(n).signature(); + let sig = hugr.signature(n); sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); @@ -306,7 +306,7 @@ impl SiblingSubgraph { .outputs .iter() .map(|&(n, p)| { - let sig = hugr.get_optype(n).signature(); + let sig = hugr.signature(n); sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); @@ -356,10 +356,10 @@ impl SiblingSubgraph { // See https://github.com/CQCL-DEV/hugr/discussions/432 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); - let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs - .partition(|&(n, p)| replacement.get_optype(n).signature().port_type(p).is_some()); - let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs - .partition(|&(n, p)| replacement.get_optype(n).signature().port_type(p).is_some()); + let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = + rep_inputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some()); + let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = + rep_outputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some()); if combine_in_out(&vec![out_order_ports], &in_order_ports) .any(|(n, p)| is_order_edge(&replacement, n, p)) @@ -467,10 +467,10 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> { /// If the array is empty or a port does not exist, returns `None`. fn get_edge_type + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option { let &(n, p) = ports.first()?; - let edge_t = hugr.get_optype(n).signature().port_type(p)?.clone(); + let edge_t = hugr.signature(n).port_type(p)?.clone(); ports .iter() - .all(|&(n, p)| hugr.get_optype(n).signature().port_type(p) == Some(&edge_t)) + .all(|&(n, p)| hugr.signature(n).port_type(p) == Some(&edge_t)) .then_some(edge_t) } diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 816a0c41b..9b1833213 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -117,3 +117,34 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle Date: Mon, 13 Nov 2023 13:58:47 +0000 Subject: [PATCH 10/31] feat: single_source/target for when you know only one connected port Closes Shorthand HugrView method for "connected port" #499 --- src/hugr/rewrite/insert_identity.rs | 5 +---- src/hugr/rewrite/outline_cfg.rs | 2 +- src/hugr/rewrite/replace.rs | 5 ++--- src/hugr/rewrite/simple_replace.rs | 27 +++++++-------------------- src/hugr/views.rs | 25 +++++++++++++++++++++++++ src/hugr/views/sibling_subgraph.rs | 4 ++-- 6 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index d5613a59f..53d2790fe 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -9,7 +9,6 @@ use crate::{HugrView, IncomingPort}; use super::Rewrite; -use itertools::Itertools; use thiserror::Error; /// Specification of a identity-insertion operation. @@ -73,9 +72,7 @@ impl Rewrite for IdentityInsertion { }; let (pre_node, pre_port) = h - .linked_outputs(self.post_node, self.post_port) - .exactly_one() - .ok() + .single_source(self.post_node, self.post_port) .expect("Value kind input can only have one connection."); h.disconnect(self.post_node, self.post_port).unwrap(); diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 179ea486d..5ebf6edd9 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -177,7 +177,7 @@ impl Rewrite for OutlineCfg { let exit_port = h .node_outputs(exit) .filter(|p| { - let (t, p2) = h.linked_ports(exit, *p).exactly_one().ok().unwrap(); + let (t, p2) = h.single_target(exit, *p).unwrap(); assert!(p2.index() == 0); t == outside }) diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 0b01df5aa..0ae10acec 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -139,9 +139,8 @@ impl NewEdgeSpec { true }; let found_incoming = h - .linked_ports(self.tgt, tgt_pos) - .exactly_one() - .is_ok_and(|(src_n, _)| descends_from_legal(src_n)); + .single_source(self.tgt, tgt_pos) + .is_some_and(|(src_n, _)| descends_from_legal(src_n)); if !found_incoming { return Err(ReplaceError::NoRemovedEdge(err_edge())); }; diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 44eb2c1ae..b4480fe8d 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -4,8 +4,6 @@ use std::collections::{hash_map, HashMap}; use std::iter::{self, Copied}; use std::slice; -use itertools::Itertools; - use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; @@ -129,11 +127,8 @@ impl Rewrite for SimpleReplacement { for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp { if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) - let (rem_inp_pred_node, rem_inp_pred_port) = h - .linked_outputs(*rem_inp_node, *rem_inp_port) - .exactly_one() - .ok() // PortLinks does not implement Debug - .unwrap(); + let (rem_inp_pred_node, rem_inp_pred_port) = + h.single_source(*rem_inp_node, *rem_inp_port).unwrap(); h.disconnect(*rem_inp_node, *rem_inp_port).unwrap(); let new_inp_node = index_map.get(rep_inp_node).unwrap(); h.connect( @@ -150,8 +145,7 @@ impl Rewrite for SimpleReplacement { for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { let (rep_out_pred_node, rep_out_pred_port) = self .replacement - .linked_outputs(replacement_output_node, *rep_out_port) - .exactly_one() + .single_source(replacement_output_node, *rep_out_port) .unwrap(); if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input { let new_out_node = index_map.get(&rep_out_pred_node).unwrap(); @@ -171,11 +165,8 @@ impl Rewrite for SimpleReplacement { let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): - let (rem_inp_pred_node, rem_inp_pred_port) = h - .linked_outputs(*rem_inp_node, *rem_inp_port) - .exactly_one() - .ok() // PortLinks does not implement Debug - .unwrap(); + let (rem_inp_pred_node, rem_inp_pred_port) = + h.single_source(*rem_inp_node, *rem_inp_port).unwrap(); h.disconnect(*rem_inp_node, *rem_inp_port).unwrap(); h.disconnect(*rem_out_node, *rem_out_port).unwrap(); h.connect( @@ -603,7 +594,7 @@ pub(in crate::hugr::rewrite) mod test { if *tgt == out { unimplemented!() }; - let (src, src_port) = h.linked_outputs(*r_n, *r_p).exactly_one().ok().unwrap(); + let (src, src_port) = h.single_source(*r_n, *r_p).unwrap(); NewEdgeSpec { src, tgt: *tgt, @@ -618,11 +609,7 @@ pub(in crate::hugr::rewrite) mod test { .nu_out .iter() .map(|((tgt, tgt_port), out_port)| { - let (src, src_port) = replacement - .linked_outputs(out, *out_port) - .exactly_one() - .ok() - .unwrap(); + let (src, src_port) = replacement.single_source(out, *out_port).unwrap(); if src == in_ { unimplemented!() }; diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 2078abf29..501eeb118 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -195,6 +195,31 @@ pub trait HugrView: sealed::HugrInternals { .flat_map(move |port| self.linked_inputs(node, port)) } + /// If there is exactly one OutgoingPort connected to this IncomingPort, return + /// it and its node. + fn single_source( + &self, + node: Node, + port: impl Into, + ) -> Option<(Node, OutgoingPort)> { + self.linked_ports(node, port.into()) + .exactly_one() + .ok() + .map(|(n, p)| (n, p.as_outgoing().unwrap())) + } + + /// If there is exactly one IncomingPort connected to this OutgoingPort, return + /// it and its node. + fn single_target( + &self, + node: Node, + port: impl Into, + ) -> Option<(Node, IncomingPort)> { + self.linked_ports(node, port.into()) + .exactly_one() + .ok() + .map(|(n, p)| (n, p.as_incoming().unwrap())) + } /// Iterator over the nodes and output ports connected to a given *input* port. /// Like [`linked_ports`][HugrView::linked_ports] but preserves knowledge /// that the linked ports are [OutgoingPort]s. diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 91a7dcb66..443c4864b 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -251,7 +251,7 @@ impl SiblingSubgraph { if !hugr.is_linked(n, p) { return false; } - let (out_n, _) = hugr.linked_ports(n, p).exactly_one().ok().unwrap(); + let (out_n, _) = hugr.single_source(n, p).unwrap(); !nodes_set.contains(&out_n) }) // Every incoming edge is its own input. @@ -882,7 +882,7 @@ mod tests { .collect(), hugr.node_inputs(out) .take(2) - .filter_map(|p| hugr.linked_outputs(out, p).exactly_one().ok()) + .filter_map(|p| hugr.single_source(out, p)) .collect(), &func, ) From da9aea82e752f4870ef4763ee0f9c6983edac122 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 14:25:56 +0000 Subject: [PATCH 11/31] feat: `static_targets` and `OpType::static_output_port` --- src/hugr/views.rs | 6 ++++++ src/hugr/views/tests.rs | 23 +++++++++++++++++++++++ src/ops.rs | 9 ++++++++- src/ops/tag.rs | 8 ++++++-- 4 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 501eeb118..b5b300619 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -383,6 +383,12 @@ pub trait HugrView: sealed::HugrInternals { .map(|(n, _)| n) } + #[rustversion::since(1.75)] // uses impl in return position + /// If a node has a static output, return the targets. + fn static_targets(&self, node: Node) -> Option> { + Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?)) + } + /// Get the "signature" (incoming and outgoing types) of a node, non-Value /// kind edges will be missing. fn signature(&self, node: Node) -> FunctionType { diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 9b1833213..f26b54612 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use portgraph::PortOffset; use rstest::{fixture, rstest}; @@ -148,3 +149,25 @@ fn value_types() { assert_eq!(&out_types[..], &[(0.into(), BOOL_T), (1.into(), QB_T)]); } + +#[rustversion::since(1.75)] // uses impl in return position +#[test] +fn static_targets() { + use crate::extension::prelude::{ConstUsize, USIZE_T}; + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap(); + + let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap(); + + let load = dfg.load_const(&c).unwrap(); + + let h = dfg + .finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY) + .unwrap(); + + assert_eq!(h.static_source(load.node()), Some(c.node())); + + assert_eq!( + &h.static_targets(c.node()).unwrap().collect_vec()[..], + &[(load.node(), 0.into())] + ) +} diff --git a/src/ops.rs b/src/ops.rs index fd794895d..e764648dc 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -10,7 +10,7 @@ pub mod module; pub mod tag; pub mod validate; use crate::types::{EdgeKind, FunctionType, Type}; -use crate::{Direction, Port}; +use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use portgraph::NodeIndex; @@ -142,6 +142,13 @@ impl OpType { } } + /// If the op has a static output (Const, FuncDefn, FuncDecl), the port of that output. + pub fn static_output_port(&self) -> Option { + OpTag::StaticOutput + .is_superset(self.tag()) + .then_some(0.into()) + } + /// Returns the number of ports for the given direction. pub fn port_count(&self, dir: Direction) -> usize { let signature = self.signature(); diff --git a/src/ops/tag.rs b/src/ops/tag.rs index 049f383e1..d1e56346d 100644 --- a/src/ops/tag.rs +++ b/src/ops/tag.rs @@ -48,6 +48,8 @@ pub enum OpTag { Output, /// Dataflow node that has a static input StaticInput, + /// Node that has a static output + StaticOutput, /// A function call. FnCall, /// A constant load operation. @@ -106,14 +108,14 @@ impl OpTag { OpTag::DataflowChild => &[OpTag::Any], OpTag::Input => &[OpTag::DataflowChild], OpTag::Output => &[OpTag::DataflowChild], - OpTag::Function => &[OpTag::ModuleOp], + OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput], OpTag::Alias => &[OpTag::ScopedDefn], OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent], OpTag::BasicBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent], OpTag::BasicBlockExit => &[OpTag::BasicBlock], OpTag::Case => &[OpTag::Any, OpTag::DataflowParent], OpTag::ModuleRoot => &[OpTag::Any], - OpTag::Const => &[OpTag::ScopedDefn], + OpTag::Const => &[OpTag::ScopedDefn, OpTag::StaticOutput], OpTag::Dfg => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Cfg => &[OpTag::DataflowChild], OpTag::ScopedDefn => &[ @@ -124,6 +126,7 @@ impl OpTag { OpTag::TailLoop => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Conditional => &[OpTag::DataflowChild], OpTag::StaticInput => &[OpTag::DataflowChild], + OpTag::StaticOutput => &[OpTag::ModuleOp], OpTag::FnCall => &[OpTag::StaticInput], OpTag::LoadConst => &[OpTag::StaticInput], OpTag::Leaf => &[OpTag::DataflowChild], @@ -154,6 +157,7 @@ impl OpTag { OpTag::TailLoop => "Tail-recursive loop", OpTag::Conditional => "Conditional operation", OpTag::StaticInput => "Dataflow child with static input (LoadConst or FnCall)", + OpTag::StaticOutput => "Node with static input (FuncDefn, FuncDecl, Const)", OpTag::FnCall => "Function call", OpTag::LoadConst => "Constant load operation", OpTag::Leaf => "Leaf operation", From b2a2ae0cd3dcd160a73aa1ecd59d17dc69a01363 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 14:56:37 +0000 Subject: [PATCH 12/31] feat: `dataflow_ports_only` function to filter node-ports --- src/hugr/views.rs | 13 ++++++++++++ src/hugr/views/tests.rs | 44 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index b5b300619..d0b8d7af4 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -590,6 +590,19 @@ impl> HugrView for T { } } +/// Filter an iterator of node-ports to only dataflow dependency specifying +/// ports (Value and StateOrder) +pub fn dataflow_ports_only<'i, 'a: 'i, P: Into + Copy>( + hugr: &'a impl HugrView, + it: impl Iterator + 'i, +) -> impl Iterator + 'i { + it.filter(move |(n, p)| { + matches!( + hugr.get_optype(*n).port_kind(*p), + Some(EdgeKind::Value(_) | EdgeKind::StateOrder) + ) + }) +} pub(crate) mod sealed { use super::*; diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index f26b54612..65601bf15 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -1,4 +1,3 @@ -use itertools::Itertools; use portgraph::PortOffset; use rstest::{fixture, rstest}; @@ -154,6 +153,8 @@ fn value_types() { #[test] fn static_targets() { use crate::extension::prelude::{ConstUsize, USIZE_T}; + use itertools::Itertools; + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap(); let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap(); @@ -171,3 +172,44 @@ fn static_targets() { &[(load.node(), 0.into())] ) } + +#[rustversion::since(1.75)] // uses impl in return position +#[test] +fn test_dataflow_ports_only() { + use crate::builder::DataflowSubContainer; + use crate::extension::prelude::BOOL_T; + use crate::hugr::views::dataflow_ports_only; + use crate::std_extensions::logic::test::not_op; + use itertools::Itertools; + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); + let local_and = { + let local_and = dfg + .define_function( + "and", + FunctionType::new(type_row![BOOL_T; 2], type_row![BOOL_T]).pure(), + ) + .unwrap(); + let first_input = local_and.input().out_wire(0); + local_and.finish_with_outputs([first_input]).unwrap() + }; + let [in_bool] = dfg.input_wires_arr(); + + let not = dfg.add_dataflow_op(not_op(), [in_bool]).unwrap(); + let call = dfg.call(local_and.handle(), [not.out_wire(0); 2]).unwrap(); + dfg.add_other_wire(not.node(), call.node()).unwrap(); + let h = dfg + .finish_hugr_with_outputs(not.outputs(), &crate::extension::PRELUDE_REGISTRY) + .unwrap(); + let filtered_ports = dataflow_ports_only(&h, h.all_linked_outputs(call.node())).collect_vec(); + + // should ignore the static input in to call, but report the two value ports + // and the order port. + assert_eq!( + &filtered_ports[..], + &[ + (not.node(), 0.into()), + (not.node(), 0.into()), + (not.node(), 1.into()) + ] + ) +} From b4db711d9487bc4f2130ecd0c4e2e5f8cdfdd16e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 11:55:13 +0000 Subject: [PATCH 13/31] refactor!: OpType::signature returns option (non-dataflow ops don't return signature) --- src/builder/build_traits.rs | 10 ++++---- src/builder/conditional.rs | 5 ++-- src/extension/infer.rs | 14 +++++++---- src/extension/validate.rs | 9 +++++-- src/hugr.rs | 11 ++++++--- src/hugr/rewrite/outline_cfg.rs | 3 ++- src/hugr/rewrite/replace.rs | 2 +- src/hugr/rewrite/simple_replace.rs | 23 ++++++++++++------ src/hugr/serialize.rs | 2 +- src/hugr/views.rs | 4 +-- src/hugr/views/sibling.rs | 2 +- src/hugr/views/sibling_subgraph.rs | 39 +++++++++++++++++++++--------- src/ops.rs | 31 ++++++++++++++++-------- src/ops/controlflow.rs | 16 ++++++------ src/ops/custom.rs | 6 ++--- src/ops/dataflow.rs | 20 +++++++-------- src/ops/leaf.rs | 8 +++--- src/ops/validate.rs | 19 +++++++-------- 18 files changed, 136 insertions(+), 88 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 2cea5568a..e1435fffe 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -236,7 +236,7 @@ pub trait Dataflow: Container { hugr: Hugr, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); + let num_outputs = hugr.get_optype(hugr.root()).value_output_count(); let node = self.add_hugr(hugr)?.new_root; let inputs = input_wires.into_iter().collect(); @@ -257,8 +257,8 @@ pub trait Dataflow: Container { hugr: &impl HugrView, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); let node = self.add_hugr_view(hugr)?.new_root; + let num_outputs = hugr.get_optype(hugr.root()).value_output_count(); let inputs = input_wires.into_iter().collect(); wire_up_inputs(inputs, node, self)?; @@ -634,13 +634,13 @@ fn add_node_with_wires( nodetype: impl Into, inputs: Vec, ) -> Result<(Node, usize), BuildError> { - let nodetype = nodetype.into(); - let sig = nodetype.op_signature(); + let nodetype: NodeType = nodetype.into(); + let num_outputs = nodetype.op().value_output_count(); let op_node = data_builder.add_child_node(nodetype)?; wire_up_inputs(inputs, op_node, data_builder)?; - Ok((op_node, sig.output().len())) + Ok((op_node, num_outputs)) } fn wire_up_inputs( diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index e6219f46a..78472d8cc 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -1,9 +1,10 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::HugrView; +use crate::ops::dataflow::DataflowOpTrait; use crate::types::{FunctionType, TypeRow}; +use crate::ops; use crate::ops::handle::CaseID; -use crate::ops::{self, OpTrait}; use super::build_traits::SubContainer; use super::handle::BuildHandle; @@ -104,12 +105,12 @@ impl + AsRef> ConditionalBuilder { pub fn case_builder(&mut self, case: usize) -> Result, BuildError> { let conditional = self.conditional_node; let control_op = self.hugr().get_optype(self.conditional_node); - let extension_delta = control_op.signature().extension_reqs; let cond: ops::Conditional = control_op .clone() .try_into() .expect("Parent node does not have Conditional optype."); + let extension_delta = cond.dataflow_signature().extension_reqs; let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; diff --git a/src/extension/infer.rs b/src/extension/infer.rs index e2ea1f89f..4ef2ac47d 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -305,11 +305,15 @@ impl UnificationContext { match node_type.signature() { // Input extensions are open None => { - let delta = node_type.op_signature().extension_reqs; - let c = if delta.is_empty() { - Constraint::Equal(m_input) + let c = if let Some(sig) = node_type.op_signature() { + let delta = sig.extension_reqs; + if delta.is_empty() { + Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) + } } else { - Constraint::Plus(delta, m_input) + Constraint::Equal(m_input) }; self.add_constraint(m_output, c); } @@ -1063,7 +1067,7 @@ mod test { hugr, conditional_node, op.clone(), - Into::::into(op).signature(), + Into::::into(op).signature().unwrap(), )?; let lift1 = hugr.add_node_with_parent( diff --git a/src/extension/validate.rs b/src/extension/validate.rs index c026c2519..2f267edf9 100644 --- a/src/extension/validate.rs +++ b/src/extension/validate.rs @@ -25,8 +25,13 @@ impl ExtensionValidator { pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self { let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new(); for (node, incoming_sol) in closure.into_iter() { - let op_signature = hugr.get_nodetype(node).op_signature(); - let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol); + let extension_reqs = hugr + .get_nodetype(node) + .op_signature() + .map(|s| s.extension_reqs) + .unwrap_or_default(); + + let outgoing_sol = extension_reqs.union(&incoming_sol); extensions.insert((node, Direction::Incoming), incoming_sol); extensions.insert((node, Direction::Outgoing), outgoing_sol); diff --git a/src/hugr.rs b/src/hugr.rs index 97bf44ed6..a23703bc0 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -111,13 +111,16 @@ impl NodeType { /// Use the input extensions to calculate the concrete signature of the node pub fn signature(&self) -> Option { - self.input_extensions - .as_ref() - .map(|rs| self.op.signature().with_input_extensions(rs.clone())) + self.input_extensions.as_ref().map(|rs| { + self.op + .signature() + .unwrap_or_default() + .with_input_extensions(rs.clone()) + }) } /// Get the function type from the embedded op - pub fn op_signature(&self) -> FunctionType { + pub fn op_signature(&self) -> Option { self.op.signature() } diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 5ebf6edd9..e24316bfb 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -70,7 +70,8 @@ impl OutlineCfg { } } } - extension_delta = extension_delta.union(&o.signature().extension_reqs); + extension_delta = extension_delta + .union(&o.signature().expect("cfg missing signature").extension_reqs); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 0ae10acec..9099e9913 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -616,7 +616,7 @@ mod test { entry: bool, ) -> Result { let op: OpType = op.into(); - let op_sig = op.signature(); + let op_sig = op.signature().expect("dataflow op needs signature"); let mut bb = if entry { assert_eq!( match h.hugr().get_optype(h.container_node()) { diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index b4480fe8d..d0922f233 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -496,7 +496,13 @@ pub(in crate::hugr::rewrite) mod test { .collect_vec(); let inputs = h .node_outputs(input) - .filter(|&p| h.get_optype(input).signature().port_type(p).is_some()) + .filter(|&p| { + h.get_optype(input) + .signature() + .unwrap() + .port_type(p) + .is_some() + }) .map(|p| { let link = h.linked_inputs(input, p).next().unwrap(); (link, link) @@ -504,7 +510,13 @@ pub(in crate::hugr::rewrite) mod test { .collect(); let outputs = h .node_inputs(output) - .filter(|&p| h.get_optype(output).signature().port_type(p).is_some()) + .filter(|&p| { + h.get_optype(output) + .signature() + .unwrap() + .port_type(p) + .is_some() + }) .map(|p| ((output, p), p)) .collect(); h.apply_rewrite(SimpleReplacement::new( @@ -556,12 +568,7 @@ pub(in crate::hugr::rewrite) mod test { let outputs = repl .node_inputs(repl_output) - .filter(|&p| { - repl.get_optype(repl_output) - .signature() - .port_type(p) - .is_some() - }) + .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some()) .map(|p| ((repl_output, p), p)) .collect(); diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 67e209ac1..b1f99a4a2 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -166,7 +166,7 @@ impl TryFrom<&Hugr> for SerHugrV0 { .expect("Could not reach one of the nodes"); let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| { - let sig = hugr.signature(node); + let sig = hugr.signature(node).unwrap_or_default(); let offset = match offset < sig.port_count(dir) { true => Some(offset as u16), false => None, diff --git a/src/hugr/views.rs b/src/hugr/views.rs index d0b8d7af4..1f4a22974 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -391,7 +391,7 @@ pub trait HugrView: sealed::HugrInternals { /// Get the "signature" (incoming and outgoing types) of a node, non-Value /// kind edges will be missing. - fn signature(&self, node: Node) -> FunctionType { + fn signature(&self, node: Node) -> Option { self.get_optype(node).signature() } @@ -399,7 +399,7 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over all ports in a given direction that have Value type, along /// with corresponding types. fn value_types(&self, node: Node, dir: Direction) -> impl Iterator { - let sig = self.signature(node); + let sig = self.signature(node).unwrap_or_default(); self.node_ports(node, dir) .flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone()))) } diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 77f74f39b..7d54469f4 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -472,7 +472,7 @@ mod test { fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { let root = simple_dfg_hugr.root(); let case_nodetype = NodeType::new_open(crate::ops::Case { - signature: simple_dfg_hugr.root_type().op_signature(), + signature: simple_dfg_hugr.root_type().op_signature().unwrap(), }); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); // As expected, we cannot replace the root with a Case diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 443c4864b..0b1fb9135 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -298,7 +298,7 @@ impl SiblingSubgraph { .iter() .map(|part| { let &(n, p) = part.iter().next().expect("is non-empty"); - let sig = hugr.signature(n); + let sig = hugr.signature(n).expect("must have dataflow signature"); sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); @@ -306,7 +306,7 @@ impl SiblingSubgraph { .outputs .iter() .map(|&(n, p)| { - let sig = hugr.signature(n); + let sig = hugr.signature(n).expect("must have dataflow signature"); sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); @@ -348,7 +348,7 @@ impl SiblingSubgraph { let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; - if dfg_optype.signature() != self.signature(hugr) { + if dfg_optype.signature() != Some(self.signature(hugr)) { return Err(InvalidReplacement::InvalidSignature); } @@ -356,10 +356,16 @@ impl SiblingSubgraph { // See https://github.com/CQCL-DEV/hugr/discussions/432 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); - let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = - rep_inputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some()); - let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = - rep_outputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some()); + let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs.partition(|&(n, p)| { + replacement + .signature(n) + .is_some_and(|s| s.port_type(p).is_some()) + }); + let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| { + replacement + .signature(n) + .is_some_and(|s| s.port_type(p).is_some()) + }); if combine_in_out(&vec![out_order_ports], &in_order_ports) .any(|(n, p)| is_order_edge(&replacement, n, p)) @@ -467,10 +473,13 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> { /// If the array is empty or a port does not exist, returns `None`. fn get_edge_type + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option { let &(n, p) = ports.first()?; - let edge_t = hugr.signature(n).port_type(p)?.clone(); + let edge_t = hugr.signature(n)?.port_type(p)?.clone(); ports .iter() - .all(|&(n, p)| hugr.signature(n).port_type(p) == Some(&edge_t)) + .all(|&(n, p)| { + hugr.signature(n) + .is_some_and(|s| s.port_type(p) == Some(&edge_t)) + }) .then_some(edge_t) } @@ -567,11 +576,19 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort if has_other_edge(hugr, inp, Direction::Outgoing) { unimplemented!("Non-dataflow output not supported at input node") } - let dfg_inputs = hugr.get_optype(inp).signature().output_ports(); + let dfg_inputs = hugr + .get_optype(inp) + .signature() + .unwrap_or_default() + .output_ports(); if has_other_edge(hugr, out, Direction::Incoming) { unimplemented!("Non-dataflow input not supported at output node") } - let dfg_outputs = hugr.get_optype(out).signature().input_ports(); + let dfg_outputs = hugr + .get_optype(out) + .signature() + .unwrap_or_default() + .input_ports(); // Collect for each port in the input the set of target ports, filtering // direct wires to the output. diff --git a/src/ops.rs b/src/ops.rs index e764648dc..eceb7185e 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -77,12 +77,12 @@ impl OpType { /// Returns the edge kind for the given port. pub fn port_kind(&self, port: impl Into) -> Option { - let signature = self.signature(); + let signature = self.signature().unwrap_or_default(); let port: Port = port.into(); + let port_as_in = port.as_incoming().ok(); let dir = port.direction(); let port_count = signature.port_count(dir); - let port_as_in = port.as_incoming().ok(); if port.index() < port_count { signature.port_type(port).cloned().map(EdgeKind::Value) } else if port_as_in.is_some() && port_as_in == self.static_input_port() { @@ -103,15 +103,27 @@ impl OpType { let static_input = (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; - Some(Port::new( - dir, - self.signature().port_count(dir) + static_input, - )) + Some(Port::new(dir, self.value_port_count(dir) + static_input)) } else { None } } + /// The number of Value ports in given direction. + pub fn value_port_count(&self, dir: portgraph::Direction) -> usize { + self.signature().map(|sig| sig.port_count(dir)).unwrap_or(0) + } + + /// The number of Value input ports. + pub fn value_input_count(&self) -> usize { + self.value_port_count(Direction::Incoming) + } + + /// The number of Value output ports. + pub fn value_output_count(&self) -> usize { + self.value_port_count(Direction::Outgoing) + } + /// The non-dataflow input port for the operation, not described by the signature. /// See `[OpType::other_port]`. pub fn other_input_port(&self) -> Option { @@ -151,7 +163,6 @@ impl OpType { /// Returns the number of ports for the given direction. pub fn port_count(&self, dir: Direction) -> usize { - let signature = self.signature(); let has_other_ports = self.other_port_kind(dir).is_some(); let non_df_count = self .non_df_port_count(dir) @@ -159,7 +170,7 @@ impl OpType { // if there is a static input it comes before the non_df_ports let static_input = (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; - signature.port_count(dir) + non_df_count + static_input + self.value_port_count(dir) + non_df_count + static_input } /// Returns the number of inputs ports for the operation. @@ -228,8 +239,8 @@ pub trait OpTrait { /// The signature of the operation. /// - /// Only dataflow operations have a non-empty signature. - fn signature(&self) -> FunctionType { + /// Only dataflow operations have a signature, otherwise returns None. + fn signature(&self) -> Option { Default::default() } /// The edge kind for the non-dataflow or constant inputs of the operation, diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index b6836ea70..a54aac41d 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -30,7 +30,7 @@ impl DataflowOpTrait for TailLoop { "A tail-controlled loop" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| tuple_sum_first(row, &self.rest)); FunctionType::new(inputs, outputs) @@ -74,7 +74,7 @@ impl DataflowOpTrait for Conditional { "HUGR conditional operation" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { let mut inputs = self.other_inputs.clone(); inputs .to_mut() @@ -109,7 +109,7 @@ impl DataflowOpTrait for CFG { "A dataflow node defined by a child CFG" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { self.signature.clone() } } @@ -169,13 +169,13 @@ impl OpTrait for BasicBlock { Some(EdgeKind::ControlFlow) } - fn signature(&self) -> FunctionType { - match self { + fn signature(&self) -> Option { + Some(match self { BasicBlock::DFB { extension_delta, .. } => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta), BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), - } + }) } fn non_df_port_count(&self, dir: Direction) -> Option { @@ -233,8 +233,8 @@ impl OpTrait for Case { ::TAG } - fn signature(&self) -> FunctionType { - self.signature.clone() + fn signature(&self) -> Option { + Some(self.signature.clone()) } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index b1c5a39b3..724b64984 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -82,10 +82,10 @@ impl OpTrait for ExternalOp { /// Note the case of an OpaqueOp without a signature should already /// have been detected in [resolve_extension_ops] - fn signature(&self) -> FunctionType { + fn signature(&self) -> Option { match self { - Self::Opaque(op) => op.signature.clone().unwrap(), - Self::Extension(ExtensionOp { signature, .. }) => signature.clone(), + Self::Opaque(op) => op.signature.clone(), + Self::Extension(ExtensionOp { signature, .. }) => Some(signature.clone()), } } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 4ab06d5cf..fe109ab15 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -6,10 +6,10 @@ use crate::extension::ExtensionSet; use crate::ops::StaticTag; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; -pub(super) trait DataflowOpTrait { +pub(crate) trait DataflowOpTrait { const TAG: OpTag; fn description(&self) -> &str; - fn signature(&self) -> FunctionType; + fn dataflow_signature(&self) -> FunctionType; /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. @@ -81,7 +81,7 @@ impl DataflowOpTrait for Input { None } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), self.types.clone()) .with_extension_delta(&ExtensionSet::new()) } @@ -95,7 +95,7 @@ impl DataflowOpTrait for Output { // Note: We know what the input extensions should be, so we *could* give an // instantiated Signature instead - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { FunctionType::new(self.types.clone(), TypeRow::new()) } @@ -111,8 +111,8 @@ impl OpTrait for T { fn tag(&self) -> OpTag { T::TAG } - fn signature(&self) -> FunctionType { - DataflowOpTrait::signature(self) + fn signature(&self) -> Option { + Some(DataflowOpTrait::dataflow_signature(self)) } fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) @@ -145,7 +145,7 @@ impl DataflowOpTrait for Call { "Call a function directly" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { self.signature.clone() } } @@ -172,7 +172,7 @@ impl DataflowOpTrait for CallIndirect { "Call a function indirectly" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { let mut s = self.signature.clone(); s.input .to_mut() @@ -195,7 +195,7 @@ impl DataflowOpTrait for LoadConstant { "Load a static constant in to the local dataflow graph" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), vec![self.datatype.clone()]) } } @@ -222,7 +222,7 @@ impl DataflowOpTrait for DFG { "A simply nested dataflow graph" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { self.signature.clone() } } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 24f1ebb7e..ef9da58d8 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -153,13 +153,13 @@ impl OpTrait for LeafOp { } /// The signature of the operation. - fn signature(&self) -> FunctionType { + fn signature(&self) -> Option { // Static signatures. The `TypeRow`s in the `FunctionType` use a // copy-on-write strategy, so we can avoid unnecessary allocations. - match self { + Some(match self { LeafOp::Noop { ty: typ } => FunctionType::new(vec![typ.clone()], vec![typ.clone()]), - LeafOp::CustomOp(ext) => ext.signature(), + LeafOp::CustomOp(ext) => return ext.signature(), LeafOp::MakeTuple { tys: types } => { FunctionType::new(types.clone(), vec![Type::new_tuple(types.clone())]) } @@ -179,7 +179,7 @@ impl OpTrait for LeafOp { vec![Type::new_function(ta.input.clone())], vec![Type::new_function(ta.output.clone())], ), - } + }) } fn other_input(&self) -> Option { diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 546d6cec6..28a9bdb55 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -102,12 +102,8 @@ impl ValidateOp for super::DFG { &self, children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { - validate_io_nodes( - &self.signature().input, - &self.signature().output, - "nested graph", - children, - ) + let sig = self.signature().unwrap_or_default(); + validate_io_nodes(&sig.input, &sig.output, "nested graph", children) } } @@ -371,19 +367,22 @@ fn validate_io_nodes<'a>( let (first, first_optype) = children.next().unwrap(); let (second, second_optype) = children.next().unwrap(); - if &first_optype.signature().output != expected_input { + let first_sig = first_optype.signature().unwrap_or_default(); + if &first_sig.output != expected_input { return Err(ChildrenValidationError::IOSignatureMismatch { child: first, - actual: first_optype.signature().output, + actual: first_sig.output, expected: expected_input.clone(), node_desc: "Input", container_desc, }); } - if &second_optype.signature().input != expected_output { + let second_sig = second_optype.signature().unwrap_or_default(); + + if &second_sig.input != expected_output { return Err(ChildrenValidationError::IOSignatureMismatch { child: second, - actual: second_optype.signature().input, + actual: second_sig.input, expected: expected_output.clone(), node_desc: "Output", container_desc, From 39e14909d19ca151fbfd8bd78ed5adfcd5f304f9 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 12:18:52 +0000 Subject: [PATCH 14/31] refactor!: OpType::signature returns option (non-dataflow ops don't return signature) --- src/builder/build_traits.rs | 10 ++++---- src/builder/conditional.rs | 5 ++-- src/extension/infer.rs | 14 +++++++---- src/extension/validate.rs | 9 +++++-- src/hugr.rs | 11 ++++++--- src/hugr/rewrite/outline_cfg.rs | 3 ++- src/hugr/rewrite/replace.rs | 2 +- src/hugr/rewrite/simple_replace.rs | 23 ++++++++++++------ src/hugr/serialize.rs | 2 +- src/hugr/views.rs | 4 +-- src/hugr/views/sibling.rs | 2 +- src/hugr/views/sibling_subgraph.rs | 39 +++++++++++++++++++++--------- src/ops.rs | 31 ++++++++++++++++-------- src/ops/controlflow.rs | 16 ++++++------ src/ops/custom.rs | 6 ++--- src/ops/dataflow.rs | 20 +++++++-------- src/ops/leaf.rs | 8 +++--- src/ops/validate.rs | 19 +++++++-------- 18 files changed, 136 insertions(+), 88 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 2cea5568a..e1435fffe 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -236,7 +236,7 @@ pub trait Dataflow: Container { hugr: Hugr, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); + let num_outputs = hugr.get_optype(hugr.root()).value_output_count(); let node = self.add_hugr(hugr)?.new_root; let inputs = input_wires.into_iter().collect(); @@ -257,8 +257,8 @@ pub trait Dataflow: Container { hugr: &impl HugrView, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); let node = self.add_hugr_view(hugr)?.new_root; + let num_outputs = hugr.get_optype(hugr.root()).value_output_count(); let inputs = input_wires.into_iter().collect(); wire_up_inputs(inputs, node, self)?; @@ -634,13 +634,13 @@ fn add_node_with_wires( nodetype: impl Into, inputs: Vec, ) -> Result<(Node, usize), BuildError> { - let nodetype = nodetype.into(); - let sig = nodetype.op_signature(); + let nodetype: NodeType = nodetype.into(); + let num_outputs = nodetype.op().value_output_count(); let op_node = data_builder.add_child_node(nodetype)?; wire_up_inputs(inputs, op_node, data_builder)?; - Ok((op_node, sig.output().len())) + Ok((op_node, num_outputs)) } fn wire_up_inputs( diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index e6219f46a..78472d8cc 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -1,9 +1,10 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::HugrView; +use crate::ops::dataflow::DataflowOpTrait; use crate::types::{FunctionType, TypeRow}; +use crate::ops; use crate::ops::handle::CaseID; -use crate::ops::{self, OpTrait}; use super::build_traits::SubContainer; use super::handle::BuildHandle; @@ -104,12 +105,12 @@ impl + AsRef> ConditionalBuilder { pub fn case_builder(&mut self, case: usize) -> Result, BuildError> { let conditional = self.conditional_node; let control_op = self.hugr().get_optype(self.conditional_node); - let extension_delta = control_op.signature().extension_reqs; let cond: ops::Conditional = control_op .clone() .try_into() .expect("Parent node does not have Conditional optype."); + let extension_delta = cond.dataflow_signature().extension_reqs; let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; diff --git a/src/extension/infer.rs b/src/extension/infer.rs index e2ea1f89f..4ef2ac47d 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -305,11 +305,15 @@ impl UnificationContext { match node_type.signature() { // Input extensions are open None => { - let delta = node_type.op_signature().extension_reqs; - let c = if delta.is_empty() { - Constraint::Equal(m_input) + let c = if let Some(sig) = node_type.op_signature() { + let delta = sig.extension_reqs; + if delta.is_empty() { + Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) + } } else { - Constraint::Plus(delta, m_input) + Constraint::Equal(m_input) }; self.add_constraint(m_output, c); } @@ -1063,7 +1067,7 @@ mod test { hugr, conditional_node, op.clone(), - Into::::into(op).signature(), + Into::::into(op).signature().unwrap(), )?; let lift1 = hugr.add_node_with_parent( diff --git a/src/extension/validate.rs b/src/extension/validate.rs index c026c2519..2f267edf9 100644 --- a/src/extension/validate.rs +++ b/src/extension/validate.rs @@ -25,8 +25,13 @@ impl ExtensionValidator { pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self { let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new(); for (node, incoming_sol) in closure.into_iter() { - let op_signature = hugr.get_nodetype(node).op_signature(); - let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol); + let extension_reqs = hugr + .get_nodetype(node) + .op_signature() + .map(|s| s.extension_reqs) + .unwrap_or_default(); + + let outgoing_sol = extension_reqs.union(&incoming_sol); extensions.insert((node, Direction::Incoming), incoming_sol); extensions.insert((node, Direction::Outgoing), outgoing_sol); diff --git a/src/hugr.rs b/src/hugr.rs index 97bf44ed6..a23703bc0 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -111,13 +111,16 @@ impl NodeType { /// Use the input extensions to calculate the concrete signature of the node pub fn signature(&self) -> Option { - self.input_extensions - .as_ref() - .map(|rs| self.op.signature().with_input_extensions(rs.clone())) + self.input_extensions.as_ref().map(|rs| { + self.op + .signature() + .unwrap_or_default() + .with_input_extensions(rs.clone()) + }) } /// Get the function type from the embedded op - pub fn op_signature(&self) -> FunctionType { + pub fn op_signature(&self) -> Option { self.op.signature() } diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 5ebf6edd9..e24316bfb 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -70,7 +70,8 @@ impl OutlineCfg { } } } - extension_delta = extension_delta.union(&o.signature().extension_reqs); + extension_delta = extension_delta + .union(&o.signature().expect("cfg missing signature").extension_reqs); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 0ae10acec..9099e9913 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -616,7 +616,7 @@ mod test { entry: bool, ) -> Result { let op: OpType = op.into(); - let op_sig = op.signature(); + let op_sig = op.signature().expect("dataflow op needs signature"); let mut bb = if entry { assert_eq!( match h.hugr().get_optype(h.container_node()) { diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index b4480fe8d..d0922f233 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -496,7 +496,13 @@ pub(in crate::hugr::rewrite) mod test { .collect_vec(); let inputs = h .node_outputs(input) - .filter(|&p| h.get_optype(input).signature().port_type(p).is_some()) + .filter(|&p| { + h.get_optype(input) + .signature() + .unwrap() + .port_type(p) + .is_some() + }) .map(|p| { let link = h.linked_inputs(input, p).next().unwrap(); (link, link) @@ -504,7 +510,13 @@ pub(in crate::hugr::rewrite) mod test { .collect(); let outputs = h .node_inputs(output) - .filter(|&p| h.get_optype(output).signature().port_type(p).is_some()) + .filter(|&p| { + h.get_optype(output) + .signature() + .unwrap() + .port_type(p) + .is_some() + }) .map(|p| ((output, p), p)) .collect(); h.apply_rewrite(SimpleReplacement::new( @@ -556,12 +568,7 @@ pub(in crate::hugr::rewrite) mod test { let outputs = repl .node_inputs(repl_output) - .filter(|&p| { - repl.get_optype(repl_output) - .signature() - .port_type(p) - .is_some() - }) + .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some()) .map(|p| ((repl_output, p), p)) .collect(); diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 67e209ac1..b1f99a4a2 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -166,7 +166,7 @@ impl TryFrom<&Hugr> for SerHugrV0 { .expect("Could not reach one of the nodes"); let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| { - let sig = hugr.signature(node); + let sig = hugr.signature(node).unwrap_or_default(); let offset = match offset < sig.port_count(dir) { true => Some(offset as u16), false => None, diff --git a/src/hugr/views.rs b/src/hugr/views.rs index d0b8d7af4..1f4a22974 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -391,7 +391,7 @@ pub trait HugrView: sealed::HugrInternals { /// Get the "signature" (incoming and outgoing types) of a node, non-Value /// kind edges will be missing. - fn signature(&self, node: Node) -> FunctionType { + fn signature(&self, node: Node) -> Option { self.get_optype(node).signature() } @@ -399,7 +399,7 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over all ports in a given direction that have Value type, along /// with corresponding types. fn value_types(&self, node: Node, dir: Direction) -> impl Iterator { - let sig = self.signature(node); + let sig = self.signature(node).unwrap_or_default(); self.node_ports(node, dir) .flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone()))) } diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 77f74f39b..7d54469f4 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -472,7 +472,7 @@ mod test { fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { let root = simple_dfg_hugr.root(); let case_nodetype = NodeType::new_open(crate::ops::Case { - signature: simple_dfg_hugr.root_type().op_signature(), + signature: simple_dfg_hugr.root_type().op_signature().unwrap(), }); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); // As expected, we cannot replace the root with a Case diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 443c4864b..0b1fb9135 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -298,7 +298,7 @@ impl SiblingSubgraph { .iter() .map(|part| { let &(n, p) = part.iter().next().expect("is non-empty"); - let sig = hugr.signature(n); + let sig = hugr.signature(n).expect("must have dataflow signature"); sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); @@ -306,7 +306,7 @@ impl SiblingSubgraph { .outputs .iter() .map(|&(n, p)| { - let sig = hugr.signature(n); + let sig = hugr.signature(n).expect("must have dataflow signature"); sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); @@ -348,7 +348,7 @@ impl SiblingSubgraph { let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; - if dfg_optype.signature() != self.signature(hugr) { + if dfg_optype.signature() != Some(self.signature(hugr)) { return Err(InvalidReplacement::InvalidSignature); } @@ -356,10 +356,16 @@ impl SiblingSubgraph { // See https://github.com/CQCL-DEV/hugr/discussions/432 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); - let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = - rep_inputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some()); - let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = - rep_outputs.partition(|&(n, p)| replacement.signature(n).port_type(p).is_some()); + let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs.partition(|&(n, p)| { + replacement + .signature(n) + .is_some_and(|s| s.port_type(p).is_some()) + }); + let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| { + replacement + .signature(n) + .is_some_and(|s| s.port_type(p).is_some()) + }); if combine_in_out(&vec![out_order_ports], &in_order_ports) .any(|(n, p)| is_order_edge(&replacement, n, p)) @@ -467,10 +473,13 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> { /// If the array is empty or a port does not exist, returns `None`. fn get_edge_type + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option { let &(n, p) = ports.first()?; - let edge_t = hugr.signature(n).port_type(p)?.clone(); + let edge_t = hugr.signature(n)?.port_type(p)?.clone(); ports .iter() - .all(|&(n, p)| hugr.signature(n).port_type(p) == Some(&edge_t)) + .all(|&(n, p)| { + hugr.signature(n) + .is_some_and(|s| s.port_type(p) == Some(&edge_t)) + }) .then_some(edge_t) } @@ -567,11 +576,19 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort if has_other_edge(hugr, inp, Direction::Outgoing) { unimplemented!("Non-dataflow output not supported at input node") } - let dfg_inputs = hugr.get_optype(inp).signature().output_ports(); + let dfg_inputs = hugr + .get_optype(inp) + .signature() + .unwrap_or_default() + .output_ports(); if has_other_edge(hugr, out, Direction::Incoming) { unimplemented!("Non-dataflow input not supported at output node") } - let dfg_outputs = hugr.get_optype(out).signature().input_ports(); + let dfg_outputs = hugr + .get_optype(out) + .signature() + .unwrap_or_default() + .input_ports(); // Collect for each port in the input the set of target ports, filtering // direct wires to the output. diff --git a/src/ops.rs b/src/ops.rs index e764648dc..eceb7185e 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -77,12 +77,12 @@ impl OpType { /// Returns the edge kind for the given port. pub fn port_kind(&self, port: impl Into) -> Option { - let signature = self.signature(); + let signature = self.signature().unwrap_or_default(); let port: Port = port.into(); + let port_as_in = port.as_incoming().ok(); let dir = port.direction(); let port_count = signature.port_count(dir); - let port_as_in = port.as_incoming().ok(); if port.index() < port_count { signature.port_type(port).cloned().map(EdgeKind::Value) } else if port_as_in.is_some() && port_as_in == self.static_input_port() { @@ -103,15 +103,27 @@ impl OpType { let static_input = (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; - Some(Port::new( - dir, - self.signature().port_count(dir) + static_input, - )) + Some(Port::new(dir, self.value_port_count(dir) + static_input)) } else { None } } + /// The number of Value ports in given direction. + pub fn value_port_count(&self, dir: portgraph::Direction) -> usize { + self.signature().map(|sig| sig.port_count(dir)).unwrap_or(0) + } + + /// The number of Value input ports. + pub fn value_input_count(&self) -> usize { + self.value_port_count(Direction::Incoming) + } + + /// The number of Value output ports. + pub fn value_output_count(&self) -> usize { + self.value_port_count(Direction::Outgoing) + } + /// The non-dataflow input port for the operation, not described by the signature. /// See `[OpType::other_port]`. pub fn other_input_port(&self) -> Option { @@ -151,7 +163,6 @@ impl OpType { /// Returns the number of ports for the given direction. pub fn port_count(&self, dir: Direction) -> usize { - let signature = self.signature(); let has_other_ports = self.other_port_kind(dir).is_some(); let non_df_count = self .non_df_port_count(dir) @@ -159,7 +170,7 @@ impl OpType { // if there is a static input it comes before the non_df_ports let static_input = (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; - signature.port_count(dir) + non_df_count + static_input + self.value_port_count(dir) + non_df_count + static_input } /// Returns the number of inputs ports for the operation. @@ -228,8 +239,8 @@ pub trait OpTrait { /// The signature of the operation. /// - /// Only dataflow operations have a non-empty signature. - fn signature(&self) -> FunctionType { + /// Only dataflow operations have a signature, otherwise returns None. + fn signature(&self) -> Option { Default::default() } /// The edge kind for the non-dataflow or constant inputs of the operation, diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index b6836ea70..a54aac41d 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -30,7 +30,7 @@ impl DataflowOpTrait for TailLoop { "A tail-controlled loop" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| tuple_sum_first(row, &self.rest)); FunctionType::new(inputs, outputs) @@ -74,7 +74,7 @@ impl DataflowOpTrait for Conditional { "HUGR conditional operation" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { let mut inputs = self.other_inputs.clone(); inputs .to_mut() @@ -109,7 +109,7 @@ impl DataflowOpTrait for CFG { "A dataflow node defined by a child CFG" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { self.signature.clone() } } @@ -169,13 +169,13 @@ impl OpTrait for BasicBlock { Some(EdgeKind::ControlFlow) } - fn signature(&self) -> FunctionType { - match self { + fn signature(&self) -> Option { + Some(match self { BasicBlock::DFB { extension_delta, .. } => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta), BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), - } + }) } fn non_df_port_count(&self, dir: Direction) -> Option { @@ -233,8 +233,8 @@ impl OpTrait for Case { ::TAG } - fn signature(&self) -> FunctionType { - self.signature.clone() + fn signature(&self) -> Option { + Some(self.signature.clone()) } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index b1c5a39b3..724b64984 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -82,10 +82,10 @@ impl OpTrait for ExternalOp { /// Note the case of an OpaqueOp without a signature should already /// have been detected in [resolve_extension_ops] - fn signature(&self) -> FunctionType { + fn signature(&self) -> Option { match self { - Self::Opaque(op) => op.signature.clone().unwrap(), - Self::Extension(ExtensionOp { signature, .. }) => signature.clone(), + Self::Opaque(op) => op.signature.clone(), + Self::Extension(ExtensionOp { signature, .. }) => Some(signature.clone()), } } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 4ab06d5cf..fe109ab15 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -6,10 +6,10 @@ use crate::extension::ExtensionSet; use crate::ops::StaticTag; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; -pub(super) trait DataflowOpTrait { +pub(crate) trait DataflowOpTrait { const TAG: OpTag; fn description(&self) -> &str; - fn signature(&self) -> FunctionType; + fn dataflow_signature(&self) -> FunctionType; /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. @@ -81,7 +81,7 @@ impl DataflowOpTrait for Input { None } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), self.types.clone()) .with_extension_delta(&ExtensionSet::new()) } @@ -95,7 +95,7 @@ impl DataflowOpTrait for Output { // Note: We know what the input extensions should be, so we *could* give an // instantiated Signature instead - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { FunctionType::new(self.types.clone(), TypeRow::new()) } @@ -111,8 +111,8 @@ impl OpTrait for T { fn tag(&self) -> OpTag { T::TAG } - fn signature(&self) -> FunctionType { - DataflowOpTrait::signature(self) + fn signature(&self) -> Option { + Some(DataflowOpTrait::dataflow_signature(self)) } fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) @@ -145,7 +145,7 @@ impl DataflowOpTrait for Call { "Call a function directly" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { self.signature.clone() } } @@ -172,7 +172,7 @@ impl DataflowOpTrait for CallIndirect { "Call a function indirectly" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { let mut s = self.signature.clone(); s.input .to_mut() @@ -195,7 +195,7 @@ impl DataflowOpTrait for LoadConstant { "Load a static constant in to the local dataflow graph" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), vec![self.datatype.clone()]) } } @@ -222,7 +222,7 @@ impl DataflowOpTrait for DFG { "A simply nested dataflow graph" } - fn signature(&self) -> FunctionType { + fn dataflow_signature(&self) -> FunctionType { self.signature.clone() } } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 24f1ebb7e..ef9da58d8 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -153,13 +153,13 @@ impl OpTrait for LeafOp { } /// The signature of the operation. - fn signature(&self) -> FunctionType { + fn signature(&self) -> Option { // Static signatures. The `TypeRow`s in the `FunctionType` use a // copy-on-write strategy, so we can avoid unnecessary allocations. - match self { + Some(match self { LeafOp::Noop { ty: typ } => FunctionType::new(vec![typ.clone()], vec![typ.clone()]), - LeafOp::CustomOp(ext) => ext.signature(), + LeafOp::CustomOp(ext) => return ext.signature(), LeafOp::MakeTuple { tys: types } => { FunctionType::new(types.clone(), vec![Type::new_tuple(types.clone())]) } @@ -179,7 +179,7 @@ impl OpTrait for LeafOp { vec![Type::new_function(ta.input.clone())], vec![Type::new_function(ta.output.clone())], ), - } + }) } fn other_input(&self) -> Option { diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 546d6cec6..28a9bdb55 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -102,12 +102,8 @@ impl ValidateOp for super::DFG { &self, children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { - validate_io_nodes( - &self.signature().input, - &self.signature().output, - "nested graph", - children, - ) + let sig = self.signature().unwrap_or_default(); + validate_io_nodes(&sig.input, &sig.output, "nested graph", children) } } @@ -371,19 +367,22 @@ fn validate_io_nodes<'a>( let (first, first_optype) = children.next().unwrap(); let (second, second_optype) = children.next().unwrap(); - if &first_optype.signature().output != expected_input { + let first_sig = first_optype.signature().unwrap_or_default(); + if &first_sig.output != expected_input { return Err(ChildrenValidationError::IOSignatureMismatch { child: first, - actual: first_optype.signature().output, + actual: first_sig.output, expected: expected_input.clone(), node_desc: "Input", container_desc, }); } - if &second_optype.signature().input != expected_output { + let second_sig = second_optype.signature().unwrap_or_default(); + + if &second_sig.input != expected_output { return Err(ChildrenValidationError::IOSignatureMismatch { child: second, - actual: second_optype.signature().input, + actual: second_sig.input, expected: expected_output.clone(), node_desc: "Output", container_desc, From b1db6cc487388dfd2f8d96e51372ec9658ddb29a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 12:46:28 +0000 Subject: [PATCH 15/31] refactor: impl DataFlowOpTrait for LeafOp and use dataflow_signature --- src/hugr/rewrite/outline_cfg.rs | 8 ++++---- src/hugr/rewrite/replace.rs | 12 ++++++------ src/ops/leaf.rs | 23 ++++++++++------------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index e24316bfb..b161f27ee 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -12,8 +12,9 @@ use crate::hugr::rewrite::Rewrite; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; +use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; -use crate::ops::{BasicBlock, OpTrait, OpType}; +use crate::ops::{BasicBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; @@ -49,7 +50,7 @@ impl OutlineCfg { _ => return Err(OutlineCfgError::NotSiblings), }; let o = h.get_optype(cfg_n); - if !matches!(o, OpType::CFG(_)) { + let OpType::CFG(o) = o else { return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; let cfg_entry = h.children(cfg_n).next().unwrap(); @@ -70,8 +71,7 @@ impl OutlineCfg { } } } - extension_delta = extension_delta - .union(&o.signature().expect("cfg missing signature").extension_reqs); + extension_delta = extension_delta.union(&o.dataflow_signature().extension_reqs); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 9099e9913..44edc5291 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -447,8 +447,9 @@ mod test { use crate::hugr::rewrite::replace::WhichHugr; use crate::hugr::{HugrMut, NodeType, Rewrite}; use crate::ops::custom::{ExternalOp, OpaqueOp}; + use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; - use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpTrait, OpType, DFG}; + use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpType, DFG}; use crate::std_extensions::collections; use crate::types::{FunctionType, Type, TypeArg, TypeRow}; use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; @@ -609,14 +610,13 @@ mod test { .unwrap() } - fn single_node_block + AsMut>( + fn single_node_block + AsMut, O: DataflowOpTrait + Into>( h: &mut CFGBuilder, - op: impl Into, + op: O, pred_const: &ConstID, entry: bool, ) -> Result { - let op: OpType = op.into(); - let op_sig = op.signature().expect("dataflow op needs signature"); + let op_sig = op.dataflow_signature(); let mut bb = if entry { assert_eq!( match h.hugr().get_optype(h.container_node()) { @@ -629,7 +629,7 @@ mod test { } else { h.simple_block_builder(op_sig, 1)? }; - + let op: OpType = op.into(); let op = bb.add_dataflow_op(op, bb.input_wires())?; let load_pred = bb.load_const(pred_const)?; bb.finish_with_outputs(load_pred, op.outputs()) diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index ef9da58d8..1aab93b8b 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -3,7 +3,8 @@ use smol_str::SmolStr; use super::custom::ExternalOp; -use super::{OpName, OpTag, OpTrait, StaticTag}; +use super::dataflow::DataflowOpTrait; +use super::{OpName, OpTag, OpTrait}; use crate::extension::{ExtensionRegistry, SignatureError}; use crate::types::type_param::TypeArg; @@ -128,11 +129,11 @@ impl OpName for LeafOp { } } -impl StaticTag for LeafOp { - const TAG: OpTag = OpTag::Leaf; -} +// impl StaticTag for LeafOp { +// } -impl OpTrait for LeafOp { +impl DataflowOpTrait for LeafOp { + const TAG: OpTag = OpTag::Leaf; /// A human-readable description of the operation. fn description(&self) -> &str { match self { @@ -148,18 +149,14 @@ impl OpTrait for LeafOp { } } - fn tag(&self) -> OpTag { - ::TAG - } - /// The signature of the operation. - fn signature(&self) -> Option { + fn dataflow_signature(&self) -> FunctionType { // Static signatures. The `TypeRow`s in the `FunctionType` use a // copy-on-write strategy, so we can avoid unnecessary allocations. - Some(match self { + match self { LeafOp::Noop { ty: typ } => FunctionType::new(vec![typ.clone()], vec![typ.clone()]), - LeafOp::CustomOp(ext) => return ext.signature(), + LeafOp::CustomOp(ext) => ext.signature().unwrap_or_default(), LeafOp::MakeTuple { tys: types } => { FunctionType::new(types.clone(), vec![Type::new_tuple(types.clone())]) } @@ -179,7 +176,7 @@ impl OpTrait for LeafOp { vec![Type::new_function(ta.input.clone())], vec![Type::new_function(ta.output.clone())], ), - }) + } } fn other_input(&self) -> Option { From 1677014f2f8e7052e960826d8dcb246a2b85e9b0 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 14:20:35 +0000 Subject: [PATCH 16/31] feat: use macro for OpType reference casting to inner --- Cargo.toml | 3 ++- src/ops.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f445091bb..0c8e7ae8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ context-iterators = "0.2.0" serde_json = "1.0.97" delegate = "0.10.0" rustversion = "1.0.14" +paste = "1.0" [features] pyo3 = ["dep:pyo3"] @@ -72,4 +73,4 @@ harness = false [profile.dev.package] insta.opt-level = 3 -similar.opt-level = 3 \ No newline at end of file +similar.opt-level = 3 diff --git a/src/ops.rs b/src/ops.rs index eceb7185e..7193c10c7 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -12,6 +12,7 @@ pub mod validate; use crate::types::{EdgeKind, FunctionType, Type}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; +use paste::paste; use portgraph::NodeIndex; use smol_str::SmolStr; @@ -53,6 +54,45 @@ pub enum OpType { Case, } +macro_rules! impl_op_ref_try_into { + ($Op: ident, $sname:expr) => { + paste! { + impl OpType { + #[doc = "If is an instance of `" $Op "` return a reference to it."] + pub fn [](&self) -> Option<&$Op> { + if let OpType::$Op(l) = self { + Some(l) + } else { + None + } + } + } + } + }; + ($name:tt) => { + impl_op_ref_try_into!($name, stringify!($name)); + }; +} + +impl_op_ref_try_into!(Module); +impl_op_ref_try_into!(FuncDefn); +impl_op_ref_try_into!(FuncDecl); +impl_op_ref_try_into!(AliasDecl); +impl_op_ref_try_into!(AliasDefn); +impl_op_ref_try_into!(Const); +impl_op_ref_try_into!(Input); +impl_op_ref_try_into!(Output); +impl_op_ref_try_into!(Call); +impl_op_ref_try_into!(CallIndirect); +impl_op_ref_try_into!(LoadConstant); +impl_op_ref_try_into!(DFG); +impl_op_ref_try_into!(LeafOp); +impl_op_ref_try_into!(BasicBlock); +impl_op_ref_try_into!(TailLoop); +impl_op_ref_try_into!(CFG); +impl_op_ref_try_into!(Conditional); +impl_op_ref_try_into!(Case); + /// The default OpType (as returned by [Default::default]) pub const DEFAULT_OPTYPE: OpType = OpType::Module(Module); From 5fd10357e32ad01457e3507f730faf8a9b6d758f Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 14:52:25 +0000 Subject: [PATCH 17/31] refactor: use enum to simplify optype pattern matches --- src/builder/module.rs | 17 ++++++++--------- src/builder/tail_loop.rs | 11 +++++------ src/hugr/rewrite/replace.rs | 4 ++-- src/hugr/rewrite/simple_replace.rs | 7 +++++-- src/hugr/views/sibling_subgraph.rs | 20 +++++++++----------- src/ops.rs | 17 +++++++++++------ src/ops/validate.rs | 12 ++++++------ 7 files changed, 46 insertions(+), 42 deletions(-) diff --git a/src/builder/module.rs b/src/builder/module.rs index fb3e2a1e0..939df96aa 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -12,7 +12,6 @@ use crate::{ }; use crate::ops::handle::{AliasID, FuncID, NodeHandle}; -use crate::ops::OpType; use crate::types::Signature; @@ -78,16 +77,16 @@ impl + AsRef> ModuleBuilder { f_id: &FuncID, ) -> Result, BuildError> { let f_node = f_id.node(); - let (signature, name) = if let OpType::FuncDecl(ops::FuncDecl { signature, name }) = - self.hugr().get_optype(f_node) - { - (signature.clone(), name.clone()) - } else { - return Err(BuildError::UnexpectedType { + let ops::FuncDecl { signature, name } = self + .hugr() + .get_optype(f_node) + .as_func_decl() + .ok_or(BuildError::UnexpectedType { node: f_node, op_desc: "OpType::FuncDecl", - }); - }; + })? + .clone(); + self.hugr_mut().replace_op( f_node, NodeType::new_pure(ops::FuncDefn { diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 5eb8286e1..8ea5113ae 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -1,4 +1,4 @@ -use crate::ops::{self, OpType}; +use crate::ops; use crate::hugr::{views::HugrView, NodeType}; use crate::types::{FunctionType, TypeRow}; @@ -38,14 +38,13 @@ impl + AsRef> TailLoopBuilder { /// Get a reference to the [`ops::TailLoop`] /// that defines the signature of the [`ops::TailLoop`] pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> { - if let OpType::TailLoop(tail_loop) = self.hugr().get_optype(self.container_node()) { - Ok(tail_loop) - } else { - Err(BuildError::UnexpectedType { + self.hugr() + .get_optype(self.container_node()) + .as_tail_loop() + .ok_or(BuildError::UnexpectedType { node: self.container_node(), op_desc: "crate::ops::TailLoop", }) - } } /// The output types of the child graph, including the TupleSum as the first. diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 44edc5291..a83162dd4 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -588,8 +588,8 @@ mod test { let popp = h.get_parent(pop).unwrap(); let pushp = h.get_parent(push).unwrap(); assert_ne!(popp, pushp); // Two different DFGs - assert!(matches!(h.get_optype(popp), OpType::DFG(_))); - assert!(matches!(h.get_optype(pushp), OpType::DFG(_))); + assert!(h.get_optype(popp).is_dfg()); + assert!(h.get_optype(pushp).is_dfg()); let grandp = h.get_parent(popp).unwrap(); assert_eq!(grandp, h.get_parent(pushp).unwrap()); diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index d0922f233..e58b952d1 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -222,6 +222,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; + use crate::ops::dataflow::DataflowOpTrait; use crate::ops::OpTag; use crate::ops::{OpTrait, OpType}; use crate::std_extensions::logic::test::and_op; @@ -498,8 +499,9 @@ pub(in crate::hugr::rewrite) mod test { .node_outputs(input) .filter(|&p| { h.get_optype(input) - .signature() + .as_input() .unwrap() + .dataflow_signature() .port_type(p) .is_some() }) @@ -512,8 +514,9 @@ pub(in crate::hugr::rewrite) mod test { .node_inputs(output) .filter(|&p| { h.get_optype(output) - .signature() + .as_output() .unwrap() + .dataflow_signature() .port_type(p) .is_some() }) diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 0b1fb9135..e7e35358a 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -19,6 +19,7 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::extension::ExtensionSet; use crate::hugr::{HugrError, HugrMut, HugrView, RootTagged}; +use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; use crate::ops::{OpTag, OpTrait}; use crate::types::Signature; @@ -578,16 +579,18 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort } let dfg_inputs = hugr .get_optype(inp) - .signature() - .unwrap_or_default() + .as_input() + .unwrap() + .dataflow_signature() .output_ports(); if has_other_edge(hugr, out, Direction::Incoming) { unimplemented!("Non-dataflow input not supported at output node") } let dfg_outputs = hugr .get_optype(out) - .signature() - .unwrap_or_default() + .as_output() + .unwrap() + .dataflow_signature() .input_ports(); // Collect for each port in the input the set of target ports, filtering @@ -710,10 +713,7 @@ mod tests { }, hugr::views::{HierarchyView, SiblingGraph}, hugr::HugrMut, - ops::{ - handle::{DfgID, FuncID, NodeHandle}, - OpType, - }, + ops::handle::{DfgID, FuncID, NodeHandle}, std_extensions::logic::test::{and_op, not_op}, type_row, }; @@ -976,9 +976,7 @@ mod tests { let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); - let OpType::FuncDefn(func_defn) = hugr.get_optype(func_root) else { - panic!() - }; + let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); assert_eq!(func_defn.signature, func.signature(&func_graph)) } diff --git a/src/ops.rs b/src/ops.rs index 7193c10c7..e6185b386 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -55,22 +55,27 @@ pub enum OpType { } macro_rules! impl_op_ref_try_into { - ($Op: ident, $sname:expr) => { + ($Op: tt, $sname:ident) => { paste! { impl OpType { #[doc = "If is an instance of `" $Op "` return a reference to it."] - pub fn [](&self) -> Option<&$Op> { + pub fn [](&self) -> Option<&$Op> { if let OpType::$Op(l) = self { Some(l) } else { None } } + + #[doc = "If is an instance of `" $Op "`."] + pub fn [](&self) -> bool { + self.[]().is_some() + } } } }; - ($name:tt) => { - impl_op_ref_try_into!($name, stringify!($name)); + ($Op:tt) => { + impl_op_ref_try_into!($Op, $Op); }; } @@ -85,11 +90,11 @@ impl_op_ref_try_into!(Output); impl_op_ref_try_into!(Call); impl_op_ref_try_into!(CallIndirect); impl_op_ref_try_into!(LoadConstant); -impl_op_ref_try_into!(DFG); +impl_op_ref_try_into!(DFG, dfg); impl_op_ref_try_into!(LeafOp); impl_op_ref_try_into!(BasicBlock); impl_op_ref_try_into!(TailLoop); -impl_op_ref_try_into!(CFG); +impl_op_ref_try_into!(CFG, cfg); impl_op_ref_try_into!(Conditional); impl_op_ref_try_into!(Case); diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 28a9bdb55..a914f6247 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -135,9 +135,9 @@ impl ValidateOp for super::Conditional { // Each child must have its variant's row and the rest of `inputs` as input, // and matching output for (i, (child, optype)) in children.into_iter().enumerate() { - let OpType::Case(case_op) = optype else { - panic!("Child check should have already checked valid ops.") - }; + let case_op = optype + .as_case() + .expect("Child check should have already checked valid ops."); let sig = &case_op.signature; if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs { return Err(ChildrenValidationError::ConditionalCaseSignature { @@ -415,9 +415,9 @@ fn validate_io_nodes<'a>( /// Validate an edge between two basic blocks in a CFG sibling graph. fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> { let [source, target]: [&BasicBlock; 2] = [&edge.source_op, &edge.target_op].map(|op| { - let OpType::BasicBlock(block_op) = op else { - panic!("CFG sibling graphs can only contain basic block operations.") - }; + let block_op = op + .as_basic_block() + .expect("CFG sibling graphs can only contain basic block operations."); block_op }); From 09844a123374595b59b1c230729451bb5e68781a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 15:05:38 +0000 Subject: [PATCH 18/31] chore: remove paste from dev-dependencies --- Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0c8e7ae8c..9e2fd8af7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,6 @@ rmp-serde = "1.1.1" webbrowser = "0.8.10" urlencoding = "2.1.2" cool_asserts = "2.0.3" -paste = "1.0" insta = { version = "1.34.0", features = ["yaml"] } [[bench]] From b124160ac50beec23b5cf504a2368f3a00a9d92b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 15:11:00 +0000 Subject: [PATCH 19/31] feat: `all_linked_ports` --- src/hugr/views.rs | 22 +++++++++++++++++++--- src/hugr/views/tests.rs | 41 ++++++++++++----------------------------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 1f4a22974..327177658 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -18,7 +18,7 @@ pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; +use itertools::{Either, Itertools, MapInto}; use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle}; use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; @@ -182,14 +182,30 @@ pub trait HugrView: sealed::HugrInternals { fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; #[rustversion::since(1.75)] // uses impl in return position - /// Iterator over all the nodes and ports connected to a node's inputs.. + /// Iterator over all the nodes and ports connected to a node. + /// First the OutgoingPorts connected to the inputs of this node, + /// then the IncomingPorts connected to the outputs of this node. + fn all_linked_ports( + &self, + node: Node, + ) -> impl Iterator)> { + self.all_linked_outputs(node) + .map(|(n, p)| (n, Either::Left(p))) + .chain( + self.all_linked_inputs(node) + .map(|(n, p)| (n, Either::Right(p))), + ) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all the nodes and ports connected to a node's inputs. fn all_linked_outputs(&self, node: Node) -> impl Iterator { self.node_inputs(node) .flat_map(move |port| self.linked_outputs(node, port)) } #[rustversion::since(1.75)] // uses impl in return position - /// Iterator over all the nodes and ports connected to a node's outputs.. + /// Iterator over all the nodes and ports connected to a node's outputs. fn all_linked_inputs(&self, node: Node) -> impl Iterator { self.node_outputs(node) .flat_map(move |port| self.linked_inputs(node, port)) diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 65601bf15..89da2e5c5 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -1,3 +1,4 @@ +use itertools::Either; use portgraph::PortOffset; use rstest::{fixture, rstest}; @@ -72,8 +73,6 @@ fn dot_string(sample_hugr: (Hugr, BuildHandle, BuildHandle, BuildHandle)) { - use crate::hugr::Direction; - use crate::Port; use itertools::Itertools; let (h, n1, n2) = sample_hugr; @@ -82,38 +81,22 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle Date: Tue, 14 Nov 2023 15:18:24 +0000 Subject: [PATCH 20/31] address minor review comments --- src/hugr/views.rs | 4 ++-- src/ops.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 327177658..f08ce8814 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -406,13 +406,13 @@ pub trait HugrView: sealed::HugrInternals { } /// Get the "signature" (incoming and outgoing types) of a node, non-Value - /// kind edges will be missing. + /// kind ports will be missing. fn signature(&self, node: Node) -> Option { self.get_optype(node).signature() } #[rustversion::since(1.75)] // uses impl in return position - /// Iterator over all ports in a given direction that have Value type, along + /// Iterator over all outgoing ports that have Value type, along /// with corresponding types. fn value_types(&self, node: Node, dir: Direction) -> impl Iterator { let sig = self.signature(node).unwrap_or_default(); diff --git a/src/ops.rs b/src/ops.rs index e6185b386..28ab9c739 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -286,7 +286,7 @@ pub trait OpTrait { /// /// Only dataflow operations have a signature, otherwise returns None. fn signature(&self) -> Option { - Default::default() + None } /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. From 44cb5b011ebc7ccb58e2916bb768e797e56e358d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 15:21:37 +0000 Subject: [PATCH 21/31] feat!: single_linked_port, single_linked_input, single_linked_output --- src/hugr/rewrite/insert_identity.rs | 2 +- src/hugr/rewrite/outline_cfg.rs | 2 +- src/hugr/rewrite/replace.rs | 2 +- src/hugr/rewrite/simple_replace.rs | 16 +++++++++------- src/hugr/views.rs | 18 ++++++++++-------- src/hugr/views/sibling_subgraph.rs | 4 ++-- 6 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index 53d2790fe..e3fd29318 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -72,7 +72,7 @@ impl Rewrite for IdentityInsertion { }; let (pre_node, pre_port) = h - .single_source(self.post_node, self.post_port) + .single_linked_output(self.post_node, self.post_port) .expect("Value kind input can only have one connection."); h.disconnect(self.post_node, self.post_port).unwrap(); diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index b161f27ee..951b4cf17 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -178,7 +178,7 @@ impl Rewrite for OutlineCfg { let exit_port = h .node_outputs(exit) .filter(|p| { - let (t, p2) = h.single_target(exit, *p).unwrap(); + let (t, p2) = h.single_linked_input(exit, *p).unwrap(); assert!(p2.index() == 0); t == outside }) diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index a83162dd4..c321b00d8 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -139,7 +139,7 @@ impl NewEdgeSpec { true }; let found_incoming = h - .single_source(self.tgt, tgt_pos) + .single_linked_output(self.tgt, tgt_pos) .is_some_and(|(src_n, _)| descends_from_legal(src_n)); if !found_incoming { return Err(ReplaceError::NoRemovedEdge(err_edge())); diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index e58b952d1..25e266bc1 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -127,8 +127,9 @@ impl Rewrite for SimpleReplacement { for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp { if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) - let (rem_inp_pred_node, rem_inp_pred_port) = - h.single_source(*rem_inp_node, *rem_inp_port).unwrap(); + let (rem_inp_pred_node, rem_inp_pred_port) = h + .single_linked_output(*rem_inp_node, *rem_inp_port) + .unwrap(); h.disconnect(*rem_inp_node, *rem_inp_port).unwrap(); let new_inp_node = index_map.get(rep_inp_node).unwrap(); h.connect( @@ -145,7 +146,7 @@ impl Rewrite for SimpleReplacement { for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { let (rep_out_pred_node, rep_out_pred_port) = self .replacement - .single_source(replacement_output_node, *rep_out_port) + .single_linked_output(replacement_output_node, *rep_out_port) .unwrap(); if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input { let new_out_node = index_map.get(&rep_out_pred_node).unwrap(); @@ -165,8 +166,9 @@ impl Rewrite for SimpleReplacement { let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): - let (rem_inp_pred_node, rem_inp_pred_port) = - h.single_source(*rem_inp_node, *rem_inp_port).unwrap(); + let (rem_inp_pred_node, rem_inp_pred_port) = h + .single_linked_output(*rem_inp_node, *rem_inp_port) + .unwrap(); h.disconnect(*rem_inp_node, *rem_inp_port).unwrap(); h.disconnect(*rem_out_node, *rem_out_port).unwrap(); h.connect( @@ -604,7 +606,7 @@ pub(in crate::hugr::rewrite) mod test { if *tgt == out { unimplemented!() }; - let (src, src_port) = h.single_source(*r_n, *r_p).unwrap(); + let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap(); NewEdgeSpec { src, tgt: *tgt, @@ -619,7 +621,7 @@ pub(in crate::hugr::rewrite) mod test { .nu_out .iter() .map(|((tgt, tgt_port), out_port)| { - let (src, src_port) = replacement.single_source(out, *out_port).unwrap(); + let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap(); if src == in_ { unimplemented!() }; diff --git a/src/hugr/views.rs b/src/hugr/views.rs index f08ce8814..55249e881 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -211,29 +211,31 @@ pub trait HugrView: sealed::HugrInternals { .flat_map(move |port| self.linked_inputs(node, port)) } + /// If there is exactly one port connected to this port, return + /// it and its node. + fn single_linked_port(&self, node: Node, port: impl Into) -> Option<(Node, Port)> { + self.linked_ports(node, port).exactly_one().ok() + } + /// If there is exactly one OutgoingPort connected to this IncomingPort, return /// it and its node. - fn single_source( + fn single_linked_output( &self, node: Node, port: impl Into, ) -> Option<(Node, OutgoingPort)> { - self.linked_ports(node, port.into()) - .exactly_one() - .ok() + self.single_linked_port(node, port.into()) .map(|(n, p)| (n, p.as_outgoing().unwrap())) } /// If there is exactly one IncomingPort connected to this OutgoingPort, return /// it and its node. - fn single_target( + fn single_linked_input( &self, node: Node, port: impl Into, ) -> Option<(Node, IncomingPort)> { - self.linked_ports(node, port.into()) - .exactly_one() - .ok() + self.single_linked_port(node, port.into()) .map(|(n, p)| (n, p.as_incoming().unwrap())) } /// Iterator over the nodes and output ports connected to a given *input* port. diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index e7e35358a..5abb4ed94 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -252,7 +252,7 @@ impl SiblingSubgraph { if !hugr.is_linked(n, p) { return false; } - let (out_n, _) = hugr.single_source(n, p).unwrap(); + let (out_n, _) = hugr.single_linked_output(n, p).unwrap(); !nodes_set.contains(&out_n) }) // Every incoming edge is its own input. @@ -899,7 +899,7 @@ mod tests { .collect(), hugr.node_inputs(out) .take(2) - .filter_map(|p| hugr.single_source(out, p)) + .filter_map(|p| hugr.single_linked_output(out, p)) .collect(), &func, ) From f7bfb466d32ffaf39d56c98acad7a5f2c544461b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 15:33:49 +0000 Subject: [PATCH 22/31] refactor!: allow chaining with `dataflow_ports_only` --- src/hugr/views.rs | 36 ++++++++++++++++++++++++------------ src/hugr/views/tests.rs | 7 +++++-- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 55249e881..7f72f7697 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -608,19 +608,31 @@ impl> HugrView for T { } } -/// Filter an iterator of node-ports to only dataflow dependency specifying -/// ports (Value and StateOrder) -pub fn dataflow_ports_only<'i, 'a: 'i, P: Into + Copy>( - hugr: &'a impl HugrView, - it: impl Iterator + 'i, -) -> impl Iterator + 'i { - it.filter(move |(n, p)| { - matches!( - hugr.get_optype(*n).port_kind(*p), - Some(EdgeKind::Value(_) | EdgeKind::StateOrder) - ) - }) +#[rustversion::since(1.75)] // uses impl in return position +/// Trait implementing methods on port iterators. +pub trait PortIterator

: Iterator +where + P: Into + Copy, + Self: Sized, +{ + /// Filter an iterator of node-ports to only dataflow dependency specifying + /// ports (Value and StateOrder) + fn dataflow_ports_only(self, hugr: &impl HugrView) -> impl Iterator { + self.filter(move |(n, p)| { + matches!( + hugr.get_optype(*n).port_kind(*p), + Some(EdgeKind::Value(_) | EdgeKind::StateOrder) + ) + }) + } } +impl PortIterator

for I +where + I: Iterator, + P: Into + Copy, +{ +} + pub(crate) mod sealed { use super::*; diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 89da2e5c5..96c4c43ae 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -161,7 +161,7 @@ fn static_targets() { fn test_dataflow_ports_only() { use crate::builder::DataflowSubContainer; use crate::extension::prelude::BOOL_T; - use crate::hugr::views::dataflow_ports_only; + use crate::hugr::views::PortIterator; use crate::std_extensions::logic::test::not_op; use itertools::Itertools; let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); @@ -183,7 +183,10 @@ fn test_dataflow_ports_only() { let h = dfg .finish_hugr_with_outputs(not.outputs(), &crate::extension::PRELUDE_REGISTRY) .unwrap(); - let filtered_ports = dataflow_ports_only(&h, h.all_linked_outputs(call.node())).collect_vec(); + let filtered_ports = h + .all_linked_outputs(call.node()) + .dataflow_ports_only(&h) + .collect_vec(); // should ignore the static input in to call, but report the two value ports // and the order port. From 920880ea86d7880822e274179bb3e3c5b9a11eec Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 15:41:25 +0000 Subject: [PATCH 23/31] feat!: parametric all_linked_ports --- src/hugr/views.rs | 36 ++++++++++++++++++++++-------------- src/hugr/views/tests.rs | 14 +++++--------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 7f72f7697..1026f32f5 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -182,33 +182,41 @@ pub trait HugrView: sealed::HugrInternals { fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; #[rustversion::since(1.75)] // uses impl in return position - /// Iterator over all the nodes and ports connected to a node. - /// First the OutgoingPorts connected to the inputs of this node, - /// then the IncomingPorts connected to the outputs of this node. + /// Iterator over all the nodes and ports connected to a node in a given direction. fn all_linked_ports( &self, node: Node, - ) -> impl Iterator)> { - self.all_linked_outputs(node) - .map(|(n, p)| (n, Either::Left(p))) - .chain( - self.all_linked_inputs(node) - .map(|(n, p)| (n, Either::Right(p))), - ) + dir: Direction, + ) -> Either< + impl Iterator, + impl Iterator, + > { + match dir { + Direction::Incoming => Either::Left( + self.node_inputs(node) + .flat_map(move |port| self.linked_outputs(node, port)), + ), + Direction::Outgoing => Either::Right( + self.node_outputs(node) + .flat_map(move |port| self.linked_inputs(node, port)), + ), + } } #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node's inputs. fn all_linked_outputs(&self, node: Node) -> impl Iterator { - self.node_inputs(node) - .flat_map(move |port| self.linked_outputs(node, port)) + self.all_linked_ports(node, Direction::Incoming) + .left() + .unwrap() } #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node's outputs. fn all_linked_inputs(&self, node: Node) -> impl Iterator { - self.node_outputs(node) - .flat_map(move |port| self.linked_inputs(node, port)) + self.all_linked_ports(node, Direction::Outgoing) + .right() + .unwrap() } /// If there is exactly one port connected to this port, return diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 96c4c43ae..44b4ee8df 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -1,4 +1,3 @@ -use itertools::Either; use portgraph::PortOffset; use rstest::{fixture, rstest}; @@ -87,16 +86,13 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle Date: Tue, 14 Nov 2023 15:47:02 +0000 Subject: [PATCH 24/31] refactor: simpler static_input_port --- src/ops.rs | 13 ++----------- src/ops/dataflow.rs | 13 +++++++++++++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index 28ab9c739..26aeace37 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -184,17 +184,8 @@ impl OpType { /// If the op has a static input (Call and LoadConstant), the port of that input. pub fn static_input_port(&self) -> Option { match self { - OpType::Call(call) => Some( - Port::new( - Direction::Incoming, - call.called_function_type().input_count(), - ) - .as_incoming() - .unwrap(), - ), - OpType::LoadConstant(_) => { - Some(Port::new(Direction::Incoming, 0).as_incoming().unwrap()) - } + OpType::Call(call) => Some(call.called_function_port()), + OpType::LoadConstant(l) => Some(l.constant_port()), _ => None, } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index fe109ab15..777aaf2f3 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -5,6 +5,7 @@ use super::{impl_op_name, OpTag, OpTrait}; use crate::extension::ExtensionSet; use crate::ops::StaticTag; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; +use crate::IncomingPort; pub(crate) trait DataflowOpTrait { const TAG: OpTag; @@ -155,6 +156,12 @@ impl Call { pub fn called_function_type(&self) -> &FunctionType { &self.signature } + + /// The IncomingPort which links to the function being called. + #[inline] + pub fn called_function_port(&self) -> IncomingPort { + self.called_function_type().input_count().into() + } } /// Call a function indirectly. Like call, but the first input is a standard dataflow graph type. @@ -205,6 +212,12 @@ impl LoadConstant { pub fn constant_type(&self) -> &Type { &self.datatype } + + /// The IncomingPort which links to the loaded constant. + #[inline] + pub fn constant_port(&self) -> IncomingPort { + 0.into() + } } /// A simply nested dataflow graph. From 7ea609a95afeb33ee1a19620a52c6b69cff85617 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 16:16:20 +0000 Subject: [PATCH 25/31] refactor: remove uncessary trait impls for ExternalOp --- src/ops/custom.rs | 25 +++++++++++++------------ src/ops/leaf.rs | 4 ++-- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 724b64984..5f2af204f 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -11,7 +11,7 @@ use crate::types::{type_param::TypeArg, FunctionType}; use crate::{Hugr, Node}; use super::tag::OpTag; -use super::{LeafOp, OpName, OpTrait, OpType}; +use super::{LeafOp, OpTrait, OpType}; /// An instantiation of an operation (declared by a extension) with values for the type arguments #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -58,8 +58,9 @@ impl From for LeafOp { } } -impl OpName for ExternalOp { - fn name(&self) -> SmolStr { +impl ExternalOp { + /// Name of the ExternalOp + pub fn name(&self) -> SmolStr { let (res_id, op_name) = match self { Self::Opaque(op) => (&op.extension, &op.op_name), Self::Extension(ExtensionOp { def, .. }) => (def.extension(), def.name()), @@ -68,24 +69,24 @@ impl OpName for ExternalOp { } } -impl OpTrait for ExternalOp { - fn description(&self) -> &str { +impl ExternalOp { + /// A description of the external op. + pub fn description(&self) -> &str { match self { Self::Opaque(op) => op.description.as_str(), Self::Extension(ExtensionOp { def, .. }) => def.description(), } } - fn tag(&self) -> OpTag { - OpTag::Leaf - } - /// Note the case of an OpaqueOp without a signature should already /// have been detected in [resolve_extension_ops] - fn signature(&self) -> Option { + pub fn dataflow_signature(&self) -> FunctionType { match self { - Self::Opaque(op) => op.signature.clone(), - Self::Extension(ExtensionOp { signature, .. }) => Some(signature.clone()), + Self::Opaque(op) => op + .signature + .clone() + .expect("Op should have been serialized with signature."), + Self::Extension(ExtensionOp { signature, .. }) => signature.clone(), } } } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 1aab93b8b..e21e19adf 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -4,7 +4,7 @@ use smol_str::SmolStr; use super::custom::ExternalOp; use super::dataflow::DataflowOpTrait; -use super::{OpName, OpTag, OpTrait}; +use super::{OpName, OpTag}; use crate::extension::{ExtensionRegistry, SignatureError}; use crate::types::type_param::TypeArg; @@ -156,7 +156,7 @@ impl DataflowOpTrait for LeafOp { match self { LeafOp::Noop { ty: typ } => FunctionType::new(vec![typ.clone()], vec![typ.clone()]), - LeafOp::CustomOp(ext) => ext.signature().unwrap_or_default(), + LeafOp::CustomOp(ext) => ext.dataflow_signature(), LeafOp::MakeTuple { tys: types } => { FunctionType::new(types.clone(), vec![Type::new_tuple(types.clone())]) } From 37ceea82902526d3401e33ec224aee92c7bda4bf Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 16:18:16 +0000 Subject: [PATCH 26/31] simplify static i/o tags --- src/ops/tag.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ops/tag.rs b/src/ops/tag.rs index d1e56346d..435f279f1 100644 --- a/src/ops/tag.rs +++ b/src/ops/tag.rs @@ -125,10 +125,10 @@ impl OpTag { ], OpTag::TailLoop => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Conditional => &[OpTag::DataflowChild], - OpTag::StaticInput => &[OpTag::DataflowChild], - OpTag::StaticOutput => &[OpTag::ModuleOp], - OpTag::FnCall => &[OpTag::StaticInput], - OpTag::LoadConst => &[OpTag::StaticInput], + OpTag::StaticInput => &[OpTag::Any], + OpTag::StaticOutput => &[OpTag::Any], + OpTag::FnCall => &[OpTag::StaticInput, OpTag::DataflowChild], + OpTag::LoadConst => &[OpTag::StaticInput, OpTag::DataflowChild], OpTag::Leaf => &[OpTag::DataflowChild], OpTag::DataflowParent => &[OpTag::Any], } @@ -156,8 +156,8 @@ impl OpTag { OpTag::Cfg => "Nested control-flow operation", OpTag::TailLoop => "Tail-recursive loop", OpTag::Conditional => "Conditional operation", - OpTag::StaticInput => "Dataflow child with static input (LoadConst or FnCall)", - OpTag::StaticOutput => "Node with static input (FuncDefn, FuncDecl, Const)", + OpTag::StaticInput => "Node with static input (LoadConst or FnCall)", + OpTag::StaticOutput => "Node with static output (FuncDefn, FuncDecl, Const)", OpTag::FnCall => "Function call", OpTag::LoadConst => "Constant load operation", OpTag::Leaf => "Leaf operation", From d31794dfc76c55a7a30ed926756196bef1892d27 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 16:52:35 +0000 Subject: [PATCH 27/31] refactor: nicer code in `signature.rs` --- src/types/signature.rs | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/types/signature.rs b/src/types/signature.rs index 78a701196..6e4c6b60c 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -1,5 +1,6 @@ //! Abstract and concrete Signature types. +use itertools::Either; #[cfg(feature = "pyo3")] use pyo3::{pyclass, pymethods}; @@ -119,9 +120,9 @@ impl FunctionType { #[inline] pub fn port_type(&self, port: impl Into) -> Option<&Type> { let port: Port = port.into(); - match port.direction() { - Direction::Incoming => self.in_port_type(port.as_incoming().unwrap()), - Direction::Outgoing => self.out_port_type(port.as_outgoing().unwrap()), + match port.as_directed() { + Either::Left(port) => self.in_port_type(port), + Either::Right(port) => self.out_port_type(port), } } @@ -139,14 +140,28 @@ impl FunctionType { self.output.get(port.into()) } + /// Returns a mutable reference to the type of a value input [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn in_port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { + self.input.get_mut(port.into()) + } + + /// Returns the type of a value output [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn out_port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { + self.output.get_mut(port.into()) + } + /// Returns a mutable reference to the type of a value [`Port`]. /// Returns `None` if the port is out of bounds. #[inline] pub fn port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { let port = port.into(); - match port.direction() { - Direction::Incoming => self.input.get_mut(port), - Direction::Outgoing => self.output.get_mut(port), + match port.as_directed() { + Either::Left(port) => self.in_port_type_mut(port), + Either::Right(port) => self.out_port_type_mut(port), } } From 4636bf7b0b258b0582692c330afafb8efd13378c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 17:13:27 +0000 Subject: [PATCH 28/31] refactor!: swap `signature` and `dataflow_signature` --- src/builder/conditional.rs | 2 +- src/extension/infer.rs | 2 +- src/hugr.rs | 4 ++-- src/hugr/rewrite/outline_cfg.rs | 2 +- src/hugr/rewrite/replace.rs | 2 +- src/hugr/rewrite/simple_replace.rs | 4 ++-- src/hugr/views.rs | 2 +- src/hugr/views/sibling_subgraph.rs | 6 +++--- src/ops.rs | 8 +++++--- src/ops/controlflow.rs | 10 +++++----- src/ops/dataflow.rs | 18 +++++++++--------- src/ops/leaf.rs | 2 +- src/ops/validate.rs | 6 +++--- 13 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 78472d8cc..895bdf552 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -110,7 +110,7 @@ impl + AsRef> ConditionalBuilder { .clone() .try_into() .expect("Parent node does not have Conditional optype."); - let extension_delta = cond.dataflow_signature().extension_reqs; + let extension_delta = cond.signature().extension_reqs; let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 4ef2ac47d..7af4788b1 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1067,7 +1067,7 @@ mod test { hugr, conditional_node, op.clone(), - Into::::into(op).signature().unwrap(), + Into::::into(op).dataflow_signature().unwrap(), )?; let lift1 = hugr.add_node_with_parent( diff --git a/src/hugr.rs b/src/hugr.rs index a23703bc0..33f02dc43 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -113,7 +113,7 @@ impl NodeType { pub fn signature(&self) -> Option { self.input_extensions.as_ref().map(|rs| { self.op - .signature() + .dataflow_signature() .unwrap_or_default() .with_input_extensions(rs.clone()) }) @@ -121,7 +121,7 @@ impl NodeType { /// Get the function type from the embedded op pub fn op_signature(&self) -> Option { - self.op.signature() + self.op.dataflow_signature() } /// The input extensions defined for this node. diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 951b4cf17..3cd5ad2f5 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -71,7 +71,7 @@ impl OutlineCfg { } } } - extension_delta = extension_delta.union(&o.dataflow_signature().extension_reqs); + extension_delta = extension_delta.union(&o.signature().extension_reqs); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index c321b00d8..3d78668c8 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -616,7 +616,7 @@ mod test { pred_const: &ConstID, entry: bool, ) -> Result { - let op_sig = op.dataflow_signature(); + let op_sig = op.signature(); let mut bb = if entry { assert_eq!( match h.hugr().get_optype(h.container_node()) { diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 25e266bc1..b02dacafd 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -503,7 +503,7 @@ pub(in crate::hugr::rewrite) mod test { h.get_optype(input) .as_input() .unwrap() - .dataflow_signature() + .signature() .port_type(p) .is_some() }) @@ -518,7 +518,7 @@ pub(in crate::hugr::rewrite) mod test { h.get_optype(output) .as_output() .unwrap() - .dataflow_signature() + .signature() .port_type(p) .is_some() }) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 1026f32f5..ed0c048c0 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -418,7 +418,7 @@ pub trait HugrView: sealed::HugrInternals { /// Get the "signature" (incoming and outgoing types) of a node, non-Value /// kind ports will be missing. fn signature(&self, node: Node) -> Option { - self.get_optype(node).signature() + self.get_optype(node).dataflow_signature() } #[rustversion::since(1.75)] // uses impl in return position diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 5abb4ed94..c9b615eb0 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -349,7 +349,7 @@ impl SiblingSubgraph { let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; - if dfg_optype.signature() != Some(self.signature(hugr)) { + if dfg_optype.dataflow_signature() != Some(self.signature(hugr)) { return Err(InvalidReplacement::InvalidSignature); } @@ -581,7 +581,7 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort .get_optype(inp) .as_input() .unwrap() - .dataflow_signature() + .signature() .output_ports(); if has_other_edge(hugr, out, Direction::Incoming) { unimplemented!("Non-dataflow input not supported at output node") @@ -590,7 +590,7 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort .get_optype(out) .as_output() .unwrap() - .dataflow_signature() + .signature() .input_ports(); // Collect for each port in the input the set of target ports, filtering diff --git a/src/ops.rs b/src/ops.rs index 26aeace37..3ffdf1f2b 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -122,7 +122,7 @@ impl OpType { /// Returns the edge kind for the given port. pub fn port_kind(&self, port: impl Into) -> Option { - let signature = self.signature().unwrap_or_default(); + let signature = self.dataflow_signature().unwrap_or_default(); let port: Port = port.into(); let port_as_in = port.as_incoming().ok(); let dir = port.direction(); @@ -156,7 +156,9 @@ impl OpType { /// The number of Value ports in given direction. pub fn value_port_count(&self, dir: portgraph::Direction) -> usize { - self.signature().map(|sig| sig.port_count(dir)).unwrap_or(0) + self.dataflow_signature() + .map(|sig| sig.port_count(dir)) + .unwrap_or(0) } /// The number of Value input ports. @@ -276,7 +278,7 @@ pub trait OpTrait { /// The signature of the operation. /// /// Only dataflow operations have a signature, otherwise returns None. - fn signature(&self) -> Option { + fn dataflow_signature(&self) -> Option { None } /// The edge kind for the non-dataflow or constant inputs of the operation, diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index a54aac41d..f8bbe7e40 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -30,7 +30,7 @@ impl DataflowOpTrait for TailLoop { "A tail-controlled loop" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| tuple_sum_first(row, &self.rest)); FunctionType::new(inputs, outputs) @@ -74,7 +74,7 @@ impl DataflowOpTrait for Conditional { "HUGR conditional operation" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { let mut inputs = self.other_inputs.clone(); inputs .to_mut() @@ -109,7 +109,7 @@ impl DataflowOpTrait for CFG { "A dataflow node defined by a child CFG" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { self.signature.clone() } } @@ -169,7 +169,7 @@ impl OpTrait for BasicBlock { Some(EdgeKind::ControlFlow) } - fn signature(&self) -> Option { + fn dataflow_signature(&self) -> Option { Some(match self { BasicBlock::DFB { extension_delta, .. @@ -233,7 +233,7 @@ impl OpTrait for Case { ::TAG } - fn signature(&self) -> Option { + fn dataflow_signature(&self) -> Option { Some(self.signature.clone()) } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 777aaf2f3..bf4fd83a9 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -10,7 +10,7 @@ use crate::IncomingPort; pub(crate) trait DataflowOpTrait { const TAG: OpTag; fn description(&self) -> &str; - fn dataflow_signature(&self) -> FunctionType; + fn signature(&self) -> FunctionType; /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. @@ -82,7 +82,7 @@ impl DataflowOpTrait for Input { None } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), self.types.clone()) .with_extension_delta(&ExtensionSet::new()) } @@ -96,7 +96,7 @@ impl DataflowOpTrait for Output { // Note: We know what the input extensions should be, so we *could* give an // instantiated Signature instead - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { FunctionType::new(self.types.clone(), TypeRow::new()) } @@ -112,8 +112,8 @@ impl OpTrait for T { fn tag(&self) -> OpTag { T::TAG } - fn signature(&self) -> Option { - Some(DataflowOpTrait::dataflow_signature(self)) + fn dataflow_signature(&self) -> Option { + Some(DataflowOpTrait::signature(self)) } fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) @@ -146,7 +146,7 @@ impl DataflowOpTrait for Call { "Call a function directly" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { self.signature.clone() } } @@ -179,7 +179,7 @@ impl DataflowOpTrait for CallIndirect { "Call a function indirectly" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { let mut s = self.signature.clone(); s.input .to_mut() @@ -202,7 +202,7 @@ impl DataflowOpTrait for LoadConstant { "Load a static constant in to the local dataflow graph" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), vec![self.datatype.clone()]) } } @@ -235,7 +235,7 @@ impl DataflowOpTrait for DFG { "A simply nested dataflow graph" } - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { self.signature.clone() } } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index e21e19adf..85de1a0a8 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -150,7 +150,7 @@ impl DataflowOpTrait for LeafOp { } /// The signature of the operation. - fn dataflow_signature(&self) -> FunctionType { + fn signature(&self) -> FunctionType { // Static signatures. The `TypeRow`s in the `FunctionType` use a // copy-on-write strategy, so we can avoid unnecessary allocations. diff --git a/src/ops/validate.rs b/src/ops/validate.rs index a914f6247..ec8b75c24 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -102,7 +102,7 @@ impl ValidateOp for super::DFG { &self, children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { - let sig = self.signature().unwrap_or_default(); + let sig = self.dataflow_signature().unwrap_or_default(); validate_io_nodes(&sig.input, &sig.output, "nested graph", children) } } @@ -367,7 +367,7 @@ fn validate_io_nodes<'a>( let (first, first_optype) = children.next().unwrap(); let (second, second_optype) = children.next().unwrap(); - let first_sig = first_optype.signature().unwrap_or_default(); + let first_sig = first_optype.dataflow_signature().unwrap_or_default(); if &first_sig.output != expected_input { return Err(ChildrenValidationError::IOSignatureMismatch { child: first, @@ -377,7 +377,7 @@ fn validate_io_nodes<'a>( container_desc, }); } - let second_sig = second_optype.signature().unwrap_or_default(); + let second_sig = second_optype.dataflow_signature().unwrap_or_default(); if &second_sig.input != expected_output { return Err(ChildrenValidationError::IOSignatureMismatch { From 0a6bb8af19bcfab2d25d7400b0eecfdcdaac4725 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 17:28:58 +0000 Subject: [PATCH 29/31] refactor!: `non_df_port_count` just returns usize --- src/ops.rs | 18 ++++++++---------- src/ops/controlflow.rs | 9 ++++----- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index 3ffdf1f2b..7deb6328d 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -142,7 +142,7 @@ impl OpType { /// /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. pub fn other_port(&self, dir: Direction) -> Option { - let non_df_count = self.non_df_port_count(dir).unwrap_or(1); + let non_df_count = self.non_df_port_count(dir); if self.other_port_kind(dir).is_some() && non_df_count == 1 { // if there is a static input it comes before the non_df_ports let static_input = @@ -201,10 +201,7 @@ impl OpType { /// Returns the number of ports for the given direction. pub fn port_count(&self, dir: Direction) -> usize { - let has_other_ports = self.other_port_kind(dir).is_some(); - let non_df_count = self - .non_df_port_count(dir) - .unwrap_or(has_other_ports as usize); + let non_df_count = self.non_df_port_count(dir); // if there is a static input it comes before the non_df_ports let static_input = (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; @@ -300,11 +297,12 @@ pub trait OpTrait { } /// Get the number of non-dataflow multiports. - /// - /// If None, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - fn non_df_port_count(&self, _dir: Direction) -> Option { - None + fn non_df_port_count(&self, dir: Direction) -> usize { + match dir { + Direction::Incoming => self.other_input(), + Direction::Outgoing => self.other_output(), + } + .is_some() as usize } } diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index f8bbe7e40..d2121cd36 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -178,12 +178,11 @@ impl OpTrait for BasicBlock { }) } - fn non_df_port_count(&self, dir: Direction) -> Option { + fn non_df_port_count(&self, dir: Direction) -> usize { match self { - Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => { - Some(tuple_sum_rows.len()) - } - _ => None, + Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => tuple_sum_rows.len(), + Self::Exit { .. } if dir == Direction::Outgoing => 0, + _ => 1, } } } From f091b5afc611f17e09b386397e9a1c47af4c6eed Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 17:36:46 +0000 Subject: [PATCH 30/31] fix: put correct compile flags in --- src/hugr/views.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index ed0c048c0..63fb03ab4 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -18,7 +18,7 @@ pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Either, Itertools, MapInto}; +use itertools::{Itertools, MapInto}; use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle}; use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; @@ -29,6 +29,8 @@ use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; use crate::types::Type; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; +#[rustversion::since(1.75)] // uses impl in return position +use itertools::Either; /// A trait for inspecting HUGRs. /// For end users we intend this to be superseded by region-specific APIs. @@ -634,6 +636,7 @@ where }) } } +#[rustversion::since(1.75)] // uses impl in return position impl PortIterator

for I where I: Iterator, From 451ce23610cd9b7ef9e60569ee4d96c791138881 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 14 Nov 2023 20:32:57 +0000 Subject: [PATCH 31/31] fix doclinks --- src/builder/module.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/builder/module.rs b/src/builder/module.rs index 939df96aa..3fbce55c4 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -71,7 +71,7 @@ impl + AsRef> ModuleBuilder { /// # Errors /// /// This function will return an error if there is an error in adding the - /// [`OpType::FuncDefn`] node. + /// [`crate::ops::OpType::FuncDefn`] node. pub fn define_declaration( &mut self, f_id: &FuncID, @@ -83,7 +83,7 @@ impl + AsRef> ModuleBuilder { .as_func_decl() .ok_or(BuildError::UnexpectedType { node: f_node, - op_desc: "OpType::FuncDecl", + op_desc: "crate::ops::OpType::FuncDecl", })? .clone(); @@ -104,7 +104,7 @@ impl + AsRef> ModuleBuilder { /// # Errors /// /// This function will return an error if there is an error in adding the - /// [`OpType::FuncDecl`] node. + /// [`crate::ops::OpType::FuncDecl`] node. pub fn declare( &mut self, name: impl Into, @@ -123,11 +123,11 @@ impl + AsRef> ModuleBuilder { Ok(declare_n.into()) } - /// Add a [`OpType::AliasDefn`] node and return a handle to the Alias. + /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias. /// /// # Errors /// - /// Error in adding [`OpType::AliasDefn`] child node. + /// Error in adding [`crate::ops::OpType::AliasDefn`] child node. pub fn add_alias_def( &mut self, name: impl Into, @@ -148,10 +148,10 @@ impl + AsRef> ModuleBuilder { Ok(AliasID::new(node, name, bound)) } - /// Add a [`OpType::AliasDecl`] node and return a handle to the Alias. + /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias. /// # Errors /// - /// Error in adding [`OpType::AliasDecl`] child node. + /// Error in adding [`crate::ops::OpType::AliasDecl`] child node. pub fn add_alias_declare( &mut self, name: impl Into,