diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index b9ea98989..6ae55f8a9 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -12,7 +12,7 @@ use crate::{ views::SiblingSubgraph, HugrMut, }, - ops::{Const, LeafOp, OpType}, + ops::{Const, LeafOp}, type_row, types::{FunctionType, Type, TypeEnum}, values::Value, @@ -44,9 +44,7 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { .collect() } /// For a given op and consts, attempt to evaluate the op. -pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { - let op = op.as_leaf_op()?; - +pub fn fold_leaf_op(op: &LeafOp, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { match op { LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), LeafOp::MakeTuple { .. } => { @@ -138,6 +136,8 @@ fn fold_op( op_node: Node, reg: &ExtensionRegistry, ) -> Option<(SimpleReplacement, Vec)> { + // only support leaf folding for now. + let neighbour_op = hugr.get_optype(op_node).as_leaf_op()?; let (in_consts, removals): (Vec<_>, Vec<_>) = hugr .node_inputs(op_node) .filter_map(|in_p| { @@ -145,9 +145,8 @@ fn fold_op( Some(((in_p, con_op), RemoveConstIgnore(load_n))) }) .unzip(); - let neighbour_op = hugr.get_optype(op_node); // attempt to evaluate op - let folded = fold_const(neighbour_op, &in_consts)?; + let folded = fold_leaf_op(neighbour_op, &in_consts)?; let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); let nu_out = op_outs .into_iter() @@ -220,6 +219,7 @@ mod test { use super::*; use crate::extension::prelude::sum_with_error; use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::ops::OpType; use crate::std_extensions::arithmetic; use crate::std_extensions::arithmetic::conversions::ConvertOpDef; use crate::std_extensions::arithmetic::float_ops::FloatOps; @@ -249,7 +249,7 @@ mod test { fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; let add_op: OpType = FloatOps::fadd.into(); - let out = fold_const(&add_op, &consts).unwrap(); + let out = fold_leaf_op(add_op.as_leaf_op().unwrap(), &consts).unwrap(); assert_eq!(&out[..], &[(0.into(), f2c(c))]); }