Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: constant folding for arithmetic conversion operations #720

Merged
merged 58 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
bffed99
wip: constant folding
ss2165 Nov 13, 2023
1a27d54
start moving folding to op_def
ss2165 Nov 20, 2023
b84766b
thread through folding methods
ss2165 Nov 23, 2023
8ee49da
integer addition tests passing
ss2165 Nov 23, 2023
520de7c
remove FoldOutput
ss2165 Nov 24, 2023
1d656d6
Merge branch 'main' into feat/const-fold2
ss2165 Dec 18, 2023
9398d9d
refactor int folding to separate repo
ss2165 Dec 18, 2023
7b955a9
add tuple and sum constant folding
ss2165 Dec 18, 2023
6cb3c62
simplify test code
ss2165 Dec 18, 2023
0500624
wip: fold finder
ss2165 Dec 20, 2023
8f554e0
chore(deps): bump actions/upload-artifact from 3 to 4 (#751)
dependabot[bot] Dec 20, 2023
215eb40
chore(deps): bump dawidd6/action-download-artifact from 2 to 3 (#752)
dependabot[bot] Dec 20, 2023
ff26546
fix: case node should not have an external signature (#749)
ss2165 Dec 20, 2023
64b9199
refactor: move hugr equality check out for reuse
ss2165 Dec 20, 2023
6d7d440
feat: implement RemoveConst and RemoveConstIgnore
ss2165 Dec 21, 2023
cdde503
use remove rewrites while folding
ss2165 Dec 21, 2023
114524c
alllow candidate node specification in find_consts
ss2165 Dec 21, 2023
a087fbc
add exhaustive fold pass
ss2165 Dec 21, 2023
07768b2
refactor!: use enum op traits for floats + conversions
ss2165 Dec 21, 2023
9a81260
Merge branch 'refactor/fops-enum' into feat/const-fold2
ss2165 Dec 21, 2023
658adf4
add folding definitions for float ops
ss2165 Dec 21, 2023
2c0e75b
refactor: ERROR_CUSTOM_TYPE
ss2165 Dec 21, 2023
dc7ff13
refactor: const ConstF64::new
ss2165 Dec 21, 2023
aa73ab2
feat: implement folding for conversion ops
ss2165 Dec 21, 2023
a519f34
fixup! refactor: ERROR_CUSTOM_TYPE
ss2165 Dec 21, 2023
a7a4088
Merge branch 'main' into feat/const-fold2
ss2165 Dec 21, 2023
46075c2
implement bigger tests and fix unearthed bugs
ss2165 Dec 21, 2023
df854e8
Revert "refactor: move hugr equality check out for reuse"
ss2165 Dec 22, 2023
ba81e7b
feat: implement RemoveConst and RemoveConstIgnore
ss2165 Dec 21, 2023
09ce1c9
remove conversion foldin
ss2165 Dec 22, 2023
5a372c7
Merge branch 'main' into feat/const-fold-floats
ss2165 Dec 22, 2023
26bc5ff
add rust version guards
ss2165 Dec 22, 2023
b513ace
Merge branch 'feat/const-rewrites' into feat/const-fold-floats
ss2165 Dec 22, 2023
1348891
Revert "remove conversion foldin"
ss2165 Dec 22, 2023
5a71f75
docs: add public method docstrings
ss2165 Dec 22, 2023
6fa7eb9
add some docstrings and comments
ss2165 Dec 22, 2023
7381432
remove integer folding
ss2165 Dec 22, 2023
3bfda50
Revert "remove integer folding"
ss2165 Dec 22, 2023
0e0411f
remove unused imports
ss2165 Dec 22, 2023
8e88f3e
add docstrings and simplify
ss2165 Dec 22, 2023
dea6085
Merge branch 'feat/const-fold-floats' into feat/const-fold2
ss2165 Dec 22, 2023
48eb430
Merge branch 'feat/fold-ints' into feat/const-fold2
ss2165 Dec 22, 2023
41fa47a
Merge branch 'feat/const-fold-floats' into feat/fold-ints
ss2165 Dec 22, 2023
ccf789e
Merge branch 'feat/fold-ints' into feat/const-fold2
ss2165 Dec 22, 2023
0c060fb
Merge branch 'main' into feat/const-fold-floats
lmondada Jan 2, 2024
4e24c28
Merge branch 'main' into feat/const-fold2
ss2165 Jan 3, 2024
4bca931
docs: Spec clarifications (#738)
cqc-alec Jan 3, 2024
3193cdb
docs: Spec updates (#741)
cqc-alec Jan 3, 2024
d0513c4
docs: [spec] Remove references to causal cone and Order edges from In…
acl-cqc Jan 3, 2024
89f1827
chore: remove rustversion (#764)
ss2165 Jan 3, 2024
4b6123e
ci: Setup release-plz and related files (#765)
aborgna-q Jan 3, 2024
9500803
feat: implement RemoveConst and RemoveConstIgnore (#757)
ss2165 Jan 3, 2024
2c6abc6
Merge branch 'main' into feat/const-fold-floats
ss2165 Jan 3, 2024
905ef01
address minor review comments
ss2165 Jan 3, 2024
a6928e0
Merge branch 'feat/const-fold-floats' into feat/const-fold2
ss2165 Jan 3, 2024
6e36684
remove integer folding
ss2165 Jan 3, 2024
b0c686d
Merge branch 'main' into feat/const-fold2
ss2165 Jan 3, 2024
8f693ac
Update src/std_extensions/arithmetic/conversions.rs
ss2165 Jan 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,24 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
#[cfg(test)]
mod test {

use super::*;
use crate::extension::prelude::sum_with_error;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::std_extensions::arithmetic;

use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};

use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
use rstest::rstest;

use super::*;
/// int to constant
fn i2c(b: u64) -> Const {
Const::new(
ConstIntU::new(5, b).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

/// float to constant
fn f2c(f: f64) -> Const {
Expand All @@ -244,19 +253,19 @@ mod test {

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}

#[test]
fn test_big() {
/*
Test hugr approximately calculates
let x = (5.5, 3.25);
x.0 - x.1 == 2.25
Test approximately calculates
let x = (5.6, 3.2);
int(x.0 - x.1) == 2
*/
let sum_type = sum_with_error(INT_TYPES[5].to_owned());
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap();
DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap();

let tup = build
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)]))
.add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)]))
.unwrap();

let unpack = build
Expand All @@ -271,19 +280,31 @@ mod test {
let sub = build
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();
let to_int = build
.add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
arithmetic::float_ops::EXTENSION.to_owned(),
arithmetic::conversions::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(sub.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 7);
let mut h = build
.finish_hugr_with_outputs(to_int.outputs(), &reg)
.unwrap();
assert_eq!(h.node_count(), 8);

constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &f2c(2.25));
let expected = Value::Sum {
tag: 0,
value: Box::new(i2c(2).value().clone()),
};
let expected = Const::new(expected, sum_type).unwrap();
assert_fully_folded(&h, &expected);
}
fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
Expand Down
14 changes: 14 additions & 0 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
use super::int_types::int_tv;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

Expand Down Expand Up @@ -63,8 +64,21 @@ impl MakeOpDef for ConvertOpDef {
}
.to_string()
}

fn post_opdef(&self, def: &mut OpDef) {
const_fold::set_fold(self, def)
}
}

impl ConvertOpDef {
/// INitialise a conversion op with an integer log width type argument.
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
pub fn with_width(self, log_width: u8) -> ConvertOpType {
ConvertOpType {
def: self,
log_width: log_width as u64,
}
}
}
/// Concrete convert operation with integer width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
Expand Down
134 changes: 134 additions & 0 deletions src/std_extensions/arithmetic/conversions/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use crate::{
extension::{
prelude::{sum_with_error, ConstError},
ConstFold, ConstFoldResult, OpDef,
},
ops,
std_extensions::arithmetic::{
float_types::ConstF64,
int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES},
},
types::ConstTypeError,
values::{CustomConst, Value},
IncomingPort,
};

