diff --git a/.github/workflows/notify-coverage.yml b/.github/workflows/notify-coverage.yml index c67b6a939..e176500a9 100644 --- a/.github/workflows/notify-coverage.yml +++ b/.github/workflows/notify-coverage.yml @@ -15,33 +15,98 @@ jobs: runs-on: ubuntu-latest outputs: msg: ${{ steps.make_msg.outputs.msg }} + should_notify: ${{ steps.get_coverage.outputs.should_notify }} steps: + - name: Download commit sha of the most recent successful run + uses: dawidd6/action-download-artifact@v2 + with: + # Downloads the artifact from the most recent successful run + workflow: 'notify-coverage.yml' + name: head-sha.txt + if_no_artifact_found: ignore - name: Get today's and yesterday's coverage trends from codecov + id: get_coverage # API reference: https://docs.codecov.com/reference/repos_totals_retrieve run: | - YESTERDAY=$( date -u +%Y-%m-%dT%H:%M:%SZ -d 'yesterday' ) + # Get the previous commit coverage, if the last sha is available + if [ ! -f head-sha.txt ] + then + echo "No previous coverage found." + # Update the head-sha.txt file with the current sha, + # so next time we campare against the current coverage. + echo ${{ github.sha }} > head-sha.txt + + echo "should_notify=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + + PREV_SHA=$( cat head-sha.txt ) + echo "Previous sha: \"$PREV_SHA\"" + + # Check if the sha has changed + if [ "$PREV_SHA" == "${{ github.sha }}" ] + then + echo "No new commits since last run." + echo "should_notify=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + + # Query the previous coverage from codecov curl --request GET \ - --url "https://api.codecov.io/api/v2/github/${{ github.repository_owner }}/repos/${{ github.event.repository.name }}/coverage/?interval=1d&start_date=$YESTERDAY" \ + --url "https://api.codecov.io/api/v2/github/${{ github.repository_owner }}/repos/${{ github.event.repository.name }}/totals/?sha=$PREV_SHA" \ --header 'accept: application/json' \ --header "authorization: Bearer ${{ secrets.CODECOV_GET_TOKEN }}" \ - > coverage.json - - echo "Coverage JSON:" - cat coverage.json + > coverage-prev.json + cat coverage-prev.json | jq ".totals.coverage" > coverage-prev.txt + echo "Previous coverage query result:" + cat coverage-prev.json | jq "del(.files)" echo - cat coverage.json | jq ".results[0].max" > coverage-prev.txt - cat coverage.json | jq ".results[-1].max" > coverage.txt + # Query the current coverage from codecov + curl --request GET \ + --url "https://api.codecov.io/api/v2/github/${{ github.repository_owner }}/repos/${{ github.event.repository.name }}/totals/?sha=${{ github.sha }}" \ + --header 'accept: application/json' \ + --header "authorization: Bearer ${{ secrets.CODECOV_GET_TOKEN }}" \ + > coverage.json + cat coverage.json | jq ".totals.coverage" > coverage.txt + echo "Current coverage query result:" + cat coverage.json | jq "del(.files)" + echo + echo echo "Previous coverage: `cat coverage-prev.txt`%" echo "Current coverage: `cat coverage.txt`%" + + # A `null` in either coverage means that the coverage is not available, + # so we don't want to notify about that. + if [ "$( cat coverage-prev.txt )" == "null" ] + then + echo "Previous coverage not available." + echo ${{ github.sha }} > head-sha.txt + echo "should_notify=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + if [ "$( cat coverage.txt )" == "null" ] + then + echo "Current coverage not available." + # Note that we don't update the head-sha.txt file here, + # so next time we compare against the one that had coverage data. + echo "should_notify=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + + echo ${{ github.sha }} > head-sha.txt + echo "should_notify=true" >> "$GITHUB_OUTPUT" - name: Compare with previous summary and make message id: make_msg + if: steps.get_coverage.outputs.should_notify == 'true' run: | - change="`cat coverage-prev.txt`% --> `cat coverage.txt`%" + prev=`cat coverage-prev.txt` + current=`cat coverage-prev.txt` + change=`printf "%.2f%% --> %.2f%%" $prev $current` codecov="https://codecov.io/gh/${{ github.repository }}?search=&trend=7%20days" - if (( $(echo "`cat coverage-prev.txt` < `cat coverage.txt` + 0.04" | bc -l) )) + if (( $(echo "$prev < $current + 0.04" | bc -l) )) then MSG="msg=Coverage check for hugr shows no regression (${change}). ✅ ${codecov}" else @@ -49,9 +114,16 @@ jobs: fi echo $MSG echo $MSG >> "$GITHUB_OUTPUT" + - name: Upload current HEAD sha + uses: actions/upload-artifact@v3 + with: + name: head-sha.txt + path: head-sha.txt + notify-slack: needs: check-coverage runs-on: ubuntu-latest + if: needs.check-coverage.outputs.should_notify == 'true' steps: - name: Send notification uses: slackapi/slack-github-action@v1.24.0 @@ -60,4 +132,3 @@ jobs: slack-message: ${{ needs.check-coverage.outputs.msg }} env: SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} - diff --git a/Cargo.toml b/Cargo.toml index 53a0f81f2..9e2fd8af7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,8 @@ petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" delegate = "0.10.0" +rustversion = "1.0.14" +paste = "1.0" [features] pyo3 = ["dep:pyo3"] @@ -61,7 +63,6 @@ rmp-serde = "1.1.1" webbrowser = "0.8.10" urlencoding = "2.1.2" cool_asserts = "2.0.3" -paste = "1.0" insta = { version = "1.34.0", features = ["yaml"] } [[bench]] @@ -71,4 +72,4 @@ harness = false [profile.dev.package] insta.opt-level = 3 -similar.opt-level = 3 \ No newline at end of file +similar.opt-level = 3 diff --git a/specification/hugr.md b/specification/hugr.md index a72560ddb..93e9ceb53 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -1719,72 +1719,6 @@ Conversions between integers and floats: | `convert_u` | `int` | `float64` | unsigned int to float | | `convert_s` | `int` | `float64` | signed int to float | -### Quantum Extension - -This is the extension that is designed to be natively understood by -TKET2. Besides a range of quantum operations (like Hadamard, CX, etc.) -that take and return `Qubit`, we note the following operations for -allocating/deallocating `Qubit`s: - -``` -qalloc: () -> Qubit -qfree: Qubit -> () -``` - -`qalloc` allocates a fresh, 0 state Qubit - if none is available at -runtime it panics. `qfree` loses a handle to a Qubit (may be reallocated -in future). The point at which an allocated qubit is reset may be -target/compiler specific. - -Note there are also `measurez: Qubit -> (i1, Qubit)` and on supported -targets `reset: Qubit -> Qubit` operations to measure or reset a qubit -without losing a handle to it. - -#### Dynamic vs static allocation - -With these operations the programmer/front-end can request dynamic qubit -allocation, and the compiler can add/remove/move these operations to use -more or fewer qubits. In some use cases, that may not be desirable, and -we may instead want to guarantee only a certain number of qubits are -used by the program. For this purpose TKET2 places additional -constraints on the HUGR that are in line with TKET1 backwards -compatibility: - -1. The `main` function takes one `Array` -input and has one output of the same type (the same statically known -size). -2. All Operations that have a `FunctionType` involving `Qubit` have as - many `Qubit` input wires as output. - - -With these constraints, we can treat all `Qubit` operations as returning all qubits they take -in. The implicit bijection from input `Qubit` to output allows register -allocation for all `Qubit` wires. -If further the program does not contain any `qalloc` or `qfree` -operations we can state the program only uses `N` qubits. - -#### Angles - -The Quantum extension also defines a specialized `angle` type which is used -to express parameters of rotation gates. The type is parametrized by the -_log-denominator_, which is an integer $N \in [0, 53]$; angles with -log-denominator $N$ are multiples of $2 \pi / 2^N$, where the multiplier is an -unsigned `int` in the range $[0, 2^N]$. The maximum log-denominator $53$ -effectively gives the resolution of a `float64` value; but note that unlike -`float64` all angle values are equatable and hashable; and two `angle` that -differ by a multiple of $2 \pi$ are _equal_. - -The following operations are defined: - -| Name | Inputs | Outputs | Meaning | -| -------------- | ---------- | ---------- | ------- | -| `aconst` | none | `angle` | const node producing angle $2 \pi x / 2^N$ (where $0 \leq x \lt 2^N$) | -| `atrunc` | `angle` | `angle` | round `angle` to `angle`, where $M \geq N$, rounding down in $[0, 2\pi)$ if necessary | -| `aconvert` | `angle` | `Sum(angle, ErrorType)` | convert `angle` to `angle`, returning an error if $M \gt N$ and exact conversion is impossible | -| `aadd` | `angle`, `angle` | `angle` | add two angles | -| `asub` | `angle`, `angle` | `angle` | subtract the second angle from the first | -| `aneg` | `angle` | `angle` | negate an angle | - ### Higher-order (Tierkreis) Extension In **some** contexts, notably the Tierkreis runtime, higher-order diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index dfdc86488..e1435fffe 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -236,7 +236,7 @@ pub trait Dataflow: Container { hugr: Hugr, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); + let num_outputs = hugr.get_optype(hugr.root()).value_output_count(); let node = self.add_hugr(hugr)?.new_root; let inputs = input_wires.into_iter().collect(); @@ -257,8 +257,8 @@ pub trait Dataflow: Container { hugr: &impl HugrView, input_wires: impl IntoIterator, ) -> Result, BuildError> { - let num_outputs = hugr.get_optype(hugr.root()).signature().output_count(); let node = self.add_hugr_view(hugr)?.new_root; + let num_outputs = hugr.get_optype(hugr.root()).value_output_count(); let inputs = input_wires.into_iter().collect(); wire_up_inputs(inputs, node, self)?; @@ -612,8 +612,9 @@ pub trait Dataflow: Container { }) } }; - let const_in_port = signature.output.len(); - let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?; + let op: OpType = ops::Call { signature }.into(); + let const_in_port = op.static_input_port().unwrap(); + let op_id = self.add_dataflow_op(op, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; self.hugr_mut() @@ -633,13 +634,13 @@ fn add_node_with_wires( nodetype: impl Into, inputs: Vec, ) -> Result<(Node, usize), BuildError> { - let nodetype = nodetype.into(); - let sig = nodetype.op_signature(); + let nodetype: NodeType = nodetype.into(); + let num_outputs = nodetype.op().value_output_count(); let op_node = data_builder.add_child_node(nodetype)?; wire_up_inputs(inputs, op_node, data_builder)?; - Ok((op_node, sig.output().len())) + Ok((op_node, num_outputs)) } fn wire_up_inputs( diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index e6219f46a..895bdf552 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -1,9 +1,10 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::HugrView; +use crate::ops::dataflow::DataflowOpTrait; use crate::types::{FunctionType, TypeRow}; +use crate::ops; use crate::ops::handle::CaseID; -use crate::ops::{self, OpTrait}; use super::build_traits::SubContainer; use super::handle::BuildHandle; @@ -104,12 +105,12 @@ impl + AsRef> ConditionalBuilder { pub fn case_builder(&mut self, case: usize) -> Result, BuildError> { let conditional = self.conditional_node; let control_op = self.hugr().get_optype(self.conditional_node); - let extension_delta = control_op.signature().extension_reqs; let cond: ops::Conditional = control_op .clone() .try_into() .expect("Parent node does not have Conditional optype."); + let extension_delta = cond.signature().extension_reqs; let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; diff --git a/src/builder/module.rs b/src/builder/module.rs index a78c047d7..3fbce55c4 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -12,7 +12,6 @@ use crate::{ }; use crate::ops::handle::{AliasID, FuncID, NodeHandle}; -use crate::ops::OpType; use crate::types::Signature; @@ -72,22 +71,22 @@ impl + AsRef> ModuleBuilder { /// # Errors /// /// This function will return an error if there is an error in adding the - /// [`OpType::FuncDefn`] node. + /// [`crate::ops::OpType::FuncDefn`] node. pub fn define_declaration( &mut self, f_id: &FuncID, ) -> Result, BuildError> { let f_node = f_id.node(); - let (signature, name) = if let OpType::FuncDecl(ops::FuncDecl { signature, name }) = - self.hugr().get_optype(f_node) - { - (signature.clone(), name.clone()) - } else { - return Err(BuildError::UnexpectedType { + let ops::FuncDecl { signature, name } = self + .hugr() + .get_optype(f_node) + .as_func_decl() + .ok_or(BuildError::UnexpectedType { node: f_node, - op_desc: "OpType::FuncDecl", - }); - }; + op_desc: "crate::ops::OpType::FuncDecl", + })? + .clone(); + self.hugr_mut().replace_op( f_node, NodeType::new_pure(ops::FuncDefn { @@ -105,7 +104,7 @@ impl + AsRef> ModuleBuilder { /// # Errors /// /// This function will return an error if there is an error in adding the - /// [`OpType::FuncDecl`] node. + /// [`crate::ops::OpType::FuncDecl`] node. pub fn declare( &mut self, name: impl Into, @@ -124,11 +123,11 @@ impl + AsRef> ModuleBuilder { Ok(declare_n.into()) } - /// Add a [`OpType::AliasDefn`] node and return a handle to the Alias. + /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias. /// /// # Errors /// - /// Error in adding [`OpType::AliasDefn`] child node. + /// Error in adding [`crate::ops::OpType::AliasDefn`] child node. pub fn add_alias_def( &mut self, name: impl Into, @@ -149,10 +148,10 @@ impl + AsRef> ModuleBuilder { Ok(AliasID::new(node, name, bound)) } - /// Add a [`OpType::AliasDecl`] node and return a handle to the Alias. + /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias. /// # Errors /// - /// Error in adding [`OpType::AliasDecl`] child node. + /// Error in adding [`crate::ops::OpType::AliasDecl`] child node. pub fn add_alias_declare( &mut self, name: impl Into, @@ -233,14 +232,14 @@ mod test { let mut f_build = module_builder.define_function( "main", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(), )?; let local_build = f_build.define_function( "local", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(), )?; let [wire] = local_build.input_wires_arr(); - let f_id = local_build.finish_with_outputs([wire])?; + let f_id = local_build.finish_with_outputs([wire, wire])?; let call = f_build.call(f_id.handle(), f_build.input_wires())?; diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 5eb8286e1..8ea5113ae 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -1,4 +1,4 @@ -use crate::ops::{self, OpType}; +use crate::ops; use crate::hugr::{views::HugrView, NodeType}; use crate::types::{FunctionType, TypeRow}; @@ -38,14 +38,13 @@ impl + AsRef> TailLoopBuilder { /// Get a reference to the [`ops::TailLoop`] /// that defines the signature of the [`ops::TailLoop`] pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> { - if let OpType::TailLoop(tail_loop) = self.hugr().get_optype(self.container_node()) { - Ok(tail_loop) - } else { - Err(BuildError::UnexpectedType { + self.hugr() + .get_optype(self.container_node()) + .as_tail_loop() + .ok_or(BuildError::UnexpectedType { node: self.container_node(), op_desc: "crate::ops::TailLoop", }) - } } /// The output types of the child graph, including the TupleSum as the first. diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 8b60c0ac8..7a1b03cfb 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -313,11 +313,15 @@ impl UnificationContext { match node_type.signature() { // Input extensions are open None => { - let delta = node_type.op_signature().extension_reqs; - let c = if delta.is_empty() { - Constraint::Equal(m_input) + let c = if let Some(sig) = node_type.op_signature() { + let delta = sig.extension_reqs; + if delta.is_empty() { + Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) + } } else { - Constraint::Plus(delta, m_input) + Constraint::Equal(m_input) }; self.add_constraint(m_output, c); } diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 9fd550b8e..8f51e6236 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -334,7 +334,7 @@ fn test_conditional_inference() -> Result<(), Box> { hugr, conditional_node, op.clone(), - Into::::into(op).signature(), + Into::::into(op).dataflow_signature().unwrap(), )?; let lift1 = hugr.add_node_with_parent( @@ -885,7 +885,7 @@ fn simple_cfg_loop() -> Result<(), Box> { fn plus_on_self() -> Result<(), Box> { let ext = ExtensionId::new("unknown1").unwrap(); let delta = ExtensionSet::singleton(&ext); - let ft = FunctionType::new_linear(type_row![QB_T, QB_T]).with_extension_delta(&delta); + let ft = FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(&delta); let mut dfg = DFGBuilder::new(ft.clone())?; // While https://github.com/CQCL-DEV/hugr/issues/388 is unsolved, @@ -899,7 +899,7 @@ fn plus_on_self() -> Result<(), Box> { Some(ft), )) .into(); - let unary_sig = FunctionType::new_linear(type_row![QB_T]) + let unary_sig = FunctionType::new_endo(type_row![QB_T]) .with_extension_delta(&ExtensionSet::singleton(&ext)); let unop: LeafOp = ExternalOp::Opaque(OpaqueOp::new( ext, diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 60916914a..4774d5c62 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -198,8 +198,8 @@ impl OpDef { pub(crate) fn should_serialize_signature(&self) -> bool { match self.signature_func { - SignatureFunc::TypeScheme { .. } => true, - SignatureFunc::CustomFunc { .. } => false, + SignatureFunc::TypeScheme { .. } => false, + SignatureFunc::CustomFunc { .. } => true, } } @@ -379,14 +379,14 @@ mod test { const OP_NAME: SmolStr = SmolStr::new_inline("Reverse"); let type_scheme = PolyFuncType::new_validated( vec![TP], - FunctionType::new_linear(vec![list_of_var]), + FunctionType::new_endo(vec![list_of_var]), ®1, )?; e.add_op_type_scheme(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?; let list_usize = Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?); - let mut dfg = DFGBuilder::new(FunctionType::new_linear(vec![list_usize]))?; + let mut dfg = DFGBuilder::new(FunctionType::new_endo(vec![list_usize]))?; let rev = dfg.add_dataflow_op( LeafOp::from(ExternalOp::Extension( e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], ®1) diff --git a/src/extension/validate.rs b/src/extension/validate.rs index c026c2519..2f267edf9 100644 --- a/src/extension/validate.rs +++ b/src/extension/validate.rs @@ -25,8 +25,13 @@ impl ExtensionValidator { pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self { let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new(); for (node, incoming_sol) in closure.into_iter() { - let op_signature = hugr.get_nodetype(node).op_signature(); - let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol); + let extension_reqs = hugr + .get_nodetype(node) + .op_signature() + .map(|s| s.extension_reqs) + .unwrap_or_default(); + + let outgoing_sol = extension_reqs.union(&incoming_sol); extensions.insert((node, Direction::Incoming), incoming_sol); extensions.insert((node, Direction::Outgoing), outgoing_sol); diff --git a/src/hugr.rs b/src/hugr.rs index 97bf44ed6..33f02dc43 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -111,14 +111,17 @@ impl NodeType { /// Use the input extensions to calculate the concrete signature of the node pub fn signature(&self) -> Option { - self.input_extensions - .as_ref() - .map(|rs| self.op.signature().with_input_extensions(rs.clone())) + self.input_extensions.as_ref().map(|rs| { + self.op + .dataflow_signature() + .unwrap_or_default() + .with_input_extensions(rs.clone()) + }) } /// Get the function type from the embedded op - pub fn op_signature(&self) -> FunctionType { - self.op.signature() + pub fn op_signature(&self) -> Option { + self.op.dataflow_signature() } /// The input extensions defined for this node. diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 00bef6d18..f549929d0 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -314,12 +314,12 @@ impl + AsMut> HugrMut for T { ) -> Result<(OutgoingPort, IncomingPort), HugrError> { let src_port = self .get_optype(src) - .other_port_index(Direction::Outgoing) + .other_output_port() .expect("Source operation has no non-dataflow outgoing edges") .as_outgoing()?; let dst_port = self .get_optype(dst) - .other_port_index(Direction::Incoming) + .other_input_port() .expect("Destination operation has no non-dataflow incoming edges") .as_incoming()?; self.connect(src, src_port, dst, dst_port)?; diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index d5613a59f..e3fd29318 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -9,7 +9,6 @@ use crate::{HugrView, IncomingPort}; use super::Rewrite; -use itertools::Itertools; use thiserror::Error; /// Specification of a identity-insertion operation. @@ -73,9 +72,7 @@ impl Rewrite for IdentityInsertion { }; let (pre_node, pre_port) = h - .linked_outputs(self.post_node, self.post_port) - .exactly_one() - .ok() + .single_linked_output(self.post_node, self.post_port) .expect("Value kind input can only have one connection."); h.disconnect(self.post_node, self.post_port).unwrap(); diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 179ea486d..3cd5ad2f5 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -12,8 +12,9 @@ use crate::hugr::rewrite::Rewrite; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; +use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; -use crate::ops::{BasicBlock, OpTrait, OpType}; +use crate::ops::{BasicBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; @@ -49,7 +50,7 @@ impl OutlineCfg { _ => return Err(OutlineCfgError::NotSiblings), }; let o = h.get_optype(cfg_n); - if !matches!(o, OpType::CFG(_)) { + let OpType::CFG(o) = o else { return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; let cfg_entry = h.children(cfg_n).next().unwrap(); @@ -177,7 +178,7 @@ impl Rewrite for OutlineCfg { let exit_port = h .node_outputs(exit) .filter(|p| { - let (t, p2) = h.linked_ports(exit, *p).exactly_one().ok().unwrap(); + let (t, p2) = h.single_linked_input(exit, *p).unwrap(); assert!(p2.index() == 0); t == outside }) diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 0b01df5aa..16eeb9321 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -139,9 +139,8 @@ impl NewEdgeSpec { true }; let found_incoming = h - .linked_ports(self.tgt, tgt_pos) - .exactly_one() - .is_ok_and(|(src_n, _)| descends_from_legal(src_n)); + .single_linked_output(self.tgt, tgt_pos) + .is_some_and(|(src_n, _)| descends_from_legal(src_n)); if !found_incoming { return Err(ReplaceError::NoRemovedEdge(err_edge())); }; @@ -448,8 +447,9 @@ mod test { use crate::hugr::rewrite::replace::WhichHugr; use crate::hugr::{HugrMut, NodeType, Rewrite}; use crate::ops::custom::{ExternalOp, OpaqueOp}; + use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; - use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpTrait, OpType, DFG}; + use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpType, DFG}; use crate::std_extensions::collections; use crate::types::{FunctionType, Type, TypeArg, TypeRow}; use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; @@ -479,7 +479,7 @@ mod test { let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]); let mut cfg = CFGBuilder::new( - FunctionType::new_linear(just_list.clone()).with_extension_delta(&exset), + FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset), )?; let pred_const = cfg.add_constant(ops::Const::unary_unit_sum(), None)?; @@ -516,7 +516,7 @@ mod test { // Replacement: one BB with two DFGs inside. // Use Hugr rather than Builder because DFGs must be empty (not even Input/Output). let mut replacement = Hugr::new(NodeType::new_open(ops::CFG { - signature: FunctionType::new_linear(just_list.clone()), + signature: FunctionType::new_endo(just_list.clone()), })); let r_bb = replacement.add_node_with_parent( replacement.root(), @@ -588,8 +588,8 @@ mod test { let popp = h.get_parent(pop).unwrap(); let pushp = h.get_parent(push).unwrap(); assert_ne!(popp, pushp); // Two different DFGs - assert!(matches!(h.get_optype(popp), OpType::DFG(_))); - assert!(matches!(h.get_optype(pushp), OpType::DFG(_))); + assert!(h.get_optype(popp).is_dfg()); + assert!(h.get_optype(pushp).is_dfg()); let grandp = h.get_parent(popp).unwrap(); assert_eq!(grandp, h.get_parent(pushp).unwrap()); @@ -610,13 +610,12 @@ mod test { .unwrap() } - fn single_node_block + AsMut>( + fn single_node_block + AsMut, O: DataflowOpTrait + Into>( h: &mut CFGBuilder, - op: impl Into, + op: O, pred_const: &ConstID, entry: bool, ) -> Result { - let op: OpType = op.into(); let op_sig = op.signature(); let mut bb = if entry { assert_eq!( @@ -630,7 +629,7 @@ mod test { } else { h.simple_block_builder(op_sig, 1)? }; - + let op: OpType = op.into(); let op = bb.add_dataflow_op(op, bb.input_wires())?; let load_pred = bb.load_const(pred_const)?; bb.finish_with_outputs(load_pred, op.outputs()) @@ -644,7 +643,7 @@ mod test { #[test] fn test_invalid() -> Result<(), Box> { - let utou = FunctionType::new_linear(vec![USIZE_T]); + let utou = FunctionType::new_endo(vec![USIZE_T]); let mk_op = |s| { LeafOp::from(ExternalOp::Opaque(OpaqueOp::new( ExtensionId::new("unknown_ext").unwrap(), diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 4a376604d..b02dacafd 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -4,8 +4,6 @@ use std::collections::{hash_map, HashMap}; use std::iter::{self, Copied}; use std::slice; -use itertools::Itertools; - use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; @@ -130,9 +128,7 @@ impl Rewrite for SimpleReplacement { if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) let (rem_inp_pred_node, rem_inp_pred_port) = h - .linked_outputs(*rem_inp_node, *rem_inp_port) - .exactly_one() - .ok() // PortLinks does not implement Debug + .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); h.disconnect(*rem_inp_node, *rem_inp_port).unwrap(); let new_inp_node = index_map.get(rep_inp_node).unwrap(); @@ -150,8 +146,7 @@ impl Rewrite for SimpleReplacement { for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { let (rep_out_pred_node, rep_out_pred_port) = self .replacement - .linked_outputs(replacement_output_node, *rep_out_port) - .exactly_one() + .single_linked_output(replacement_output_node, *rep_out_port) .unwrap(); if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input { let new_out_node = index_map.get(&rep_out_pred_node).unwrap(); @@ -172,9 +167,7 @@ impl Rewrite for SimpleReplacement { if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): let (rem_inp_pred_node, rem_inp_pred_port) = h - .linked_outputs(*rem_inp_node, *rem_inp_port) - .exactly_one() - .ok() // PortLinks does not implement Debug + .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); h.disconnect(*rem_inp_node, *rem_inp_port).unwrap(); h.disconnect(*rem_out_node, *rem_out_port).unwrap(); @@ -231,6 +224,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; + use crate::ops::dataflow::DataflowOpTrait; use crate::ops::OpTag; use crate::ops::{OpTrait, OpType}; use crate::std_extensions::logic::test::and_op; @@ -505,7 +499,14 @@ pub(in crate::hugr::rewrite) mod test { .collect_vec(); let inputs = h .node_outputs(input) - .filter(|&p| h.get_optype(input).signature().get(p).is_some()) + .filter(|&p| { + h.get_optype(input) + .as_input() + .unwrap() + .signature() + .port_type(p) + .is_some() + }) .map(|p| { let link = h.linked_inputs(input, p).next().unwrap(); (link, link) @@ -513,7 +514,14 @@ pub(in crate::hugr::rewrite) mod test { .collect(); let outputs = h .node_inputs(output) - .filter(|&p| h.get_optype(output).signature().get(p).is_some()) + .filter(|&p| { + h.get_optype(output) + .as_output() + .unwrap() + .signature() + .port_type(p) + .is_some() + }) .map(|p| ((output, p), p)) .collect(); h.apply_rewrite(SimpleReplacement::new( @@ -565,7 +573,7 @@ pub(in crate::hugr::rewrite) mod test { let outputs = repl .node_inputs(repl_output) - .filter(|&p| repl.get_optype(repl_output).signature().get(p).is_some()) + .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some()) .map(|p| ((repl_output, p), p)) .collect(); @@ -598,7 +606,7 @@ pub(in crate::hugr::rewrite) mod test { if *tgt == out { unimplemented!() }; - let (src, src_port) = h.linked_outputs(*r_n, *r_p).exactly_one().ok().unwrap(); + let (src, src_port) = h.single_linked_output(*r_n, *r_p).unwrap(); NewEdgeSpec { src, tgt: *tgt, @@ -613,11 +621,7 @@ pub(in crate::hugr::rewrite) mod test { .nu_out .iter() .map(|((tgt, tgt_port), out_port)| { - let (src, src_port) = replacement - .linked_outputs(out, *out_port) - .exactly_one() - .ok() - .unwrap(); + let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap(); if src == in_ { unimplemented!() }; diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index a96d1abca..b1f99a4a2 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -10,7 +10,6 @@ use pyo3::{create_exception, exceptions::PyException, PyErr}; use crate::core::NodeIndex; use crate::extension::ExtensionSet; use crate::hugr::{Hugr, NodeType}; -use crate::ops::OpTrait; use crate::ops::OpType; use crate::{Node, PortIndex}; use portgraph::hierarchy::AttachError; @@ -167,7 +166,7 @@ impl TryFrom<&Hugr> for SerHugrV0 { .expect("Could not reach one of the nodes"); let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| { - let sig = hugr.get_optype(node).signature(); + let sig = hugr.signature(node).unwrap_or_default(); let offset = match offset < sig.port_count(dir) { true => Some(offset as u16), false => None, @@ -246,7 +245,7 @@ impl TryFrom for Hugr { None => { let op_type = hugr.get_optype(node); op_type - .other_port_index(dir) + .other_port(dir) .ok_or(HUGRSerializationError::MissingPortOffset { node, op_type: op_type.clone(), diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 0c6594a95..edc80c955 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -139,7 +139,7 @@ fn leaf_root() { #[test] fn dfg_root() { let dfg_op: OpType = ops::DFG { - signature: FunctionType::new_linear(type_row![BOOL_T]), + signature: FunctionType::new_endo(type_row![BOOL_T]), } .into(); @@ -366,7 +366,7 @@ fn test_ext_edge() -> Result<(), HugrError> { let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: FunctionType::new_linear(type_row![BOOL_T]), + signature: FunctionType::new_endo(type_row![BOOL_T]), }, )?; // this Xor has its 2nd input unconnected @@ -428,8 +428,10 @@ fn test_local_const() -> Result<(), HugrError> { // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op)?; let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: BOOL_T })?; + h.connect(cst, 0, lcst, 0)?; h.connect(lcst, 0, and, 1)?; + assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: h.update_validate(&EMPTY_REG).unwrap(); Ok(()) diff --git a/src/hugr/views.rs b/src/hugr/views.rs index eafc68986..63fb03ab4 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -25,8 +25,12 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE}; use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; +#[rustversion::since(1.75)] // uses impl in return position +use crate::types::Type; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; +#[rustversion::since(1.75)] // uses impl in return position +use itertools::Either; /// A trait for inspecting HUGRs. /// For end users we intend this to be superseded by region-specific APIs. @@ -179,6 +183,71 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over the nodes and ports connected to a port. fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all the nodes and ports connected to a node in a given direction. + fn all_linked_ports( + &self, + node: Node, + dir: Direction, + ) -> Either< + impl Iterator, + impl Iterator, + > { + match dir { + Direction::Incoming => Either::Left( + self.node_inputs(node) + .flat_map(move |port| self.linked_outputs(node, port)), + ), + Direction::Outgoing => Either::Right( + self.node_outputs(node) + .flat_map(move |port| self.linked_inputs(node, port)), + ), + } + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all the nodes and ports connected to a node's inputs. + fn all_linked_outputs(&self, node: Node) -> impl Iterator { + self.all_linked_ports(node, Direction::Incoming) + .left() + .unwrap() + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all the nodes and ports connected to a node's outputs. + fn all_linked_inputs(&self, node: Node) -> impl Iterator { + self.all_linked_ports(node, Direction::Outgoing) + .right() + .unwrap() + } + + /// If there is exactly one port connected to this port, return + /// it and its node. + fn single_linked_port(&self, node: Node, port: impl Into) -> Option<(Node, Port)> { + self.linked_ports(node, port).exactly_one().ok() + } + + /// If there is exactly one OutgoingPort connected to this IncomingPort, return + /// it and its node. + fn single_linked_output( + &self, + node: Node, + port: impl Into, + ) -> Option<(Node, OutgoingPort)> { + self.single_linked_port(node, port.into()) + .map(|(n, p)| (n, p.as_outgoing().unwrap())) + } + + /// If there is exactly one IncomingPort connected to this OutgoingPort, return + /// it and its node. + fn single_linked_input( + &self, + node: Node, + port: impl Into, + ) -> Option<(Node, IncomingPort)> { + self.single_linked_port(node, port.into()) + .map(|(n, p)| (n, p.as_incoming().unwrap())) + } /// Iterator over the nodes and output ports connected to a given *input* port. /// Like [`linked_ports`][HugrView::linked_ports] but preserves knowledge /// that the linked ports are [OutgoingPort]s. @@ -334,6 +403,50 @@ pub trait HugrView: sealed::HugrInternals { }) .finish() } + + /// If a node has a static input, return the source node. + fn static_source(&self, node: Node) -> Option { + self.linked_outputs(node, self.get_optype(node).static_input_port()?) + .next() + .map(|(n, _)| n) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// If a node has a static output, return the targets. + fn static_targets(&self, node: Node) -> Option> { + Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?)) + } + + /// Get the "signature" (incoming and outgoing types) of a node, non-Value + /// kind ports will be missing. + fn signature(&self, node: Node) -> Option { + self.get_optype(node).dataflow_signature() + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all outgoing ports that have Value type, along + /// with corresponding types. + fn value_types(&self, node: Node, dir: Direction) -> impl Iterator { + let sig = self.signature(node).unwrap_or_default(); + self.node_ports(node, dir) + .flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone()))) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all incoming ports that have Value type, along + /// with corresponding types. + fn in_value_types(&self, node: Node) -> impl Iterator { + self.value_types(node, Direction::Incoming) + .map(|(p, t)| (p.as_incoming().unwrap(), t)) + } + + #[rustversion::since(1.75)] // uses impl in return position + /// Iterator over all incoming ports that have Value type, along + /// with corresponding types. + fn out_value_types(&self, node: Node) -> impl Iterator { + self.value_types(node, Direction::Outgoing) + .map(|(p, t)| (p.as_outgoing().unwrap(), t)) + } } /// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s @@ -505,6 +618,32 @@ impl> HugrView for T { } } +#[rustversion::since(1.75)] // uses impl in return position +/// Trait implementing methods on port iterators. +pub trait PortIterator

