From bffed99f6ffed2b0a9663cfcfefe1c59a8e06556 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 16:15:00 +0000 Subject: [PATCH 01/40] wip: constant folding refactor!: Closes Flatten `Prim(Type/Value)` in to parent enum #665 BREAKING_CHANGES: In serialization, extension and function values no longer wrapped by "pv". --- src/algorithm.rs | 1 + src/algorithm/const_fold.rs | 60 ++++++++++++++++++++++ src/std_extensions/arithmetic/int_types.rs | 11 +++- src/values.rs | 13 ++++- 4 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/algorithm/const_fold.rs diff --git a/src/algorithm.rs b/src/algorithm.rs index 0023b5916..633231504 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,4 +1,5 @@ //! Algorithms using the Hugr. +pub mod const_fold; mod half_node; pub mod nest_cfgs; diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs new file mode 100644 index 000000000..eeee6bba5 --- /dev/null +++ b/src/algorithm/const_fold.rs @@ -0,0 +1,60 @@ +//! Constant folding routines. + +use crate::{ + ops::{Const, OpType}, + values::Value, + IncomingPort, OutgoingPort, +}; + +/// For a given op and consts, attempt to evaluate the op. +pub fn fold_const( + op: &OpType, + consts: &[(IncomingPort, Const)], +) -> Option> { + consts.iter().find_map(|(_, cnst)| match cnst.value() { + Value::Extension { c: (c,) } => c.fold(op, consts), + Value::Tuple { .. } => todo!(), + Value::Sum { .. } => todo!(), + Value::Function { .. } => None, + }) +} + +#[cfg(test)] +mod test { + use crate::{ + extension::PRELUDE_REGISTRY, + ops::LeafOp, + std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, + types::TypeArg, + }; + use rstest::rstest; + + use super::*; + + fn i2c(b: u64) -> Const { + Const::new( + ConstIntU::new(5, b).unwrap().into(), + INT_TYPES[5].to_owned(), + ) + .unwrap() + } + + fn u64_add() -> LeafOp { + crate::std_extensions::arithmetic::int_types::EXTENSION + .instantiate_extension_op("iadd", [TypeArg::BoundedNat { n: 5 }], &PRELUDE_REGISTRY) + .unwrap() + .into() + } + #[rstest] + #[case(0, 0, 0)] + #[case(0, 1, 1)] + #[case(23, 435, 458)] + // c = a && b + fn test_and(#[case] a: u64, #[case] b: u64, #[case] c: u64) { + let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))]; + let add_op: OpType = u64_add().into(); + let out = fold_const(&add_op, &consts).unwrap(); + + assert_eq!(&out[..], &[(0.into(), i2c(c))]); + } +} diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 7a67de28a..079310f96 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -6,12 +6,13 @@ use smol_str::SmolStr; use crate::{ extension::ExtensionId, + ops::OpType, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, }, values::CustomConst, - Extension, + Extension, IncomingPort, OutgoingPort, }; use lazy_static::lazy_static; /// The extension identifier. @@ -161,6 +162,14 @@ impl CustomConst for ConstIntU { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn fold( + &self, + _op: &OpType, + _consts: &[(IncomingPort, crate::ops::Const)], + ) -> Option> { + None + } } #[typetag::serde] diff --git a/src/values.rs b/src/values.rs index 428654066..07391c592 100644 --- a/src/values.rs +++ b/src/values.rs @@ -9,7 +9,8 @@ use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; use crate::macros::impl_box_clone; -use crate::{Hugr, HugrView}; +use crate::ops::OpType; +use crate::{Hugr, HugrView, IncomingPort, OutgoingPort}; use crate::types::{CustomCheckFailure, CustomType}; @@ -143,6 +144,16 @@ pub trait CustomConst: // false unless overloaded false } + + /// Attempt to evaluate an operation given some constant inputs - typically + /// involving instances of Self + fn fold( + &self, + _op: &OpType, + _consts: &[(IncomingPort, crate::ops::Const)], + ) -> Option> { + None + } } /// Const equality for types that have PartialEq From 1a27d5440adedf304432888019eff5ac185773a8 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 20 Nov 2023 14:25:16 +0000 Subject: [PATCH 02/40] start moving folding to op_def --- src/extension/op_def.rs | 46 ++++++++++++++++++++++ src/std_extensions/arithmetic/int_types.rs | 11 +----- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index fbd832b92..1ce2ea825 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -245,6 +245,44 @@ impl Debug for LowerFunc { } } +type ConstFoldResult = Option>; +pub trait GenericConstFold: Send + Sync { + fn fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult; +} + +impl Debug for Box { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl Default for Box { + fn default() -> Self { + Box::new(|&_: &_| None) + } +} + +impl GenericConstFold for T +where + T: Fn( + &[(crate::IncomingPort, crate::ops::Const)], + ) -> Option> + + Send + + Sync, +{ + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self(consts) + } +} + /// Serializable definition for dynamically loaded operations. /// /// TODO: Define a way to construct new OpDef's from a serialized definition. @@ -267,6 +305,9 @@ pub struct OpDef { // can only treat them as opaque/black-box ops. #[serde(flatten)] lower_funcs: Vec, + + #[serde(skip)] + constant_fold: Box, } impl OpDef { @@ -399,6 +440,10 @@ impl OpDef { ) -> Option { self.misc.insert(k.to_string(), v) } + + pub fn add_constant_folding(&mut self, fold: impl GenericConstFold + 'static) { + self.constant_fold = Box::new(fold) + } } impl Extension { @@ -419,6 +464,7 @@ impl Extension { signature_func: signature_func.into(), misc: Default::default(), lower_funcs: Default::default(), + constant_fold: Default::default(), }; match self.operations.entry(op.name.clone()) { diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 079310f96..7a67de28a 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -6,13 +6,12 @@ use smol_str::SmolStr; use crate::{ extension::ExtensionId, - ops::OpType, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, }, values::CustomConst, - Extension, IncomingPort, OutgoingPort, + Extension, }; use lazy_static::lazy_static; /// The extension identifier. @@ -162,14 +161,6 @@ impl CustomConst for ConstIntU { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } - - fn fold( - &self, - _op: &OpType, - _consts: &[(IncomingPort, crate::ops::Const)], - ) -> Option> { - None - } } #[typetag::serde] From b84766bb2ffd6b67157249185e8abd19858abd43 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 23 Nov 2023 14:58:03 +0000 Subject: [PATCH 03/40] thread through folding methods --- src/algorithm/const_fold.rs | 24 ++++++++++++++---------- src/extension/op_def.rs | 24 ++++++++++++++++-------- src/ops/custom.rs | 9 ++++++++- src/ops/leaf.rs | 16 +++++++++++++++- src/values.rs | 10 ---------- 5 files changed, 53 insertions(+), 30 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index eeee6bba5..b64774399 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -1,7 +1,7 @@ //! Constant folding routines. use crate::{ - ops::{Const, OpType}, + ops::{custom::ExternalOp, Const, LeafOp, OpType}, values::Value, IncomingPort, OutgoingPort, }; @@ -11,18 +11,16 @@ pub fn fold_const( op: &OpType, consts: &[(IncomingPort, Const)], ) -> Option> { - consts.iter().find_map(|(_, cnst)| match cnst.value() { - Value::Extension { c: (c,) } => c.fold(op, consts), - Value::Tuple { .. } => todo!(), - Value::Sum { .. } => todo!(), - Value::Function { .. } => None, - }) + let op = op.as_leaf_op()?; + let ext_op = op.as_extension_op()?; + + ext_op.constant_fold(consts) } #[cfg(test)] mod test { use crate::{ - extension::PRELUDE_REGISTRY, + extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}, ops::LeafOp, std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, types::TypeArg, @@ -40,8 +38,14 @@ mod test { } fn u64_add() -> LeafOp { - crate::std_extensions::arithmetic::int_types::EXTENSION - .instantiate_extension_op("iadd", [TypeArg::BoundedNat { n: 5 }], &PRELUDE_REGISTRY) + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + crate::std_extensions::arithmetic::int_ops::EXTENSION.to_owned(), + crate::std_extensions::arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + crate::std_extensions::arithmetic::int_ops::EXTENSION + .instantiate_extension_op("iadd", [TypeArg::BoundedNat { n: 5 }], ®) .unwrap() .into() } diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 1ce2ea825..43b965c24 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -246,7 +246,7 @@ impl Debug for LowerFunc { } type ConstFoldResult = Option>; -pub trait GenericConstFold: Send + Sync { +pub trait ConstFold: Send + Sync { fn fold( &self, type_args: &[TypeArg], @@ -254,19 +254,19 @@ pub trait GenericConstFold: Send + Sync { ) -> ConstFoldResult; } -impl Debug for Box { +impl Debug for Box { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "") } } -impl Default for Box { +impl Default for Box { fn default() -> Self { Box::new(|&_: &_| None) } } -impl GenericConstFold for T +impl ConstFold for T where T: Fn( &[(crate::IncomingPort, crate::ops::Const)], @@ -307,7 +307,7 @@ pub struct OpDef { lower_funcs: Vec, #[serde(skip)] - constant_fold: Box, + constant_folder: Box, } impl OpDef { @@ -441,8 +441,16 @@ impl OpDef { self.misc.insert(k.to_string(), v) } - pub fn add_constant_folding(&mut self, fold: impl GenericConstFold + 'static) { - self.constant_fold = Box::new(fold) + pub fn add_constant_folding(&mut self, fold: impl ConstFold + 'static) { + self.constant_folder = Box::new(fold) + } + + pub fn constant_fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self.constant_folder.fold(type_args, consts) } } @@ -464,7 +472,7 @@ impl Extension { signature_func: signature_func.into(), misc: Default::default(), lower_funcs: Default::default(), - constant_fold: Default::default(), + constant_folder: Default::default(), }; match self.operations.entry(op.name.clone()) { diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 5f2af204f..089b1f5f9 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -8,7 +8,7 @@ use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; -use crate::{Hugr, Node}; +use crate::{ops, Hugr, IncomingPort, Node, OutgoingPort}; use super::tag::OpTag; use super::{LeafOp, OpTrait, OpType}; @@ -127,6 +127,13 @@ impl ExtensionOp { pub fn def(&self) -> &OpDef { self.def.as_ref() } + + pub fn constant_fold( + &self, + consts: &[(IncomingPort, ops::Const)], + ) -> Option> { + self.def().constant_fold(self.args(), consts) + } } impl From for OpaqueOp { diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 432790cf8..6a1e5ac62 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -2,7 +2,7 @@ use smol_str::SmolStr; -use super::custom::ExternalOp; +use super::custom::{ExtensionOp, ExternalOp}; use super::dataflow::DataflowOpTrait; use super::{OpName, OpTag}; @@ -62,6 +62,20 @@ pub enum LeafOp { }, } +impl LeafOp { + /// If instance of [ExtensionOp] return a reference to it. + pub fn as_extension_op(&self) -> Option<&ExtensionOp> { + let LeafOp::CustomOp(ext) = self else { + return None; + }; + + match ext.as_ref() { + ExternalOp::Extension(e) => Some(e), + ExternalOp::Opaque(_) => None, + } + } +} + /// Records details of an application of a [PolyFuncType] to some [TypeArg]s /// and the result (a less-, but still potentially-, polymorphic type). #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] diff --git a/src/values.rs b/src/values.rs index 07391c592..76494de55 100644 --- a/src/values.rs +++ b/src/values.rs @@ -144,16 +144,6 @@ pub trait CustomConst: // false unless overloaded false } - - /// Attempt to evaluate an operation given some constant inputs - typically - /// involving instances of Self - fn fold( - &self, - _op: &OpType, - _consts: &[(IncomingPort, crate::ops::Const)], - ) -> Option> { - None - } } /// Const equality for types that have PartialEq From 8ee49da9d9793a14d34441f0c57917e6134b0508 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 23 Nov 2023 17:16:37 +0000 Subject: [PATCH 04/40] integer addition tests passing --- src/algorithm/const_fold.rs | 26 ++++++---- src/extension.rs | 3 +- src/extension/const_fold.rs | 61 ++++++++++++++++++++++++ src/extension/op_def.rs | 45 ++--------------- src/ops/custom.rs | 7 +-- src/std_extensions/arithmetic/int_ops.rs | 46 ++++++++++++++++-- 6 files changed, 129 insertions(+), 59 deletions(-) create mode 100644 src/extension/const_fold.rs diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index b64774399..f0b682fc9 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -1,16 +1,14 @@ //! Constant folding routines. use crate::{ + extension::ConstFoldResult, ops::{custom::ExternalOp, Const, LeafOp, OpType}, values::Value, IncomingPort, OutgoingPort, }; /// For a given op and consts, attempt to evaluate the op. -pub fn fold_const( - op: &OpType, - consts: &[(IncomingPort, Const)], -) -> Option> { +pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { let op = op.as_leaf_op()?; let ext_op = op.as_extension_op()?; @@ -20,7 +18,7 @@ pub fn fold_const( #[cfg(test)] mod test { use crate::{ - extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}, + extension::{ExtensionRegistry, FoldOutput, PRELUDE, PRELUDE_REGISTRY}, ops::LeafOp, std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, types::TypeArg, @@ -53,12 +51,24 @@ mod test { #[case(0, 0, 0)] #[case(0, 1, 1)] #[case(23, 435, 458)] - // c = a && b - fn test_and(#[case] a: u64, #[case] b: u64, #[case] c: u64) { + // c = a + b + fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) { let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))]; let add_op: OpType = u64_add().into(); let out = fold_const(&add_op, &consts).unwrap(); - assert_eq!(&out[..], &[(0.into(), i2c(c))]); + assert_eq!(&out[..], &[(0.into(), FoldOutput::Value(Box::new(i2c(c))))]); + } + + #[test] + // a = a + 0 + fn test_zero_add() { + for in_port in [0, 1] { + let other_in = 1 - in_port; + let consts = vec![(in_port.into(), i2c(0))]; + let add_op: OpType = u64_add().into(); + let out = fold_const(&add_op, &consts).unwrap(); + assert_eq!(&out[..], &[(0.into(), FoldOutput::Input(other_in.into()))]); + } } } diff --git a/src/extension.rs b/src/extension.rs index dfdfc5acc..8ae3d7210 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -28,9 +28,10 @@ pub use op_def::{ }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; +mod const_fold; pub mod prelude; pub mod validate; - +pub use const_fold::{ConstFold, ConstFoldResult, FoldOutput}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs new file mode 100644 index 000000000..2baa03cd1 --- /dev/null +++ b/src/extension/const_fold.rs @@ -0,0 +1,61 @@ +use std::fmt::Formatter; + +use std::fmt::Debug; + +use crate::types::TypeArg; + +use crate::OutgoingPort; + +use crate::IncomingPort; + +use crate::ops; +use derive_more::From; + +#[derive(From, Clone, PartialEq, Debug)] +pub enum FoldOutput { + /// Value from port can be replaced with a constant + Value(Box), + /// Value from port corresponds to one of the incoming values. + Input(IncomingPort), +} + +impl From for FoldOutput { + fn from(value: ops::Const) -> Self { + Self::Value(Box::new(value)) + } +} + +pub type ConstFoldResult = Option>; + +pub trait ConstFold: Send + Sync { + fn fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult; +} + +impl Debug for Box { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl Default for Box { + fn default() -> Self { + Box::new(|&_: &_| None) + } +} + +impl ConstFold for T +where + T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync, +{ + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self(consts) + } +} diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 43b965c24..282808dc9 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -7,12 +7,13 @@ use std::sync::Arc; use smol_str::SmolStr; use super::{ - Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, + ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, + ExtensionSet, SignatureError, }; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; use crate::types::{FunctionType, PolyFuncType}; -use crate::Hugr; +use crate::{ops, Hugr, IncomingPort}; /// Trait necessary for binary computations of OpDef signature pub trait CustomSignatureFunc: Send + Sync { @@ -245,44 +246,6 @@ impl Debug for LowerFunc { } } -type ConstFoldResult = Option>; -pub trait ConstFold: Send + Sync { - fn fold( - &self, - type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult; -} - -impl Debug for Box { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "") - } -} - -impl Default for Box { - fn default() -> Self { - Box::new(|&_: &_| None) - } -} - -impl ConstFold for T -where - T: Fn( - &[(crate::IncomingPort, crate::ops::Const)], - ) -> Option> - + Send - + Sync, -{ - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult { - self(consts) - } -} - /// Serializable definition for dynamically loaded operations. /// /// TODO: Define a way to construct new OpDef's from a serialized definition. @@ -441,7 +404,7 @@ impl OpDef { self.misc.insert(k.to_string(), v) } - pub fn add_constant_folding(&mut self, fold: impl ConstFold + 'static) { + pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { self.constant_folder = Box::new(fold) } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 089b1f5f9..3ed692702 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -4,7 +4,7 @@ use smol_str::SmolStr; use std::sync::Arc; use thiserror::Error; -use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; +use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; @@ -128,10 +128,7 @@ impl ExtensionOp { self.def.as_ref() } - pub fn constant_fold( - &self, - consts: &[(IncomingPort, ops::Const)], - ) -> Option> { + pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { self.def().constant_fold(self.args(), consts) } } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 9a77f5dfb..1281c5f16 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,16 +1,17 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type_var, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_type_var, ConstIntU, INT_TYPES, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; -use crate::extension::{CustomValidator, ValidateJustArgs}; -use crate::type_row; +use crate::extension::{ConstFoldResult, CustomValidator, FoldOutput, ValidateJustArgs}; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; +use crate::values::Value; use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, types::{type_param::TypeArg, Type, TypeRow}, Extension, }; +use crate::{ops, type_row, IncomingPort}; use lazy_static::lazy_static; @@ -71,6 +72,42 @@ fn idivmod_sig() -> PolyFuncType { int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)]) } +fn zero(width: u8) -> ops::Const { + ops::Const::new( + ConstIntU::new(width, 0).unwrap().into(), + INT_TYPES[5].to_owned(), + ) + .unwrap() +} + +fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { + // TODO get width from const + let width = 5; + match consts { + [(p, c)] if c == &zero(width) => { + let other_port: IncomingPort = if &IncomingPort::from(0) == p { 1 } else { 0 }.into(); + Some(vec![(0.into(), other_port.into())]) + } + [(_, c1), (_, c2)] => { + let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap()); + + Some(vec![( + 0.into(), + ops::Const::new( + ConstIntU::new(width, c1.value() + c2.value()) + .unwrap() + .into(), + INT_TYPES[5].to_owned(), + ) + .unwrap() + .into(), + )]) + } + + _ => None, + } +} + /// Extension for basic integer operations. fn extension() -> Extension { let itob_sig = int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T]); @@ -246,13 +283,14 @@ fn extension() -> Extension { ibinop_sig(), ) .unwrap(); - extension + let iadd = extension .add_op( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), ibinop_sig(), ) .unwrap(); + iadd.set_constant_folder(iadd_fold); extension .add_op( "isub".into(), From 520de7cfd99faf3d3a7f6bc1d3a269171dda3371 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 24 Nov 2023 14:16:21 +0000 Subject: [PATCH 05/40] remove FoldOutput --- src/algorithm/const_fold.rs | 21 ++++----------------- src/extension.rs | 2 +- src/extension/const_fold.rs | 19 +------------------ src/extension/op_def.rs | 2 +- src/ops/custom.rs | 2 +- src/std_extensions/arithmetic/int_ops.rs | 19 +++---------------- src/values.rs | 4 ++-- 7 files changed, 13 insertions(+), 56 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index f0b682fc9..b317909f0 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -2,9 +2,8 @@ use crate::{ extension::ConstFoldResult, - ops::{custom::ExternalOp, Const, LeafOp, OpType}, - values::Value, - IncomingPort, OutgoingPort, + ops::{Const, OpType}, + IncomingPort, }; /// For a given op and consts, attempt to evaluate the op. @@ -18,7 +17,7 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes #[cfg(test)] mod test { use crate::{ - extension::{ExtensionRegistry, FoldOutput, PRELUDE, PRELUDE_REGISTRY}, + extension::{ExtensionRegistry, PRELUDE}, ops::LeafOp, std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, types::TypeArg, @@ -57,18 +56,6 @@ mod test { let add_op: OpType = u64_add().into(); let out = fold_const(&add_op, &consts).unwrap(); - assert_eq!(&out[..], &[(0.into(), FoldOutput::Value(Box::new(i2c(c))))]); - } - - #[test] - // a = a + 0 - fn test_zero_add() { - for in_port in [0, 1] { - let other_in = 1 - in_port; - let consts = vec![(in_port.into(), i2c(0))]; - let add_op: OpType = u64_add().into(); - let out = fold_const(&add_op, &consts).unwrap(); - assert_eq!(&out[..], &[(0.into(), FoldOutput::Input(other_in.into()))]); - } + assert_eq!(&out[..], &[(0.into(), i2c(c))]); } } diff --git a/src/extension.rs b/src/extension.rs index 8ae3d7210..99ffb5782 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -31,7 +31,7 @@ pub use type_def::{TypeDef, TypeDefBound}; mod const_fold; pub mod prelude; pub mod validate; -pub use const_fold::{ConstFold, ConstFoldResult, FoldOutput}; +pub use const_fold::{ConstFold, ConstFoldResult}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs index 2baa03cd1..a3bb48c27 100644 --- a/src/extension/const_fold.rs +++ b/src/extension/const_fold.rs @@ -6,26 +6,9 @@ use crate::types::TypeArg; use crate::OutgoingPort; -use crate::IncomingPort; - use crate::ops; -use derive_more::From; - -#[derive(From, Clone, PartialEq, Debug)] -pub enum FoldOutput { - /// Value from port can be replaced with a constant - Value(Box), - /// Value from port corresponds to one of the incoming values. - Input(IncomingPort), -} - -impl From for FoldOutput { - fn from(value: ops::Const) -> Self { - Self::Value(Box::new(value)) - } -} -pub type ConstFoldResult = Option>; +pub type ConstFoldResult = Option>; pub trait ConstFold: Send + Sync { fn fold( diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 282808dc9..6643a859a 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -13,7 +13,7 @@ use super::{ use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; use crate::types::{FunctionType, PolyFuncType}; -use crate::{ops, Hugr, IncomingPort}; +use crate::Hugr; /// Trait necessary for binary computations of OpDef signature pub trait CustomSignatureFunc: Send + Sync { diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 3ed692702..6269f6773 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -8,7 +8,7 @@ use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, S use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; -use crate::{ops, Hugr, IncomingPort, Node, OutgoingPort}; +use crate::{ops, Hugr, IncomingPort, Node}; use super::tag::OpTag; use super::{LeafOp, OpTrait, OpType}; diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 1281c5f16..2ff9e07c5 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -2,10 +2,10 @@ use super::int_types::{get_log_width, int_type_var, ConstIntU, INT_TYPES, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; -use crate::extension::{ConstFoldResult, CustomValidator, FoldOutput, ValidateJustArgs}; +use crate::extension::{ConstFoldResult, CustomValidator, ValidateJustArgs}; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; -use crate::values::Value; + use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, types::{type_param::TypeArg, Type, TypeRow}, @@ -72,22 +72,10 @@ fn idivmod_sig() -> PolyFuncType { int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)]) } -fn zero(width: u8) -> ops::Const { - ops::Const::new( - ConstIntU::new(width, 0).unwrap().into(), - INT_TYPES[5].to_owned(), - ) - .unwrap() -} - fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { // TODO get width from const let width = 5; match consts { - [(p, c)] if c == &zero(width) => { - let other_port: IncomingPort = if &IncomingPort::from(0) == p { 1 } else { 0 }.into(); - Some(vec![(0.into(), other_port.into())]) - } [(_, c1), (_, c2)] => { let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap()); @@ -99,8 +87,7 @@ fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { .into(), INT_TYPES[5].to_owned(), ) - .unwrap() - .into(), + .unwrap(), )]) } diff --git a/src/values.rs b/src/values.rs index 76494de55..549f48fae 100644 --- a/src/values.rs +++ b/src/values.rs @@ -9,8 +9,8 @@ use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; use crate::macros::impl_box_clone; -use crate::ops::OpType; -use crate::{Hugr, HugrView, IncomingPort, OutgoingPort}; + +use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType}; From 9398d9d45c3c9a3af7ffb7e83b39470e31c2692e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 18 Dec 2023 14:11:56 +0000 Subject: [PATCH 06/40] refactor int folding to separate repo --- src/std_extensions/arithmetic/int_ops.rs | 35 +++-------------- src/std_extensions/arithmetic/int_ops/fold.rs | 38 +++++++++++++++++++ 2 files changed, 43 insertions(+), 30 deletions(-) create mode 100644 src/std_extensions/arithmetic/int_ops/fold.rs diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index d4735f400..3deac3e60 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,11 +1,10 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_tv, ConstIntU, INT_TYPES, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; use crate::extension::{ - ConstFoldResult, CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, - PRELUDE, + CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, PRELUDE, }; use crate::ops::custom::ExtensionOp; use crate::ops::OpName; @@ -18,12 +17,12 @@ use crate::{ types::{type_param::TypeArg, Type, TypeRow}, Extension, }; -use crate::{ops, IncomingPort}; use lazy_static::lazy_static; use smol_str::SmolStr; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; +mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int"); @@ -218,10 +217,9 @@ impl MakeOpDef for IntOpDef { (rightmost bits replace leftmost bits)", }.into() } + fn post_opdef(&self, def: &mut OpDef) { - if self == &Self::iadd { - def.set_constant_folder(iadd_fold); - } + fold::set_fold(self, def) } } fn int_polytype( @@ -246,29 +244,6 @@ fn iunop_sig() -> PolyFuncType { int_polytype(1, vec![int_type_var.clone()], vec![int_type_var]) } -fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { - // TODO get width from const - let width = 5; - match consts { - [(_, c1), (_, c2)] => { - let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap()); - - Some(vec![( - 0.into(), - ops::Const::new( - ConstIntU::new(width, c1.value() + c2.value()) - .unwrap() - .into(), - INT_TYPES[5].to_owned(), - ) - .unwrap(), - )]) - } - - _ => None, - } -} - lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Extension = { diff --git a/src/std_extensions/arithmetic/int_ops/fold.rs b/src/std_extensions/arithmetic/int_ops/fold.rs new file mode 100644 index 000000000..b092d9a29 --- /dev/null +++ b/src/std_extensions/arithmetic/int_ops/fold.rs @@ -0,0 +1,38 @@ +use crate::{ + extension::{ConstFoldResult, OpDef}, + ops, + std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, + IncomingPort, +}; + +use super::IntOpDef; + +pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { + match op { + IntOpDef::iadd => def.set_constant_folder(iadd_fold), + _ => (), + } +} + +// TODO get width from const +fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { + let width = 5; + match consts { + [(_, c1), (_, c2)] => { + let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap()); + + Some(vec![( + 0.into(), + ops::Const::new( + ConstIntU::new(width, c1.value() + c2.value()) + .unwrap() + .into(), + INT_TYPES[5].to_owned(), + ) + .unwrap(), + )]) + } + + _ => None, + } +} From 7b955a96a2b27fd00489ac0fd7c92628e546fcc5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 18 Dec 2023 15:01:18 +0000 Subject: [PATCH 07/40] add tuple and sum constant folding --- src/algorithm/const_fold.rs | 59 +++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index b317909f0..3f73866b4 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -2,16 +2,69 @@ use crate::{ extension::ConstFoldResult, - ops::{Const, OpType}, + ops::{Const, LeafOp, OpType}, + types::{Type, TypeEnum}, + values::Value, IncomingPort, }; +fn out_row(consts: impl IntoIterator) -> ConstFoldResult { + let vec = consts + .into_iter() + .enumerate() + .map(|(i, c)| (i.into(), c)) + .collect(); + + Some(vec) +} + +fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { + let mut v: Vec<_> = consts.iter().collect(); + v.sort_by_key(|(i, _)| i); + v +} + +fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { + sort_by_in_port(consts) + .into_iter() + .map(|(_, c)| c) + .collect() +} /// For a given op and consts, attempt to evaluate the op. pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { let op = op.as_leaf_op()?; - let ext_op = op.as_extension_op()?; - ext_op.constant_fold(consts) + match op { + LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), + LeafOp::MakeTuple { .. } => { + out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())]) + } + LeafOp::UnpackTuple { .. } => { + let c = &consts.first()?.1; + + if let Value::Tuple { vs } = c.value() { + if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() { + return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| { + Const::new(v.clone(), t.clone()) + .expect("types should already have been checked") + })); + } + } + None + } + + LeafOp::Tag { tag, variants } => out_row([Const::new( + Value::sum(*tag, consts.first()?.1.value().clone()), + Type::new_sum(variants.clone()), + ) + .unwrap()]), + LeafOp::CustomOp(_) => { + let ext_op = op.as_extension_op()?; + + ext_op.constant_fold(consts) + } + _ => None, + } } #[cfg(test)] From 6cb3c629ff4604f812bd28b8b48c316c598f77cb Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 18 Dec 2023 15:03:21 +0000 Subject: [PATCH 08/40] simplify test code --- src/algorithm/const_fold.rs | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 3f73866b4..90d8947fb 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -69,12 +69,9 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes #[cfg(test)] mod test { - use crate::{ - extension::{ExtensionRegistry, PRELUDE}, - ops::LeafOp, - std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, - types::TypeArg, - }; + use crate::std_extensions::arithmetic::int_ops::IntOpDef; + use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; + use rstest::rstest; use super::*; @@ -87,18 +84,6 @@ mod test { .unwrap() } - fn u64_add() -> LeafOp { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - crate::std_extensions::arithmetic::int_ops::EXTENSION.to_owned(), - crate::std_extensions::arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - crate::std_extensions::arithmetic::int_ops::EXTENSION - .instantiate_extension_op("iadd", [TypeArg::BoundedNat { n: 5 }], ®) - .unwrap() - .into() - } #[rstest] #[case(0, 0, 0)] #[case(0, 1, 1)] @@ -106,7 +91,7 @@ mod test { // c = a + b fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) { let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))]; - let add_op: OpType = u64_add().into(); + let add_op: OpType = IntOpDef::iadd.with_width(6).into(); let out = fold_const(&add_op, &consts).unwrap(); assert_eq!(&out[..], &[(0.into(), i2c(c))]); From 0500624c883bb89912f860db6e8a5703b0cf1baf Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 20 Dec 2023 12:39:32 +0000 Subject: [PATCH 09/40] wip: fold finder --- src/algorithm/const_fold.rs | 106 +++++++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 3 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 90d8947fb..92ec422a7 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -1,11 +1,16 @@ //! Constant folding routines. +use itertools::Itertools; + use crate::{ - extension::ConstFoldResult, + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ConstFoldResult, ExtensionSet, PRELUDE_REGISTRY}, + hugr::views::SiblingSubgraph, ops::{Const, LeafOp, OpType}, - types::{Type, TypeEnum}, + type_row, + types::{FunctionType, Type, TypeEnum}, values::Value, - IncomingPort, + Hugr, HugrView, IncomingPort, OutgoingPort, SimpleReplacement, }; fn out_row(consts: impl IntoIterator) -> ConstFoldResult { @@ -67,8 +72,76 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes } } +fn const_graph(mut consts: Vec<(OutgoingPort, Const)>) -> Hugr { + consts.sort_by_key(|(o, _)| *o); + let consts = consts.into_iter().map(|(_, c)| c).collect_vec(); + let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); + // TODO need to get const extensions. + let extensions: ExtensionSet = consts + .iter() + .map(|c| c.value().extension_reqs()) + .fold(ExtensionSet::new(), |e, e2| e.union(&e2)); + let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); + + let outputs = consts + .into_iter() + .map(|c| b.add_load_const(c).unwrap()) + .collect_vec(); + + b.finish_hugr_with_outputs(outputs, &PRELUDE_REGISTRY) + .unwrap() +} + +fn find_consts( + hugr: &impl HugrView, + reg: &ExtensionRegistry, +) -> impl Iterator + '_ { + hugr.nodes().filter_map(|n| { + let op = hugr.get_optype(n); + + op.is_load_constant().then_some(())?; + + let neighbour = hugr.output_neighbours(n).exactly_one().ok()?; + + let mut remove_nodes = vec![neighbour]; + + let all_ins = hugr + .node_inputs(neighbour) + .filter_map(|in_p| { + let (in_n, _) = hugr.single_linked_output(neighbour, in_p)?; + let op = hugr.get_optype(in_n); + + op.is_load_constant().then_some(())?; + + let const_node = hugr.input_neighbours(in_n).exactly_one().ok()?; + let const_op = hugr.get_optype(const_node).as_const()?; + + remove_nodes.push(const_node); + remove_nodes.push(in_n); + + // TODO avoid const clone here + Some((in_p, const_op.clone())) + }) + .collect_vec(); + + let neighbour_op = hugr.get_optype(neighbour); + + let folded = fold_const(neighbour_op, &all_ins)?; + + let replacement = const_graph(folded); + + let sibling_graph = SiblingSubgraph::try_from_nodes(remove_nodes, hugr) + .expect("Make unmake should be valid subgraph."); + + sibling_graph + .create_simple_replacement(hugr, replacement) + .ok() + }) +} + #[cfg(test)] mod test { + use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::std_extensions::arithmetic::int_ops::IntOpDef; use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; @@ -96,4 +169,31 @@ mod test { assert_eq!(&out[..], &[(0.into(), i2c(c))]); } + + #[test] + fn test_fold() { + let mut b = DFGBuilder::new(FunctionType::new( + type_row![], + vec![INT_TYPES[5].to_owned()], + )) + .unwrap(); + + let one = b.add_load_const(i2c(1)).unwrap(); + let two = b.add_load_const(i2c(2)).unwrap(); + + let add = b + .add_dataflow_op(IntOpDef::iadd.with_width(5), [one, two]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + crate::std_extensions::arithmetic::int_types::EXTENSION.to_owned(), + crate::std_extensions::arithmetic::int_ops::EXTENSION.to_owned(), + ]) + .unwrap(); + let h = b.finish_hugr_with_outputs(add.outputs(), ®).unwrap(); + + let consts = find_consts(&h).collect_vec(); + + dbg!(consts); + } } From 8f554e0f7774b2ff791fbdaada6425a0963f3429 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Dec 2023 12:31:18 +0000 Subject: [PATCH 10/40] chore(deps): bump actions/upload-artifact from 3 to 4 (#751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4.
Release notes

