diff --git a/src/ops.rs b/src/ops.rs index f6ef25004..e7162f65a 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -99,7 +99,7 @@ impl OpType { /// /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. pub fn other_port_index(&self, dir: Direction) -> Option { - let non_df_count = self.validity_flags().non_df_port_count(dir).unwrap_or(1); + let non_df_count = self.non_df_port_count(dir).unwrap_or(1); if self.other_port(dir).is_some() && non_df_count == 1 { // if there is a static input it comes before the non_df_ports let static_input = @@ -119,7 +119,6 @@ impl OpType { let signature = self.signature(); let has_other_ports = self.other_port(dir).is_some(); let non_df_count = self - .validity_flags() .non_df_port_count(dir) .unwrap_or(has_other_ports as usize); // if there is a static input it comes before the non_df_ports @@ -214,6 +213,14 @@ pub trait OpTrait { fn other_output(&self) -> Option { None } + + /// Get the number of non-dataflow multiports. + /// + /// If None, the operation must have exactly one non-dataflow port + /// if the operation type has other_edges, or zero otherwise. + fn non_df_port_count(&self, _dir: Direction) -> Option { + None + } } #[enum_dispatch] diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 54adcf803..b6836ea70 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -3,8 +3,8 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; -use crate::type_row; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; +use crate::{type_row, Direction}; use super::dataflow::DataflowOpTrait; use super::OpTag; @@ -177,6 +177,15 @@ impl OpTrait for BasicBlock { BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), } } + + fn non_df_port_count(&self, dir: Direction) -> Option { + match self { + Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => { + Some(tuple_sum_rows.len()) + } + _ => None, + } + } } impl BasicBlock { diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 6c6ff65cd..546d6cec6 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -11,7 +11,6 @@ use portgraph::{NodeIndex, PortOffset}; use thiserror::Error; use crate::types::{Type, TypeRow}; -use crate::Direction; use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp}; @@ -32,30 +31,12 @@ pub struct OpValidityFlags { pub requires_children: bool, /// Whether the children must form a DAG (no cycles). pub requires_dag: bool, - /// A strict requirement on the number of non-dataflow multiports. - /// - /// If not specified, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - pub non_df_ports: (Option, Option), /// A validation check for edges between children /// // Enclosed in an `Option` to avoid iterating over the edges if not needed. pub edge_check: Option Result<(), EdgeValidationError>>, } -impl OpValidityFlags { - /// Get the number of non-dataflow multiports. - /// - /// If None, the operation must have exactly one non-dataflow port - /// if the operation type has other_edges, or zero otherwise. - pub fn non_df_port_count(&self, dir: Direction) -> Option { - match dir { - Direction::Incoming => self.non_df_ports.0, - Direction::Outgoing => self.non_df_ports.1, - } - } -} - impl Default for OpValidityFlags { fn default() -> Self { // Defaults to flags valid for non-container operations @@ -65,7 +46,6 @@ impl Default for OpValidityFlags { allowed_second_child: OpTag::Any, requires_children: false, requires_dag: false, - non_df_ports: (None, None), edge_check: None, } } @@ -316,16 +296,12 @@ impl ValidateOp for BasicBlock { /// Returns the set of allowed parent operation types. fn validity_flags(&self) -> OpValidityFlags { match self { - BasicBlock::DFB { - tuple_sum_rows: tuple_sum_variants, - .. - } => OpValidityFlags { + BasicBlock::DFB { .. } => OpValidityFlags { allowed_children: OpTag::DataflowChild, allowed_first_child: OpTag::Input, allowed_second_child: OpTag::Output, requires_children: true, requires_dag: true, - non_df_ports: (None, Some(tuple_sum_variants.len())), ..Default::default() }, // Default flags are valid for non-container operations