diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 641ef1ae2..c85a02eb0 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -70,8 +70,8 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant(&mut self, constant: ops::Const) -> Result { - let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?; + fn add_constant(&mut self, constant: impl Into) -> Result { + let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?; Ok(const_n.into()) } @@ -374,7 +374,7 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const(&mut self, constant: ops::Const) -> Result { + fn add_load_const(&mut self, constant: impl Into) -> Result { let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index bbcddade7..8f99ee512 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -109,7 +109,7 @@ mod test { let build_result: Result = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?; + let const_wire = loop_b.add_load_const(ConstUsize::new(1))?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -173,7 +173,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const(ConstUsize::new(2).into())?; + let wire = branch_1.add_load_const(ConstUsize::new(2))?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index 3f524e2db..05f1f48d9 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -1,5 +1,6 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. +pub mod consts; pub mod insert_identity; pub mod outline_cfg; pub mod replace; diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs new file mode 100644 index 000000000..61be178d5 --- /dev/null +++ b/src/hugr/rewrite/consts.rs @@ -0,0 +1,214 @@ +//! Rewrite operations involving Const and LoadConst operations + +use std::iter; + +use crate::{ + hugr::{HugrError, HugrMut}, + HugrView, Node, +}; + +use itertools::Itertools; +use thiserror::Error; + +use super::Rewrite; + +/// Remove a [`crate::ops::LoadConstant`] node with no consumers. +#[derive(Debug, Clone)] +pub struct RemoveConstIgnore(pub Node); + +/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not correct operation).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConstIgnore { + type Error = RemoveError; + + // The Const node the LoadConstant was connected to. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { + return Err(RemoveError::InvalidNode(node)); + } + + if h.out_value_types(node) + .next() + .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) + { + return Err(RemoveError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .input_neighbours(node) + .exactly_one() + .ok() + .expect("Validation should check a Const is connected to LoadConstant."); + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +/// Remove a [`crate::ops::Const`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConst(pub Node); + +impl Rewrite for RemoveConst { + type Error = RemoveError; + + // The parent of the Const node. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { + return Err(RemoveError::InvalidNode(node)); + } + + if h.output_neighbours(node).next().is_some() { + return Err(RemoveError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let parent = h + .get_parent(node) + .expect("Const node without a parent shouldn't happen."); + h.remove_node(node)?; + + Ok(parent) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, + extension::{ + prelude::{ConstUsize, USIZE_T}, + PRELUDE_REGISTRY, + }, + hugr::HugrMut, + ops::{handle::NodeHandle, LeafOp}, + type_row, + types::FunctionType, + }; + #[test] + fn test_const_remove() -> Result<(), Box> { + let mut build = ModuleBuilder::new(); + let con_node = build.add_constant(ConstUsize::new(2))?; + + let mut dfg_build = + build.define_function("main", FunctionType::new_endo(type_row![]).into())?; + let load_1 = dfg_build.load_const(&con_node)?; + let load_2 = dfg_build.load_const(&con_node)?; + let tup = dfg_build.add_dataflow_op( + LeafOp::MakeTuple { + tys: type_row![USIZE_T, USIZE_T], + }, + [load_1, load_2], + )?; + dfg_build.finish_sub_container()?; + + let mut h = build.finish_prelude_hugr()?; + // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple + assert_eq!(h.node_count(), 8); + let tup_node = tup.node(); + // can't remove invalid node + assert_eq!( + h.apply_rewrite(RemoveConst(tup_node)), + Err(RemoveError::InvalidNode(tup_node)) + ); + + assert_eq!( + h.apply_rewrite(RemoveConstIgnore(tup_node)), + Err(RemoveError::InvalidNode(tup_node)) + ); + let load_1_node = load_1.node(); + let load_2_node = load_2.node(); + let con_node = con_node.node(); + + let remove_1 = RemoveConstIgnore(load_1_node); + assert_eq!( + remove_1.invalidation_set().exactly_one().ok(), + Some(load_1_node) + ); + + let remove_2 = RemoveConstIgnore(load_2_node); + + let remove_con = RemoveConst(con_node); + assert_eq!( + remove_con.invalidation_set().exactly_one().ok(), + Some(con_node) + ); + + // can't remove nodes in use + assert_eq!( + h.apply_rewrite(remove_1.clone()), + Err(RemoveError::ValueUsed(load_1_node)) + ); + + // remove the use + h.remove_node(tup_node)?; + + // remove first load + let reported_con_node = h.apply_rewrite(remove_1)?; + assert_eq!(reported_con_node, con_node); + + // still can't remove const, in use by second load + assert_eq!( + h.apply_rewrite(remove_con.clone()), + Err(RemoveError::ValueUsed(con_node)) + ); + + // remove second use + let reported_con_node = h.apply_rewrite(remove_2)?; + assert_eq!(reported_con_node, con_node); + // remove const + assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + + assert_eq!(h.node_count(), 4); + assert!(h.validate(&PRELUDE_REGISTRY).is_ok()); + Ok(()) + } +} diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index c248fda2b..ce0353d48 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -140,7 +140,7 @@ fn static_targets() { ) .unwrap(); - let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap(); + let c = dfg.add_constant(ConstUsize::new(1)).unwrap(); let load = dfg.load_const(&c).unwrap();