Sourced from actions/upload-artifact's releases.

v4.0.0

What's Changed

The release of upload-artifact@v4 and download-artifact@v4 are major changes to the backend architecture of Artifacts. They have numerous performance and behavioral improvements.

For more information, see the @​actions/artifact documentation.

New Contributors

Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v4.0.0

v3.1.3

What's Changed

Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v3.1.3

v3.1.2

  • Update all @actions/* NPM packages to their latest versions- #374
  • Update all dev dependencies to their most recent versions - #375

v3.1.1

  • Update actions/core package to latest version to remove set-output deprecation warning #351

v3.1.0

What's Changed

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/upload-artifact&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/notify-coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/notify-coverage.yml b/.github/workflows/notify-coverage.yml index 048212cbc..2721173c1 100644 --- a/.github/workflows/notify-coverage.yml +++ b/.github/workflows/notify-coverage.yml @@ -115,7 +115,7 @@ jobs: echo $MSG echo $MSG >> "$GITHUB_OUTPUT" - name: Upload current HEAD sha - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: head-sha.txt path: head-sha.txt From 215eb403f3142415a0b2cd696a8b9d8a72bce18b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Dec 2023 12:31:40 +0000 Subject: [PATCH 11/40] chore(deps): bump dawidd6/action-download-artifact from 2 to 3 (#752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact) from 2 to 3.
Release notes

Sourced from dawidd6/action-download-artifact's releases.

v3.0.0

Node was updated from 16 to 20. Node 20 requires glibc>=2.28.

v2.28.0

No release notes provided.

v2.27.0

No release notes provided.

v2.26.1

No release notes provided.

v2.26.0

No release notes provided.

v2.25.0

No release notes provided.

v2.24.4

No release notes provided.

v2.24.3

No release notes provided.

v2.24.2

No release notes provided.

v2.24.0

No release notes provided.

v2.23.0

No release notes provided.

v2.22.0

No release notes provided.

v2.21.1

No release notes provided.

v2.21.0

No release notes provided.

v2.20.0

No release notes provided.

v2.19.0

No release notes provided.

v2.18.0

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=dawidd6/action-download-artifact&package-manager=github_actions&previous-version=2&new-version=3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/notify-coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/notify-coverage.yml b/.github/workflows/notify-coverage.yml index 2721173c1..1c095d7d7 100644 --- a/.github/workflows/notify-coverage.yml +++ b/.github/workflows/notify-coverage.yml @@ -18,7 +18,7 @@ jobs: should_notify: ${{ steps.get_coverage.outputs.should_notify }} steps: - name: Download commit sha of the most recent successful run - uses: dawidd6/action-download-artifact@v2 + uses: dawidd6/action-download-artifact@v3 with: # Downloads the artifact from the most recent successful run workflow: 'notify-coverage.yml' From ff26546b3cf3b38cfda34be0454d8176ea545ed6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 20 Dec 2023 13:47:58 +0000 Subject: [PATCH 12/40] fix: case node should not have an external signature (#749) Previously it was incorrectly reporting the internal signature as the node signature. Fixes #750 Required further fixes to `replace.rs`: * test was looking for an error that should have been hidden behind another (that was not previously reported because of #750). Look for the latter, not the former. * Fix node indices not being correctly reported from `apply` --------- Co-authored-by: Alan Lawrence --- src/extension/infer/test.rs | 10 +++------- src/hugr/rewrite/replace.rs | 22 +++++++++++----------- src/ops/controlflow.rs | 9 +++++---- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 5714cd4e8..7653c7578 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -11,7 +11,7 @@ use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; use crate::ops::custom::{ExternalOp, OpaqueOp}; -use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; +use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle}; use crate::ops::{LeafOp, OpType}; use crate::type_row; @@ -314,12 +314,8 @@ fn test_conditional_inference() -> Result<(), Box> { first_ext: ExtensionId, second_ext: ExtensionId, ) -> Result> { - let [case, case_in, case_out] = create_with_io( - hugr, - conditional_node, - op.clone(), - Into::::into(op).dataflow_signature().unwrap(), - )?; + let [case, case_in, case_out] = + create_with_io(hugr, conditional_node, op.clone(), op.inner_signature())?; let lift1 = hugr.add_node_with_parent( case, diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 1ca02b4ac..017407c97 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -85,7 +85,7 @@ pub struct Replacement { } impl NewEdgeSpec { - fn check_src(&self, h: &impl HugrView) -> Result<(), ReplaceError> { + fn check_src(&self, h: &impl HugrView, err_spec: &NewEdgeSpec) -> Result<(), ReplaceError> { let optype = h.get_optype(self.src); let ok = match self.kind { NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder), @@ -100,9 +100,9 @@ impl NewEdgeSpec { } }; ok.then_some(()) - .ok_or(ReplaceError::BadEdgeKind(Direction::Outgoing, self.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone())) } - fn check_tgt(&self, h: &impl HugrView) -> Result<(), ReplaceError> { + fn check_tgt(&self, h: &impl HugrView, err_spec: &NewEdgeSpec) -> Result<(), ReplaceError> { let optype = h.get_optype(self.tgt); let ok = match self.kind { NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder), @@ -118,7 +118,7 @@ impl NewEdgeSpec { ), }; ok.then_some(()) - .ok_or(ReplaceError::BadEdgeKind(Direction::Incoming, self.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone())) } fn check_existing_edge( @@ -233,20 +233,20 @@ impl Rewrite for Replacement { e.clone(), )); } - e.check_src(h)?; + e.check_src(h, e)?; } self.mu_out.iter().try_for_each(|e| { self.replacement.valid_non_root(e.src).map_err(|_| { ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, e.clone()) })?; - e.check_src(&self.replacement) + e.check_src(&self.replacement, e) })?; // Edge targets... self.mu_inp.iter().try_for_each(|e| { self.replacement.valid_non_root(e.tgt).map_err(|_| { ReplaceError::BadEdgeSpec(Direction::Incoming, WhichHugr::Replacement, e.clone()) })?; - e.check_tgt(&self.replacement) + e.check_tgt(&self.replacement, e) })?; for e in self.mu_out.iter().chain(self.mu_new.iter()) { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { @@ -256,7 +256,7 @@ impl Rewrite for Replacement { e.clone(), )); } - e.check_tgt(h)?; + e.check_tgt(h, e)?; // The descendant check is to allow the case where the old edge is nonlocal // from a part of the Hugr being moved (which may require changing source, // depending on where the transplanted portion ends up). While this subsumes @@ -353,8 +353,8 @@ fn transfer_edges<'a>( h.valid_node(e.tgt).map_err(|_| { ReplaceError::BadEdgeSpec(Direction::Incoming, WhichHugr::Retained, oe.clone()) })?; - e.check_src(h)?; - e.check_tgt(h)?; + e.check_src(h, oe)?; + e.check_tgt(h, oe)?; match e.kind { NewEdgeKind::Order => { h.add_other_edge(e.src, e.tgt).unwrap(); @@ -820,7 +820,7 @@ mod test { mu_out: vec![new_out_edge.clone()], ..r.clone() }), - Err(ReplaceError::NoRemovedEdge(new_out_edge)) + Err(ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)) ); Ok(()) } diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index afd0ef5f8..4489cc4fe 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -235,10 +235,6 @@ impl OpTrait for Case { fn tag(&self) -> OpTag { ::TAG } - - fn dataflow_signature(&self) -> Option { - Some(self.signature.clone()) - } } impl Case { @@ -251,6 +247,11 @@ impl Case { pub fn dataflow_output(&self) -> &TypeRow { &self.signature.output } + + /// The signature of the dataflow sibling graph contained in the [`Case`] + pub fn inner_signature(&self) -> FunctionType { + self.signature.clone() + } } fn tuple_sum_first(tuple_sum_row: &TypeRow, rest: &TypeRow) -> TypeRow { From 64b91997e7cfee61264ab93179cdff9150c768fa Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 20 Dec 2023 17:03:39 +0000 Subject: [PATCH 13/40] refactor: move hugr equality check out for reuse --- src/hugr.rs | 44 +++++++++++++++++++++++++++++++++++++++++-- src/hugr/serialize.rs | 37 ++---------------------------------- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/hugr.rs b/src/hugr.rs index 9672f3dbb..2efb8e867 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -343,12 +343,16 @@ pub enum HugrError { } #[cfg(test)] -mod test { +pub(crate) mod test { + use itertools::Itertools; + use portgraph::{LinkView, PortView}; + use super::{Hugr, HugrView}; use crate::builder::test::closed_dfg_root_hugr; use crate::extension::ExtensionSet; use crate::hugr::HugrMut; - use crate::ops; + use crate::ops::LeafOp; + use crate::ops::{self, OpType}; use crate::type_row; use crate::types::{FunctionType, Type}; @@ -398,4 +402,40 @@ mod test { assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r); Ok(()) } + + pub(crate) fn assert_hugr_equality(hugr: &Hugr, other: &Hugr) { + assert_eq!(other.root, hugr.root); + assert_eq!(other.hierarchy, hugr.hierarchy); + assert_eq!(other.metadata, hugr.metadata); + + // Extension operations may have been downgraded to opaque operations. + for node in other.nodes() { + let new_op = other.get_optype(node); + let old_op = hugr.get_optype(node); + if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op { + if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op { + assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque()); + } else { + panic!("Expected old_op to be a custom op"); + } + } else { + assert_eq!(new_op, old_op); + } + } + + // Check that the graphs are equivalent up to port renumbering. + let new_graph = &other.graph; + let old_graph = &hugr.graph; + assert_eq!(new_graph.node_count(), old_graph.node_count()); + assert_eq!(new_graph.port_count(), old_graph.port_count()); + assert_eq!(new_graph.link_count(), old_graph.link_count()); + for n in old_graph.nodes_iter() { + assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n)); + assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n)); + assert_eq!( + new_graph.output_neighbours(n).collect_vec(), + old_graph.output_neighbours(n).collect_vec() + ); + } + } } diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index b49549b0f..5f9b236f8 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -260,6 +260,7 @@ pub mod test { use crate::extension::simple_op::MakeRegisteredOp; use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; + use crate::hugr::test::assert_hugr_equality; use crate::hugr::NodeType; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG}; @@ -267,7 +268,6 @@ pub mod test { use crate::types::{FunctionType, Type}; use crate::OutgoingPort; use itertools::Itertools; - use portgraph::LinkView; use portgraph::{ multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap, }; @@ -298,40 +298,7 @@ pub mod test { // The internal port indices may still be different. let mut h_canon = hugr.clone(); h_canon.canonicalize_nodes(|_, _| {}); - - assert_eq!(new_hugr.root, h_canon.root); - assert_eq!(new_hugr.hierarchy, h_canon.hierarchy); - assert_eq!(new_hugr.metadata, h_canon.metadata); - - // Extension operations may have been downgraded to opaque operations. - for node in new_hugr.nodes() { - let new_op = new_hugr.get_optype(node); - let old_op = h_canon.get_optype(node); - if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op { - if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op { - assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque()); - } else { - panic!("Expected old_op to be a custom op"); - } - } else { - assert_eq!(new_op, old_op); - } - } - - // Check that the graphs are equivalent up to port renumbering. - let new_graph = &new_hugr.graph; - let old_graph = &h_canon.graph; - assert_eq!(new_graph.node_count(), old_graph.node_count()); - assert_eq!(new_graph.port_count(), old_graph.port_count()); - assert_eq!(new_graph.link_count(), old_graph.link_count()); - for n in old_graph.nodes_iter() { - assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n)); - assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n)); - assert_eq!( - new_graph.output_neighbours(n).collect_vec(), - old_graph.output_neighbours(n).collect_vec() - ); - } + assert_hugr_equality(&h_canon, &new_hugr); new_hugr } From 6d7d4403790b7772d23933092d4fd54ecace7f45 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 13:26:44 +0000 Subject: [PATCH 14/40] feat: implement RemoveConst and RemoveConstIgnore as per spec --- src/hugr/rewrite.rs | 1 + src/hugr/rewrite/consts.rs | 135 +++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 src/hugr/rewrite/consts.rs diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index 3f524e2db..05f1f48d9 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -1,5 +1,6 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. +pub mod consts; pub mod insert_identity; pub mod outline_cfg; pub mod replace; diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs new file mode 100644 index 000000000..c4c7123cd --- /dev/null +++ b/src/hugr/rewrite/consts.rs @@ -0,0 +1,135 @@ +//! Rewrite operations involving Const and LoadConst operations + +use std::iter; + +use crate::{ + hugr::{HugrError, HugrMut}, + HugrView, Node, +}; +use itertools::Itertools; +use thiserror::Error; + +use super::Rewrite; + +/// Remove a [`crate::ops::LoadConstant`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConstIgnore(pub Node); + +/// Error from an [`RemoveConstIgnore`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveConstIgnoreError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not LoadConst).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Not connected to a Const. + #[error("Node: {0:?} is not connected to a Const node.")] + NoConst(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConstIgnore { + type Error = RemoveConstIgnoreError; + + // The Const node the LoadConstant was connected to. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { + return Err(RemoveConstIgnoreError::InvalidNode(node)); + } + + if h.out_value_types(node) + .next() + .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) + { + return Err(RemoveConstIgnoreError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .input_neighbours(node) + .exactly_one() + .map_err(|_| RemoveConstIgnoreError::NoConst(node))?; + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +/// Remove a [`crate::ops::Const`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConst(pub Node); + +/// Error from an [`RemoveConst`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveConstError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not Const).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConst { + type Error = RemoveConstError; + + // The parent of the Const node. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { + return Err(RemoveConstError::InvalidNode(node)); + } + + if h.output_neighbours(node).next().is_some() { + return Err(RemoveConstError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .get_parent(node) + .expect("Const node without a parent shouldn't happen."); + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} From cdde50352c7853b409a24459660658d1e54a1ca1 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 13:27:07 +0000 Subject: [PATCH 15/40] use remove rewrites while folding --- src/algorithm/const_fold.rs | 180 ++++++++++++++++++++++++------------ 1 file changed, 120 insertions(+), 60 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 92ec422a7..532251dff 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -1,16 +1,18 @@ //! Constant folding routines. +use std::collections::{BTreeSet, HashMap}; + use itertools::Itertools; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{ConstFoldResult, ExtensionSet, PRELUDE_REGISTRY}, - hugr::views::SiblingSubgraph, + extension::{ConstFoldResult, ExtensionRegistry}, + hugr::{rewrite::consts::RemoveConstIgnore, views::SiblingSubgraph}, ops::{Const, LeafOp, OpType}, type_row, types::{FunctionType, Type, TypeEnum}, values::Value, - Hugr, HugrView, IncomingPort, OutgoingPort, SimpleReplacement, + Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; fn out_row(consts: impl IntoIterator) -> ConstFoldResult { @@ -72,15 +74,8 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes } } -fn const_graph(mut consts: Vec<(OutgoingPort, Const)>) -> Hugr { - consts.sort_by_key(|(o, _)| *o); - let consts = consts.into_iter().map(|(_, c)| c).collect_vec(); +fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); - // TODO need to get const extensions. - let extensions: ExtensionSet = consts - .iter() - .map(|c| c.value().extension_reqs()) - .fold(ExtensionSet::new(), |e, e2| e.union(&e2)); let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); let outputs = consts @@ -88,60 +83,100 @@ fn const_graph(mut consts: Vec<(OutgoingPort, Const)>) -> Hugr { .map(|c| b.add_load_const(c).unwrap()) .collect_vec(); - b.finish_hugr_with_outputs(outputs, &PRELUDE_REGISTRY) - .unwrap() + b.finish_hugr_with_outputs(outputs, reg).unwrap() +} + +pub fn find_consts<'a, 'r: 'a>( + hugr: &'a impl HugrView, + reg: &'r ExtensionRegistry, +) -> impl Iterator)> + 'a { + let mut used_neighbours = BTreeSet::new(); + + hugr.nodes() + .filter_map(move |n| { + hugr.get_optype(n).is_load_constant().then_some(())?; + + let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; + let neighbours = hugr + .linked_inputs(n, out_p) + .filter(|(n, _)| used_neighbours.insert(*n)) + .collect_vec(); + if neighbours.is_empty() { + return None; + } + let fold_iter = neighbours + .into_iter() + .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); + Some(fold_iter) + }) + .flatten() } -fn find_consts( +fn fold_op( hugr: &impl HugrView, + op_node: Node, reg: &ExtensionRegistry, -) -> impl Iterator + '_ { - hugr.nodes().filter_map(|n| { - let op = hugr.get_optype(n); - - op.is_load_constant().then_some(())?; - - let neighbour = hugr.output_neighbours(n).exactly_one().ok()?; - - let mut remove_nodes = vec![neighbour]; - - let all_ins = hugr - .node_inputs(neighbour) - .filter_map(|in_p| { - let (in_n, _) = hugr.single_linked_output(neighbour, in_p)?; - let op = hugr.get_optype(in_n); - - op.is_load_constant().then_some(())?; - - let const_node = hugr.input_neighbours(in_n).exactly_one().ok()?; - let const_op = hugr.get_optype(const_node).as_const()?; - - remove_nodes.push(const_node); - remove_nodes.push(in_n); - - // TODO avoid const clone here - Some((in_p, const_op.clone())) - }) - .collect_vec(); - - let neighbour_op = hugr.get_optype(neighbour); - - let folded = fold_const(neighbour_op, &all_ins)?; - - let replacement = const_graph(folded); - - let sibling_graph = SiblingSubgraph::try_from_nodes(remove_nodes, hugr) - .expect("Make unmake should be valid subgraph."); +) -> Option<(SimpleReplacement, Vec)> { + let (in_consts, removals): (Vec<_>, Vec<_>) = hugr + .node_inputs(op_node) + .filter_map(|in_p| get_const(hugr, op_node, in_p)) + .unzip(); + let neighbour_op = hugr.get_optype(op_node); + let folded = fold_const(neighbour_op, &in_consts)?; + let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); + let nu_out = op_outs + .into_iter() + .flat_map(|out| { + // map from the ports the op was linked to, to the output ports of + // the replacement. + hugr.linked_inputs(op_node, out) + .enumerate() + .map(|(i, np)| (np, i.into())) + }) + .collect(); + let replacement = const_graph(consts, reg); + let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) + .expect("Load consts and operation should form valid subgraph."); + + let simple_replace = SimpleReplacement::new( + sibling_graph, + replacement, + // no inputs to replacement + HashMap::new(), + nu_out, + ); + Some((simple_replace, removals)) +} - sibling_graph - .create_simple_replacement(hugr, replacement) - .ok() - }) +fn get_const( + hugr: &impl HugrView, + op_node: Node, + in_p: IncomingPort, +) -> Option<((IncomingPort, Const), RemoveConstIgnore)> { + let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; + let load_op = hugr.get_optype(load_n).as_load_constant()?; + let const_node = hugr + .linked_outputs(load_n, load_op.constant_port()) + .exactly_one() + .ok()? + .0; + + let const_op = hugr.get_optype(const_node).as_const()?; + + // remove_nodes.push(in_n); + + // TODO avoid const clone here + Some(((in_p, const_op.clone()), RemoveConstIgnore(load_n))) } #[cfg(test)] mod test { + use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::hugr::rewrite::consts::RemoveConst; + + use crate::hugr::HugrMut; + use crate::std_extensions::arithmetic; use crate::std_extensions::arithmetic::int_ops::IntOpDef; use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; @@ -186,14 +221,39 @@ mod test { .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), - crate::std_extensions::arithmetic::int_types::EXTENSION.to_owned(), - crate::std_extensions::arithmetic::int_ops::EXTENSION.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::int_ops::EXTENSION.to_owned(), ]) .unwrap(); - let h = b.finish_hugr_with_outputs(add.outputs(), ®).unwrap(); + let mut h = b.finish_hugr_with_outputs(add.outputs(), ®).unwrap(); + assert_eq!(h.node_count(), 8); + + let (repl, removes) = find_consts(&h, ®).exactly_one().ok().unwrap(); + h.apply_rewrite(repl).unwrap(); + for rem in removes { + if let Ok(const_node) = h.apply_rewrite(rem) { + if h.apply_rewrite(RemoveConst(const_node)).is_err() { + continue; + } + } + } + + assert_fully_folded(&h, &i2c(3)); + } - let consts = find_consts(&h).collect_vec(); + fn assert_fully_folded(h: &Hugr, expected_const: &Const) { + // check the hugr just loads and returns a single const + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if c == expected_const => node_count += 1, + _ => panic!("unexpected op: {:?}", op), + } + } - dbg!(consts); + assert_eq!(node_count, 4); } } From 114524ca4523ec9fc32d995aae8c752573b74440 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 15:11:47 +0000 Subject: [PATCH 16/40] alllow candidate node specification in find_consts --- src/algorithm/const_fold.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 532251dff..30d2a2a43 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -88,11 +88,13 @@ fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { pub fn find_consts<'a, 'r: 'a>( hugr: &'a impl HugrView, + candidate_nodes: impl IntoIterator + 'a, reg: &'r ExtensionRegistry, ) -> impl Iterator)> + 'a { let mut used_neighbours = BTreeSet::new(); - hugr.nodes() + candidate_nodes + .into_iter() .filter_map(move |n| { hugr.get_optype(n).is_load_constant().then_some(())?; @@ -228,7 +230,7 @@ mod test { let mut h = b.finish_hugr_with_outputs(add.outputs(), ®).unwrap(); assert_eq!(h.node_count(), 8); - let (repl, removes) = find_consts(&h, ®).exactly_one().ok().unwrap(); + let (repl, removes) = find_consts(&h, h.nodes(), ®).exactly_one().ok().unwrap(); h.apply_rewrite(repl).unwrap(); for rem in removes { if let Ok(const_node) = h.apply_rewrite(rem) { From a087fbc9749629f00fdf7b391d89a707beed52e9 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 15:41:29 +0000 Subject: [PATCH 17/40] add exhaustive fold pass --- src/algorithm/const_fold.rs | 42 +++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 30d2a2a43..d2de33a40 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -7,7 +7,11 @@ use itertools::Itertools; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::{ConstFoldResult, ExtensionRegistry}, - hugr::{rewrite::consts::RemoveConstIgnore, views::SiblingSubgraph}, + hugr::{ + rewrite::consts::{RemoveConst, RemoveConstIgnore}, + views::SiblingSubgraph, + HugrMut, + }, ops::{Const, LeafOp, OpType}, type_row, types::{FunctionType, Type, TypeEnum}, @@ -165,12 +169,33 @@ fn get_const( let const_op = hugr.get_optype(const_node).as_const()?; - // remove_nodes.push(in_n); - // TODO avoid const clone here Some(((in_p, const_op.clone()), RemoveConstIgnore(load_n))) } +pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { + loop { + // would be preferable if the candidates were updated to be just the + // neighbouring nodes of those added. + let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); + if rewrites.is_empty() { + break; + } + for (replace, removes) in rewrites { + h.apply_rewrite(replace).unwrap(); + for rem in removes { + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + if h.apply_rewrite(RemoveConst(const_node)).is_err() { + // const cannot be removed - no problem + continue; + } + } + } + } + } +} + #[cfg(test)] mod test { @@ -231,13 +256,12 @@ mod test { assert_eq!(h.node_count(), 8); let (repl, removes) = find_consts(&h, h.nodes(), ®).exactly_one().ok().unwrap(); + let [remove_1, remove_2] = removes.try_into().unwrap(); + h.apply_rewrite(repl).unwrap(); - for rem in removes { - if let Ok(const_node) = h.apply_rewrite(rem) { - if h.apply_rewrite(RemoveConst(const_node)).is_err() { - continue; - } - } + for rem in [remove_1, remove_2] { + let const_node = h.apply_rewrite(rem).unwrap(); + h.apply_rewrite(RemoveConst(const_node)).unwrap(); } assert_fully_folded(&h, &i2c(3)); From 07768b2ac3939c58317c2fba2ee631f4f035ac29 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 16:21:38 +0000 Subject: [PATCH 18/40] refactor!: use enum op traits for floats + conversions BREAKING CHANGES: extension() function replaced with EXTENSION static ref for float_ops and conversions --- src/ops/constant.rs | 2 +- src/std_extensions/arithmetic/conversions.rs | 164 ++++++++++----- src/std_extensions/arithmetic/float_ops.rs | 198 ++++++++++--------- src/std_extensions/arithmetic/float_types.rs | 34 ++-- src/std_extensions/arithmetic/int_ops.rs | 2 +- src/std_extensions/collections.rs | 2 +- src/utils.rs | 2 +- 7 files changed, 245 insertions(+), 159 deletions(-) diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 6db0006b3..c66ab6f76 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -148,7 +148,7 @@ mod test { use super::*; fn test_registry() -> ExtensionRegistry { - ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap() + ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap() } #[test] diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 4ae262b77..98e5df887 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,63 +1,131 @@ //! Conversions between integer and floating-point values. +use smol_str::SmolStr; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + use crate::{ - extension::{prelude::sum_with_error, ExtensionId, ExtensionSet}, + extension::{ + prelude::sum_with_error, + simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, + }, + ops::{custom::ExtensionOp, OpName}, type_row, - types::{FunctionType, PolyFuncType}, + types::{FunctionType, PolyFuncType, TypeArg}, Extension, }; use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { - let ftoi_sig = PolyFuncType::new( - vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), - ); - - let itof_sig = PolyFuncType::new( - vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), - ); - - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ]), - ); - extension - .add_op( - "trunc_u".into(), - "float to unsigned int".to_owned(), - ftoi_sig.clone(), - ) - .unwrap(); - extension - .add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig) - .unwrap(); - extension - .add_op( - "convert_u".into(), - "unsigned int to float".to_owned(), - itof_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "convert_s".into(), - "signed int to float".to_owned(), - itof_sig, - ) - .unwrap(); - - extension +/// Extensiop for conversions between floats and integers. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum ConvertOpDef { + trunc_u, + trunc_s, + convert_u, + convert_s, +} + +impl MakeOpDef for ConvertOpDef { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use ConvertOpDef::*; + match self { + trunc_s | trunc_u => PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), + ), + + convert_s | convert_u => PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), + ), + } + .into() + } + + fn description(&self) -> String { + use ConvertOpDef::*; + match self { + trunc_u => "float to unsigned int", + trunc_s => "float to signed int", + convert_u => "unsigned int to float", + convert_s => "signed int to float", + } + .to_string() + } +} + +/// Concrete convert operation with integer width set. +#[derive(Debug, Clone, PartialEq)] +pub struct ConvertOpType { + def: ConvertOpDef, + width: u64, +} + +impl OpName for ConvertOpType { + fn name(&self) -> SmolStr { + self.def.name() + } +} + +impl MakeExtensionOp for ConvertOpType { + fn from_extension_op(ext_op: &ExtensionOp) -> Result { + let def = ConvertOpDef::from_def(ext_op.def())?; + let width = match *ext_op.args() { + [TypeArg::BoundedNat { n }] => n, + _ => return Err(SignatureError::InvalidTypeArgs.into()), + }; + Ok(Self { def, width }) + } + + fn type_args(&self) -> Vec { + vec![TypeArg::BoundedNat { n: self.width }] + } +} + +lazy_static! { + /// Extension for conversions between integers and floats. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::from_iter(vec![ + super::int_types::EXTENSION_ID, + super::float_types::EXTENSION_ID, + ]), + ); + + ConvertOpDef::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate integer operations. + pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + super::int_types::EXTENSION.to_owned(), + super::float_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +impl MakeRegisteredOp for ConvertOpType { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &CONVERT_OPS_REGISTRY + } } #[cfg(test)] @@ -66,7 +134,7 @@ mod test { #[test] fn test_conversions_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.conversions"); assert_eq!(r.types().count(), 0); for (name, _) in r.operations() { diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 5cef5d19a..62a23b294 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -1,103 +1,119 @@ //! Basic floating-point operations. +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + +use super::float_types::FLOAT64_TYPE; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::{ + prelude::BOOL_T, + simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, PRELUDE, + }, type_row, - types::{FunctionType, PolyFuncType}, + types::FunctionType, Extension, }; - -use super::float_types::FLOAT64_TYPE; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::singleton(&super::float_types::EXTENSION_ID), - ); - - let fcmp_sig: PolyFuncType = FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![crate::extension::prelude::BOOL_T], - ) - .into(); - let fbinop_sig: PolyFuncType = - FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into(); - let funop_sig: PolyFuncType = - FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into(); - extension - .add_op("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op( - "fgt".into(), - "\"greater than\"".to_owned(), - fcmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fle".into(), - "\"less than or equal\"".to_owned(), - fcmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fge".into(), - "\"greater than or equal\"".to_owned(), - fcmp_sig, - ) - .unwrap(); - extension - .add_op("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fneg".into(), "negation".to_owned(), funop_sig.clone()) - .unwrap(); - extension - .add_op( - "fabs".into(), - "absolute value".to_owned(), - funop_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fmul".into(), - "multiplication".to_owned(), - fbinop_sig.clone(), - ) - .unwrap(); - extension - .add_op("fdiv".into(), "division".to_owned(), fbinop_sig) - .unwrap(); - extension - .add_op("ffloor".into(), "floor".to_owned(), funop_sig.clone()) - .unwrap(); - extension - .add_op("fceil".into(), "ceiling".to_owned(), funop_sig) - .unwrap(); - - extension +/// Integer extension operation definitions. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum FloatOps { + feq, + fne, + flt, + fgt, + fle, + fge, + fmax, + fmin, + fadd, + fsub, + fneg, + fabs, + fmul, + fdiv, + ffloor, + fceil, +} + +impl MakeOpDef for FloatOps { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use FloatOps::*; + + match self { + feq | fne | flt | fgt | fle | fge => { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T]) + } + fmax | fmin | fadd | fsub | fmul | fdiv => { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]) + } + fneg | fabs | ffloor | fceil => FunctionType::new_endo(type_row![FLOAT64_TYPE]), + } + .into() + } + + fn description(&self) -> String { + use FloatOps::*; + match self { + feq => "equality test", + fne => "inequality test", + flt => "\"less than\"", + fgt => "\"greater than\"", + fle => "\"less than or equal\"", + fge => "\"greater than or equal\"", + fmax => "maximum", + fmin => "minimum", + fadd => "addition", + fsub => "subtraction", + fneg => "negation", + fabs => "absolute value", + fmul => "multiplication", + fdiv => "division", + ffloor => "floot", + fceil => "ceiling", + } + .to_string() + } +} + +lazy_static! { + /// Extension for basic float operations. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::singleton(&super::int_types::EXTENSION_ID), + ); + + FloatOps::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate float operations. + pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + super::float_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +impl MakeRegisteredOp for FloatOps { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &FLOAT_OPS_REGISTRY + } } #[cfg(test)] @@ -106,7 +122,7 @@ mod test { #[test] fn test_float_ops_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.float"); assert_eq!(r.types().count(), 0); for (name, _) in r.operations() { diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 582c0eee7..71f91bf87 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -8,6 +8,7 @@ use crate::{ values::{CustomConst, KnownTypeConst}, Extension, }; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float.types"); @@ -72,29 +73,30 @@ impl CustomConst for ConstF64 { } } -/// Extension for basic floating-point types. -pub fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_ID); - - extension - .add_type( - FLOAT_TYPE_ID, - vec![], - "64-bit IEEE 754-2019 floating-point value".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - extension +lazy_static! { + /// Extension defining the float type. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new(EXTENSION_ID); + + extension + .add_type( + FLOAT_TYPE_ID, + vec![], + "64-bit IEEE 754-2019 floating-point value".to_owned(), + TypeBound::Copyable.into(), + ) + .unwrap(); + + extension + }; } - #[cfg(test)] mod test { use super::*; #[test] fn test_float_types_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.float.types"); assert_eq!(r.types().count(), 1); assert_eq!(r.operations().count(), 0); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 267fde902..ae5160ffd 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -41,7 +41,7 @@ impl ValidateJustArgs for IOValidator { Ok(()) } } -/// Logic extension operation definitions. +/// Integer extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs, non_camel_case_types)] pub enum IntOpDef { diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index a78f4793a..ebec9bda7 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -177,7 +177,7 @@ mod test { let reg = ExtensionRegistry::try_new([ EXTENSION.to_owned(), PRELUDE.to_owned(), - float_types::extension(), + float_types::EXTENSION.to_owned(), ]) .unwrap(); let pop_sig = get_op(&POP_NAME) diff --git a/src/utils.rs b/src/utils.rs index 6c297b4bd..f62cee50f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -86,7 +86,7 @@ pub(crate) mod test_quantum_extension { lazy_static! { /// Quantum extension definition. pub static ref EXTENSION: Extension = extension(); - static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::extension()]).unwrap(); + static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap(); } fn get_gate(gate_name: &str) -> LeafOp { From 658adf44eaa8cccb4ca097fa0f4c6e4674be4d19 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 17:25:15 +0000 Subject: [PATCH 19/40] add folding definitions for float ops --- src/algorithm/const_fold.rs | 2 +- src/std_extensions/arithmetic/float_ops.rs | 6 +- .../arithmetic/float_ops/fold.rs | 120 ++++++++++++++++++ 3 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 src/std_extensions/arithmetic/float_ops/fold.rs diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index d2de33a40..5a3e6037c 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -35,7 +35,7 @@ fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Cons v } -fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { +pub fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { sort_by_in_port(consts) .into_iter() .map(|(_, c)| c) diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 62a23b294..f21ce104c 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; use lazy_static::lazy_static; - +mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); @@ -82,6 +82,10 @@ impl MakeOpDef for FloatOps { } .to_string() } + + fn post_opdef(&self, def: &mut OpDef) { + fold::set_fold(self, def) + } } lazy_static! { diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/fold.rs new file mode 100644 index 000000000..f56814d22 --- /dev/null +++ b/src/std_extensions/arithmetic/float_ops/fold.rs @@ -0,0 +1,120 @@ +use crate::{ + algorithm::const_fold::sorted_consts, + extension::{ConstFold, ConstFoldResult, OpDef}, + ops, + std_extensions::arithmetic::float_types::ConstF64, + IncomingPort, +}; + +use super::FloatOps; + +pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { + use FloatOps::*; + + match op { + fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)), + feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)), + fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)), + } +} + +fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[f64; N]> { + let consts: [&ops::Const; N] = sorted_consts(consts).try_into().ok()?; + + Some(consts.map(|c| { + let const_f64: &ConstF64 = c + .get_custom_value() + .expect("This function assumes all incoming constants are floats."); + const_f64.value() + })) +} + +struct BinaryFold(Box f64 + Send + Sync>); +impl BinaryFold { + fn from_op(op: &FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(match op { + fmax => f64::max, + fmin => f64::min, + fadd => std::ops::Add::add, + fsub => std::ops::Sub::sub, + fmul => std::ops::Mul::mul, + fdiv => std::ops::Div::div, + _ => panic!("not binary op"), + })) + } +} +impl ConstFold for BinaryFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1, f2] = get_floats(consts)?; + + let res = ConstF64::new((self.0)(f1, f2)); + Some(vec![(0.into(), res.into())]) + } +} + +struct CmpFold(Box bool + Send + Sync>); +impl CmpFold { + fn from_op(op: FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(move |x, y| { + (match op { + feq => f64::eq, + fne => f64::lt, + flt => f64::lt, + fgt => f64::gt, + fle => f64::le, + fge => f64::ge, + _ => panic!("not cmp op"), + })(&x, &y) + })) + } +} + +impl ConstFold for CmpFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1, f2] = get_floats(consts)?; + + let res = if (self.0)(f1, f2) { + ops::Const::true_val() + } else { + ops::Const::false_val() + }; + + Some(vec![(0.into(), res)]) + } +} + +struct UnaryFold(Box f64 + Send + Sync>); +impl UnaryFold { + fn from_op(op: &FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(match op { + fneg => std::ops::Neg::neg, + fabs => f64::abs, + ffloor => f64::floor, + fceil => f64::ceil, + _ => panic!("not unary op."), + })) + } +} + +impl ConstFold for UnaryFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1] = get_floats(consts)?; + let res = ConstF64::new((self.0)(f1)); + Some(vec![(0.into(), res.into())]) + } +} From 2c0e75b54f8c9744faa12d555c75787f96f0dfc7 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 19:03:01 +0000 Subject: [PATCH 20/40] refactor: ERROR_CUSTOM_TYPE --- src/extension/prelude.rs | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index f96046ba8..c2587621e 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,12 +137,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } +/// The custom type for Errors. +pub const ERROR_CUSTOM_TYPE: CustomType = + CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( - ERROR_TYPE_NAME, - PRELUDE_ID, - TypeBound::Eq, -)); +pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -191,6 +190,36 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +/// Structure for holding constant usize values. +pub struct ConstError { + pub signal: u32, + pub message: String, +} + +#[typetag::serde] +impl CustomConst for ConstError { + fn name(&self) -> SmolStr { + format!("ConstError({:?}, {:?})", self.signal, self.message).into() + } + + fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { + self.check_known_type(typ) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::values::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } +} + +impl KnownTypeConst for ConstError { + const TYPE: CustomType = ERROR_CUSTOM_TYPE; +} + #[cfg(test)] mod test { use crate::{ From dc7ff131607c4bb4454b7e7ac8d4624f24f76d17 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 19:03:29 +0000 Subject: [PATCH 21/40] refactor: const ConstF64::new --- src/std_extensions/arithmetic/float_types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 71f91bf87..ba5fe0956 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -40,7 +40,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Create a new [`ConstF64`] - pub fn new(value: f64) -> Self { + pub const fn new(value: f64) -> Self { Self { value } } From aa73ab24afe4541bd1767b8e5cb64a8180b54c16 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 19:03:41 +0000 Subject: [PATCH 22/40] feat: implement folding for conversion ops --- src/std_extensions/arithmetic/conversions.rs | 6 +- .../arithmetic/conversions/fold.rs | 120 ++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 src/std_extensions/arithmetic/conversions/fold.rs diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 98e5df887..cf5b53d3f 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -18,7 +18,7 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; - +mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -63,6 +63,10 @@ impl MakeOpDef for ConvertOpDef { } .to_string() } + + fn post_opdef(&self, def: &mut OpDef) { + fold::set_fold(self, def) + } } /// Concrete convert operation with integer width set. diff --git a/src/std_extensions/arithmetic/conversions/fold.rs b/src/std_extensions/arithmetic/conversions/fold.rs new file mode 100644 index 000000000..bfb32892f --- /dev/null +++ b/src/std_extensions/arithmetic/conversions/fold.rs @@ -0,0 +1,120 @@ +use crate::{ + extension::{prelude::ConstError, ConstFold, ConstFoldResult, OpDef}, + ops, + std_extensions::arithmetic::{ + float_types::ConstF64, + int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES}, + }, + types::ConstTypeError, + values::{CustomConst, Value}, + IncomingPort, +}; + +use super::ConvertOpDef; + +pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) { + use ConvertOpDef::*; + + match op { + trunc_u => def.set_constant_folder(TruncU), + trunc_s => def.set_constant_folder(TruncS), + convert_u => def.set_constant_folder(ConvertU), + convert_s => def.set_constant_folder(ConvertS), + } +} + +fn get_input(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> { + let [(_, c)] = consts else { + return None; + }; + c.get_custom_value() +} + +fn fold_trunc( + type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + convert: impl Fn(f64, u8) -> Result, +) -> ConstFoldResult { + let f: &ConstF64 = get_input(consts)?; + let f = f.value(); + + let err_value = || { + ConstError { + signal: 0, + message: "Can't truncate non-finite float".to_string(), + } + .into() + }; + let out_const: ops::Const = if !f.is_finite() { + err_value() + } else { + let [arg] = type_args else { + return None; + }; + let log_width = get_log_width(arg).ok()?; + let cv = convert(f, log_width); + if let Ok(cv) = cv { + ops::Const::new(cv, INT_TYPES[log_width as usize].to_owned()).unwrap() + } else { + err_value() + } + }; + + Some(vec![(0.into(), out_const)]) +} + +struct TruncU; + +impl ConstFold for TruncU { + fn fold( + &self, + type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + fold_trunc(type_args, consts, |f, log_width| { + ConstIntU::new(log_width, f.trunc() as u64).map(Into::into) + }) + } +} + +struct TruncS; + +impl ConstFold for TruncS { + fn fold( + &self, + type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + fold_trunc(type_args, consts, |f, log_width| { + ConstIntS::new(log_width, f.trunc() as i64).map(Into::into) + }) + } +} + +struct ConvertU; + +impl ConstFold for ConvertU { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let u: &ConstIntU = get_input(consts)?; + let f = u.value() as f64; + Some(vec![(0.into(), ConstF64::new(f).into())]) + } +} + +struct ConvertS; + +impl ConstFold for ConvertS { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let u: &ConstIntS = get_input(consts)?; + let f = u.value() as f64; + Some(vec![(0.into(), ConstF64::new(f).into())]) + } +} From a519f344efeb53741c8dfc490757f018c8484ec6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 19:05:14 +0000 Subject: [PATCH 23/40] fixup! refactor: ERROR_CUSTOM_TYPE --- src/extension/prelude.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index c2587621e..83d24db03 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -193,7 +193,9 @@ impl KnownTypeConst for ConstUsize { #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// Structure for holding constant usize values. pub struct ConstError { + /// Integer tag/signal for the error. pub signal: u32, + /// Error message. pub message: String, } From 46075c2b73f545aafa5d34d90cfdfb525866787b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 20:03:43 +0000 Subject: [PATCH 24/40] implement bigger tests and fix unearthed bugs --- src/algorithm/const_fold.rs | 72 +++++++++++++++++-- src/std_extensions/arithmetic/conversions.rs | 20 +++++- .../arithmetic/conversions/fold.rs | 34 ++++++--- 3 files changed, 107 insertions(+), 19 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 5a3e6037c..a27c1ed4f 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -25,7 +25,6 @@ fn out_row(consts: impl IntoIterator) -> ConstFoldResult { .enumerate() .map(|(i, c)| (i.into(), c)) .collect(); - Some(vec) } @@ -35,7 +34,7 @@ fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Cons v } -pub fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { +pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { sort_by_in_port(consts) .into_iter() .map(|(_, c)| c) @@ -132,12 +131,12 @@ fn fold_op( let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); let nu_out = op_outs .into_iter() - .flat_map(|out| { + .enumerate() + .filter_map(|(i, out)| { // map from the ports the op was linked to, to the output ports of // the replacement. - hugr.linked_inputs(op_node, out) - .enumerate() - .map(|(i, np)| (np, i.into())) + hugr.single_linked_input(op_node, out) + .map(|np| (np, i.into())) }) .collect(); let replacement = const_graph(consts, reg); @@ -199,11 +198,15 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { #[cfg(test)] mod test { + use crate::extension::prelude::sum_with_error; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::hugr::rewrite::consts::RemoveConst; use crate::hugr::HugrMut; use crate::std_extensions::arithmetic; + use crate::std_extensions::arithmetic::conversions::ConvertOpDef; + use crate::std_extensions::arithmetic::float_ops::FloatOps; + use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::std_extensions::arithmetic::int_ops::IntOpDef; use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; @@ -219,6 +222,10 @@ mod test { .unwrap() } + fn f2c(f: f64) -> Const { + ConstF64::new(f).into() + } + #[rstest] #[case(0, 0, 0)] #[case(0, 1, 1)] @@ -267,6 +274,59 @@ mod test { assert_fully_folded(&h, &i2c(3)); } + #[test] + fn test_big() { + /* + Test approximately calculates + let x = (5.6, 3.2); + int(x.0 - x.1) == 2 + */ + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap(); + + let tup = build + .add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)])) + .unwrap(); + + let unpack = build + .add_dataflow_op( + LeafOp::UnpackTuple { + tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], + }, + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + let to_int = build + .add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build + .finish_hugr_with_outputs(to_int.outputs(), ®) + .unwrap(); + assert_eq!(h.node_count(), 8); + + constant_fold_pass(&mut h, ®); + + let expected = Value::Sum { + tag: 0, + value: Box::new(i2c(2).value().clone()), + }; + let expected = Const::new(expected, sum_type).unwrap(); + assert_fully_folded(&h, &expected); + } fn assert_fully_folded(h: &Hugr, expected_const: &Const) { // check the hugr just loads and returns a single const let mut node_count = 0; diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index cf5b53d3f..5ff716e7a 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -8,6 +8,7 @@ use crate::{ prelude::sum_with_error, simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, + PRELUDE, }, ops::{custom::ExtensionOp, OpName}, type_row, @@ -69,11 +70,20 @@ impl MakeOpDef for ConvertOpDef { } } +impl ConvertOpDef { + /// INitialise a conversion op with an integer log width type argument. + pub fn with_width(self, log_width: u8) -> ConvertOpType { + ConvertOpType { + def: self, + log_width: log_width as u64, + } + } +} /// Concrete convert operation with integer width set. #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { def: ConvertOpDef, - width: u64, + log_width: u64, } impl OpName for ConvertOpType { @@ -89,11 +99,14 @@ impl MakeExtensionOp for ConvertOpType { [TypeArg::BoundedNat { n }] => n, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(Self { def, width }) + Ok(Self { + def, + log_width: width, + }) } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.width }] + vec![TypeArg::BoundedNat { n: self.log_width }] } } @@ -115,6 +128,7 @@ lazy_static! { /// Registry of extensions required to validate integer operations. pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), super::int_types::EXTENSION.to_owned(), super::float_types::EXTENSION.to_owned(), EXTENSION.to_owned(), diff --git a/src/std_extensions/arithmetic/conversions/fold.rs b/src/std_extensions/arithmetic/conversions/fold.rs index bfb32892f..3814c0504 100644 --- a/src/std_extensions/arithmetic/conversions/fold.rs +++ b/src/std_extensions/arithmetic/conversions/fold.rs @@ -1,5 +1,8 @@ use crate::{ - extension::{prelude::ConstError, ConstFold, ConstFoldResult, OpDef}, + extension::{ + prelude::{sum_with_error, ConstError}, + ConstFold, ConstFoldResult, OpDef, + }, ops, std_extensions::arithmetic::{ float_types::ConstF64, @@ -37,24 +40,35 @@ fn fold_trunc( ) -> ConstFoldResult { let f: &ConstF64 = get_input(consts)?; let f = f.value(); - + let [arg] = type_args else { + return None; + }; + let log_width = get_log_width(arg).ok()?; + let int_type = INT_TYPES[log_width as usize].to_owned(); + let sum_type = sum_with_error(int_type.clone()); let err_value = || { - ConstError { + let err_val = ConstError { signal: 0, message: "Can't truncate non-finite float".to_string(), - } - .into() + }; + let sum_val = Value::Sum { + tag: 1, + value: Box::new(err_val.into()), + }; + + ops::Const::new(sum_val, sum_type.clone()).unwrap() }; let out_const: ops::Const = if !f.is_finite() { err_value() } else { - let [arg] = type_args else { - return None; - }; - let log_width = get_log_width(arg).ok()?; let cv = convert(f, log_width); if let Ok(cv) = cv { - ops::Const::new(cv, INT_TYPES[log_width as usize].to_owned()).unwrap() + let sum_val = Value::Sum { + tag: 0, + value: Box::new(cv), + }; + + ops::Const::new(sum_val, sum_type).unwrap() } else { err_value() } From df854e82bb877a36713d6f22576f097e87435141 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 11:54:47 +0000 Subject: [PATCH 25/40] Revert "refactor: move hugr equality check out for reuse" This reverts commit 64b91997e7cfee61264ab93179cdff9150c768fa. --- src/hugr.rs | 44 ++----------------------------------------- src/hugr/serialize.rs | 37 ++++++++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/src/hugr.rs b/src/hugr.rs index 2efb8e867..9672f3dbb 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -343,16 +343,12 @@ pub enum HugrError { } #[cfg(test)] -pub(crate) mod test { - use itertools::Itertools; - use portgraph::{LinkView, PortView}; - +mod test { use super::{Hugr, HugrView}; use crate::builder::test::closed_dfg_root_hugr; use crate::extension::ExtensionSet; use crate::hugr::HugrMut; - use crate::ops::LeafOp; - use crate::ops::{self, OpType}; + use crate::ops; use crate::type_row; use crate::types::{FunctionType, Type}; @@ -402,40 +398,4 @@ pub(crate) mod test { assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r); Ok(()) } - - pub(crate) fn assert_hugr_equality(hugr: &Hugr, other: &Hugr) { - assert_eq!(other.root, hugr.root); - assert_eq!(other.hierarchy, hugr.hierarchy); - assert_eq!(other.metadata, hugr.metadata); - - // Extension operations may have been downgraded to opaque operations. - for node in other.nodes() { - let new_op = other.get_optype(node); - let old_op = hugr.get_optype(node); - if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op { - if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op { - assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque()); - } else { - panic!("Expected old_op to be a custom op"); - } - } else { - assert_eq!(new_op, old_op); - } - } - - // Check that the graphs are equivalent up to port renumbering. - let new_graph = &other.graph; - let old_graph = &hugr.graph; - assert_eq!(new_graph.node_count(), old_graph.node_count()); - assert_eq!(new_graph.port_count(), old_graph.port_count()); - assert_eq!(new_graph.link_count(), old_graph.link_count()); - for n in old_graph.nodes_iter() { - assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n)); - assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n)); - assert_eq!( - new_graph.output_neighbours(n).collect_vec(), - old_graph.output_neighbours(n).collect_vec() - ); - } - } } diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 5f9b236f8..b49549b0f 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -260,7 +260,6 @@ pub mod test { use crate::extension::simple_op::MakeRegisteredOp; use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; - use crate::hugr::test::assert_hugr_equality; use crate::hugr::NodeType; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG}; @@ -268,6 +267,7 @@ pub mod test { use crate::types::{FunctionType, Type}; use crate::OutgoingPort; use itertools::Itertools; + use portgraph::LinkView; use portgraph::{ multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap, }; @@ -298,7 +298,40 @@ pub mod test { // The internal port indices may still be different. let mut h_canon = hugr.clone(); h_canon.canonicalize_nodes(|_, _| {}); - assert_hugr_equality(&h_canon, &new_hugr); + + assert_eq!(new_hugr.root, h_canon.root); + assert_eq!(new_hugr.hierarchy, h_canon.hierarchy); + assert_eq!(new_hugr.metadata, h_canon.metadata); + + // Extension operations may have been downgraded to opaque operations. + for node in new_hugr.nodes() { + let new_op = new_hugr.get_optype(node); + let old_op = h_canon.get_optype(node); + if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op { + if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op { + assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque()); + } else { + panic!("Expected old_op to be a custom op"); + } + } else { + assert_eq!(new_op, old_op); + } + } + + // Check that the graphs are equivalent up to port renumbering. + let new_graph = &new_hugr.graph; + let old_graph = &h_canon.graph; + assert_eq!(new_graph.node_count(), old_graph.node_count()); + assert_eq!(new_graph.port_count(), old_graph.port_count()); + assert_eq!(new_graph.link_count(), old_graph.link_count()); + for n in old_graph.nodes_iter() { + assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n)); + assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n)); + assert_eq!( + new_graph.output_neighbours(n).collect_vec(), + old_graph.output_neighbours(n).collect_vec() + ); + } new_hugr } From ba81e7b0e2e4bf597a49e27b29e034c03431af22 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 13:26:44 +0000 Subject: [PATCH 26/40] feat: implement RemoveConst and RemoveConstIgnore as per spec refactor!: allow Into for builder.add_const BREAKING_CHANGES: existing .into() calls will error --- src/builder/build_traits.rs | 6 +- src/builder/tail_loop.rs | 4 +- src/hugr/rewrite.rs | 1 + src/hugr/rewrite/consts.rs | 228 ++++++++++++++++++++++++++++++++++++ src/hugr/views/tests.rs | 2 +- 5 files changed, 235 insertions(+), 6 deletions(-) create mode 100644 src/hugr/rewrite/consts.rs diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 641ef1ae2..c85a02eb0 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -70,8 +70,8 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant(&mut self, constant: ops::Const) -> Result { - let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?; + fn add_constant(&mut self, constant: impl Into) -> Result { + let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?; Ok(const_n.into()) } @@ -374,7 +374,7 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const(&mut self, constant: ops::Const) -> Result { + fn add_load_const(&mut self, constant: impl Into) -> Result { let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index bbcddade7..8f99ee512 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -109,7 +109,7 @@ mod test { let build_result: Result = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?; + let const_wire = loop_b.add_load_const(ConstUsize::new(1))?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -173,7 +173,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const(ConstUsize::new(2).into())?; + let wire = branch_1.add_load_const(ConstUsize::new(2))?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index 3f524e2db..05f1f48d9 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -1,5 +1,6 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. +pub mod consts; pub mod insert_identity; pub mod outline_cfg; pub mod replace; diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs new file mode 100644 index 000000000..f00b67cb0 --- /dev/null +++ b/src/hugr/rewrite/consts.rs @@ -0,0 +1,228 @@ +//! Rewrite operations involving Const and LoadConst operations + +use std::iter; + +use crate::{ + hugr::{HugrError, HugrMut}, + HugrView, Node, +}; +use itertools::Itertools; +use thiserror::Error; + +use super::Rewrite; + +/// Remove a [`crate::ops::LoadConstant`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConstIgnore(pub Node); + +/// Error from an [`RemoveConstIgnore`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveConstIgnoreError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not LoadConst).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Not connected to a Const. + #[error("Node: {0:?} is not connected to a Const node.")] + NoConst(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConstIgnore { + type Error = RemoveConstIgnoreError; + + // The Const node the LoadConstant was connected to. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { + return Err(RemoveConstIgnoreError::InvalidNode(node)); + } + + if h.out_value_types(node) + .next() + .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) + { + return Err(RemoveConstIgnoreError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .input_neighbours(node) + .exactly_one() + .map_err(|_| RemoveConstIgnoreError::NoConst(node))?; + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +/// Remove a [`crate::ops::Const`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConst(pub Node); + +/// Error from an [`RemoveConst`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveConstError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not Const).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConst { + type Error = RemoveConstError; + + // The parent of the Const node. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { + return Err(RemoveConstError::InvalidNode(node)); + } + + if h.output_neighbours(node).next().is_some() { + return Err(RemoveConstError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.verify(h)?; + let node = self.0; + let source = h + .get_parent(node) + .expect("Const node without a parent shouldn't happen."); + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, + extension::{ + prelude::{ConstUsize, USIZE_T}, + PRELUDE_REGISTRY, + }, + hugr::HugrMut, + ops::{handle::NodeHandle, LeafOp}, + type_row, + types::FunctionType, + }; + #[test] + fn test_const_remove() -> Result<(), Box> { + let mut build = ModuleBuilder::new(); + let con_node = build.add_constant(ConstUsize::new(2))?; + + let mut dfg_build = + build.define_function("main", FunctionType::new_endo(type_row![]).into())?; + let load_1 = dfg_build.load_const(&con_node)?; + let load_2 = dfg_build.load_const(&con_node)?; + let tup = dfg_build.add_dataflow_op( + LeafOp::MakeTuple { + tys: type_row![USIZE_T, USIZE_T], + }, + [load_1, load_2], + )?; + dfg_build.finish_sub_container()?; + + let mut h = build.finish_prelude_hugr()?; + assert_eq!(h.node_count(), 8); + let tup_node = tup.node(); + // can't remove invalid node + assert_eq!( + h.apply_rewrite(RemoveConst(tup_node)), + Err(RemoveConstError::InvalidNode(tup_node)) + ); + + assert_eq!( + h.apply_rewrite(RemoveConstIgnore(tup_node)), + Err(RemoveConstIgnoreError::InvalidNode(tup_node)) + ); + let load_1_node = load_1.node(); + let load_2_node = load_2.node(); + let con_node = con_node.node(); + + let remove_1 = RemoveConstIgnore(load_1_node); + assert_eq!( + remove_1.invalidation_set().exactly_one().ok(), + Some(load_1_node) + ); + + let remove_2 = RemoveConstIgnore(load_2_node); + + let remove_con = RemoveConst(con_node); + assert_eq!( + remove_con.invalidation_set().exactly_one().ok(), + Some(con_node) + ); + + // can't remove nodes in use + assert_eq!( + h.apply_rewrite(remove_1.clone()), + Err(RemoveConstIgnoreError::ValueUsed(load_1_node)) + ); + + // remove the use + h.remove_node(tup_node)?; + + // remove first load + let reported_con_node = h.apply_rewrite(remove_1)?; + assert_eq!(reported_con_node, con_node); + + // still can't remove const, in use by second load + assert_eq!( + h.apply_rewrite(remove_con.clone()), + Err(RemoveConstError::ValueUsed(con_node)) + ); + + // remove second use + let reported_con_node = h.apply_rewrite(remove_2)?; + assert_eq!(reported_con_node, con_node); + // remove const + assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + + assert_eq!(h.node_count(), 4); + assert!(h.validate(&PRELUDE_REGISTRY).is_ok()); + Ok(()) + } +} diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 97fb50861..9a712dc03 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -143,7 +143,7 @@ fn static_targets() { ) .unwrap(); - let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap(); + let c = dfg.add_constant(ConstUsize::new(1)).unwrap(); let load = dfg.load_const(&c).unwrap(); From 09ce1c935b911f33a9077c610b24440f242a8158 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 13:35:44 +0000 Subject: [PATCH 27/40] remove conversion foldin --- src/algorithm/const_fold.rs | 28 +--- src/std_extensions/arithmetic/conversions.rs | 14 -- .../arithmetic/conversions/fold.rs | 134 ------------------ 3 files changed, 7 insertions(+), 169 deletions(-) delete mode 100644 src/std_extensions/arithmetic/conversions/fold.rs diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index a27c1ed4f..b786b3aa8 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -198,13 +198,12 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { #[cfg(test)] mod test { - use crate::extension::prelude::sum_with_error; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::hugr::rewrite::consts::RemoveConst; use crate::hugr::HugrMut; use crate::std_extensions::arithmetic; - use crate::std_extensions::arithmetic::conversions::ConvertOpDef; + use crate::std_extensions::arithmetic::float_ops::FloatOps; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::std_extensions::arithmetic::int_ops::IntOpDef; @@ -279,14 +278,13 @@ mod test { /* Test approximately calculates let x = (5.6, 3.2); - int(x.0 - x.1) == 2 + x.0 - x.1 == 2.4 */ - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap(); + DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); let tup = build - .add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)])) + .add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) .unwrap(); let unpack = build @@ -301,31 +299,19 @@ mod test { let sub = build .add_dataflow_op(FloatOps::fsub, unpack.outputs()) .unwrap(); - let to_int = build - .add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs()) - .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), arithmetic::float_types::EXTENSION.to_owned(), arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), ]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) - .unwrap(); - assert_eq!(h.node_count(), 8); + let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); + assert_eq!(h.node_count(), 7); constant_fold_pass(&mut h, ®); - let expected = Value::Sum { - tag: 0, - value: Box::new(i2c(2).value().clone()), - }; - let expected = Const::new(expected, sum_type).unwrap(); - assert_fully_folded(&h, &expected); + assert_fully_folded(&h, &f2c(2.25)); } fn assert_fully_folded(h: &Hugr, expected_const: &Const) { // check the hugr just loads and returns a single const diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 5ff716e7a..23b457f7c 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -19,7 +19,6 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; -mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -64,21 +63,8 @@ impl MakeOpDef for ConvertOpDef { } .to_string() } - - fn post_opdef(&self, def: &mut OpDef) { - fold::set_fold(self, def) - } } -impl ConvertOpDef { - /// INitialise a conversion op with an integer log width type argument. - pub fn with_width(self, log_width: u8) -> ConvertOpType { - ConvertOpType { - def: self, - log_width: log_width as u64, - } - } -} /// Concrete convert operation with integer width set. #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { diff --git a/src/std_extensions/arithmetic/conversions/fold.rs b/src/std_extensions/arithmetic/conversions/fold.rs deleted file mode 100644 index 3814c0504..000000000 --- a/src/std_extensions/arithmetic/conversions/fold.rs +++ /dev/null @@ -1,134 +0,0 @@ -use crate::{ - extension::{ - prelude::{sum_with_error, ConstError}, - ConstFold, ConstFoldResult, OpDef, - }, - ops, - std_extensions::arithmetic::{ - float_types::ConstF64, - int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES}, - }, - types::ConstTypeError, - values::{CustomConst, Value}, - IncomingPort, -}; - -use super::ConvertOpDef; - -pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) { - use ConvertOpDef::*; - - match op { - trunc_u => def.set_constant_folder(TruncU), - trunc_s => def.set_constant_folder(TruncS), - convert_u => def.set_constant_folder(ConvertU), - convert_s => def.set_constant_folder(ConvertS), - } -} - -fn get_input(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> { - let [(_, c)] = consts else { - return None; - }; - c.get_custom_value() -} - -fn fold_trunc( - type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - convert: impl Fn(f64, u8) -> Result, -) -> ConstFoldResult { - let f: &ConstF64 = get_input(consts)?; - let f = f.value(); - let [arg] = type_args else { - return None; - }; - let log_width = get_log_width(arg).ok()?; - let int_type = INT_TYPES[log_width as usize].to_owned(); - let sum_type = sum_with_error(int_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Can't truncate non-finite float".to_string(), - }; - let sum_val = Value::Sum { - tag: 1, - value: Box::new(err_val.into()), - }; - - ops::Const::new(sum_val, sum_type.clone()).unwrap() - }; - let out_const: ops::Const = if !f.is_finite() { - err_value() - } else { - let cv = convert(f, log_width); - if let Ok(cv) = cv { - let sum_val = Value::Sum { - tag: 0, - value: Box::new(cv), - }; - - ops::Const::new(sum_val, sum_type).unwrap() - } else { - err_value() - } - }; - - Some(vec![(0.into(), out_const)]) -} - -struct TruncU; - -impl ConstFold for TruncU { - fn fold( - &self, - type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - fold_trunc(type_args, consts, |f, log_width| { - ConstIntU::new(log_width, f.trunc() as u64).map(Into::into) - }) - } -} - -struct TruncS; - -impl ConstFold for TruncS { - fn fold( - &self, - type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - fold_trunc(type_args, consts, |f, log_width| { - ConstIntS::new(log_width, f.trunc() as i64).map(Into::into) - }) - } -} - -struct ConvertU; - -impl ConstFold for ConvertU { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let u: &ConstIntU = get_input(consts)?; - let f = u.value() as f64; - Some(vec![(0.into(), ConstF64::new(f).into())]) - } -} - -struct ConvertS; - -impl ConstFold for ConvertS { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let u: &ConstIntS = get_input(consts)?; - let f = u.value() as f64; - Some(vec![(0.into(), ConstF64::new(f).into())]) - } -} From 26bc5ff1afb2a589e2b7f97249ac1e3c5cdac2c0 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 13:41:17 +0000 Subject: [PATCH 28/40] add rust version guards --- src/hugr/rewrite/consts.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs index f00b67cb0..d8a0060b7 100644 --- a/src/hugr/rewrite/consts.rs +++ b/src/hugr/rewrite/consts.rs @@ -6,6 +6,7 @@ use crate::{ hugr::{HugrError, HugrMut}, HugrView, Node, }; +#[rustversion::since(1.75)] // uses impl in return position use itertools::Itertools; use thiserror::Error; @@ -32,6 +33,7 @@ pub enum RemoveConstIgnoreError { RemoveFail(#[from] HugrError), } +#[rustversion::since(1.75)] // uses impl in return position impl Rewrite for RemoveConstIgnore { type Error = RemoveConstIgnoreError; @@ -134,6 +136,7 @@ impl Rewrite for RemoveConst { } } +#[rustversion::since(1.75)] // uses impl in return position #[cfg(test)] mod test { use super::*; From 5a71f75e41425867416dea84d7fcec0e1102567c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 14:49:50 +0000 Subject: [PATCH 29/40] docs: add public method docstrings --- src/algorithm/const_fold.rs | 7 +++++++ src/extension/const_fold.rs | 7 +++++++ src/extension/op_def.rs | 4 ++++ src/ops/custom.rs | 1 + 4 files changed, 19 insertions(+) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index b786b3aa8..941946df2 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -89,6 +89,12 @@ fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { b.finish_hugr_with_outputs(outputs, reg).unwrap() } +/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, +/// return an iterator of possible constant folding rewrites. The +/// [`SimpleReplacement`] replaces an operation with constants that result from +/// evaluating it, the extension registry `reg` is used to validate the +/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the +/// LoadConstant nodes that could be removed. pub fn find_consts<'a, 'r: 'a>( hugr: &'a impl HugrView, candidate_nodes: impl IntoIterator + 'a, @@ -172,6 +178,7 @@ fn get_const( Some(((in_p, const_op.clone()), RemoveConstIgnore(load_n))) } +/// Exhaustively apply constant folding to a HUGR. pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { loop { // would be preferable if the candidates were updated to be just the diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs index a3bb48c27..ee3c4f7ed 100644 --- a/src/extension/const_fold.rs +++ b/src/extension/const_fold.rs @@ -8,9 +8,16 @@ use crate::OutgoingPort; use crate::ops; +/// Output of constant folding an operation, None indicates folding was either +/// not possible or unsuccessful. An empty vector indicates folding was +/// successful and no values are output. pub type ConstFoldResult = Option>; +/// Trait implemented by extension operations that can perform constant folding. pub trait ConstFold: Send + Sync { + /// Given type arguments `type_args` and + /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, + /// try to evaluate the operation. fn fold( &self, type_args: &[TypeArg], diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 42f6fcf31..2ea686ab4 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -417,10 +417,14 @@ impl OpDef { self.misc.insert(k.to_string(), v) } + /// Set the constant folding function for this Op, which can evaluate it + /// given constant inputs. pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { self.constant_folder = Box::new(fold) } + /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given + /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. pub fn constant_fold( &self, type_args: &[TypeArg], diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 1a9df29c7..d4e4c88a6 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -138,6 +138,7 @@ impl ExtensionOp { self.def.as_ref() } + /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { self.def().constant_fold(self.args(), consts) } From 6fa7eb9eddbb12c4ad7099a237f51f2b8826be42 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 14:55:15 +0000 Subject: [PATCH 30/40] add some docstrings and comments --- src/algorithm/const_fold.rs | 14 ++++++++++---- src/extension/const_fold.rs | 2 ++ src/std_extensions/arithmetic/float_ops/fold.rs | 4 ++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 941946df2..8231b8114 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -220,6 +220,7 @@ mod test { use super::*; + /// int to constant fn i2c(b: u64) -> Const { Const::new( ConstIntU::new(5, b).unwrap().into(), @@ -228,6 +229,7 @@ mod test { .unwrap() } + /// float to constant fn f2c(f: f64) -> Const { ConstF64::new(f).into() } @@ -239,7 +241,7 @@ mod test { // c = a + b fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) { let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))]; - let add_op: OpType = IntOpDef::iadd.with_width(6).into(); + let add_op: OpType = IntOpDef::iadd.with_width(5).into(); let out = fold_const(&add_op, &consts).unwrap(); assert_eq!(&out[..], &[(0.into(), i2c(c))]); @@ -247,6 +249,10 @@ mod test { #[test] fn test_fold() { + /* + Test hugr calculates + 1 + 2 == 3 + */ let mut b = DFGBuilder::new(FunctionType::new( type_row![], vec![INT_TYPES[5].to_owned()], @@ -283,9 +289,9 @@ mod test { #[test] fn test_big() { /* - Test approximately calculates - let x = (5.6, 3.2); - x.0 - x.1 == 2.4 + Test hugr approximately calculates + let x = (5.5, 3.25); + x.0 - x.1 == 2.25 */ let mut build = DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs index ee3c4f7ed..29cf4b5a9 100644 --- a/src/extension/const_fold.rs +++ b/src/extension/const_fold.rs @@ -37,6 +37,8 @@ impl Default for Box { } } +/// Blanket implementation for functions that only require the constants to +/// evaluate - type arguments are not relevant. impl ConstFold for T where T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync, diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/fold.rs index f56814d22..34d162f4d 100644 --- a/src/std_extensions/arithmetic/float_ops/fold.rs +++ b/src/std_extensions/arithmetic/float_ops/fold.rs @@ -18,6 +18,7 @@ pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { } } +/// Extract float values from constants in port order. fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[f64; N]> { let consts: [&ops::Const; N] = sorted_consts(consts).try_into().ok()?; @@ -29,6 +30,7 @@ fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[ })) } +/// Fold binary operations struct BinaryFold(Box f64 + Send + Sync>); impl BinaryFold { fn from_op(op: &FloatOps) -> Self { @@ -57,6 +59,7 @@ impl ConstFold for BinaryFold { } } +/// Fold comparisons. struct CmpFold(Box bool + Send + Sync>); impl CmpFold { fn from_op(op: FloatOps) -> Self { @@ -93,6 +96,7 @@ impl ConstFold for CmpFold { } } +/// Fold unary operations struct UnaryFold(Box f64 + Send + Sync>); impl UnaryFold { fn from_op(op: &FloatOps) -> Self { From 7381432a537adb295f0c7b8ecd5410464f8ab1ee Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 14:58:01 +0000 Subject: [PATCH 31/40] remove integer folding --- src/algorithm/const_fold.rs | 62 +++---------------- src/std_extensions/arithmetic/int_ops.rs | 5 -- src/std_extensions/arithmetic/int_ops/fold.rs | 38 ------------ 3 files changed, 7 insertions(+), 98 deletions(-) delete mode 100644 src/std_extensions/arithmetic/int_ops/fold.rs diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 8231b8114..5e9ab992e 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -220,70 +220,22 @@ mod test { use super::*; - /// int to constant - fn i2c(b: u64) -> Const { - Const::new( - ConstIntU::new(5, b).unwrap().into(), - INT_TYPES[5].to_owned(), - ) - .unwrap() - } - /// float to constant fn f2c(f: f64) -> Const { ConstF64::new(f).into() } #[rstest] - #[case(0, 0, 0)] - #[case(0, 1, 1)] - #[case(23, 435, 458)] + #[case(0.0, 0.0, 0.0)] + #[case(0.0, 1.0, 1.0)] + #[case(23.5, 435.5, 459.0)] // c = a + b - fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) { - let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))]; - let add_op: OpType = IntOpDef::iadd.with_width(5).into(); + fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; + let add_op: OpType = FloatOps::fadd.into(); let out = fold_const(&add_op, &consts).unwrap(); - assert_eq!(&out[..], &[(0.into(), i2c(c))]); - } - - #[test] - fn test_fold() { - /* - Test hugr calculates - 1 + 2 == 3 - */ - let mut b = DFGBuilder::new(FunctionType::new( - type_row![], - vec![INT_TYPES[5].to_owned()], - )) - .unwrap(); - - let one = b.add_load_const(i2c(1)).unwrap(); - let two = b.add_load_const(i2c(2)).unwrap(); - - let add = b - .add_dataflow_op(IntOpDef::iadd.with_width(5), [one, two]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::int_ops::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = b.finish_hugr_with_outputs(add.outputs(), ®).unwrap(); - assert_eq!(h.node_count(), 8); - - let (repl, removes) = find_consts(&h, h.nodes(), ®).exactly_one().ok().unwrap(); - let [remove_1, remove_2] = removes.try_into().unwrap(); - - h.apply_rewrite(repl).unwrap(); - for rem in [remove_1, remove_2] { - let const_node = h.apply_rewrite(rem).unwrap(); - h.apply_rewrite(RemoveConst(const_node)).unwrap(); - } - - assert_fully_folded(&h, &i2c(3)); + assert_eq!(&out[..], &[(0.into(), f2c(c))]); } #[test] diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 6b68782d8..e1cea6c49 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -22,7 +22,6 @@ use lazy_static::lazy_static; use smol_str::SmolStr; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int"); @@ -217,10 +216,6 @@ impl MakeOpDef for IntOpDef { (rightmost bits replace leftmost bits)", }.into() } - - fn post_opdef(&self, def: &mut OpDef) { - fold::set_fold(self, def) - } } fn int_polytype( n_vars: usize, diff --git a/src/std_extensions/arithmetic/int_ops/fold.rs b/src/std_extensions/arithmetic/int_ops/fold.rs deleted file mode 100644 index b092d9a29..000000000 --- a/src/std_extensions/arithmetic/int_ops/fold.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::{ - extension::{ConstFoldResult, OpDef}, - ops, - std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, - IncomingPort, -}; - -use super::IntOpDef; - -pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { - match op { - IntOpDef::iadd => def.set_constant_folder(iadd_fold), - _ => (), - } -} - -// TODO get width from const -fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { - let width = 5; - match consts { - [(_, c1), (_, c2)] => { - let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap()); - - Some(vec![( - 0.into(), - ops::Const::new( - ConstIntU::new(width, c1.value() + c2.value()) - .unwrap() - .into(), - INT_TYPES[5].to_owned(), - ) - .unwrap(), - )]) - } - - _ => None, - } -} From 0e0411fd5725f9028e8c569c6078063480c4e0cf Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 15:07:41 +0000 Subject: [PATCH 32/40] remove unused imports --- src/algorithm/const_fold.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 5e9ab992e..b2a84947a 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -19,6 +19,7 @@ use crate::{ Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; +/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. fn out_row(consts: impl IntoIterator) -> ConstFoldResult { let vec = consts .into_iter() @@ -28,12 +29,14 @@ fn out_row(consts: impl IntoIterator) -> ConstFoldResult { Some(vec) } +/// Sort folding inputs with [`IncomingPort`] as key fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { let mut v: Vec<_> = consts.iter().collect(); v.sort_by_key(|(i, _)| i); v } +/// Sort some input constants by port and just return the constants. pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { sort_by_in_port(consts) .into_iter() @@ -60,7 +63,7 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes })); } } - None + None // could panic } LeafOp::Tag { tag, variants } => out_row([Const::new( @@ -77,6 +80,8 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes } } +/// Generate a graph that loads and outputs `consts` in order, validating +/// against `reg`. fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); @@ -206,15 +211,10 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { mod test { use crate::extension::{ExtensionRegistry, PRELUDE}; - use crate::hugr::rewrite::consts::RemoveConst; - - use crate::hugr::HugrMut; use crate::std_extensions::arithmetic; use crate::std_extensions::arithmetic::float_ops::FloatOps; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - use crate::std_extensions::arithmetic::int_ops::IntOpDef; - use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; use rstest::rstest; From 8e88f3eb14236cfb87d2ef42798a6fd062357b0e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 22 Dec 2023 15:17:10 +0000 Subject: [PATCH 33/40] add docstrings and simplify --- src/algorithm/const_fold.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index b2a84947a..49474c43e 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -105,11 +105,13 @@ pub fn find_consts<'a, 'r: 'a>( candidate_nodes: impl IntoIterator + 'a, reg: &'r ExtensionRegistry, ) -> impl Iterator)> + 'a { + // track nodes for operations that have already been considered for folding let mut used_neighbours = BTreeSet::new(); candidate_nodes .into_iter() .filter_map(move |n| { + // only look at LoadConstant hugr.get_optype(n).is_load_constant().then_some(())?; let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; @@ -118,6 +120,7 @@ pub fn find_consts<'a, 'r: 'a>( .filter(|(n, _)| used_neighbours.insert(*n)) .collect_vec(); if neighbours.is_empty() { + // no uses of LoadConstant that haven't already been considered. return None; } let fold_iter = neighbours @@ -128,6 +131,7 @@ pub fn find_consts<'a, 'r: 'a>( .flatten() } +/// Attempt to evaluate and generate rewrites for the operation at `op_node` fn fold_op( hugr: &impl HugrView, op_node: Node, @@ -135,9 +139,13 @@ fn fold_op( ) -> Option<(SimpleReplacement, Vec)> { let (in_consts, removals): (Vec<_>, Vec<_>) = hugr .node_inputs(op_node) - .filter_map(|in_p| get_const(hugr, op_node, in_p)) + .filter_map(|in_p| { + let (con_op, load_n) = get_const(hugr, op_node, in_p)?; + Some(((in_p, con_op), RemoveConstIgnore(load_n))) + }) .unzip(); let neighbour_op = hugr.get_optype(op_node); + // attempt to evaluate op let folded = fold_const(neighbour_op, &in_consts)?; let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); let nu_out = op_outs @@ -152,7 +160,7 @@ fn fold_op( .collect(); let replacement = const_graph(consts, reg); let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Load consts and operation should form valid subgraph."); + .expect("Operation should form valid subgraph."); let simple_replace = SimpleReplacement::new( sibling_graph, @@ -164,11 +172,9 @@ fn fold_op( Some((simple_replace, removals)) } -fn get_const( - hugr: &impl HugrView, - op_node: Node, - in_p: IncomingPort, -) -> Option<((IncomingPort, Const), RemoveConstIgnore)> { +/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant +/// and the LoadConstant node +fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> { let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; let load_op = hugr.get_optype(load_n).as_load_constant()?; let const_node = hugr @@ -180,7 +186,7 @@ fn get_const( let const_op = hugr.get_optype(const_node).as_const()?; // TODO avoid const clone here - Some(((in_p, const_op.clone()), RemoveConstIgnore(load_n))) + Some((const_op.clone(), load_n)) } /// Exhaustively apply constant folding to a HUGR. From 4bca931a49f73140dcd89f43da659ee409f33187 Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Wed, 3 Jan 2024 09:10:07 +0000 Subject: [PATCH 34/40] docs: Spec clarifications (#738) A couple of small clarifications suggested by Will. --- specification/hugr.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/specification/hugr.md b/specification/hugr.md index 3c6517dcf..ce989f495 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -1234,7 +1234,9 @@ Note that considering all three $\mu$ lists together, - the `TgtNode` + `TgtPos`s of all `NewEdgeSpec`s with `EdgeKind` == `Value` will be unique - and similarly for `EdgeKind` == `Static` -The well-formedness requirements of Hugr imply that $\mu\_\textrm{inp}$ and $\mu\_\textrm{out}$ may only contain `NewEdgeSpec`s with certain `EdgeKind`s, depending on $P$: +The well-formedness requirements of Hugr imply that $\mu\_\textrm{inp}$, +$\mu\_\textrm{out}$ and $\mu\_\textrm{new}$ may only contain `NewEdgeSpec`s with +certain `EdgeKind`s, depending on $P$: - if $P$ is a dataflow container, `EdgeKind`s may be `Order`, `Value` or `Static` only (no `ControlFlow`) - if $P$ is a CFG node, `EdgeKind`s may be `ControlFlow`, `Value`, or `Static` only (no `Order`) - if $P$ is a Module node, there may be `Value` or `Static` only (no `Order`). @@ -1262,7 +1264,8 @@ The new hugr is then derived as follows: 6. For each node $(n, b = B(n))$ and for each child $m$ of $b$, replace the hierarchy edge from $b$ to $m$ with a hierarchy edge from the new copy of $n$ to $m$ (preserving the order). -7. Remove all nodes in $R$ and edges adjoining them. +7. Remove all nodes in $R$ and edges adjoining them. (Reindexing may be + necessary after this step.) ##### Outlining methods From 3193cdbeb12a2ea569b1f3d5ada2d04e4a13df49 Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Wed, 3 Jan 2024 09:11:44 +0000 Subject: [PATCH 35/40] docs: Spec updates (#741) Following Will's review. --- specification/hugr.md | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/specification/hugr.md b/specification/hugr.md index ce989f495..5fd65cd3b 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -1191,7 +1191,19 @@ The new hugr is then derived as follows: ###### `Replace` -This is the general subgraph-replacement method. +This is the general subgraph-replacement method. Intuitively, it takes a set of +sibling nodes to remove and replace with a new set of nodes. The new set of +nodes is itself a HUGR with some "holes" (edges and nodes that get "filled in" +by the `Replace` operation). To fully specify the operation, some further data +are needed: + + - The replacement may include container nodes with no children, which adopt + the children of removed container nodes and prevent those children being + removed. + - All new incoming edges from the retained nodes to the new nodes, all + outgoing edges from the new nodes to the retained nodes, and any new edges + that bypass the replacement (going between retained nodes) must be + specified. Given a set $S$ of nodes in a hugr, let $S^\*$ be the set of all nodes descended from nodes in $S$ (i.e. reachable from $S$ by following hierarchy edges), @@ -1328,8 +1340,8 @@ successor. Insert an Order edge from `n0` to `n1` where `n0` and `n1` are distinct siblings in a DSG such that there is no path in the DSG from `n1` to -`n0`. If there is already an order edge from `n0` to `n1` this does -nothing (but is not an error). +`n0`. (Thus acyclicity is preserved.) If there is already an order edge from +`n0` to `n1` this does nothing (but is not an error). ###### `RemoveOrder` @@ -1374,7 +1386,7 @@ nodes. The most basic case – replacing a convex set of Op nodes in a DSG with another graph of Op nodes having the same signature – is implemented by -having T map everything to the parent node, and bot(G) is empty. +`SimpleReplace`. If one of the nodes in the region is a complex container node that we wish to preserve in the replacement without doing a deep copy, we can From d0513c43670e2ff43c3a113b9809cf20d6a2bf54 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 3 Jan 2024 10:54:04 +0000 Subject: [PATCH 36/40] docs: [spec] Remove references to causal cone and Order edges from Input (#762) I believe the question of causal cones was definitively resolved by #468, and the codebase seems clear on that point - however, I just stumbled across this I-believe-erroneous reference in the spec. (FWIW, searching through the spec for "causal" has only one other occurrence, which appears to explicitly say we don't need this rule.) --- specification/hugr.md | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/specification/hugr.md b/specification/hugr.md index 5fd65cd3b..fc5f6f65b 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -267,19 +267,13 @@ the following basic dataflow operations are available (in addition to the - `Input/Output`: input/output nodes, the outputs of `Input` node are the inputs to the function, and the inputs to `Output` are the - outputs of the function. In a data dependency subgraph, a valid - ordering of operations can be achieved by topologically sorting the - nodes starting from `Input` with respect to the Value and Order - edges. + outputs of the function. - `Call`: Call a statically defined function. There is an incoming `Static` edge to specify the graph being called. The signature of the node (defined by its incoming and outgoing `Value` edges) matches the function being called. - `LoadConstant`: has an incoming `Static` edge, where `T` is a `CopyableType`, and a `Value` output, used to load a static constant into the local - dataflow graph. They also have an incoming `Order` edge connecting - them to the `Input` node, as should all operations that - take no dataflow input, to ensure they lie in the causal cone of the - `Input` node when traversing. + dataflow graph. - `identity`: pass-through, no operation is performed. - `DFG`: A nested dataflow graph. These nodes are parents in the hierarchy. @@ -515,10 +509,11 @@ graph: cycles. The common parent is a CFG-node. **Dataflow Sibling Graph (DSG)**: nodes are operations, `CFG`, -`Conditional`, `TailLoop` and `DFG` nodes; edges are `Value`, `Order` and `Static`; -and must be acyclic. There is a unique Input node and Output node. All nodes must be -reachable from the Input node, and must reach the Output node. The common parent -may be a `FuncDefn`, `TailLoop`, `DFG`, `Case` or `DFB` node. +`Conditional`, `TailLoop` and `DFG` nodes; edges are `Value`, `Order` and `Static`, and must be acyclic. +(Thus a valid ordering of operations can be achieved by topologically sorting the +nodes.) +There is a unique Input node and Output node. +The common parent may be a `FuncDefn`, `TailLoop`, `DFG`, `Case` or `DFB` node. | **Edge Kind** | **Locality** | | -------------- | ------------ | @@ -1355,8 +1350,7 @@ remove it. (If there is an non-local edge from `n0` to a descendent of Given a `Const` node `c`, and optionally `P`, a parent of a DSG, add a new `LoadConstant` node `n` as a child of `P` with a `Static` edge -from `c` to `n` and no outgoing edges from `n`. Also add an Order edge -from the Input node under `P` to `n`. Return the ID of `n`. If `P` is +from `c` to `n` and no outgoing edges from `n`. Return the ID of `n`. If `P` is omitted it defaults to the parent of `c` (in this case said `c` will have to be in a DSG or CSG rather than under the Module Root.) If `P` is provided, it must be a descendent of the parent of `c`. @@ -1364,7 +1358,7 @@ provided, it must be a descendent of the parent of `c`. ###### `RemoveConstIgnore` Given a `LoadConstant` node `n` that has no outgoing edges, remove -it (and its incoming value and Order edges) from the hugr. +it (and its incoming Static edge and any Order edges) from the hugr. ##### Insertion and removal of const nodes From 89f1827567fde1ecd308cafcb74dc6d425d6040b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:10:21 +0000 Subject: [PATCH 37/40] chore: remove rustversion (#764) and devenv update to 1.75 stable --- Cargo.toml | 2 -- devenv.lock | 36 ++++++++++++++++++------------------ devenv.nix | 2 +- src/hugr/views.rs | 14 +++----------- src/hugr/views/tests.rs | 4 ---- 5 files changed, 22 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ea4861062..c633a1222 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,6 @@ petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" delegate = "0.12.0" -rustversion = "1.0.14" paste = "1.0" strum = "0.25.0" strum_macros = "0.25.3" @@ -68,4 +67,3 @@ harness = false [profile.dev.package] insta.opt-level = 3 -similar.opt-level = 3 diff --git a/devenv.lock b/devenv.lock index 3d6e2de22..7750dfcfa 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,11 +3,11 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1700140236, - "narHash": "sha256-OpukFO0rRG2hJzD+pCQq+nSWuT9dBL6DSvADQaUlmFg=", + "lastModified": 1703939110, + "narHash": "sha256-GgjYWkkHQ8pUBwXX++ah+4d07DqOeCDaaQL6Ab86C50=", "owner": "cachix", "repo": "devenv", - "rev": "525d60c44de848a6b2dd468f6efddff078eb2af2", + "rev": "7354096fc026f79645fdac73e9aeea71a09412c3", "type": "github" }, "original": { @@ -25,11 +25,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1700461394, - "narHash": "sha256-lBpjEshdBxeuJwc4+vh4jbO3AmhXbiFrkdWy2pABAAc=", + "lastModified": 1704262971, + "narHash": "sha256-3HB1yaMBBox3z9oXEiQuZzQhXegOc9P3FR6/XNsJGn0=", "owner": "nix-community", "repo": "fenix", - "rev": "5ad1b10123ca40c9d983fb0863403fd97a06c0f8", + "rev": "38aaea4e54dc3874a6355c10861bd8316a6f09f3", "type": "github" }, "original": { @@ -95,11 +95,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1700444282, - "narHash": "sha256-s/+tgT+Iz0LZO+nBvSms+xsMqvHt2LqYniG9r+CYyJc=", + "lastModified": 1704008649, + "narHash": "sha256-rGPSWjXTXTurQN9beuHdyJhB8O761w1Zc5BqSSmHvoM=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "3f21a22b5aafefa1845dec6f4a378a8f53d8681c", + "rev": "d44d59d2b5bd694cd9d996fd8c51d03e3e9ba7f7", "type": "github" }, "original": { @@ -111,11 +111,11 @@ }, "nixpkgs-stable": { "locked": { - "lastModified": 1700403855, - "narHash": "sha256-Q0Uzjik9kUTN9pd/kp52XJi5kletBhy29ctBlAG+III=", + "lastModified": 1704018918, + "narHash": "sha256-erjg/HrpC9liEfm7oLqb8GXCqsxaFwIIPqCsknW5aFY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "0c5678df521e1407884205fe3ce3cf1d7df297db", + "rev": "2c9c58e98243930f8cb70387934daa4bc8b00373", "type": "github" }, "original": { @@ -152,11 +152,11 @@ "nixpkgs-stable": "nixpkgs-stable_2" }, "locked": { - "lastModified": 1700064067, - "narHash": "sha256-1ZWNDzhu8UlVCK7+DUN9dVQfiHX1bv6OQP9VxstY/gs=", + "lastModified": 1703939133, + "narHash": "sha256-Gxe+mfOT6bL7wLC/tuT2F+V+Sb44jNr8YsJ3cyIl4Mo=", "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "e558068cba67b23b4fbc5537173dbb43748a17e8", + "rev": "9d3d7e18c6bc4473d7520200d4ddab12f8402d38", "type": "github" }, "original": { @@ -177,11 +177,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1700247620, - "narHash": "sha256-+Xg0qZLbC9dZx0Z6JbaVHR/BklAr2I83dzKLB8r41c8=", + "lastModified": 1704207973, + "narHash": "sha256-VEWsjIKtdinx5iyhfxuTHRijYBKSbO/8Gw1HPoWD9mQ=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "255eed40c45fcf108ba844b4ad126bdc4e7a18df", + "rev": "426d2842c1f0e5cc5e34bb37c7ac3ee0945f9746", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index f4a85aca5..bc9fdba1c 100644 --- a/devenv.nix +++ b/devenv.nix @@ -31,7 +31,7 @@ in # https://devenv.sh/languages/ # https://devenv.sh/reference/options/#languagesrustversion languages.rust = { - channel = "beta"; + channel = "stable"; enable = true; components = [ "rustc" "cargo" "clippy" "rustfmt" "rust-analyzer" ]; }; diff --git a/src/hugr/views.rs b/src/hugr/views.rs index b9f798370..7e9e2e3c6 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -25,11 +25,11 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE}; use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; -#[rustversion::since(1.75)] // uses impl in return position + use crate::types::Type; use crate::types::{EdgeKind, FunctionType, PolyFuncType}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; -#[rustversion::since(1.75)] // uses impl in return position + use itertools::Either; /// A trait for inspecting HUGRs. @@ -183,7 +183,6 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over the nodes and ports connected to a port. fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node in a given direction. fn all_linked_ports( &self, @@ -205,7 +204,6 @@ pub trait HugrView: sealed::HugrInternals { } } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node's inputs. fn all_linked_outputs(&self, node: Node) -> impl Iterator { self.all_linked_ports(node, Direction::Incoming) @@ -213,7 +211,6 @@ pub trait HugrView: sealed::HugrInternals { .unwrap() } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node's outputs. fn all_linked_inputs(&self, node: Node) -> impl Iterator { self.all_linked_ports(node, Direction::Outgoing) @@ -411,7 +408,6 @@ pub trait HugrView: sealed::HugrInternals { .map(|(n, _)| n) } - #[rustversion::since(1.75)] // uses impl in return position /// If a node has a static output, return the targets. fn static_targets(&self, node: Node) -> Option> { Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?)) @@ -423,7 +419,6 @@ pub trait HugrView: sealed::HugrInternals { self.get_optype(node).dataflow_signature() } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all outgoing ports that have Value type, along /// with corresponding types. fn value_types(&self, node: Node, dir: Direction) -> impl Iterator { @@ -432,7 +427,6 @@ pub trait HugrView: sealed::HugrInternals { .flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone()))) } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all incoming ports that have Value type, along /// with corresponding types. fn in_value_types(&self, node: Node) -> impl Iterator { @@ -440,7 +434,6 @@ pub trait HugrView: sealed::HugrInternals { .map(|(p, t)| (p.as_incoming().unwrap(), t)) } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all incoming ports that have Value type, along /// with corresponding types. fn out_value_types(&self, node: Node) -> impl Iterator { @@ -618,7 +611,6 @@ impl> HugrView for T { } } -#[rustversion::since(1.75)] // uses impl in return position /// Trait implementing methods on port iterators. pub trait PortIterator

: Iterator where @@ -636,7 +628,7 @@ where }) } } -#[rustversion::since(1.75)] // uses impl in return position + impl PortIterator

