From ba81e7b0e2e4bf597a49e27b29e034c03431af22 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 13:26:44 +0000 Subject: [PATCH 1/9] feat: implement RemoveConst and RemoveConstIgnore as per spec refactor!: allow Into for builder.add_const BREAKING_CHANGES: existing .into() calls will error --- src/builder/build_traits.rs | 6 +- src/builder/tail_loop.rs | 4 +- src/hugr/rewrite.rs | 1 + src/hugr/rewrite/consts.rs | 228 ++++++++++++++++++++++++++++++++++++ src/hugr/views/tests.rs | 2 +- 5 files changed, 235 insertions(+), 6 deletions(-) create mode 100644 src/hugr/rewrite/consts.rs diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 641ef1ae2..c85a02eb0 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -70,8 +70,8 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant(&mut self, constant: ops::Const) -> Result { - let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?; + fn add_constant(&mut self, constant: impl Into) -> Result { + let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?; Ok(const_n.into()) } @@ -374,7 +374,7 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const(&mut self, constant: ops::Const) -> Result { + fn add_load_const(&mut self, constant: impl Into) -> Result { let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index bbcddade7..8f99ee512 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -109,7 +109,7 @@ mod test { let build_result: Result = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?; + let const_wire = loop_b.add_load_const(ConstUsize::new(1))?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -173,7 +173,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const(ConstUsize::new(2).into())?; + let wire = branch_1.add_load_const(ConstUsize::new(2))?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index 3f524e2db..05f1f48d9 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -1,5 +1,6 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. +pub mod consts; pub mod insert_identity; pub mod outline_cfg; pub mod replace; diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs new file mode 100644 index 000000000..f00b67cb0 --- /dev/null +++ b/src/hugr/rewrite/consts.rs @@ -0,0 +1,228 @@ +//! Rewrite operations involving Const and LoadConst operations + +use std::iter; + +use crate::{ + hugr::{HugrError, HugrMut}, + HugrView, Node, +}; +use itertools::Itertools; +use thiserror::Error; + +use super::Rewrite; + +/// Remove a [`crate::ops::LoadConstant`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConstIgnore(pub Node); + +/// Error from an [`RemoveConstIgnore`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveConstIgnoreError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not LoadConst).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Not connected to a Const. + #[error("Node: {0:?} is not connected to a Const node.")] + NoConst(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConstIgnore { + type Error = RemoveConstIgnoreError; + + // The Const node the LoadConstant was connected to. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { + return Err(RemoveConstIgnoreError::InvalidNode(node)); + } + + if h.out_value_types(node) + .next() + .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) + { + return Err(RemoveConstIgnoreError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .input_neighbours(node) + .exactly_one() + .map_err(|_| RemoveConstIgnoreError::NoConst(node))?; + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +/// Remove a [`crate::ops::Const`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConst(pub Node); + +/// Error from an [`RemoveConst`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveConstError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not Const).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConst { + type Error = RemoveConstError; + + // The parent of the Const node. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { + return Err(RemoveConstError::InvalidNode(node)); + } + + if h.output_neighbours(node).next().is_some() { + return Err(RemoveConstError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .get_parent(node) + .expect("Const node without a parent shouldn't happen."); + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, + extension::{ + prelude::{ConstUsize, USIZE_T}, + PRELUDE_REGISTRY, + }, + hugr::HugrMut, + ops::{handle::NodeHandle, LeafOp}, + type_row, + types::FunctionType, + }; + #[test] + fn test_const_remove() -> Result<(), Box> { + let mut build = ModuleBuilder::new(); + let con_node = build.add_constant(ConstUsize::new(2))?; + + let mut dfg_build = + build.define_function("main", FunctionType::new_endo(type_row![]).into())?; + let load_1 = dfg_build.load_const(&con_node)?; + let load_2 = dfg_build.load_const(&con_node)?; + let tup = dfg_build.add_dataflow_op( + LeafOp::MakeTuple { + tys: type_row![USIZE_T, USIZE_T], + }, + [load_1, load_2], + )?; + dfg_build.finish_sub_container()?; + + let mut h = build.finish_prelude_hugr()?; + assert_eq!(h.node_count(), 8); + let tup_node = tup.node(); + // can't remove invalid node + assert_eq!( + h.apply_rewrite(RemoveConst(tup_node)), + Err(RemoveConstError::InvalidNode(tup_node)) + ); + + assert_eq!( + h.apply_rewrite(RemoveConstIgnore(tup_node)), + Err(RemoveConstIgnoreError::InvalidNode(tup_node)) + ); + let load_1_node = load_1.node(); + let load_2_node = load_2.node(); + let con_node = con_node.node(); + + let remove_1 = RemoveConstIgnore(load_1_node); + assert_eq!( + remove_1.invalidation_set().exactly_one().ok(), + Some(load_1_node) + ); + + let remove_2 = RemoveConstIgnore(load_2_node); + + let remove_con = RemoveConst(con_node); + assert_eq!( + remove_con.invalidation_set().exactly_one().ok(), + Some(con_node) + ); + + // can't remove nodes in use + assert_eq!( + h.apply_rewrite(remove_1.clone()), + Err(RemoveConstIgnoreError::ValueUsed(load_1_node)) + ); + + // remove the use + h.remove_node(tup_node)?; + + // remove first load + let reported_con_node = h.apply_rewrite(remove_1)?; + assert_eq!(reported_con_node, con_node); + + // still can't remove const, in use by second load + assert_eq!( + h.apply_rewrite(remove_con.clone()), + Err(RemoveConstError::ValueUsed(con_node)) + ); + + // remove second use + let reported_con_node = h.apply_rewrite(remove_2)?; + assert_eq!(reported_con_node, con_node); + // remove const + assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + + assert_eq!(h.node_count(), 4); + assert!(h.validate(&PRELUDE_REGISTRY).is_ok()); + Ok(()) + } +} diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 97fb50861..9a712dc03 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -143,7 +143,7 @@ fn static_targets() { ) .unwrap(); - let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap(); + let c = dfg.add_constant(ConstUsize::new(1)).unwrap(); let load = dfg.load_const(&c).unwrap(); From 26bc5ff1afb2a589e2b7f97249ac1e3c5cdac2c0 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 13:41:17 +0000 Subject: [PATCH 2/9] add rust version guards --- src/hugr/rewrite/consts.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs index f00b67cb0..d8a0060b7 100644 --- a/src/hugr/rewrite/consts.rs +++ b/src/hugr/rewrite/consts.rs @@ -6,6 +6,7 @@ use crate::{ hugr::{HugrError, HugrMut}, HugrView, Node, }; +#[rustversion::since(1.75)] // uses impl in return position use itertools::Itertools; use thiserror::Error; @@ -32,6 +33,7 @@ pub enum RemoveConstIgnoreError { RemoveFail(#[from] HugrError), } +#[rustversion::since(1.75)] // uses impl in return position impl Rewrite for RemoveConstIgnore { type Error = RemoveConstIgnoreError; @@ -134,6 +136,7 @@ impl Rewrite for RemoveConst { } } +#[rustversion::since(1.75)] // uses impl in return position #[cfg(test)] mod test { use super::*; From 664fe89bce99ff5d02678c8ad0c57f4e4e1d3fb9 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 2 Jan 2024 11:06:34 +0000 Subject: [PATCH 3/9] feat: constant folding implemented for core and float extension (#758) Closes #711 CI will pass after updating MSRV to 1.75 (from end of year) based on #757 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Alan Lawrence Co-authored-by: Luca Mondada <72734770+lmondada@users.noreply.github.com> Co-authored-by: Luca Mondada --- .github/workflows/ci.yml | 4 +- Cargo.toml | 4 +- README.md | 2 +- src/algorithm.rs | 1 + src/algorithm/const_fold.rs | 302 ++++++++++++++++++ src/extension.rs | 2 + src/extension/const_fold.rs | 53 +++ src/extension/op_def.rs | 23 +- src/extension/prelude.rs | 66 +++- src/ops/custom.rs | 9 +- src/std_extensions/arithmetic/conversions.rs | 12 +- src/std_extensions/arithmetic/float_ops.rs | 6 +- .../arithmetic/float_ops/fold.rs | 124 +++++++ src/std_extensions/arithmetic/float_types.rs | 2 +- src/std_extensions/arithmetic/int_ops.rs | 1 + src/values.rs | 1 + 16 files changed, 592 insertions(+), 20 deletions(-) create mode 100644 src/algorithm/const_fold.rs create mode 100644 src/extension/const_fold.rs create mode 100644 src/std_extensions/arithmetic/float_ops/fold.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f895bc87c..5514dedee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: ['1.70', stable, beta, nightly] + rust: ['1.75', stable, beta, nightly] # workaround to ignore non-stable tests when running the merge queue checks # see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6 isMerge: - ${{ github.event_name == 'merge_group' }} exclude: - - rust: '1.70' + - rust: '1.75' isMerge: true - rust: beta isMerge: true diff --git a/Cargo.toml b/Cargo.toml index b1acc0bd4..ea4861062 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation" #categories = [] # TODO edition = "2021" -rust-version = "1.70" +rust-version = "1.75" [lib] # Using different names for the lib and for the package is supported, but may be confusing. @@ -46,7 +46,7 @@ lazy_static = "1.4.0" petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" -delegate = "0.11.0" +delegate = "0.12.0" rustversion = "1.0.14" paste = "1.0" strum = "0.25.0" diff --git a/README.md b/README.md index 91ab2c617..14adf32ed 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,6 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE diff --git a/src/algorithm.rs b/src/algorithm.rs index 0023b5916..633231504 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,4 +1,5 @@ //! Algorithms using the Hugr. +pub mod const_fold; mod half_node; pub mod nest_cfgs; diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs new file mode 100644 index 000000000..49474c43e --- /dev/null +++ b/src/algorithm/const_fold.rs @@ -0,0 +1,302 @@ +//! Constant folding routines. + +use std::collections::{BTreeSet, HashMap}; + +use itertools::Itertools; + +use crate::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ConstFoldResult, ExtensionRegistry}, + hugr::{ + rewrite::consts::{RemoveConst, RemoveConstIgnore}, + views::SiblingSubgraph, + HugrMut, + }, + ops::{Const, LeafOp, OpType}, + type_row, + types::{FunctionType, Type, TypeEnum}, + values::Value, + Hugr, HugrView, IncomingPort, Node, SimpleReplacement, +}; + +/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. +fn out_row(consts: impl IntoIterator) -> ConstFoldResult { + let vec = consts + .into_iter() + .enumerate() + .map(|(i, c)| (i.into(), c)) + .collect(); + Some(vec) +} + +/// Sort folding inputs with [`IncomingPort`] as key +fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { + let mut v: Vec<_> = consts.iter().collect(); + v.sort_by_key(|(i, _)| i); + v +} + +/// Sort some input constants by port and just return the constants. +pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { + sort_by_in_port(consts) + .into_iter() + .map(|(_, c)| c) + .collect() +} +/// For a given op and consts, attempt to evaluate the op. +pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { + let op = op.as_leaf_op()?; + + match op { + LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), + LeafOp::MakeTuple { .. } => { + out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())]) + } + LeafOp::UnpackTuple { .. } => { + let c = &consts.first()?.1; + + if let Value::Tuple { vs } = c.value() { + if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() { + return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| { + Const::new(v.clone(), t.clone()) + .expect("types should already have been checked") + })); + } + } + None // could panic + } + + LeafOp::Tag { tag, variants } => out_row([Const::new( + Value::sum(*tag, consts.first()?.1.value().clone()), + Type::new_sum(variants.clone()), + ) + .unwrap()]), + LeafOp::CustomOp(_) => { + let ext_op = op.as_extension_op()?; + + ext_op.constant_fold(consts) + } + _ => None, + } +} + +/// Generate a graph that loads and outputs `consts` in order, validating +/// against `reg`. +fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { + let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); + let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); + + let outputs = consts + .into_iter() + .map(|c| b.add_load_const(c).unwrap()) + .collect_vec(); + + b.finish_hugr_with_outputs(outputs, reg).unwrap() +} + +/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, +/// return an iterator of possible constant folding rewrites. The +/// [`SimpleReplacement`] replaces an operation with constants that result from +/// evaluating it, the extension registry `reg` is used to validate the +/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the +/// LoadConstant nodes that could be removed. +pub fn find_consts<'a, 'r: 'a>( + hugr: &'a impl HugrView, + candidate_nodes: impl IntoIterator + 'a, + reg: &'r ExtensionRegistry, +) -> impl Iterator)> + 'a { + // track nodes for operations that have already been considered for folding + let mut used_neighbours = BTreeSet::new(); + + candidate_nodes + .into_iter() + .filter_map(move |n| { + // only look at LoadConstant + hugr.get_optype(n).is_load_constant().then_some(())?; + + let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; + let neighbours = hugr + .linked_inputs(n, out_p) + .filter(|(n, _)| used_neighbours.insert(*n)) + .collect_vec(); + if neighbours.is_empty() { + // no uses of LoadConstant that haven't already been considered. + return None; + } + let fold_iter = neighbours + .into_iter() + .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); + Some(fold_iter) + }) + .flatten() +} + +/// Attempt to evaluate and generate rewrites for the operation at `op_node` +fn fold_op( + hugr: &impl HugrView, + op_node: Node, + reg: &ExtensionRegistry, +) -> Option<(SimpleReplacement, Vec)> { + let (in_consts, removals): (Vec<_>, Vec<_>) = hugr + .node_inputs(op_node) + .filter_map(|in_p| { + let (con_op, load_n) = get_const(hugr, op_node, in_p)?; + Some(((in_p, con_op), RemoveConstIgnore(load_n))) + }) + .unzip(); + let neighbour_op = hugr.get_optype(op_node); + // attempt to evaluate op + let folded = fold_const(neighbour_op, &in_consts)?; + let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); + let nu_out = op_outs + .into_iter() + .enumerate() + .filter_map(|(i, out)| { + // map from the ports the op was linked to, to the output ports of + // the replacement. + hugr.single_linked_input(op_node, out) + .map(|np| (np, i.into())) + }) + .collect(); + let replacement = const_graph(consts, reg); + let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) + .expect("Operation should form valid subgraph."); + + let simple_replace = SimpleReplacement::new( + sibling_graph, + replacement, + // no inputs to replacement + HashMap::new(), + nu_out, + ); + Some((simple_replace, removals)) +} + +/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant +/// and the LoadConstant node +fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> { + let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; + let load_op = hugr.get_optype(load_n).as_load_constant()?; + let const_node = hugr + .linked_outputs(load_n, load_op.constant_port()) + .exactly_one() + .ok()? + .0; + + let const_op = hugr.get_optype(const_node).as_const()?; + + // TODO avoid const clone here + Some((const_op.clone(), load_n)) +} + +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { + loop { + // would be preferable if the candidates were updated to be just the + // neighbouring nodes of those added. + let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); + if rewrites.is_empty() { + break; + } + for (replace, removes) in rewrites { + h.apply_rewrite(replace).unwrap(); + for rem in removes { + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + if h.apply_rewrite(RemoveConst(const_node)).is_err() { + // const cannot be removed - no problem + continue; + } + } + } + } + } +} + +#[cfg(test)] +mod test { + + use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::std_extensions::arithmetic; + + use crate::std_extensions::arithmetic::float_ops::FloatOps; + use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + + use rstest::rstest; + + use super::*; + + /// float to constant + fn f2c(f: f64) -> Const { + ConstF64::new(f).into() + } + + #[rstest] + #[case(0.0, 0.0, 0.0)] + #[case(0.0, 1.0, 1.0)] + #[case(23.5, 435.5, 459.0)] + // c = a + b + fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; + let add_op: OpType = FloatOps::fadd.into(); + let out = fold_const(&add_op, &consts).unwrap(); + + assert_eq!(&out[..], &[(0.into(), f2c(c))]); + } + + #[test] + fn test_big() { + /* + Test hugr approximately calculates + let x = (5.5, 3.25); + x.0 - x.1 == 2.25 + */ + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); + + let tup = build + .add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) + .unwrap(); + + let unpack = build + .add_dataflow_op( + LeafOp::UnpackTuple { + tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], + }, + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); + assert_eq!(h.node_count(), 7); + + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &f2c(2.25)); + } + fn assert_fully_folded(h: &Hugr, expected_const: &Const) { + // check the hugr just loads and returns a single const + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if c == expected_const => node_count += 1, + _ => panic!("unexpected op: {:?}", op), + } + } + + assert_eq!(node_count, 4); + } +} diff --git a/src/extension.rs b/src/extension.rs index 95b0474ea..5519e456b 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -28,9 +28,11 @@ pub use op_def::{ }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; +mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; +pub use const_fold::{ConstFold, ConstFoldResult}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs new file mode 100644 index 000000000..29cf4b5a9 --- /dev/null +++ b/src/extension/const_fold.rs @@ -0,0 +1,53 @@ +use std::fmt::Formatter; + +use std::fmt::Debug; + +use crate::types::TypeArg; + +use crate::OutgoingPort; + +use crate::ops; + +/// Output of constant folding an operation, None indicates folding was either +/// not possible or unsuccessful. An empty vector indicates folding was +/// successful and no values are output. +pub type ConstFoldResult = Option>; + +/// Trait implemented by extension operations that can perform constant folding. +pub trait ConstFold: Send + Sync { + /// Given type arguments `type_args` and + /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, + /// try to evaluate the operation. + fn fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult; +} + +impl Debug for Box { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl Default for Box { + fn default() -> Self { + Box::new(|&_: &_| None) + } +} + +/// Blanket implementation for functions that only require the constants to +/// evaluate - type arguments are not relevant. +impl ConstFold for T +where + T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync, +{ + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self(consts) + } +} diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 143426123..2ea686ab4 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -7,7 +7,8 @@ use std::sync::Arc; use smol_str::SmolStr; use super::{ - Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, + ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, + ExtensionSet, SignatureError, }; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; @@ -307,6 +308,9 @@ pub struct OpDef { // can only treat them as opaque/black-box ops. #[serde(flatten)] lower_funcs: Vec, + + #[serde(skip)] + constant_folder: Box, } impl OpDef { @@ -412,6 +416,22 @@ impl OpDef { ) -> Option { self.misc.insert(k.to_string(), v) } + + /// Set the constant folding function for this Op, which can evaluate it + /// given constant inputs. + pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { + self.constant_folder = Box::new(fold) + } + + /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given + /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. + pub fn constant_fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self.constant_folder.fold(type_args, consts) + } } impl Extension { @@ -432,6 +452,7 @@ impl Extension { signature_func: signature_func.into(), misc: Default::default(), lower_funcs: Default::default(), + constant_folder: Default::default(), }; match self.operations.entry(op.name.clone()) { diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index f96046ba8..b411c3667 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,12 +137,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } +/// The custom type for Errors. +pub const ERROR_CUSTOM_TYPE: CustomType = + CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( - ERROR_TYPE_NAME, - PRELUDE_ID, - TypeBound::Eq, -)); +pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -191,6 +190,48 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +/// Structure for holding constant usize values. +pub struct ConstError { + /// Integer tag/signal for the error. + pub signal: u32, + /// Error message. + pub message: String, +} + +impl ConstError { + /// Define a new error value. + pub fn new(signal: u32, message: impl ToString) -> Self { + Self { + signal, + message: message.to_string(), + } + } +} + +#[typetag::serde] +impl CustomConst for ConstError { + fn name(&self) -> SmolStr { + format!("ConstError({:?}, {:?})", self.signal, self.message).into() + } + + fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { + self.check_known_type(typ) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::values::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } +} + +impl KnownTypeConst for ConstError { + const TYPE: CustomType = ERROR_CUSTOM_TYPE; +} + #[cfg(test)] mod test { use crate::{ @@ -219,7 +260,7 @@ mod test { } #[test] - /// Test building a HUGR involving a new_array operation. + /// test the prelude error type. fn test_error_type() { let ext_def = PRELUDE .get_type(&ERROR_TYPE_NAME) @@ -229,5 +270,18 @@ mod test { let ext_type = Type::new_extension(ext_def); assert_eq!(ext_type, ERROR_TYPE); + + let error_val = ConstError::new(2, "my message"); + + assert_eq!(error_val.name(), "ConstError(2, \"my message\")"); + + assert!(error_val.check_custom_type(&ERROR_CUSTOM_TYPE).is_ok()); + + assert_eq!( + error_val.extension_reqs(), + ExtensionSet::singleton(&PRELUDE_ID) + ); + assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); + assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 9d2c0ab00..d4e4c88a6 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -4,11 +4,11 @@ use smol_str::SmolStr; use std::sync::Arc; use thiserror::Error; -use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; +use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; -use crate::{Hugr, Node}; +use crate::{ops, Hugr, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; use super::tag::OpTag; @@ -137,6 +137,11 @@ impl ExtensionOp { pub fn def(&self) -> &OpDef { self.def.as_ref() } + + /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. + pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { + self.def().constant_fold(self.args(), consts) + } } impl From for OpaqueOp { diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 98e5df887..23b457f7c 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -8,6 +8,7 @@ use crate::{ prelude::sum_with_error, simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, + PRELUDE, }, ops::{custom::ExtensionOp, OpName}, type_row, @@ -18,7 +19,6 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; - /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -69,7 +69,7 @@ impl MakeOpDef for ConvertOpDef { #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { def: ConvertOpDef, - width: u64, + log_width: u64, } impl OpName for ConvertOpType { @@ -85,11 +85,14 @@ impl MakeExtensionOp for ConvertOpType { [TypeArg::BoundedNat { n }] => n, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(Self { def, width }) + Ok(Self { + def, + log_width: width, + }) } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.width }] + vec![TypeArg::BoundedNat { n: self.log_width }] } } @@ -111,6 +114,7 @@ lazy_static! { /// Registry of extensions required to validate integer operations. pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), super::int_types::EXTENSION.to_owned(), super::float_types::EXTENSION.to_owned(), EXTENSION.to_owned(), diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 87c87751b..672b8a1ed 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; use lazy_static::lazy_static; - +mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); @@ -82,6 +82,10 @@ impl MakeOpDef for FloatOps { } .to_string() } + + fn post_opdef(&self, def: &mut OpDef) { + fold::set_fold(self, def) + } } lazy_static! { diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/fold.rs new file mode 100644 index 000000000..34d162f4d --- /dev/null +++ b/src/std_extensions/arithmetic/float_ops/fold.rs @@ -0,0 +1,124 @@ +use crate::{ + algorithm::const_fold::sorted_consts, + extension::{ConstFold, ConstFoldResult, OpDef}, + ops, + std_extensions::arithmetic::float_types::ConstF64, + IncomingPort, +}; + +use super::FloatOps; + +pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { + use FloatOps::*; + + match op { + fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)), + feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)), + fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)), + } +} + +/// Extract float values from constants in port order. +fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[f64; N]> { + let consts: [&ops::Const; N] = sorted_consts(consts).try_into().ok()?; + + Some(consts.map(|c| { + let const_f64: &ConstF64 = c + .get_custom_value() + .expect("This function assumes all incoming constants are floats."); + const_f64.value() + })) +} + +/// Fold binary operations +struct BinaryFold(Box f64 + Send + Sync>); +impl BinaryFold { + fn from_op(op: &FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(match op { + fmax => f64::max, + fmin => f64::min, + fadd => std::ops::Add::add, + fsub => std::ops::Sub::sub, + fmul => std::ops::Mul::mul, + fdiv => std::ops::Div::div, + _ => panic!("not binary op"), + })) + } +} +impl ConstFold for BinaryFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1, f2] = get_floats(consts)?; + + let res = ConstF64::new((self.0)(f1, f2)); + Some(vec![(0.into(), res.into())]) + } +} + +/// Fold comparisons. +struct CmpFold(Box bool + Send + Sync>); +impl CmpFold { + fn from_op(op: FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(move |x, y| { + (match op { + feq => f64::eq, + fne => f64::lt, + flt => f64::lt, + fgt => f64::gt, + fle => f64::le, + fge => f64::ge, + _ => panic!("not cmp op"), + })(&x, &y) + })) + } +} + +impl ConstFold for CmpFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1, f2] = get_floats(consts)?; + + let res = if (self.0)(f1, f2) { + ops::Const::true_val() + } else { + ops::Const::false_val() + }; + + Some(vec![(0.into(), res)]) + } +} + +/// Fold unary operations +struct UnaryFold(Box f64 + Send + Sync>); +impl UnaryFold { + fn from_op(op: &FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(match op { + fneg => std::ops::Neg::neg, + fabs => f64::abs, + ffloor => f64::floor, + fceil => f64::ceil, + _ => panic!("not unary op."), + })) + } +} + +impl ConstFold for UnaryFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1] = get_floats(consts)?; + let res = ConstF64::new((self.0)(f1)); + Some(vec![(0.into(), res.into())]) + } +} diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 71f91bf87..ba5fe0956 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -40,7 +40,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Create a new [`ConstF64`] - pub fn new(value: f64) -> Self { + pub const fn new(value: f64) -> Self { Self { value } } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index ae5160ffd..e1cea6c49 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -11,6 +11,7 @@ use crate::ops::OpName; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; + use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, types::{type_param::TypeArg, Type, TypeRow}, diff --git a/src/values.rs b/src/values.rs index 17d173a00..46d2778ee 100644 --- a/src/values.rs +++ b/src/values.rs @@ -10,6 +10,7 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; + use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType}; From f7a4cf7614381cc17a66dec1c6a214e786978e99 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:09:38 +0000 Subject: [PATCH 4/9] Revert "feat: constant folding implemented for core and float extension (#758)" This reverts commit 664fe89bce99ff5d02678c8ad0c57f4e4e1d3fb9. --- .github/workflows/ci.yml | 4 +- Cargo.toml | 4 +- README.md | 2 +- src/algorithm.rs | 1 - src/algorithm/const_fold.rs | 302 ------------------ src/extension.rs | 2 - src/extension/const_fold.rs | 53 --- src/extension/op_def.rs | 23 +- src/extension/prelude.rs | 66 +--- src/ops/custom.rs | 9 +- src/std_extensions/arithmetic/conversions.rs | 12 +- src/std_extensions/arithmetic/float_ops.rs | 6 +- .../arithmetic/float_ops/fold.rs | 124 ------- src/std_extensions/arithmetic/float_types.rs | 2 +- src/std_extensions/arithmetic/int_ops.rs | 1 - src/values.rs | 1 - 16 files changed, 20 insertions(+), 592 deletions(-) delete mode 100644 src/algorithm/const_fold.rs delete mode 100644 src/extension/const_fold.rs delete mode 100644 src/std_extensions/arithmetic/float_ops/fold.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5514dedee..f895bc87c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: ['1.75', stable, beta, nightly] + rust: ['1.70', stable, beta, nightly] # workaround to ignore non-stable tests when running the merge queue checks # see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6 isMerge: - ${{ github.event_name == 'merge_group' }} exclude: - - rust: '1.75' + - rust: '1.70' isMerge: true - rust: beta isMerge: true diff --git a/Cargo.toml b/Cargo.toml index ea4861062..b1acc0bd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation" #categories = [] # TODO edition = "2021" -rust-version = "1.75" +rust-version = "1.70" [lib] # Using different names for the lib and for the package is supported, but may be confusing. @@ -46,7 +46,7 @@ lazy_static = "1.4.0" petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" -delegate = "0.12.0" +delegate = "0.11.0" rustversion = "1.0.14" paste = "1.0" strum = "0.25.0" diff --git a/README.md b/README.md index 14adf32ed..91ab2c617 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,6 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE diff --git a/src/algorithm.rs b/src/algorithm.rs index 633231504..0023b5916 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,5 +1,4 @@ //! Algorithms using the Hugr. -pub mod const_fold; mod half_node; pub mod nest_cfgs; diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs deleted file mode 100644 index 49474c43e..000000000 --- a/src/algorithm/const_fold.rs +++ /dev/null @@ -1,302 +0,0 @@ -//! Constant folding routines. - -use std::collections::{BTreeSet, HashMap}; - -use itertools::Itertools; - -use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{ConstFoldResult, ExtensionRegistry}, - hugr::{ - rewrite::consts::{RemoveConst, RemoveConstIgnore}, - views::SiblingSubgraph, - HugrMut, - }, - ops::{Const, LeafOp, OpType}, - type_row, - types::{FunctionType, Type, TypeEnum}, - values::Value, - Hugr, HugrView, IncomingPort, Node, SimpleReplacement, -}; - -/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. -fn out_row(consts: impl IntoIterator) -> ConstFoldResult { - let vec = consts - .into_iter() - .enumerate() - .map(|(i, c)| (i.into(), c)) - .collect(); - Some(vec) -} - -/// Sort folding inputs with [`IncomingPort`] as key -fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { - let mut v: Vec<_> = consts.iter().collect(); - v.sort_by_key(|(i, _)| i); - v -} - -/// Sort some input constants by port and just return the constants. -pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { - sort_by_in_port(consts) - .into_iter() - .map(|(_, c)| c) - .collect() -} -/// For a given op and consts, attempt to evaluate the op. -pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { - let op = op.as_leaf_op()?; - - match op { - LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), - LeafOp::MakeTuple { .. } => { - out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())]) - } - LeafOp::UnpackTuple { .. } => { - let c = &consts.first()?.1; - - if let Value::Tuple { vs } = c.value() { - if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() { - return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| { - Const::new(v.clone(), t.clone()) - .expect("types should already have been checked") - })); - } - } - None // could panic - } - - LeafOp::Tag { tag, variants } => out_row([Const::new( - Value::sum(*tag, consts.first()?.1.value().clone()), - Type::new_sum(variants.clone()), - ) - .unwrap()]), - LeafOp::CustomOp(_) => { - let ext_op = op.as_extension_op()?; - - ext_op.constant_fold(consts) - } - _ => None, - } -} - -/// Generate a graph that loads and outputs `consts` in order, validating -/// against `reg`. -fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { - let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); - let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); - - let outputs = consts - .into_iter() - .map(|c| b.add_load_const(c).unwrap()) - .collect_vec(); - - b.finish_hugr_with_outputs(outputs, reg).unwrap() -} - -/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, -/// return an iterator of possible constant folding rewrites. The -/// [`SimpleReplacement`] replaces an operation with constants that result from -/// evaluating it, the extension registry `reg` is used to validate the -/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the -/// LoadConstant nodes that could be removed. -pub fn find_consts<'a, 'r: 'a>( - hugr: &'a impl HugrView, - candidate_nodes: impl IntoIterator + 'a, - reg: &'r ExtensionRegistry, -) -> impl Iterator)> + 'a { - // track nodes for operations that have already been considered for folding - let mut used_neighbours = BTreeSet::new(); - - candidate_nodes - .into_iter() - .filter_map(move |n| { - // only look at LoadConstant - hugr.get_optype(n).is_load_constant().then_some(())?; - - let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; - let neighbours = hugr - .linked_inputs(n, out_p) - .filter(|(n, _)| used_neighbours.insert(*n)) - .collect_vec(); - if neighbours.is_empty() { - // no uses of LoadConstant that haven't already been considered. - return None; - } - let fold_iter = neighbours - .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); - Some(fold_iter) - }) - .flatten() -} - -/// Attempt to evaluate and generate rewrites for the operation at `op_node` -fn fold_op( - hugr: &impl HugrView, - op_node: Node, - reg: &ExtensionRegistry, -) -> Option<(SimpleReplacement, Vec)> { - let (in_consts, removals): (Vec<_>, Vec<_>) = hugr - .node_inputs(op_node) - .filter_map(|in_p| { - let (con_op, load_n) = get_const(hugr, op_node, in_p)?; - Some(((in_p, con_op), RemoveConstIgnore(load_n))) - }) - .unzip(); - let neighbour_op = hugr.get_optype(op_node); - // attempt to evaluate op - let folded = fold_const(neighbour_op, &in_consts)?; - let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); - let nu_out = op_outs - .into_iter() - .enumerate() - .filter_map(|(i, out)| { - // map from the ports the op was linked to, to the output ports of - // the replacement. - hugr.single_linked_input(op_node, out) - .map(|np| (np, i.into())) - }) - .collect(); - let replacement = const_graph(consts, reg); - let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Operation should form valid subgraph."); - - let simple_replace = SimpleReplacement::new( - sibling_graph, - replacement, - // no inputs to replacement - HashMap::new(), - nu_out, - ); - Some((simple_replace, removals)) -} - -/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant -/// and the LoadConstant node -fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> { - let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; - let load_op = hugr.get_optype(load_n).as_load_constant()?; - let const_node = hugr - .linked_outputs(load_n, load_op.constant_port()) - .exactly_one() - .ok()? - .0; - - let const_op = hugr.get_optype(const_node).as_const()?; - - // TODO avoid const clone here - Some((const_op.clone(), load_n)) -} - -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { - loop { - // would be preferable if the candidates were updated to be just the - // neighbouring nodes of those added. - let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); - if rewrites.is_empty() { - break; - } - for (replace, removes) in rewrites { - h.apply_rewrite(replace).unwrap(); - for rem in removes { - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - if h.apply_rewrite(RemoveConst(const_node)).is_err() { - // const cannot be removed - no problem - continue; - } - } - } - } - } -} - -#[cfg(test)] -mod test { - - use crate::extension::{ExtensionRegistry, PRELUDE}; - use crate::std_extensions::arithmetic; - - use crate::std_extensions::arithmetic::float_ops::FloatOps; - use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - - use rstest::rstest; - - use super::*; - - /// float to constant - fn f2c(f: f64) -> Const { - ConstF64::new(f).into() - } - - #[rstest] - #[case(0.0, 0.0, 0.0)] - #[case(0.0, 1.0, 1.0)] - #[case(23.5, 435.5, 459.0)] - // c = a + b - fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let out = fold_const(&add_op, &consts).unwrap(); - - assert_eq!(&out[..], &[(0.into(), f2c(c))]); - } - - #[test] - fn test_big() { - /* - Test hugr approximately calculates - let x = (5.5, 3.25); - x.0 - x.1 == 2.25 - */ - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); - - let tup = build - .add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) - .unwrap(); - - let unpack = build - .add_dataflow_op( - LeafOp::UnpackTuple { - tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], - }, - [tup], - ) - .unwrap(); - - let sub = build - .add_dataflow_op(FloatOps::fsub, unpack.outputs()) - .unwrap(); - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); - assert_eq!(h.node_count(), 7); - - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &f2c(2.25)); - } - fn assert_fully_folded(h: &Hugr, expected_const: &Const) { - // check the hugr just loads and returns a single const - let mut node_count = 0; - - for node in h.children(h.root()) { - let op = h.get_optype(node); - match op { - OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, - OpType::Const(c) if c == expected_const => node_count += 1, - _ => panic!("unexpected op: {:?}", op), - } - } - - assert_eq!(node_count, 4); - } -} diff --git a/src/extension.rs b/src/extension.rs index 5519e456b..95b0474ea 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -28,11 +28,9 @@ pub use op_def::{ }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; -mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; -pub use const_fold::{ConstFold, ConstFoldResult}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs deleted file mode 100644 index 29cf4b5a9..000000000 --- a/src/extension/const_fold.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::fmt::Formatter; - -use std::fmt::Debug; - -use crate::types::TypeArg; - -use crate::OutgoingPort; - -use crate::ops; - -/// Output of constant folding an operation, None indicates folding was either -/// not possible or unsuccessful. An empty vector indicates folding was -/// successful and no values are output. -pub type ConstFoldResult = Option>; - -/// Trait implemented by extension operations that can perform constant folding. -pub trait ConstFold: Send + Sync { - /// Given type arguments `type_args` and - /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, - /// try to evaluate the operation. - fn fold( - &self, - type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult; -} - -impl Debug for Box { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "") - } -} - -impl Default for Box { - fn default() -> Self { - Box::new(|&_: &_| None) - } -} - -/// Blanket implementation for functions that only require the constants to -/// evaluate - type arguments are not relevant. -impl ConstFold for T -where - T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync, -{ - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult { - self(consts) - } -} diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 2ea686ab4..143426123 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -7,8 +7,7 @@ use std::sync::Arc; use smol_str::SmolStr; use super::{ - ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, - ExtensionSet, SignatureError, + Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, }; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; @@ -308,9 +307,6 @@ pub struct OpDef { // can only treat them as opaque/black-box ops. #[serde(flatten)] lower_funcs: Vec, - - #[serde(skip)] - constant_folder: Box, } impl OpDef { @@ -416,22 +412,6 @@ impl OpDef { ) -> Option { self.misc.insert(k.to_string(), v) } - - /// Set the constant folding function for this Op, which can evaluate it - /// given constant inputs. - pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { - self.constant_folder = Box::new(fold) - } - - /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given - /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. - pub fn constant_fold( - &self, - type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult { - self.constant_folder.fold(type_args, consts) - } } impl Extension { @@ -452,7 +432,6 @@ impl Extension { signature_func: signature_func.into(), misc: Default::default(), lower_funcs: Default::default(), - constant_folder: Default::default(), }; match self.operations.entry(op.name.clone()) { diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index b411c3667..f96046ba8 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,11 +137,12 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } -/// The custom type for Errors. -pub const ERROR_CUSTOM_TYPE: CustomType = - CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); +pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( + ERROR_TYPE_NAME, + PRELUDE_ID, + TypeBound::Eq, +)); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -190,48 +191,6 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -/// Structure for holding constant usize values. -pub struct ConstError { - /// Integer tag/signal for the error. - pub signal: u32, - /// Error message. - pub message: String, -} - -impl ConstError { - /// Define a new error value. - pub fn new(signal: u32, message: impl ToString) -> Self { - Self { - signal, - message: message.to_string(), - } - } -} - -#[typetag::serde] -impl CustomConst for ConstError { - fn name(&self) -> SmolStr { - format!("ConstError({:?}, {:?})", self.signal, self.message).into() - } - - fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - self.check_known_type(typ) - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::values::downcast_equal_consts(self, other) - } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&PRELUDE_ID) - } -} - -impl KnownTypeConst for ConstError { - const TYPE: CustomType = ERROR_CUSTOM_TYPE; -} - #[cfg(test)] mod test { use crate::{ @@ -260,7 +219,7 @@ mod test { } #[test] - /// test the prelude error type. + /// Test building a HUGR involving a new_array operation. fn test_error_type() { let ext_def = PRELUDE .get_type(&ERROR_TYPE_NAME) @@ -270,18 +229,5 @@ mod test { let ext_type = Type::new_extension(ext_def); assert_eq!(ext_type, ERROR_TYPE); - - let error_val = ConstError::new(2, "my message"); - - assert_eq!(error_val.name(), "ConstError(2, \"my message\")"); - - assert!(error_val.check_custom_type(&ERROR_CUSTOM_TYPE).is_ok()); - - assert_eq!( - error_val.extension_reqs(), - ExtensionSet::singleton(&PRELUDE_ID) - ); - assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); - assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index d4e4c88a6..9d2c0ab00 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -4,11 +4,11 @@ use smol_str::SmolStr; use std::sync::Arc; use thiserror::Error; -use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; +use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; -use crate::{ops, Hugr, IncomingPort, Node}; +use crate::{Hugr, Node}; use super::dataflow::DataflowOpTrait; use super::tag::OpTag; @@ -137,11 +137,6 @@ impl ExtensionOp { pub fn def(&self) -> &OpDef { self.def.as_ref() } - - /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. - pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { - self.def().constant_fold(self.args(), consts) - } } impl From for OpaqueOp { diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 23b457f7c..98e5df887 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -8,7 +8,6 @@ use crate::{ prelude::sum_with_error, simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, - PRELUDE, }, ops::{custom::ExtensionOp, OpName}, type_row, @@ -19,6 +18,7 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; + /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -69,7 +69,7 @@ impl MakeOpDef for ConvertOpDef { #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { def: ConvertOpDef, - log_width: u64, + width: u64, } impl OpName for ConvertOpType { @@ -85,14 +85,11 @@ impl MakeExtensionOp for ConvertOpType { [TypeArg::BoundedNat { n }] => n, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(Self { - def, - log_width: width, - }) + Ok(Self { def, width }) } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.log_width }] + vec![TypeArg::BoundedNat { n: self.width }] } } @@ -114,7 +111,6 @@ lazy_static! { /// Registry of extensions required to validate integer operations. pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), super::int_types::EXTENSION.to_owned(), super::float_types::EXTENSION.to_owned(), EXTENSION.to_owned(), diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 672b8a1ed..87c87751b 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; use lazy_static::lazy_static; -mod fold; + /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); @@ -82,10 +82,6 @@ impl MakeOpDef for FloatOps { } .to_string() } - - fn post_opdef(&self, def: &mut OpDef) { - fold::set_fold(self, def) - } } lazy_static! { diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/fold.rs deleted file mode 100644 index 34d162f4d..000000000 --- a/src/std_extensions/arithmetic/float_ops/fold.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::{ - algorithm::const_fold::sorted_consts, - extension::{ConstFold, ConstFoldResult, OpDef}, - ops, - std_extensions::arithmetic::float_types::ConstF64, - IncomingPort, -}; - -use super::FloatOps; - -pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { - use FloatOps::*; - - match op { - fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)), - feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)), - fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)), - } -} - -/// Extract float values from constants in port order. -fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[f64; N]> { - let consts: [&ops::Const; N] = sorted_consts(consts).try_into().ok()?; - - Some(consts.map(|c| { - let const_f64: &ConstF64 = c - .get_custom_value() - .expect("This function assumes all incoming constants are floats."); - const_f64.value() - })) -} - -/// Fold binary operations -struct BinaryFold(Box f64 + Send + Sync>); -impl BinaryFold { - fn from_op(op: &FloatOps) -> Self { - use FloatOps::*; - Self(Box::new(match op { - fmax => f64::max, - fmin => f64::min, - fadd => std::ops::Add::add, - fsub => std::ops::Sub::sub, - fmul => std::ops::Mul::mul, - fdiv => std::ops::Div::div, - _ => panic!("not binary op"), - })) - } -} -impl ConstFold for BinaryFold { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let [f1, f2] = get_floats(consts)?; - - let res = ConstF64::new((self.0)(f1, f2)); - Some(vec![(0.into(), res.into())]) - } -} - -/// Fold comparisons. -struct CmpFold(Box bool + Send + Sync>); -impl CmpFold { - fn from_op(op: FloatOps) -> Self { - use FloatOps::*; - Self(Box::new(move |x, y| { - (match op { - feq => f64::eq, - fne => f64::lt, - flt => f64::lt, - fgt => f64::gt, - fle => f64::le, - fge => f64::ge, - _ => panic!("not cmp op"), - })(&x, &y) - })) - } -} - -impl ConstFold for CmpFold { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let [f1, f2] = get_floats(consts)?; - - let res = if (self.0)(f1, f2) { - ops::Const::true_val() - } else { - ops::Const::false_val() - }; - - Some(vec![(0.into(), res)]) - } -} - -/// Fold unary operations -struct UnaryFold(Box f64 + Send + Sync>); -impl UnaryFold { - fn from_op(op: &FloatOps) -> Self { - use FloatOps::*; - Self(Box::new(match op { - fneg => std::ops::Neg::neg, - fabs => f64::abs, - ffloor => f64::floor, - fceil => f64::ceil, - _ => panic!("not unary op."), - })) - } -} - -impl ConstFold for UnaryFold { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let [f1] = get_floats(consts)?; - let res = ConstF64::new((self.0)(f1)); - Some(vec![(0.into(), res.into())]) - } -} diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index ba5fe0956..71f91bf87 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -40,7 +40,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Create a new [`ConstF64`] - pub const fn new(value: f64) -> Self { + pub fn new(value: f64) -> Self { Self { value } } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index e1cea6c49..ae5160ffd 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -11,7 +11,6 @@ use crate::ops::OpName; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; - use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, types::{type_param::TypeArg, Type, TypeRow}, diff --git a/src/values.rs b/src/values.rs index 46d2778ee..17d173a00 100644 --- a/src/values.rs +++ b/src/values.rs @@ -10,7 +10,6 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; - use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType}; From 70d99dd2bfd99b260e627de198d990a1dd8bae0a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:22:55 +0000 Subject: [PATCH 5/9] minor review suggestions --- src/hugr/rewrite/consts.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs index ca0b3afb3..4e3e1c873 100644 --- a/src/hugr/rewrite/consts.rs +++ b/src/hugr/rewrite/consts.rs @@ -12,7 +12,7 @@ use thiserror::Error; use super::Rewrite; -/// Remove a [`crate::ops::LoadConstant`] node with no outputs. +/// Remove a [`crate::ops::LoadConstant`] node with no consumers. #[derive(Debug, Clone)] pub struct RemoveConstIgnore(pub Node); @@ -20,7 +20,7 @@ pub struct RemoveConstIgnore(pub Node); #[derive(Debug, Clone, Error, PartialEq, Eq)] pub enum RemoveConstIgnoreError { /// Invalid node. - #[error("Node is invalid (either not in HUGR or not LoadConst).")] + #[error("Node is invalid (either not in HUGR or not LoadConstant).")] InvalidNode(Node), /// Node in use. #[error("Node: {0:?} has non-zero outgoing connections.")] @@ -122,12 +122,12 @@ impl Rewrite for RemoveConst { fn apply(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; - let source = h + let parent = h .get_parent(node) .expect("Const node without a parent shouldn't happen."); h.remove_node(node)?; - Ok(source) + Ok(parent) } fn invalidation_set(&self) -> Self::InvalidationSet<'_> { @@ -167,6 +167,7 @@ mod test { dfg_build.finish_sub_container()?; let mut h = build.finish_prelude_hugr()?; + // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple assert_eq!(h.node_count(), 8); let tup_node = tup.node(); // can't remove invalid node From 9b6999b8e9faaaf65308ba6c91fb7a01fd5f03e1 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:25:46 +0000 Subject: [PATCH 6/9] remove error that should be prevented by validation --- src/hugr/rewrite/consts.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs index 4e3e1c873..8921c174d 100644 --- a/src/hugr/rewrite/consts.rs +++ b/src/hugr/rewrite/consts.rs @@ -25,9 +25,6 @@ pub enum RemoveConstIgnoreError { /// Node in use. #[error("Node: {0:?} has non-zero outgoing connections.")] ValueUsed(Node), - /// Not connected to a Const. - #[error("Node: {0:?} is not connected to a Const node.")] - NoConst(Node), /// Removal error #[error("Removing node caused error: {0:?}.")] RemoveFail(#[from] HugrError), @@ -66,7 +63,8 @@ impl Rewrite for RemoveConstIgnore { let source = h .input_neighbours(node) .exactly_one() - .map_err(|_| RemoveConstIgnoreError::NoConst(node))?; + .ok() + .expect("Validation should check a Const is connected to LoadConstant."); h.remove_node(node)?; Ok(source) From 37dfb0d31aad9ebae44f64b969183bbb447231cc Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:27:26 +0000 Subject: [PATCH 7/9] merge error enums --- src/hugr/rewrite/consts.rs | 40 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs index 8921c174d..61be178d5 100644 --- a/src/hugr/rewrite/consts.rs +++ b/src/hugr/rewrite/consts.rs @@ -16,11 +16,11 @@ use super::Rewrite; #[derive(Debug, Clone)] pub struct RemoveConstIgnore(pub Node); -/// Error from an [`RemoveConstIgnore`] operation. +/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum RemoveConstIgnoreError { +pub enum RemoveError { /// Invalid node. - #[error("Node is invalid (either not in HUGR or not LoadConstant).")] + #[error("Node is invalid (either not in HUGR or not correct operation).")] InvalidNode(Node), /// Node in use. #[error("Node: {0:?} has non-zero outgoing connections.")] @@ -31,7 +31,7 @@ pub enum RemoveConstIgnoreError { } impl Rewrite for RemoveConstIgnore { - type Error = RemoveConstIgnoreError; + type Error = RemoveError; // The Const node the LoadConstant was connected to. type ApplyResult = Node; @@ -44,14 +44,14 @@ impl Rewrite for RemoveConstIgnore { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { - return Err(RemoveConstIgnoreError::InvalidNode(node)); + return Err(RemoveError::InvalidNode(node)); } if h.out_value_types(node) .next() .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) { - return Err(RemoveConstIgnoreError::ValueUsed(node)); + return Err(RemoveError::ValueUsed(node)); } Ok(()) @@ -79,22 +79,8 @@ impl Rewrite for RemoveConstIgnore { #[derive(Debug, Clone)] pub struct RemoveConst(pub Node); -/// Error from an [`RemoveConst`] operation. -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum RemoveConstError { - /// Invalid node. - #[error("Node is invalid (either not in HUGR or not Const).")] - InvalidNode(Node), - /// Node in use. - #[error("Node: {0:?} has non-zero outgoing connections.")] - ValueUsed(Node), - /// Removal error - #[error("Removing node caused error: {0:?}.")] - RemoveFail(#[from] HugrError), -} - impl Rewrite for RemoveConst { - type Error = RemoveConstError; + type Error = RemoveError; // The parent of the Const node. type ApplyResult = Node; @@ -107,11 +93,11 @@ impl Rewrite for RemoveConst { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { - return Err(RemoveConstError::InvalidNode(node)); + return Err(RemoveError::InvalidNode(node)); } if h.output_neighbours(node).next().is_some() { - return Err(RemoveConstError::ValueUsed(node)); + return Err(RemoveError::ValueUsed(node)); } Ok(()) @@ -171,12 +157,12 @@ mod test { // can't remove invalid node assert_eq!( h.apply_rewrite(RemoveConst(tup_node)), - Err(RemoveConstError::InvalidNode(tup_node)) + Err(RemoveError::InvalidNode(tup_node)) ); assert_eq!( h.apply_rewrite(RemoveConstIgnore(tup_node)), - Err(RemoveConstIgnoreError::InvalidNode(tup_node)) + Err(RemoveError::InvalidNode(tup_node)) ); let load_1_node = load_1.node(); let load_2_node = load_2.node(); @@ -199,7 +185,7 @@ mod test { // can't remove nodes in use assert_eq!( h.apply_rewrite(remove_1.clone()), - Err(RemoveConstIgnoreError::ValueUsed(load_1_node)) + Err(RemoveError::ValueUsed(load_1_node)) ); // remove the use @@ -212,7 +198,7 @@ mod test { // still can't remove const, in use by second load assert_eq!( h.apply_rewrite(remove_con.clone()), - Err(RemoveConstError::ValueUsed(con_node)) + Err(RemoveError::ValueUsed(con_node)) ); // remove second use From a8c2ca9e8899ae8562d9a7ed578547923844bdfa Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 2 Jan 2024 09:52:53 +0000 Subject: [PATCH 8/9] chore!: hike MSRV to 1.75 (#761) --- .github/workflows/ci.yml | 4 ++-- Cargo.toml | 2 +- README.md | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f895bc87c..5514dedee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: ['1.70', stable, beta, nightly] + rust: ['1.75', stable, beta, nightly] # workaround to ignore non-stable tests when running the merge queue checks # see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6 isMerge: - ${{ github.event_name == 'merge_group' }} exclude: - - rust: '1.70' + - rust: '1.75' isMerge: true - rust: beta isMerge: true diff --git a/Cargo.toml b/Cargo.toml index a2d81cd7d..c633a1222 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation" #categories = [] # TODO edition = "2021" -rust-version = "1.70" +rust-version = "1.75" [lib] # Using different names for the lib and for the package is supported, but may be confusing. diff --git a/README.md b/README.md index 1c0437eeb..4428eac1b 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/quantinuum-hugr/ [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE [CHANGELOG]: CHANGELOG.md From b3721af4d37a196bd47226c85c47c900a174c9d5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 13:30:18 +0000 Subject: [PATCH 9/9] feat: Custom const for ERROR_TYPE (#756) --- src/extension/prelude.rs | 66 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index f96046ba8..b411c3667 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,12 +137,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } +/// The custom type for Errors. +pub const ERROR_CUSTOM_TYPE: CustomType = + CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( - ERROR_TYPE_NAME, - PRELUDE_ID, - TypeBound::Eq, -)); +pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -191,6 +190,48 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +/// Structure for holding constant usize values. +pub struct ConstError { + /// Integer tag/signal for the error. + pub signal: u32, + /// Error message. + pub message: String, +} + +impl ConstError { + /// Define a new error value. + pub fn new(signal: u32, message: impl ToString) -> Self { + Self { + signal, + message: message.to_string(), + } + } +} + +#[typetag::serde] +impl CustomConst for ConstError { + fn name(&self) -> SmolStr { + format!("ConstError({:?}, {:?})", self.signal, self.message).into() + } + + fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { + self.check_known_type(typ) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::values::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } +} + +impl KnownTypeConst for ConstError { + const TYPE: CustomType = ERROR_CUSTOM_TYPE; +} + #[cfg(test)] mod test { use crate::{ @@ -219,7 +260,7 @@ mod test { } #[test] - /// Test building a HUGR involving a new_array operation. + /// test the prelude error type. fn test_error_type() { let ext_def = PRELUDE .get_type(&ERROR_TYPE_NAME) @@ -229,5 +270,18 @@ mod test { let ext_type = Type::new_extension(ext_def); assert_eq!(ext_type, ERROR_TYPE); + + let error_val = ConstError::new(2, "my message"); + + assert_eq!(error_val.name(), "ConstError(2, \"my message\")"); + + assert!(error_val.check_custom_type(&ERROR_CUSTOM_TYPE).is_ok()); + + assert_eq!( + error_val.extension_reqs(), + ExtensionSet::singleton(&PRELUDE_ID) + ); + assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); + assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); } }