diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f895bc87c..5514dedee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: ['1.70', stable, beta, nightly] + rust: ['1.75', stable, beta, nightly] # workaround to ignore non-stable tests when running the merge queue checks # see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6 isMerge: - ${{ github.event_name == 'merge_group' }} exclude: - - rust: '1.70' + - rust: '1.75' isMerge: true - rust: beta isMerge: true diff --git a/Cargo.toml b/Cargo.toml index b1acc0bd4..ea4861062 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation" #categories = [] # TODO edition = "2021" -rust-version = "1.70" +rust-version = "1.75" [lib] # Using different names for the lib and for the package is supported, but may be confusing. @@ -46,7 +46,7 @@ lazy_static = "1.4.0" petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" -delegate = "0.11.0" +delegate = "0.12.0" rustversion = "1.0.14" paste = "1.0" strum = "0.25.0" diff --git a/README.md b/README.md index 91ab2c617..14adf32ed 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,6 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE diff --git a/src/algorithm.rs b/src/algorithm.rs index 0023b5916..633231504 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,4 +1,5 @@ //! Algorithms using the Hugr. +pub mod const_fold; mod half_node; pub mod nest_cfgs; diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs new file mode 100644 index 000000000..49474c43e --- /dev/null +++ b/src/algorithm/const_fold.rs @@ -0,0 +1,302 @@ +//! Constant folding routines. + +use std::collections::{BTreeSet, HashMap}; + +use itertools::Itertools; + +use crate::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ConstFoldResult, ExtensionRegistry}, + hugr::{ + rewrite::consts::{RemoveConst, RemoveConstIgnore}, + views::SiblingSubgraph, + HugrMut, + }, + ops::{Const, LeafOp, OpType}, + type_row, + types::{FunctionType, Type, TypeEnum}, + values::Value, + Hugr, HugrView, IncomingPort, Node, SimpleReplacement, +}; + +/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. +fn out_row(consts: impl IntoIterator) -> ConstFoldResult { + let vec = consts + .into_iter() + .enumerate() + .map(|(i, c)| (i.into(), c)) + .collect(); + Some(vec) +} + +/// Sort folding inputs with [`IncomingPort`] as key +fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { + let mut v: Vec<_> = consts.iter().collect(); + v.sort_by_key(|(i, _)| i); + v +} + +/// Sort some input constants by port and just return the constants. +pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { + sort_by_in_port(consts) + .into_iter() + .map(|(_, c)| c) + .collect() +} +/// For a given op and consts, attempt to evaluate the op. +pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { + let op = op.as_leaf_op()?; + + match op { + LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), + LeafOp::MakeTuple { .. } => { + out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())]) + } + LeafOp::UnpackTuple { .. } => { + let c = &consts.first()?.1; + + if let Value::Tuple { vs } = c.value() { + if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() { + return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| { + Const::new(v.clone(), t.clone()) + .expect("types should already have been checked") + })); + } + } + None // could panic + } + + LeafOp::Tag { tag, variants } => out_row([Const::new( + Value::sum(*tag, consts.first()?.1.value().clone()), + Type::new_sum(variants.clone()), + ) + .unwrap()]), + LeafOp::CustomOp(_) => { + let ext_op = op.as_extension_op()?; + + ext_op.constant_fold(consts) + } + _ => None, + } +} + +/// Generate a graph that loads and outputs `consts` in order, validating +/// against `reg`. +fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { + let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); + let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); + + let outputs = consts + .into_iter() + .map(|c| b.add_load_const(c).unwrap()) + .collect_vec(); + + b.finish_hugr_with_outputs(outputs, reg).unwrap() +} + +/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, +/// return an iterator of possible constant folding rewrites. The +/// [`SimpleReplacement`] replaces an operation with constants that result from +/// evaluating it, the extension registry `reg` is used to validate the +/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the +/// LoadConstant nodes that could be removed. +pub fn find_consts<'a, 'r: 'a>( + hugr: &'a impl HugrView, + candidate_nodes: impl IntoIterator + 'a, + reg: &'r ExtensionRegistry, +) -> impl Iterator)> + 'a { + // track nodes for operations that have already been considered for folding + let mut used_neighbours = BTreeSet::new(); + + candidate_nodes + .into_iter() + .filter_map(move |n| { + // only look at LoadConstant + hugr.get_optype(n).is_load_constant().then_some(())?; + + let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; + let neighbours = hugr + .linked_inputs(n, out_p) + .filter(|(n, _)| used_neighbours.insert(*n)) + .collect_vec(); + if neighbours.is_empty() { + // no uses of LoadConstant that haven't already been considered. + return None; + } + let fold_iter = neighbours + .into_iter() + .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); + Some(fold_iter) + }) + .flatten() +} + +/// Attempt to evaluate and generate rewrites for the operation at `op_node` +fn fold_op( + hugr: &impl HugrView, + op_node: Node, + reg: &ExtensionRegistry, +) -> Option<(SimpleReplacement, Vec)> { + let (in_consts, removals): (Vec<_>, Vec<_>) = hugr + .node_inputs(op_node) + .filter_map(|in_p| { + let (con_op, load_n) = get_const(hugr, op_node, in_p)?; + Some(((in_p, con_op), RemoveConstIgnore(load_n))) + }) + .unzip(); + let neighbour_op = hugr.get_optype(op_node); + // attempt to evaluate op + let folded = fold_const(neighbour_op, &in_consts)?; + let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); + let nu_out = op_outs + .into_iter() + .enumerate() + .filter_map(|(i, out)| { + // map from the ports the op was linked to, to the output ports of + // the replacement. + hugr.single_linked_input(op_node, out) + .map(|np| (np, i.into())) + }) + .collect(); + let replacement = const_graph(consts, reg); + let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) + .expect("Operation should form valid subgraph."); + + let simple_replace = SimpleReplacement::new( + sibling_graph, + replacement, + // no inputs to replacement + HashMap::new(), + nu_out, + ); + Some((simple_replace, removals)) +} + +/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant +/// and the LoadConstant node +fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> { + let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; + let load_op = hugr.get_optype(load_n).as_load_constant()?; + let const_node = hugr + .linked_outputs(load_n, load_op.constant_port()) + .exactly_one() + .ok()? + .0; + + let const_op = hugr.get_optype(const_node).as_const()?; + + // TODO avoid const clone here + Some((const_op.clone(), load_n)) +} + +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { + loop { + // would be preferable if the candidates were updated to be just the + // neighbouring nodes of those added. + let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); + if rewrites.is_empty() { + break; + } + for (replace, removes) in rewrites { + h.apply_rewrite(replace).unwrap(); + for rem in removes { + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + if h.apply_rewrite(RemoveConst(const_node)).is_err() { + // const cannot be removed - no problem + continue; + } + } + } + } + } +} + +#[cfg(test)] +mod test { + + use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::std_extensions::arithmetic; + + use crate::std_extensions::arithmetic::float_ops::FloatOps; + use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + + use rstest::rstest; + + use super::*; + + /// float to constant + fn f2c(f: f64) -> Const { + ConstF64::new(f).into() + } + + #[rstest] + #[case(0.0, 0.0, 0.0)] + #[case(0.0, 1.0, 1.0)] + #[case(23.5, 435.5, 459.0)] + // c = a + b + fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; + let add_op: OpType = FloatOps::fadd.into(); + let out = fold_const(&add_op, &consts).unwrap(); + + assert_eq!(&out[..], &[(0.into(), f2c(c))]); + } + + #[test] + fn test_big() { + /* + Test hugr approximately calculates + let x = (5.5, 3.25); + x.0 - x.1 == 2.25 + */ + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); + + let tup = build + .add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) + .unwrap(); + + let unpack = build + .add_dataflow_op( + LeafOp::UnpackTuple { + tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], + }, + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); + assert_eq!(h.node_count(), 7); + + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &f2c(2.25)); + } + fn assert_fully_folded(h: &Hugr, expected_const: &Const) { + // check the hugr just loads and returns a single const + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if c == expected_const => node_count += 1, + _ => panic!("unexpected op: {:?}", op), + } + } + + assert_eq!(node_count, 4); + } +} diff --git a/src/extension.rs b/src/extension.rs index 95b0474ea..5519e456b 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -28,9 +28,11 @@ pub use op_def::{ }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; +mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; +pub use const_fold::{ConstFold, ConstFoldResult}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs new file mode 100644 index 000000000..29cf4b5a9 --- /dev/null +++ b/src/extension/const_fold.rs @@ -0,0 +1,53 @@ +use std::fmt::Formatter; + +use std::fmt::Debug; + +use crate::types::TypeArg; + +use crate::OutgoingPort; + +use crate::ops; + +/// Output of constant folding an operation, None indicates folding was either +/// not possible or unsuccessful. An empty vector indicates folding was +/// successful and no values are output. +pub type ConstFoldResult = Option>; + +/// Trait implemented by extension operations that can perform constant folding. +pub trait ConstFold: Send + Sync { + /// Given type arguments `type_args` and + /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, + /// try to evaluate the operation. + fn fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult; +} + +impl Debug for Box { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl Default for Box { + fn default() -> Self { + Box::new(|&_: &_| None) + } +} + +/// Blanket implementation for functions that only require the constants to +/// evaluate - type arguments are not relevant. +impl ConstFold for T +where + T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync, +{ + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self(consts) + } +} diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 143426123..2ea686ab4 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -7,7 +7,8 @@ use std::sync::Arc; use smol_str::SmolStr; use super::{ - Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, + ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, + ExtensionSet, SignatureError, }; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; @@ -307,6 +308,9 @@ pub struct OpDef { // can only treat them as opaque/black-box ops. #[serde(flatten)] lower_funcs: Vec, + + #[serde(skip)] + constant_folder: Box, } impl OpDef { @@ -412,6 +416,22 @@ impl OpDef { ) -> Option { self.misc.insert(k.to_string(), v) } + + /// Set the constant folding function for this Op, which can evaluate it + /// given constant inputs. + pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { + self.constant_folder = Box::new(fold) + } + + /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given + /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. + pub fn constant_fold( + &self, + type_args: &[TypeArg], + consts: &[(crate::IncomingPort, crate::ops::Const)], + ) -> ConstFoldResult { + self.constant_folder.fold(type_args, consts) + } } impl Extension { @@ -432,6 +452,7 @@ impl Extension { signature_func: signature_func.into(), misc: Default::default(), lower_funcs: Default::default(), + constant_folder: Default::default(), }; match self.operations.entry(op.name.clone()) { diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index f96046ba8..b411c3667 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,12 +137,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } +/// The custom type for Errors. +pub const ERROR_CUSTOM_TYPE: CustomType = + CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( - ERROR_TYPE_NAME, - PRELUDE_ID, - TypeBound::Eq, -)); +pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -191,6 +190,48 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +/// Structure for holding constant usize values. +pub struct ConstError { + /// Integer tag/signal for the error. + pub signal: u32, + /// Error message. + pub message: String, +} + +impl ConstError { + /// Define a new error value. + pub fn new(signal: u32, message: impl ToString) -> Self { + Self { + signal, + message: message.to_string(), + } + } +} + +#[typetag::serde] +impl CustomConst for ConstError { + fn name(&self) -> SmolStr { + format!("ConstError({:?}, {:?})", self.signal, self.message).into() + } + + fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { + self.check_known_type(typ) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::values::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } +} + +impl KnownTypeConst for ConstError { + const TYPE: CustomType = ERROR_CUSTOM_TYPE; +} + #[cfg(test)] mod test { use crate::{ @@ -219,7 +260,7 @@ mod test { } #[test] - /// Test building a HUGR involving a new_array operation. + /// test the prelude error type. fn test_error_type() { let ext_def = PRELUDE .get_type(&ERROR_TYPE_NAME) @@ -229,5 +270,18 @@ mod test { let ext_type = Type::new_extension(ext_def); assert_eq!(ext_type, ERROR_TYPE); + + let error_val = ConstError::new(2, "my message"); + + assert_eq!(error_val.name(), "ConstError(2, \"my message\")"); + + assert!(error_val.check_custom_type(&ERROR_CUSTOM_TYPE).is_ok()); + + assert_eq!( + error_val.extension_reqs(), + ExtensionSet::singleton(&PRELUDE_ID) + ); + assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); + assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 9d2c0ab00..d4e4c88a6 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -4,11 +4,11 @@ use smol_str::SmolStr; use std::sync::Arc; use thiserror::Error; -use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; +use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; -use crate::{Hugr, Node}; +use crate::{ops, Hugr, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; use super::tag::OpTag; @@ -137,6 +137,11 @@ impl ExtensionOp { pub fn def(&self) -> &OpDef { self.def.as_ref() } + + /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. + pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { + self.def().constant_fold(self.args(), consts) + } } impl From for OpaqueOp { diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 98e5df887..23b457f7c 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -8,6 +8,7 @@ use crate::{ prelude::sum_with_error, simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, + PRELUDE, }, ops::{custom::ExtensionOp, OpName}, type_row, @@ -18,7 +19,6 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; - /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -69,7 +69,7 @@ impl MakeOpDef for ConvertOpDef { #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { def: ConvertOpDef, - width: u64, + log_width: u64, } impl OpName for ConvertOpType { @@ -85,11 +85,14 @@ impl MakeExtensionOp for ConvertOpType { [TypeArg::BoundedNat { n }] => n, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(Self { def, width }) + Ok(Self { + def, + log_width: width, + }) } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.width }] + vec![TypeArg::BoundedNat { n: self.log_width }] } } @@ -111,6 +114,7 @@ lazy_static! { /// Registry of extensions required to validate integer operations. pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), super::int_types::EXTENSION.to_owned(), super::float_types::EXTENSION.to_owned(), EXTENSION.to_owned(), diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 87c87751b..672b8a1ed 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; use lazy_static::lazy_static; - +mod fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); @@ -82,6 +82,10 @@ impl MakeOpDef for FloatOps { } .to_string() } + + fn post_opdef(&self, def: &mut OpDef) { + fold::set_fold(self, def) + } } lazy_static! { diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/fold.rs new file mode 100644 index 000000000..34d162f4d --- /dev/null +++ b/src/std_extensions/arithmetic/float_ops/fold.rs @@ -0,0 +1,124 @@ +use crate::{ + algorithm::const_fold::sorted_consts, + extension::{ConstFold, ConstFoldResult, OpDef}, + ops, + std_extensions::arithmetic::float_types::ConstF64, + IncomingPort, +}; + +use super::FloatOps; + +pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { + use FloatOps::*; + + match op { + fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)), + feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)), + fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)), + } +} + +/// Extract float values from constants in port order. +fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[f64; N]> { + let consts: [&ops::Const; N] = sorted_consts(consts).try_into().ok()?; + + Some(consts.map(|c| { + let const_f64: &ConstF64 = c + .get_custom_value() + .expect("This function assumes all incoming constants are floats."); + const_f64.value() + })) +} + +/// Fold binary operations +struct BinaryFold(Box f64 + Send + Sync>); +impl BinaryFold { + fn from_op(op: &FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(match op { + fmax => f64::max, + fmin => f64::min, + fadd => std::ops::Add::add, + fsub => std::ops::Sub::sub, + fmul => std::ops::Mul::mul, + fdiv => std::ops::Div::div, + _ => panic!("not binary op"), + })) + } +} +impl ConstFold for BinaryFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1, f2] = get_floats(consts)?; + + let res = ConstF64::new((self.0)(f1, f2)); + Some(vec![(0.into(), res.into())]) + } +} + +/// Fold comparisons. +struct CmpFold(Box bool + Send + Sync>); +impl CmpFold { + fn from_op(op: FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(move |x, y| { + (match op { + feq => f64::eq, + fne => f64::lt, + flt => f64::lt, + fgt => f64::gt, + fle => f64::le, + fge => f64::ge, + _ => panic!("not cmp op"), + })(&x, &y) + })) + } +} + +impl ConstFold for CmpFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1, f2] = get_floats(consts)?; + + let res = if (self.0)(f1, f2) { + ops::Const::true_val() + } else { + ops::Const::false_val() + }; + + Some(vec![(0.into(), res)]) + } +} + +/// Fold unary operations +struct UnaryFold(Box f64 + Send + Sync>); +impl UnaryFold { + fn from_op(op: &FloatOps) -> Self { + use FloatOps::*; + Self(Box::new(match op { + fneg => std::ops::Neg::neg, + fabs => f64::abs, + ffloor => f64::floor, + fceil => f64::ceil, + _ => panic!("not unary op."), + })) + } +} + +impl ConstFold for UnaryFold { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let [f1] = get_floats(consts)?; + let res = ConstF64::new((self.0)(f1)); + Some(vec![(0.into(), res.into())]) + } +} diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 71f91bf87..ba5fe0956 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -40,7 +40,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Create a new [`ConstF64`] - pub fn new(value: f64) -> Self { + pub const fn new(value: f64) -> Self { Self { value } } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index ae5160ffd..e1cea6c49 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -11,6 +11,7 @@ use crate::ops::OpName; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; + use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, types::{type_param::TypeArg, Type, TypeRow}, diff --git a/src/values.rs b/src/values.rs index 17d173a00..46d2778ee 100644 --- a/src/values.rs +++ b/src/values.rs @@ -10,6 +10,7 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; + use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType};