Skip to content

Commit

Permalink
refactor!: move non_df_count to OpTrait
Browse files Browse the repository at this point in the history
Closes #521
  • Loading branch information
ss2165 committed Nov 13, 2023
1 parent d32d033 commit d602521
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 28 deletions.
11 changes: 9 additions & 2 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl OpType {
///
/// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports.
pub fn other_port_index(&self, dir: Direction) -> Option<Port> {
let non_df_count = self.validity_flags().non_df_port_count(dir).unwrap_or(1);
let non_df_count = self.non_df_port_count(dir).unwrap_or(1);
if self.other_port(dir).is_some() && non_df_count == 1 {
// if there is a static input it comes before the non_df_ports
let static_input =
Expand All @@ -119,7 +119,6 @@ impl OpType {
let signature = self.signature();
let has_other_ports = self.other_port(dir).is_some();
let non_df_count = self
.validity_flags()
.non_df_port_count(dir)
.unwrap_or(has_other_ports as usize);
// if there is a static input it comes before the non_df_ports
Expand Down Expand Up @@ -214,6 +213,14 @@ pub trait OpTrait {
fn other_output(&self) -> Option<EdgeKind> {
None
}

/// Get the number of non-dataflow multiports.
///
/// If None, the operation must have exactly one non-dataflow port
/// if the operation type has other_edges, or zero otherwise.
fn non_df_port_count(&self, _dir: Direction) -> Option<usize> {
None
}
}

#[enum_dispatch]
Expand Down
11 changes: 10 additions & 1 deletion src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
use smol_str::SmolStr;

use crate::extension::ExtensionSet;
use crate::type_row;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};
use crate::{type_row, Direction};

use super::dataflow::DataflowOpTrait;
use super::OpTag;
Expand Down Expand Up @@ -177,6 +177,15 @@ impl OpTrait for BasicBlock {
BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]),
}
}

fn non_df_port_count(&self, dir: Direction) -> Option<usize> {
match self {
Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => {
Some(tuple_sum_rows.len())
}
_ => None,
}
}
}

impl BasicBlock {
Expand Down
26 changes: 1 addition & 25 deletions src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use portgraph::{NodeIndex, PortOffset};
use thiserror::Error;

use crate::types::{Type, TypeRow};
use crate::Direction;

use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp};

Expand All @@ -32,30 +31,12 @@ pub struct OpValidityFlags {
pub requires_children: bool,
/// Whether the children must form a DAG (no cycles).
pub requires_dag: bool,
/// A strict requirement on the number of non-dataflow multiports.
///
/// If not specified, the operation must have exactly one non-dataflow port
/// if the operation type has other_edges, or zero otherwise.
pub non_df_ports: (Option<usize>, Option<usize>),
/// A validation check for edges between children
///
// Enclosed in an `Option` to avoid iterating over the edges if not needed.
pub edge_check: Option<fn(ChildrenEdgeData) -> Result<(), EdgeValidationError>>,
}

impl OpValidityFlags {
/// Get the number of non-dataflow multiports.
///
/// If None, the operation must have exactly one non-dataflow port
/// if the operation type has other_edges, or zero otherwise.
pub fn non_df_port_count(&self, dir: Direction) -> Option<usize> {
match dir {
Direction::Incoming => self.non_df_ports.0,
Direction::Outgoing => self.non_df_ports.1,
}
}
}

impl Default for OpValidityFlags {
fn default() -> Self {
// Defaults to flags valid for non-container operations
Expand All @@ -65,7 +46,6 @@ impl Default for OpValidityFlags {
allowed_second_child: OpTag::Any,
requires_children: false,
requires_dag: false,
non_df_ports: (None, None),
edge_check: None,
}
}
Expand Down Expand Up @@ -316,16 +296,12 @@ impl ValidateOp for BasicBlock {
/// Returns the set of allowed parent operation types.
fn validity_flags(&self) -> OpValidityFlags {
match self {
BasicBlock::DFB {
tuple_sum_rows: tuple_sum_variants,
..
} => OpValidityFlags {
BasicBlock::DFB { .. } => OpValidityFlags {
allowed_children: OpTag::DataflowChild,
allowed_first_child: OpTag::Input,
allowed_second_child: OpTag::Output,
requires_children: true,
requires_dag: true,
non_df_ports: (None, Some(tuple_sum_variants.len())),
..Default::default()
},
// Default flags are valid for non-container operations
Expand Down

0 comments on commit d602521

Please sign in to comment.