Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement RemoveConst and RemoveConstIgnore #757

Merged
merged 12 commits into from
Jan 3, 2024
6 changes: 3 additions & 3 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -374,7 +374,7 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1))?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -173,7 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let wire = branch_1.add_load_const(ConstUsize::new(2))?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
1 change: 1 addition & 0 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.

pub mod consts;
pub mod insert_identity;
pub mod outline_cfg;
pub mod replace;
Expand Down
214 changes: 214 additions & 0 deletions src/hugr/rewrite/consts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
//! Rewrite operations involving Const and LoadConst operations

use std::iter;

use crate::{
hugr::{HugrError, HugrMut},
HugrView, Node,
};

use itertools::Itertools;
use thiserror::Error;

use super::Rewrite;

/// Remove a [`crate::ops::LoadConstant`] node with no consumers.
#[derive(Debug, Clone)]
pub struct RemoveConstIgnore(pub Node);
ss2165 marked this conversation as resolved.
Show resolved Hide resolved

/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RemoveError {
/// Invalid node.
#[error("Node is invalid (either not in HUGR or not correct operation).")]
InvalidNode(Node),
/// Node in use.
#[error("Node: {0:?} has non-zero outgoing connections.")]
ValueUsed(Node),
/// Removal error
#[error("Removing node caused error: {0:?}.")]
RemoveFail(#[from] HugrError),
}

impl Rewrite for RemoveConstIgnore {
type Error = RemoveError;

// The Const node the LoadConstant was connected to.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
return Err(RemoveError::InvalidNode(node));
}

if h.out_value_types(node)
.next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about h.out_value_types(node).exactly_one().unwrap()? Aren't we allowed to assume the input Hugr validates?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, just use h.output_neighbours().next().is_some() ?

Copy link
Member Author

@ss2165 ss2165 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe things connected by order edges would show up there as well and I want to make sure I get the value edge.

.is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some())
{
return Err(RemoveError::ValueUsed(node));
}

Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let source = h
.input_neighbours(node)
.exactly_one()
.ok()
.expect("Validation should check a Const is connected to LoadConstant.");
h.remove_node(node)?;

Ok(source)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

/// Remove a [`crate::ops::Const`] node with no outputs.
#[derive(Debug, Clone)]
pub struct RemoveConst(pub Node);

impl Rewrite for RemoveConst {
type Error = RemoveError;

// The parent of the Const node.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
return Err(RemoveError::InvalidNode(node));
}

if h.output_neighbours(node).next().is_some() {
return Err(RemoveError::ValueUsed(node));
}

Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let parent = h
.get_parent(node)
.expect("Const node without a parent shouldn't happen.");
h.remove_node(node)?;

Ok(parent)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::{
builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
extension::{
prelude::{ConstUsize, USIZE_T},
PRELUDE_REGISTRY,
},
hugr::HugrMut,
ops::{handle::NodeHandle, LeafOp},
type_row,
types::FunctionType,
};
#[test]
fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thorough, test :)

let mut build = ModuleBuilder::new();
let con_node = build.add_constant(ConstUsize::new(2))?;

let mut dfg_build =
build.define_function("main", FunctionType::new_endo(type_row![]).into())?;
let load_1 = dfg_build.load_const(&con_node)?;
let load_2 = dfg_build.load_const(&con_node)?;
let tup = dfg_build.add_dataflow_op(
LeafOp::MakeTuple {
tys: type_row![USIZE_T, USIZE_T],
},
[load_1, load_2],
)?;
dfg_build.finish_sub_container()?;

let mut h = build.finish_prelude_hugr()?;
// nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple
assert_eq!(h.node_count(), 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could change this to a Function-rooted Hugr to reduce counts by 1. Would be nice to comment why there are so many (Module, Function, Input, Output, Const, LoadConstant*2, Tuple)...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to test that non-local const->load edges work ok
Will add comment

let tup_node = tup.node();
// can't remove invalid node
assert_eq!(
h.apply_rewrite(RemoveConst(tup_node)),
Err(RemoveError::InvalidNode(tup_node))
);

assert_eq!(
h.apply_rewrite(RemoveConstIgnore(tup_node)),
Err(RemoveError::InvalidNode(tup_node))
);
let load_1_node = load_1.node();
let load_2_node = load_2.node();
let con_node = con_node.node();

let remove_1 = RemoveConstIgnore(load_1_node);
assert_eq!(
remove_1.invalidation_set().exactly_one().ok(),
Some(load_1_node)
);

let remove_2 = RemoveConstIgnore(load_2_node);

let remove_con = RemoveConst(con_node);
assert_eq!(
remove_con.invalidation_set().exactly_one().ok(),
Some(con_node)
);

// can't remove nodes in use
assert_eq!(
h.apply_rewrite(remove_1.clone()),
Err(RemoveError::ValueUsed(load_1_node))
);

// remove the use
h.remove_node(tup_node)?;

// remove first load
let reported_con_node = h.apply_rewrite(remove_1)?;
assert_eq!(reported_con_node, con_node);

// still can't remove const, in use by second load
assert_eq!(
h.apply_rewrite(remove_con.clone()),
Err(RemoveError::ValueUsed(con_node))
);

// remove second use
let reported_con_node = h.apply_rewrite(remove_2)?;
assert_eq!(reported_con_node, con_node);
// remove const
assert_eq!(h.apply_rewrite(remove_con)?, h.root());

assert_eq!(h.node_count(), 4);
assert!(h.validate(&PRELUDE_REGISTRY).is_ok());
Ok(())
}
}
2 changes: 1 addition & 1 deletion src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ fn static_targets() {
)
.unwrap();

let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap();
let c = dfg.add_constant(ConstUsize::new(1)).unwrap();

let load = dfg.load_const(&c).unwrap();

Expand Down