use super::ConvertOpDef;

pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) {
use ConvertOpDef::*;

match op {
trunc_u => def.set_constant_folder(TruncU),
trunc_s => def.set_constant_folder(TruncS),
convert_u => def.set_constant_folder(ConvertU),
convert_s => def.set_constant_folder(ConvertS),
}
}

fn get_input<T: CustomConst>(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> {
let [(_, c)] = consts else {
return None;
};
c.get_custom_value()
}

fn fold_trunc(
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
convert: impl Fn(f64, u8) -> Result<Value, ConstTypeError>,
) -> ConstFoldResult {
let f: &ConstF64 = get_input(consts)?;
let f = f.value();
let [arg] = type_args else {
return None;
};
let log_width = get_log_width(arg).ok()?;
let int_type = INT_TYPES[log_width as usize].to_owned();
let sum_type = sum_with_error(int_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Can't truncate non-finite float".to_string(),
};
let sum_val = Value::Sum {
tag: 1,
value: Box::new(err_val.into()),
};

ops::Const::new(sum_val, sum_type.clone()).unwrap()
};
let out_const: ops::Const = if !f.is_finite() {
err_value()
} else {
let cv = convert(f, log_width);
if let Ok(cv) = cv {
let sum_val = Value::Sum {
tag: 0,
value: Box::new(cv),
};

ops::Const::new(sum_val, sum_type).unwrap()
} else {
err_value()
}
};

Some(vec![(0.into(), out_const)])
}

struct TruncU;

impl ConstFold for TruncU {
fn fold(
&self,
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntU::new(log_width, f.trunc() as u64).map(Into::into)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if f is negative? Does as u64 cause a panic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as u64 maps all negative floats to 0, should I add a panic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is a can of worms. The spec for trunc_u currently says "Returns an error when the float is non-finite or cannot be exactly stored in N bits". This should probably say "cannot be exactly represented as a u<N>". But what we're doing here is rounding, which is much more forgiving. WASM seems to say that the result is undefined if the number is negative, but rounded otherwise, which makes me think we should change the definition of trunc_{u,s} in the spec.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like this could be a follow up issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll make one.

})
}
}

struct TruncS;

impl ConstFold for TruncS {
fn fold(
&self,
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntS::new(log_width, f.trunc() as i64).map(Into::into)
})
}
}

struct ConvertU;

impl ConstFold for ConvertU {
fn fold(
&self,
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
let u: &ConstIntU = get_input(consts)?;
let f = u.value() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}

struct ConvertS;

impl ConstFold for ConvertS {
fn fold(
&self,
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
let u: &ConstIntS = get_input(consts)?;
let f = u.value() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}