for I where I: Iterator, diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 9a712dc03..ce0353d48 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -69,7 +69,6 @@ fn dot_string(sample_hugr: (Hugr, BuildHandle, BuildHandle, BuildHandle)) { use itertools::Itertools; @@ -97,7 +96,6 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle Date: Wed, 3 Jan 2024 11:12:51 +0000 Subject: [PATCH 38/40] ci: Setup release-plz and related files (#765) - Setups automatic changelog generation on merges to `main` - Updates the README with the appropriate links (assuming the crate is already released) --- .github/workflows/release-plz.yml | 28 ++++++++++++ README.md | 17 +++++-- cliff.toml | 73 +++++++++++++++++++++++++++++++ release-plz.toml | 2 + 4 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/release-plz.yml create mode 100644 cliff.toml create mode 100644 release-plz.toml diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml new file mode 100644 index 000000000..436bb2f4d --- /dev/null +++ b/.github/workflows/release-plz.yml @@ -0,0 +1,28 @@ +name: Release-plz + +permissions: + pull-requests: write + contents: write + +on: + push: + branches: + - main + +jobs: + release-plz: + name: Release-plz + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + - name: Run release-plz + uses: MarcoIeni/release-plz-action@v0.5 + with: + command: release-pr + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index 14adf32ed..4428eac1b 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ quantinuum-hugr =============== [![build_status][]](https://github.com/CQCL/hugr/actions) +[![crates][]](https://crates.io/crates/quantinuum-hugr) [![msrv][]](https://github.com/CQCL/hugr) [![codecov][]](https://codecov.io/gh/CQCL/hugr) @@ -16,15 +17,21 @@ The HUGR specification is [here](specification/hugr.md). ## Usage -Add this to your `Cargo.toml`: +Add the dependency to your project: -```toml -[dependencies] -quantinuum-hugr = "0.1" +```bash +cargo add quantinuum-hugr ``` The library crate is called `hugr`. +Please read the [API documentation here][]. + +## Recent Changes + +See [CHANGELOG][] for a list of changes. The minimum supported rust +version will only change on major releases. + ## Development See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the development environment. @@ -33,7 +40,9 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). + [API documentation here]: https://docs.rs/quantinuum-hugr/ [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE + [CHANGELOG]: CHANGELOG.md diff --git a/cliff.toml b/cliff.toml new file mode 100644 index 000000000..8c3dba1ca --- /dev/null +++ b/cliff.toml @@ -0,0 +1,73 @@ +# git-cliff ~ default configuration file +# https://git-cliff.org/docs/configuration +# +# Lines starting with "#" are comments. +# Configuration options are organized into tables and keys. +# See documentation for more information on available options. + +[changelog] +# changelog header +header = """ +# Changelog\n +""" +# template for the changelog body +# https://tera.netlify.app/docs +body = """ +{% if version %}\ + ## {{ version }} ({{ timestamp | date(format="%Y-%m-%d") }}) +{% else %}\ + ## Unreleased (XXXX-XX-XX) +{% endif %}\ +{% for group, commits in commits | group_by(attribute="group") %} + ### {{ group | upper_first }} + {% for commit in commits %} + - {% if commit.breaking %}[**breaking**] {% endif %}{{ commit.message | upper_first }}\ + {% endfor %} +{% endfor %}\n +""" +# remove the leading and trailing whitespace from the template +trim = true +# changelog footer +footer = "" + +[git] +# parse the commits based on https://www.conventionalcommits.org +conventional_commits = true +# filter out the commits that are not conventional +filter_unconventional = true +# process each line of a commit as an individual commit +split_commits = false +# regex for preprocessing the commit messages +commit_preprocessors = [ + { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/CQCL/portgraph/issues/${2}))"}, # replace issue numbers +] +# regex for parsing and grouping commits +commit_parsers = [ + { message = "^feat", group = "Features" }, + { message = "^fix", group = "Bug Fixes" }, + { message = "^docs", group = "Documentation" }, + { message = "^style", group = "Styling" }, + { message = "^refactor", group = "Refactor" }, + { message = "^perf", group = "Performance" }, + { message = "^test", group = "Testing" }, + { message = "^chore\\(release\\): prepare for", skip = true }, + { message = "^chore", group = "Miscellaneous Tasks", skip = true }, + { message = "^revert", group = "Reverted changes", skip = true }, + { message = "^ci", group = "CI", skip = true }, +] +# protect breaking changes from being skipped due to matching a skipping commit_parser +protect_breaking_commits = true +# filter out the commits that are not matched by commit parsers +filter_commits = false +# glob pattern for matching git tags +tag_pattern = "v[0-9.]*" +# regex for skipping tags +skip_tags = "v0.1.0-beta.1" +# regex for ignoring tags +ignore_tags = "" +# sort the tags topologically +topo_order = false +# sort the commits inside sections by oldest-first/newest-first +sort_commits = "oldest" +# limit the number of commits included in the changelog. +# limit_commits = 42 diff --git a/release-plz.toml b/release-plz.toml new file mode 100644 index 000000000..0d7dd7e3e --- /dev/null +++ b/release-plz.toml @@ -0,0 +1,2 @@ +[workspace] +changelog_config = "cliff.toml" # use a custom git-cliff configuration From 9500803890fd34108747b23bb250f75153b81b46 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:36:01 +0000 Subject: [PATCH 39/40] feat: implement RemoveConst and RemoveConstIgnore (#757) as per spec refactor!: allow Into for builder.add_const BREAKING_CHANGES: existing CustomConst.into() calls will error --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Alan Lawrence Co-authored-by: Luca Mondada <72734770+lmondada@users.noreply.github.com> Co-authored-by: Luca Mondada --- src/hugr/rewrite/consts.rs | 56 +++++++++++++------------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs index d8a0060b7..6ac32a65a 100644 --- a/src/hugr/rewrite/consts.rs +++ b/src/hugr/rewrite/consts.rs @@ -6,36 +6,31 @@ use crate::{ hugr::{HugrError, HugrMut}, HugrView, Node, }; -#[rustversion::since(1.75)] // uses impl in return position use itertools::Itertools; use thiserror::Error; use super::Rewrite; -/// Remove a [`crate::ops::LoadConstant`] node with no outputs. +/// Remove a [`crate::ops::LoadConstant`] node with no consumers. #[derive(Debug, Clone)] pub struct RemoveConstIgnore(pub Node); -/// Error from an [`RemoveConstIgnore`] operation. +/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum RemoveConstIgnoreError { +pub enum RemoveError { /// Invalid node. - #[error("Node is invalid (either not in HUGR or not LoadConst).")] + #[error("Node is invalid (either not in HUGR or not correct operation).")] InvalidNode(Node), /// Node in use. #[error("Node: {0:?} has non-zero outgoing connections.")] ValueUsed(Node), - /// Not connected to a Const. - #[error("Node: {0:?} is not connected to a Const node.")] - NoConst(Node), /// Removal error #[error("Removing node caused error: {0:?}.")] RemoveFail(#[from] HugrError), } -#[rustversion::since(1.75)] // uses impl in return position impl Rewrite for RemoveConstIgnore { - type Error = RemoveConstIgnoreError; + type Error = RemoveError; // The Const node the LoadConstant was connected to. type ApplyResult = Node; @@ -48,14 +43,14 @@ impl Rewrite for RemoveConstIgnore { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { - return Err(RemoveConstIgnoreError::InvalidNode(node)); + return Err(RemoveError::InvalidNode(node)); } if h.out_value_types(node) .next() .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) { - return Err(RemoveConstIgnoreError::ValueUsed(node)); + return Err(RemoveError::ValueUsed(node)); } Ok(()) @@ -67,7 +62,8 @@ impl Rewrite for RemoveConstIgnore { let source = h .input_neighbours(node) .exactly_one() - .map_err(|_| RemoveConstIgnoreError::NoConst(node))?; + .ok() + .expect("Validation should check a Const is connected to LoadConstant."); h.remove_node(node)?; Ok(source) @@ -82,22 +78,8 @@ impl Rewrite for RemoveConstIgnore { #[derive(Debug, Clone)] pub struct RemoveConst(pub Node); -/// Error from an [`RemoveConst`] operation. -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum RemoveConstError { - /// Invalid node. - #[error("Node is invalid (either not in HUGR or not Const).")] - InvalidNode(Node), - /// Node in use. - #[error("Node: {0:?} has non-zero outgoing connections.")] - ValueUsed(Node), - /// Removal error - #[error("Removing node caused error: {0:?}.")] - RemoveFail(#[from] HugrError), -} - impl Rewrite for RemoveConst { - type Error = RemoveConstError; + type Error = RemoveError; // The parent of the Const node. type ApplyResult = Node; @@ -110,11 +92,11 @@ impl Rewrite for RemoveConst { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { - return Err(RemoveConstError::InvalidNode(node)); + return Err(RemoveError::InvalidNode(node)); } if h.output_neighbours(node).next().is_some() { - return Err(RemoveConstError::ValueUsed(node)); + return Err(RemoveError::ValueUsed(node)); } Ok(()) @@ -123,12 +105,12 @@ impl Rewrite for RemoveConst { fn apply(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; - let source = h + let parent = h .get_parent(node) .expect("Const node without a parent shouldn't happen."); h.remove_node(node)?; - Ok(source) + Ok(parent) } fn invalidation_set(&self) -> Self::InvalidationSet<'_> { @@ -136,7 +118,6 @@ impl Rewrite for RemoveConst { } } -#[rustversion::since(1.75)] // uses impl in return position #[cfg(test)] mod test { use super::*; @@ -169,17 +150,18 @@ mod test { dfg_build.finish_sub_container()?; let mut h = build.finish_prelude_hugr()?; + // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple assert_eq!(h.node_count(), 8); let tup_node = tup.node(); // can't remove invalid node assert_eq!( h.apply_rewrite(RemoveConst(tup_node)), - Err(RemoveConstError::InvalidNode(tup_node)) + Err(RemoveError::InvalidNode(tup_node)) ); assert_eq!( h.apply_rewrite(RemoveConstIgnore(tup_node)), - Err(RemoveConstIgnoreError::InvalidNode(tup_node)) + Err(RemoveError::InvalidNode(tup_node)) ); let load_1_node = load_1.node(); let load_2_node = load_2.node(); @@ -202,7 +184,7 @@ mod test { // can't remove nodes in use assert_eq!( h.apply_rewrite(remove_1.clone()), - Err(RemoveConstIgnoreError::ValueUsed(load_1_node)) + Err(RemoveError::ValueUsed(load_1_node)) ); // remove the use @@ -215,7 +197,7 @@ mod test { // still can't remove const, in use by second load assert_eq!( h.apply_rewrite(remove_con.clone()), - Err(RemoveConstError::ValueUsed(con_node)) + Err(RemoveError::ValueUsed(con_node)) ); // remove second use From 905ef0148360db85a84fc1ffc284bd5929440590 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:57:37 +0000 Subject: [PATCH 40/40] address minor review comments --- src/algorithm/const_fold.rs | 5 +++-- src/extension/op_def.rs | 1 + src/std_extensions/arithmetic/float_ops.rs | 4 ++-- .../arithmetic/float_ops/{fold.rs => const_fold.rs} | 0 4 files changed, 6 insertions(+), 4 deletions(-) rename src/std_extensions/arithmetic/float_ops/{fold.rs => const_fold.rs} (100%) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 49474c43e..e19de46df 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -63,7 +63,7 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes })); } } - None // could panic + panic!("This op always takes a Tuple input."); } LeafOp::Tag { tag, variants } => out_row([Const::new( @@ -99,7 +99,8 @@ fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { /// [`SimpleReplacement`] replaces an operation with constants that result from /// evaluating it, the extension registry `reg` is used to validate the /// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the -/// LoadConstant nodes that could be removed. +/// LoadConstant nodes that could be removed - they are not automatically +/// removed as they may be used by other operations. pub fn find_consts<'a, 'r: 'a>( hugr: &'a impl HugrView, candidate_nodes: impl IntoIterator + 'a, diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 2ea686ab4..875778e4f 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -309,6 +309,7 @@ pub struct OpDef { #[serde(flatten)] lower_funcs: Vec, + /// Operations can optionally implement [`ConstFold`] to implement constant folding. #[serde(skip)] constant_folder: Box, } diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 672b8a1ed..9b4438340 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; use lazy_static::lazy_static; -mod fold; +mod const_fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); @@ -84,7 +84,7 @@ impl MakeOpDef for FloatOps { } fn post_opdef(&self, def: &mut OpDef) { - fold::set_fold(self, def) + const_fold::set_fold(self, def) } } diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/const_fold.rs similarity index 100% rename from src/std_extensions/arithmetic/float_ops/fold.rs rename to src/std_extensions/arithmetic/float_ops/const_fold.rs