: Iterator +where + P: Into + Copy, + Self: Sized, +{ + /// Filter an iterator of node-ports to only dataflow dependency specifying + /// ports (Value and StateOrder) + fn dataflow_ports_only(self, hugr: &impl HugrView) -> impl Iterator { + self.filter(move |(n, p)| { + matches!( + hugr.get_optype(*n).port_kind(*p), + Some(EdgeKind::Value(_) | EdgeKind::StateOrder) + ) + }) + } +} +#[rustversion::since(1.75)] // uses impl in return position +impl PortIterator

for I +where + I: Iterator, + P: Into + Copy, +{ +} + pub(crate) mod sealed { use super::*; diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 77f74f39b..7d54469f4 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -472,7 +472,7 @@ mod test { fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { let root = simple_dfg_hugr.root(); let case_nodetype = NodeType::new_open(crate::ops::Case { - signature: simple_dfg_hugr.root_type().op_signature(), + signature: simple_dfg_hugr.root_type().op_signature().unwrap(), }); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); // As expected, we cannot replace the root with a Case diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 515392615..592446a54 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -19,6 +19,7 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::extension::ExtensionSet; use crate::hugr::{HugrError, HugrMut, HugrView, RootTagged}; +use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; use crate::ops::{OpTag, OpTrait}; use crate::types::Signature; @@ -251,7 +252,7 @@ impl SiblingSubgraph { if !hugr.is_linked(n, p) { return false; } - let (out_n, _) = hugr.linked_ports(n, p).exactly_one().ok().unwrap(); + let (out_n, _) = hugr.single_linked_output(n, p).unwrap(); !nodes_set.contains(&out_n) }) // Every incoming edge is its own input. @@ -298,16 +299,16 @@ impl SiblingSubgraph { .iter() .map(|part| { let &(n, p) = part.iter().next().expect("is non-empty"); - let sig = hugr.get_optype(n).signature(); - sig.get(p).cloned().expect("must be dataflow edge") + let sig = hugr.signature(n).expect("must have dataflow signature"); + sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); let output = self .outputs .iter() .map(|&(n, p)| { - let sig = hugr.get_optype(n).signature(); - sig.get(p).cloned().expect("must be dataflow edge") + let sig = hugr.signature(n).expect("must have dataflow signature"); + sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); FunctionType::new(input, output) @@ -348,7 +349,7 @@ impl SiblingSubgraph { let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; - if dfg_optype.signature() != self.signature(hugr) { + if dfg_optype.dataflow_signature() != Some(self.signature(hugr)) { return Err(InvalidReplacement::InvalidSignature); } @@ -356,10 +357,16 @@ impl SiblingSubgraph { // See https://github.com/CQCL-DEV/hugr/discussions/432 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p)); let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p)); - let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = - rep_inputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); - let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = - rep_outputs.partition(|&(n, p)| replacement.get_optype(n).signature().get(p).is_some()); + let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs.partition(|&(n, p)| { + replacement + .signature(n) + .is_some_and(|s| s.port_type(p).is_some()) + }); + let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| { + replacement + .signature(n) + .is_some_and(|s| s.port_type(p).is_some()) + }); if combine_in_out(&vec![out_order_ports], &in_order_ports) .any(|(n, p)| is_order_edge(&replacement, n, p)) @@ -467,10 +474,13 @@ impl<'g, Base: HugrView> ConvexChecker<'g, Base> { /// If the array is empty or a port does not exist, returns `None`. fn get_edge_type + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option { let &(n, p) = ports.first()?; - let edge_t = hugr.get_optype(n).signature().get(p)?.clone(); + let edge_t = hugr.signature(n)?.port_type(p)?.clone(); ports .iter() - .all(|&(n, p)| hugr.get_optype(n).signature().get(p) == Some(&edge_t)) + .all(|&(n, p)| { + hugr.signature(n) + .is_some_and(|s| s.port_type(p) == Some(&edge_t)) + }) .then_some(edge_t) } @@ -567,11 +577,21 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort if has_other_edge(hugr, inp, Direction::Outgoing) { unimplemented!("Non-dataflow output not supported at input node") } - let dfg_inputs = hugr.get_optype(inp).signature().output_ports(); + let dfg_inputs = hugr + .get_optype(inp) + .as_input() + .unwrap() + .signature() + .output_ports(); if has_other_edge(hugr, out, Direction::Incoming) { unimplemented!("Non-dataflow input not supported at output node") } - let dfg_outputs = hugr.get_optype(out).signature().input_ports(); + let dfg_outputs = hugr + .get_optype(out) + .as_output() + .unwrap() + .signature() + .input_ports(); // Collect for each port in the input the set of target ports, filtering // direct wires to the output. @@ -596,13 +616,13 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort /// Whether a port is linked to a state order edge. fn is_order_edge(hugr: &H, node: Node, port: Port) -> bool { let op = hugr.get_optype(node); - op.other_port_index(port.direction()) == Some(port) && hugr.is_linked(node, port) + op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port) } /// Whether node has a non-df linked port in the given direction. fn has_other_edge(hugr: &H, node: Node, dir: Direction) -> bool { let op = hugr.get_optype(node); - op.other_port(dir).is_some() && hugr.is_linked(node, op.other_port_index(dir).unwrap()) + op.other_port_kind(dir).is_some() && hugr.is_linked(node, op.other_port(dir).unwrap()) } /// Errors that can occur while constructing a [`SimpleReplacement`]. @@ -693,10 +713,7 @@ mod tests { }, hugr::views::{HierarchyView, SiblingGraph}, hugr::HugrMut, - ops::{ - handle::{DfgID, FuncID, NodeHandle}, - OpType, - }, + ops::handle::{DfgID, FuncID, NodeHandle}, std_extensions::logic::test::{and_op, not_op}, type_row, }; @@ -731,7 +748,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new_linear(type_row![QB_T, QB_T, QB_T]).pure(), + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]).pure(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -747,8 +764,7 @@ mod tests { fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = - mod_builder.declare("test", FunctionType::new_linear(type_row![BOOL_T]).pure())?; + let func = mod_builder.declare("test", FunctionType::new_endo(type_row![BOOL_T]).pure())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?; @@ -806,7 +822,7 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; let empty_dfg = { - let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T, QB_T])).unwrap(); + let builder = DFGBuilder::new(FunctionType::new_endo(type_row![QB_T, QB_T])).unwrap(); let inputs = builder.input_wires(); builder.finish_prelude_hugr_with_outputs(inputs).unwrap() }; @@ -831,7 +847,7 @@ mod tests { // the first two qubits. assert_eq!( sub.signature(&func), - FunctionType::new_linear(type_row![QB_T, QB_T]) + FunctionType::new_endo(type_row![QB_T, QB_T]) ); Ok(()) } @@ -843,7 +859,7 @@ mod tests { let sub = SiblingSubgraph::from_sibling_graph(&func)?; let empty_dfg = { - let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T])).unwrap(); + let builder = DFGBuilder::new(FunctionType::new_endo(type_row![QB_T])).unwrap(); let inputs = builder.input_wires(); builder.finish_prelude_hugr_with_outputs(inputs).unwrap() }; @@ -882,7 +898,7 @@ mod tests { .collect(), hugr.node_inputs(out) .take(2) - .filter_map(|p| hugr.linked_outputs(out, p).exactly_one().ok()) + .filter_map(|p| hugr.single_linked_output(out, p)) .collect(), &func, ) @@ -959,9 +975,7 @@ mod tests { let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); - let OpType::FuncDefn(func_defn) = hugr.get_optype(func_root) else { - panic!() - }; + let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); assert_eq!(func_defn.signature, func.signature(&func_graph)) } diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 4e592a2de..44b4ee8df 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -68,3 +68,130 @@ fn dot_string(sample_hugr: (Hugr, BuildHandle, BuildHandle, BuildHandle)) { + use itertools::Itertools; + let (h, n1, n2) = sample_hugr; + + let all_output_ports = h.all_linked_outputs(n2.node()).collect_vec(); + + assert_eq!( + &all_output_ports[..], + &[ + (n1.node(), 1.into()), + (n1.node(), 0.into()), + (n1.node(), 2.into()), + ] + ); + + let all_linked_inputs = h.all_linked_inputs(n1.node()).collect_vec(); + assert_eq!( + &all_linked_inputs[..], + &[ + (n2.node(), 1.into()), + (n2.node(), 0.into()), + (n2.node(), 2.into()), + ] + ); +} + +#[rustversion::since(1.75)] // uses impl in return position +#[test] +fn value_types() { + use crate::builder::Container; + use crate::extension::prelude::BOOL_T; + use crate::std_extensions::logic::test::not_op; + use crate::utils::test_quantum_extension::h_gate; + use itertools::Itertools; + let mut dfg = DFGBuilder::new(FunctionType::new( + type_row![QB_T, BOOL_T], + type_row![BOOL_T, QB_T], + )) + .unwrap(); + + let [q, b] = dfg.input_wires_arr(); + let n1 = dfg.add_dataflow_op(h_gate(), [q]).unwrap(); + let n2 = dfg.add_dataflow_op(not_op(), [b]).unwrap(); + dfg.add_other_wire(n1.node(), n2.node()).unwrap(); + let h = dfg + .finish_prelude_hugr_with_outputs([n2.out_wire(0), n1.out_wire(0)]) + .unwrap(); + + let [_, o] = h.get_io(h.root()).unwrap(); + let n1_out_types = h.out_value_types(n1.node()).collect_vec(); + + assert_eq!(&n1_out_types[..], &[(0.into(), QB_T)]); + let out_types = h.in_value_types(o).collect_vec(); + + assert_eq!(&out_types[..], &[(0.into(), BOOL_T), (1.into(), QB_T)]); +} + +#[rustversion::since(1.75)] // uses impl in return position +#[test] +fn static_targets() { + use crate::extension::prelude::{ConstUsize, USIZE_T}; + use itertools::Itertools; + + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap(); + + let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap(); + + let load = dfg.load_const(&c).unwrap(); + + let h = dfg + .finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY) + .unwrap(); + + assert_eq!(h.static_source(load.node()), Some(c.node())); + + assert_eq!( + &h.static_targets(c.node()).unwrap().collect_vec()[..], + &[(load.node(), 0.into())] + ) +} + +#[rustversion::since(1.75)] // uses impl in return position +#[test] +fn test_dataflow_ports_only() { + use crate::builder::DataflowSubContainer; + use crate::extension::prelude::BOOL_T; + use crate::hugr::views::PortIterator; + use crate::std_extensions::logic::test::not_op; + use itertools::Itertools; + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); + let local_and = { + let local_and = dfg + .define_function( + "and", + FunctionType::new(type_row![BOOL_T; 2], type_row![BOOL_T]).pure(), + ) + .unwrap(); + let first_input = local_and.input().out_wire(0); + local_and.finish_with_outputs([first_input]).unwrap() + }; + let [in_bool] = dfg.input_wires_arr(); + + let not = dfg.add_dataflow_op(not_op(), [in_bool]).unwrap(); + let call = dfg.call(local_and.handle(), [not.out_wire(0); 2]).unwrap(); + dfg.add_other_wire(not.node(), call.node()).unwrap(); + let h = dfg + .finish_hugr_with_outputs(not.outputs(), &crate::extension::PRELUDE_REGISTRY) + .unwrap(); + let filtered_ports = h + .all_linked_outputs(call.node()) + .dataflow_ports_only(&h) + .collect_vec(); + + // should ignore the static input in to call, but report the two value ports + // and the order port. + assert_eq!( + &filtered_ports[..], + &[ + (not.node(), 0.into()), + (not.node(), 0.into()), + (not.node(), 1.into()) + ] + ) +} diff --git a/src/lib.rs b/src/lib.rs index 80a7b8edd..2fca6ace7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,139 @@ -//! `hugr` is the Hierarchical Unified Graph Representation of quantum circuits -//! and operations in the Quantinuum ecosystem. +//! Extensible, graph-based program representation with first-class support for linear types. //! -//! # Features +//! The name HUGR stands for "Hierarchical Unified Graph Representation". It is designed primarily +//! as an intermediate representation and interchange format for quantum and hybrid +//! classical–quantum programs. //! -//! - `serde` enables serialization and deserialization of the components and -//! structures. +//! Both data-flow and control-flow graphs can be represented in the HUGR. Nodes in the graph may +//! represent basic operations, or may themselves have "child" graphs, which inherit their inputs +//! and outputs. Special "non-local" edges allow data to pass directly from a node to another node +//! that is not a direct descendent (subject to causality constraints). +//! +//! The specification can be found +//! [here](https://github.com/CQCL/hugr/blob/main/specification/hugr.md). +//! +//! This crate provides a Rust implementation of HUGR and the standard extensions defined in the +//! specification. +//! +//! It includes methods for: +//! +//! - building HUGRs from basic operations; +//! - defining new extensions; +//! - serializing and deserializing HUGRs; +//! - performing local rewrites. +//! +//! # Example +//! +//! To build a HUGR for a simple quantum circuit and then serialize it to a buffer, we can define +//! a simple quantum extension and then use the [[builder::DFGBuilder]] as follows: +//! ``` +//! use hugr::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}; +//! use hugr::extension::prelude::{BOOL_T, QB_T}; +//! use hugr::hugr::Hugr; +//! use hugr::type_row; +//! use hugr::types::FunctionType; +//! +//! mod mini_quantum_extension { +//! use smol_str::SmolStr; +//! +//! use hugr::{ +//! extension::{ +//! prelude::{BOOL_T, QB_T}, +//! ExtensionId, ExtensionRegistry, PRELUDE, +//! }, +//! ops::LeafOp, +//! type_row, +//! types::{FunctionType, PolyFuncType}, +//! Extension, +//! }; +//! +//! use lazy_static::lazy_static; +//! +//! fn one_qb_func() -> PolyFuncType { +//! FunctionType::new_endo(type_row![QB_T]).into() +//! } +//! +//! fn two_qb_func() -> PolyFuncType { +//! FunctionType::new_endo(type_row![QB_T, QB_T]).into() +//! } +//! /// The extension identifier. +//! pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("mini.quantum"); +//! fn extension() -> Extension { +//! let mut extension = Extension::new(EXTENSION_ID); +//! +//! extension +//! .add_op_type_scheme_simple(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) +//! .unwrap(); +//! +//! extension +//! .add_op_type_scheme_simple(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) +//! .unwrap(); +//! +//! extension +//! .add_op_type_scheme_simple( +//! SmolStr::new_inline("Measure"), +//! "Measure a qubit, returning the qubit and the measurement result.".into(), +//! FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T]).into(), +//! ) +//! .unwrap(); +//! +//! extension +//! } +//! +//! lazy_static! { +//! /// Quantum extension definition. +//! pub static ref EXTENSION: Extension = extension(); +//! static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); +//! +//! } +//! fn get_gate(gate_name: &str) -> LeafOp { +//! EXTENSION +//! .instantiate_extension_op(gate_name, [], ®) +//! .unwrap() +//! .into() +//! } +//! pub fn h_gate() -> LeafOp { +//! get_gate("H") +//! } +//! +//! pub fn cx_gate() -> LeafOp { +//! get_gate("CX") +//! } +//! +//! pub fn measure() -> LeafOp { +//! get_gate("Measure") +//! } +//! } +//! +//! use mini_quantum_extension::{cx_gate, h_gate, measure}; +//! +//! // ┌───┐ +//! // q_0: ┤ H ├──■───── +//! // ├───┤┌─┴─┐┌─┐ +//! // q_1: ┤ H ├┤ X ├┤M├ +//! // └───┘└───┘└╥┘ +//! // c: ╩═ +//! fn make_dfg_hugr() -> Result { +//! let mut dfg_builder = DFGBuilder::new(FunctionType::new( +//! type_row![QB_T, QB_T], +//! type_row![QB_T, QB_T, BOOL_T], +//! ))?; +//! let [wire0, wire1] = dfg_builder.input_wires_arr(); +//! let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; +//! let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; +//! let wire45 = dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?; +//! let wire67 = dfg_builder.add_dataflow_op(measure(), wire45.outputs().last())?; +//! dfg_builder.finish_prelude_hugr_with_outputs(wire45.outputs().take(1).chain(wire67.outputs())) +//! } +//! +//! let h: Hugr = make_dfg_hugr().unwrap(); +//! let serialized = serde_json::to_string(&h).unwrap(); +//! println!("{}", serialized); +//! ``` +//! +//! # Optional feature flags +//! +//! - `pyo3`: Enable Python bindings via [pyo3](https://docs.rs/pyo3). //! #![warn(missing_docs)] diff --git a/src/ops.rs b/src/ops.rs index f6ef25004..7deb6328d 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -10,8 +10,9 @@ pub mod module; pub mod tag; pub mod validate; use crate::types::{EdgeKind, FunctionType, Type}; -use crate::PortIndex; -use crate::{Direction, Port}; +use crate::{Direction, OutgoingPort, Port}; +use crate::{IncomingPort, PortIndex}; +use paste::paste; use portgraph::NodeIndex; use smol_str::SmolStr; @@ -53,6 +54,50 @@ pub enum OpType { Case, } +macro_rules! impl_op_ref_try_into { + ($Op: tt, $sname:ident) => { + paste! { + impl OpType { + #[doc = "If is an instance of `" $Op "` return a reference to it."] + pub fn [](&self) -> Option<&$Op> { + if let OpType::$Op(l) = self { + Some(l) + } else { + None + } + } + + #[doc = "If is an instance of `" $Op "`."] + pub fn [](&self) -> bool { + self.[]().is_some() + } + } + } + }; + ($Op:tt) => { + impl_op_ref_try_into!($Op, $Op); + }; +} + +impl_op_ref_try_into!(Module); +impl_op_ref_try_into!(FuncDefn); +impl_op_ref_try_into!(FuncDecl); +impl_op_ref_try_into!(AliasDecl); +impl_op_ref_try_into!(AliasDefn); +impl_op_ref_try_into!(Const); +impl_op_ref_try_into!(Input); +impl_op_ref_try_into!(Output); +impl_op_ref_try_into!(Call); +impl_op_ref_try_into!(CallIndirect); +impl_op_ref_try_into!(LoadConstant); +impl_op_ref_try_into!(DFG, dfg); +impl_op_ref_try_into!(LeafOp); +impl_op_ref_try_into!(BasicBlock); +impl_op_ref_try_into!(TailLoop); +impl_op_ref_try_into!(CFG, cfg); +impl_op_ref_try_into!(Conditional); +impl_op_ref_try_into!(Case); + /// The default OpType (as returned by [Default::default]) pub const DEFAULT_OPTYPE: OpType = OpType::Module(Module); @@ -63,12 +108,12 @@ impl Default for OpType { } impl OpType { - /// The edge kind for the non-dataflow or constant-input ports of the + /// The edge kind for the non-dataflow or constant ports of the /// operation, not described by the signature. /// /// If not None, a single extra multiport of that kind will be present on /// the given direction. - pub fn other_port(&self, dir: Direction) -> Option { + pub fn other_port_kind(&self, dir: Direction) -> Option { match dir { Direction::Incoming => self.other_input(), Direction::Outgoing => self.other_output(), @@ -77,54 +122,90 @@ impl OpType { /// Returns the edge kind for the given port. pub fn port_kind(&self, port: impl Into) -> Option { - let signature = self.signature(); + let signature = self.dataflow_signature().unwrap_or_default(); let port: Port = port.into(); + let port_as_in = port.as_incoming().ok(); let dir = port.direction(); let port_count = signature.port_count(dir); if port.index() < port_count { - signature.get(port).cloned().map(EdgeKind::Value) - } else if port.index() == port_count - && dir == Direction::Incoming - && self.static_input().is_some() - { - self.static_input().map(EdgeKind::Static) + signature.port_type(port).cloned().map(EdgeKind::Value) + } else if port_as_in.is_some() && port_as_in == self.static_input_port() { + Some(EdgeKind::Static(static_in_type(self))) } else { - self.other_port(dir) + self.other_port_kind(dir) } } /// The non-dataflow port for the operation, not described by the signature. - /// See `[OpType::other_port]`. + /// See `[OpType::other_port_kind]`. /// /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. - pub fn other_port_index(&self, dir: Direction) -> Option { - let non_df_count = self.validity_flags().non_df_port_count(dir).unwrap_or(1); - if self.other_port(dir).is_some() && non_df_count == 1 { + pub fn other_port(&self, dir: Direction) -> Option { + let non_df_count = self.non_df_port_count(dir); + if self.other_port_kind(dir).is_some() && non_df_count == 1 { // if there is a static input it comes before the non_df_ports let static_input = - (dir == Direction::Incoming && self.static_input().is_some()) as usize; + (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; - Some(Port::new( - dir, - self.signature().port_count(dir) + static_input, - )) + Some(Port::new(dir, self.value_port_count(dir) + static_input)) } else { None } } + /// The number of Value ports in given direction. + pub fn value_port_count(&self, dir: portgraph::Direction) -> usize { + self.dataflow_signature() + .map(|sig| sig.port_count(dir)) + .unwrap_or(0) + } + + /// The number of Value input ports. + pub fn value_input_count(&self) -> usize { + self.value_port_count(Direction::Incoming) + } + + /// The number of Value output ports. + pub fn value_output_count(&self) -> usize { + self.value_port_count(Direction::Outgoing) + } + + /// The non-dataflow input port for the operation, not described by the signature. + /// See `[OpType::other_port]`. + pub fn other_input_port(&self) -> Option { + self.other_port(Direction::Incoming) + } + + /// The non-dataflow input port for the operation, not described by the signature. + /// See `[OpType::other_port]`. + pub fn other_output_port(&self) -> Option { + self.other_port(Direction::Outgoing) + } + + /// If the op has a static input (Call and LoadConstant), the port of that input. + pub fn static_input_port(&self) -> Option { + match self { + OpType::Call(call) => Some(call.called_function_port()), + OpType::LoadConstant(l) => Some(l.constant_port()), + _ => None, + } + } + + /// If the op has a static output (Const, FuncDefn, FuncDecl), the port of that output. + pub fn static_output_port(&self) -> Option { + OpTag::StaticOutput + .is_superset(self.tag()) + .then_some(0.into()) + } + /// Returns the number of ports for the given direction. pub fn port_count(&self, dir: Direction) -> usize { - let signature = self.signature(); - let has_other_ports = self.other_port(dir).is_some(); - let non_df_count = self - .validity_flags() - .non_df_port_count(dir) - .unwrap_or(has_other_ports as usize); + let non_df_count = self.non_df_port_count(dir); // if there is a static input it comes before the non_df_ports - let static_input = (dir == Direction::Incoming && self.static_input().is_some()) as usize; - signature.port_count(dir) + non_df_count + static_input + let static_input = + (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; + self.value_port_count(dir) + non_df_count + static_input } /// Returns the number of inputs ports for the operation. @@ -143,6 +224,14 @@ impl OpType { } } +fn static_in_type(op: &OpType) -> Type { + match op { + OpType::Call(call) => Type::new_function(call.called_function_type().clone()), + OpType::LoadConstant(load) => load.constant_type().clone(), + _ => panic!("this function should not be called if the optype is not known to be Call or LoadConst.") + } +} + /// Macro used by operations that want their /// name to be the same as their type name macro_rules! impl_op_name { @@ -185,18 +274,10 @@ pub trait OpTrait { /// The signature of the operation. /// - /// Only dataflow operations have a non-empty signature. - fn signature(&self) -> FunctionType { - Default::default() - } - - /// Get the static input type of this operation if it has one (only Some for - /// [`LoadConstant`] and [`Call`]) - #[inline] - fn static_input(&self) -> Option { + /// Only dataflow operations have a signature, otherwise returns None. + fn dataflow_signature(&self) -> Option { None } - /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// @@ -214,6 +295,15 @@ pub trait OpTrait { fn other_output(&self) -> Option { None } + + /// Get the number of non-dataflow multiports. + fn non_df_port_count(&self, dir: Direction) -> usize { + match dir { + Direction::Incoming => self.other_input(), + Direction::Outgoing => self.other_output(), + } + .is_some() as usize + } } #[enum_dispatch] diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 54adcf803..d2121cd36 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -3,8 +3,8 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; -use crate::type_row; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; +use crate::{type_row, Direction}; use super::dataflow::DataflowOpTrait; use super::OpTag; @@ -169,12 +169,20 @@ impl OpTrait for BasicBlock { Some(EdgeKind::ControlFlow) } - fn signature(&self) -> FunctionType { - match self { + fn dataflow_signature(&self) -> Option { + Some(match self { BasicBlock::DFB { extension_delta, .. } => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta), BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), + }) + } + + fn non_df_port_count(&self, dir: Direction) -> usize { + match self { + Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => tuple_sum_rows.len(), + Self::Exit { .. } if dir == Direction::Outgoing => 0, + _ => 1, } } } @@ -224,8 +232,8 @@ impl OpTrait for Case { ::TAG } - fn signature(&self) -> FunctionType { - self.signature.clone() + fn dataflow_signature(&self) -> Option { + Some(self.signature.clone()) } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index b1c5a39b3..5f2af204f 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -11,7 +11,7 @@ use crate::types::{type_param::TypeArg, FunctionType}; use crate::{Hugr, Node}; use super::tag::OpTag; -use super::{LeafOp, OpName, OpTrait, OpType}; +use super::{LeafOp, OpTrait, OpType}; /// An instantiation of an operation (declared by a extension) with values for the type arguments #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -58,8 +58,9 @@ impl From for LeafOp { } } -impl OpName for ExternalOp { - fn name(&self) -> SmolStr { +impl ExternalOp { + /// Name of the ExternalOp + pub fn name(&self) -> SmolStr { let (res_id, op_name) = match self { Self::Opaque(op) => (&op.extension, &op.op_name), Self::Extension(ExtensionOp { def, .. }) => (def.extension(), def.name()), @@ -68,23 +69,23 @@ impl OpName for ExternalOp { } } -impl OpTrait for ExternalOp { - fn description(&self) -> &str { +impl ExternalOp { + /// A description of the external op. + pub fn description(&self) -> &str { match self { Self::Opaque(op) => op.description.as_str(), Self::Extension(ExtensionOp { def, .. }) => def.description(), } } - fn tag(&self) -> OpTag { - OpTag::Leaf - } - /// Note the case of an OpaqueOp without a signature should already /// have been detected in [resolve_extension_ops] - fn signature(&self) -> FunctionType { + pub fn dataflow_signature(&self) -> FunctionType { match self { - Self::Opaque(op) => op.signature.clone().unwrap(), + Self::Opaque(op) => op + .signature + .clone() + .expect("Op should have been serialized with signature."), Self::Extension(ExtensionOp { signature, .. }) => signature.clone(), } } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index 5b63b706f..bf4fd83a9 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -5,16 +5,13 @@ use super::{impl_op_name, OpTag, OpTrait}; use crate::extension::ExtensionSet; use crate::ops::StaticTag; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; +use crate::IncomingPort; -pub(super) trait DataflowOpTrait { +pub(crate) trait DataflowOpTrait { const TAG: OpTag; fn description(&self) -> &str; fn signature(&self) -> FunctionType; - /// Get the static input type of this operation if it has one. - fn static_input(&self) -> Option { - None - } /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// @@ -115,8 +112,8 @@ impl OpTrait for T { fn tag(&self) -> OpTag { T::TAG } - fn signature(&self) -> FunctionType { - DataflowOpTrait::signature(self) + fn dataflow_signature(&self) -> Option { + Some(DataflowOpTrait::signature(self)) } fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) @@ -125,10 +122,6 @@ impl OpTrait for T { fn other_output(&self) -> Option { DataflowOpTrait::other_output(self) } - - fn static_input(&self) -> Option { - DataflowOpTrait::static_input(self) - } } impl StaticTag for T { const TAG: OpTag = T::TAG; @@ -156,10 +149,18 @@ impl DataflowOpTrait for Call { fn signature(&self) -> FunctionType { self.signature.clone() } +} +impl Call { + #[inline] + /// Return the signature of the function called by this op. + pub fn called_function_type(&self) -> &FunctionType { + &self.signature + } + /// The IncomingPort which links to the function being called. #[inline] - fn static_input(&self) -> Option { - Some(Type::new_function(self.signature.clone())) + pub fn called_function_port(&self) -> IncomingPort { + self.called_function_type().input_count().into() } } @@ -204,10 +205,18 @@ impl DataflowOpTrait for LoadConstant { fn signature(&self) -> FunctionType { FunctionType::new(TypeRow::new(), vec![self.datatype.clone()]) } +} +impl LoadConstant { + #[inline] + /// The type of the constant loaded by this op. + pub fn constant_type(&self) -> &Type { + &self.datatype + } + /// The IncomingPort which links to the loaded constant. #[inline] - fn static_input(&self) -> Option { - Some(self.datatype.clone()) + pub fn constant_port(&self) -> IncomingPort { + 0.into() } } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 24f1ebb7e..85de1a0a8 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -3,7 +3,8 @@ use smol_str::SmolStr; use super::custom::ExternalOp; -use super::{OpName, OpTag, OpTrait, StaticTag}; +use super::dataflow::DataflowOpTrait; +use super::{OpName, OpTag}; use crate::extension::{ExtensionRegistry, SignatureError}; use crate::types::type_param::TypeArg; @@ -128,11 +129,11 @@ impl OpName for LeafOp { } } -impl StaticTag for LeafOp { - const TAG: OpTag = OpTag::Leaf; -} +// impl StaticTag for LeafOp { +// } -impl OpTrait for LeafOp { +impl DataflowOpTrait for LeafOp { + const TAG: OpTag = OpTag::Leaf; /// A human-readable description of the operation. fn description(&self) -> &str { match self { @@ -148,10 +149,6 @@ impl OpTrait for LeafOp { } } - fn tag(&self) -> OpTag { - ::TAG - } - /// The signature of the operation. fn signature(&self) -> FunctionType { // Static signatures. The `TypeRow`s in the `FunctionType` use a @@ -159,7 +156,7 @@ impl OpTrait for LeafOp { match self { LeafOp::Noop { ty: typ } => FunctionType::new(vec![typ.clone()], vec![typ.clone()]), - LeafOp::CustomOp(ext) => ext.signature(), + LeafOp::CustomOp(ext) => ext.dataflow_signature(), LeafOp::MakeTuple { tys: types } => { FunctionType::new(types.clone(), vec![Type::new_tuple(types.clone())]) } diff --git a/src/ops/tag.rs b/src/ops/tag.rs index e45014977..435f279f1 100644 --- a/src/ops/tag.rs +++ b/src/ops/tag.rs @@ -46,6 +46,10 @@ pub enum OpTag { Input, /// A dataflow output. Output, + /// Dataflow node that has a static input + StaticInput, + /// Node that has a static output + StaticOutput, /// A function call. FnCall, /// A constant load operation. @@ -104,14 +108,14 @@ impl OpTag { OpTag::DataflowChild => &[OpTag::Any], OpTag::Input => &[OpTag::DataflowChild], OpTag::Output => &[OpTag::DataflowChild], - OpTag::Function => &[OpTag::ModuleOp], + OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput], OpTag::Alias => &[OpTag::ScopedDefn], OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent], OpTag::BasicBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent], OpTag::BasicBlockExit => &[OpTag::BasicBlock], OpTag::Case => &[OpTag::Any, OpTag::DataflowParent], OpTag::ModuleRoot => &[OpTag::Any], - OpTag::Const => &[OpTag::ScopedDefn], + OpTag::Const => &[OpTag::ScopedDefn, OpTag::StaticOutput], OpTag::Dfg => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Cfg => &[OpTag::DataflowChild], OpTag::ScopedDefn => &[ @@ -121,8 +125,10 @@ impl OpTag { ], OpTag::TailLoop => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Conditional => &[OpTag::DataflowChild], - OpTag::FnCall => &[OpTag::DataflowChild], - OpTag::LoadConst => &[OpTag::DataflowChild], + OpTag::StaticInput => &[OpTag::Any], + OpTag::StaticOutput => &[OpTag::Any], + OpTag::FnCall => &[OpTag::StaticInput, OpTag::DataflowChild], + OpTag::LoadConst => &[OpTag::StaticInput, OpTag::DataflowChild], OpTag::Leaf => &[OpTag::DataflowChild], OpTag::DataflowParent => &[OpTag::Any], } @@ -150,6 +156,8 @@ impl OpTag { OpTag::Cfg => "Nested control-flow operation", OpTag::TailLoop => "Tail-recursive loop", OpTag::Conditional => "Conditional operation", + OpTag::StaticInput => "Node with static input (LoadConst or FnCall)", + OpTag::StaticOutput => "Node with static output (FuncDefn, FuncDecl, Const)", OpTag::FnCall => "Function call", OpTag::LoadConst => "Constant load operation", OpTag::Leaf => "Leaf operation", diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 6c6ff65cd..ec8b75c24 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -11,7 +11,6 @@ use portgraph::{NodeIndex, PortOffset}; use thiserror::Error; use crate::types::{Type, TypeRow}; -use crate::Direction; use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp}; @@ -32,30 +31,12 @@ pub struct OpValidityFlags { pub requires_children: bool, /// Whether the children must form a DAG (no cycles). pub requires_dag: bool, - /// A strict requirement on the number of non-dataflow multiports. - /// - /// If not specified, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - pub non_df_ports: (Option, Option), /// A validation check for edges between children /// // Enclosed in an `Option` to avoid iterating over the edges if not needed. pub edge_check: Option Result<(), EdgeValidationError>>, } -impl OpValidityFlags { - /// Get the number of non-dataflow multiports. - /// - /// If None, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - pub fn non_df_port_count(&self, dir: Direction) -> Option { - match dir { - Direction::Incoming => self.non_df_ports.0, - Direction::Outgoing => self.non_df_ports.1, - } - } -} - impl Default for OpValidityFlags { fn default() -> Self { // Defaults to flags valid for non-container operations @@ -65,7 +46,6 @@ impl Default for OpValidityFlags { allowed_second_child: OpTag::Any, requires_children: false, requires_dag: false, - non_df_ports: (None, None), edge_check: None, } } @@ -122,12 +102,8 @@ impl ValidateOp for super::DFG { &self, children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { - validate_io_nodes( - &self.signature().input, - &self.signature().output, - "nested graph", - children, - ) + let sig = self.dataflow_signature().unwrap_or_default(); + validate_io_nodes(&sig.input, &sig.output, "nested graph", children) } } @@ -159,9 +135,9 @@ impl ValidateOp for super::Conditional { // Each child must have its variant's row and the rest of `inputs` as input, // and matching output for (i, (child, optype)) in children.into_iter().enumerate() { - let OpType::Case(case_op) = optype else { - panic!("Child check should have already checked valid ops.") - }; + let case_op = optype + .as_case() + .expect("Child check should have already checked valid ops."); let sig = &case_op.signature; if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs { return Err(ChildrenValidationError::ConditionalCaseSignature { @@ -316,16 +292,12 @@ impl ValidateOp for BasicBlock { /// Returns the set of allowed parent operation types. fn validity_flags(&self) -> OpValidityFlags { match self { - BasicBlock::DFB { - tuple_sum_rows: tuple_sum_variants, - .. - } => OpValidityFlags { + BasicBlock::DFB { .. } => OpValidityFlags { allowed_children: OpTag::DataflowChild, allowed_first_child: OpTag::Input, allowed_second_child: OpTag::Output, requires_children: true, requires_dag: true, - non_df_ports: (None, Some(tuple_sum_variants.len())), ..Default::default() }, // Default flags are valid for non-container operations @@ -395,19 +367,22 @@ fn validate_io_nodes<'a>( let (first, first_optype) = children.next().unwrap(); let (second, second_optype) = children.next().unwrap(); - if &first_optype.signature().output != expected_input { + let first_sig = first_optype.dataflow_signature().unwrap_or_default(); + if &first_sig.output != expected_input { return Err(ChildrenValidationError::IOSignatureMismatch { child: first, - actual: first_optype.signature().output, + actual: first_sig.output, expected: expected_input.clone(), node_desc: "Input", container_desc, }); } - if &second_optype.signature().input != expected_output { + let second_sig = second_optype.dataflow_signature().unwrap_or_default(); + + if &second_sig.input != expected_output { return Err(ChildrenValidationError::IOSignatureMismatch { child: second, - actual: second_optype.signature().input, + actual: second_sig.input, expected: expected_output.clone(), node_desc: "Output", container_desc, @@ -440,9 +415,9 @@ fn validate_io_nodes<'a>( /// Validate an edge between two basic blocks in a CFG sibling graph. fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> { let [source, target]: [&BasicBlock; 2] = [&edge.source_op, &edge.target_op].map(|op| { - let OpType::BasicBlock(block_op) = op else { - panic!("CFG sibling graphs can only contain basic block operations.") - }; + let block_op = op + .as_basic_block() + .expect("CFG sibling graphs can only contain basic block operations."); block_op }); diff --git a/src/types.rs b/src/types.rs index 93e05c677..123a320a6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -215,7 +215,7 @@ impl TypeEnum { /// ``` /// # use hugr::types::{Type, TypeBound, FunctionType}; /// -/// let func_type = Type::new_function(FunctionType::new_linear(vec![])); +/// let func_type = Type::new_function(FunctionType::new_endo(vec![])); /// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable); /// /// ``` @@ -428,7 +428,7 @@ pub(crate) mod test { fn construct() { let t: Type = Type::new_tuple(vec![ USIZE_T, - Type::new_function(FunctionType::new_linear(vec![])), + Type::new_function(FunctionType::new_endo(vec![])), Type::new_extension(CustomType::new( "my_custom", [], diff --git a/src/types/serialize.rs b/src/types/serialize.rs index 4febe2238..4c76fb75e 100644 --- a/src/types/serialize.rs +++ b/src/types/serialize.rs @@ -66,7 +66,7 @@ mod test { #[test] fn serialize_types_roundtrip() { - let g: Type = Type::new_function(FunctionType::new_linear(vec![])); + let g: Type = Type::new_function(FunctionType::new_endo(vec![])); assert_eq!(ser_roundtrip(&g), g); diff --git a/src/types/signature.rs b/src/types/signature.rs index 90df38efd..f7be4cc1c 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -1,5 +1,6 @@ //! Abstract and concrete Signature types. +use itertools::Either; #[cfg(feature = "pyo3")] use pyo3::{pyclass, pymethods}; @@ -108,8 +109,9 @@ impl FunctionType { extension_reqs: ExtensionSet::new(), } } - /// Create a new signature with the same input and output types. - pub fn new_linear(linear: impl Into) -> Self { + /// Create a new signature with the same input and output types (signature of an endomorphic + /// function). + pub fn new_endo(linear: impl Into) -> Self { let linear = linear.into(); Self::new(linear.clone(), linear) } @@ -117,22 +119,50 @@ impl FunctionType { /// Returns the type of a value [`Port`]. Returns `None` if the port is out /// of bounds. #[inline] - pub fn get(&self, port: impl Into) -> Option<&Type> { - let port = port.into(); - match port.direction() { - Direction::Incoming => self.input.get(port), - Direction::Outgoing => self.output.get(port), + pub fn port_type(&self, port: impl Into) -> Option<&Type> { + let port: Port = port.into(); + match port.as_directed() { + Either::Left(port) => self.in_port_type(port), + Either::Right(port) => self.out_port_type(port), } } + /// Returns the type of a value input [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn in_port_type(&self, port: impl Into) -> Option<&Type> { + self.input.get(port.into()) + } + + /// Returns the type of a value output [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn out_port_type(&self, port: impl Into) -> Option<&Type> { + self.output.get(port.into()) + } + + /// Returns a mutable reference to the type of a value input [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn in_port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { + self.input.get_mut(port.into()) + } + + /// Returns the type of a value output [`Port`]. Returns `None` if the port is out + /// of bounds. + #[inline] + pub fn out_port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { + self.output.get_mut(port.into()) + } + /// Returns a mutable reference to the type of a value [`Port`]. /// Returns `None` if the port is out of bounds. #[inline] - pub fn get_mut(&mut self, port: impl Into) -> Option<&mut Type> { + pub fn port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { let port = port.into(); - match port.direction() { - Direction::Incoming => self.input.get_mut(port), - Direction::Outgoing => self.output.get_mut(port), + match port.as_directed() { + Either::Left(port) => self.in_port_type_mut(port), + Either::Right(port) => self.out_port_type_mut(port), } } @@ -271,14 +301,14 @@ mod test { assert_eq!(f_type.input_types(), &[Type::UNIT]); assert_eq!( - f_type.get(Port::new(Direction::Incoming, 0)), + f_type.port_type(Port::new(Direction::Incoming, 0)), Some(&Type::UNIT) ); let out = Port::new(Direction::Outgoing, 0); - *(f_type.get_mut(out).unwrap()) = USIZE_T; + *(f_type.port_type_mut(out).unwrap()) = USIZE_T; - assert_eq!(f_type.get(out), Some(&USIZE_T)); + assert_eq!(f_type.port_type(out), Some(&USIZE_T)); assert_eq!(f_type.input_types(), &[Type::UNIT]); assert_eq!(f_type.output_types(), &[USIZE_T]); diff --git a/src/utils.rs b/src/utils.rs index 9ce62ba37..73787b4ed 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -46,11 +46,11 @@ pub(crate) mod test_quantum_extension { use lazy_static::lazy_static; fn one_qb_func() -> PolyFuncType { - FunctionType::new_linear(type_row![QB_T]).into() + FunctionType::new_endo(type_row![QB_T]).into() } fn two_qb_func() -> PolyFuncType { - FunctionType::new_linear(type_row![QB_T, QB_T]).into() + FunctionType::new_endo(type_row![QB_T, QB_T]).into() } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); diff --git a/src/values.rs b/src/values.rs index 3ace6fe5a..428654066 100644 --- a/src/values.rs +++ b/src/values.rs @@ -266,7 +266,7 @@ pub(crate) mod test { hugr: Box::new(simple_dfg_hugr), }; - let correct_type = Type::new_function(FunctionType::new_linear(type_row![ + let correct_type = Type::new_function(FunctionType::new_endo(type_row![ crate::extension::prelude::BOOL_T ]));