From 5a80c3890b0bbaed0fb765850ea056b2ff1b8267 Mon Sep 17 00:00:00 2001 From: Graham Date: Wed, 13 Sep 2023 09:34:15 -0400 Subject: [PATCH] new(rs): fully by-hand expression simplification (#273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * testing: filter out nAn-looking things in proptests * fix: some more rules, big test commented out * fix: use velcro, simplify and add some rules * fix: rename some rules * fix: remove pathological test case; simplify ruleset * fix: typo * new(rs): combine by hand & egg to simplify expressions * feat: remove egg, rely solely on by_hand * fix: unused dep * fix: debug * fix: clippy * fix: un-remove hash utility * wip: still getting syntax errors on the affine case * fix: fewer references and clones * fix: get affine working * fix: π => 3.1415926535897932384626… * fix: various tidyings and comments * doc: justify LIMIT = 1 * fix: respond to PR review comments * fix: some other opportunities for simplification * fix: responding to PR review comments * fix: test coverage * fix: make private something that no longer needs pub(crate) * fix: responding to PR comments * fix: Commentary tweaks per review comments * fix: simplify exponentiations of numbers or π * fix: remove erroneous subtraction rule * fix: add simple double subtraction test to protect against previous bad rule --- Cargo.lock | 87 +- quil-rs/Cargo.toml | 3 +- quil-rs/src/expression/mod.rs | 28 +- quil-rs/src/expression/simplification.rs | 588 ------------ .../src/expression/simplification/by_hand.rs | 868 ++++++++++++++++++ quil-rs/src/expression/simplification/mod.rs | 345 +++++++ quil-rs/src/hash.rs | 12 +- 7 files changed, 1231 insertions(+), 700 deletions(-) delete mode 100644 quil-rs/src/expression/simplification.rs create mode 100644 quil-rs/src/expression/simplification/by_hand.rs create mode 100644 quil-rs/src/expression/simplification/mod.rs diff --git a/Cargo.lock b/Cargo.lock index c0b3eda6..d7284e04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,17 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "aho-corasick" version = "1.0.2" @@ -336,24 +325,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d1b11bd5e7e98406c6ff39fbc94d6e910a489b978ce7f17c19fce91a1195b7a" -[[package]] -name = "egg" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96beaf9d35dbc4686bc86a4ecb851fd6a406f0bf32d9f646b1225a5c5cf5b5d7" -dependencies = [ - "env_logger", - "fxhash", - "hashbrown", - "indexmap", - "instant", - "log", - "smallvec", - "symbol_table", - "symbolic_expressions", - "thiserror", -] - [[package]] name = "either" version = "1.9.0" @@ -366,15 +337,6 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" -[[package]] -name = "env_logger" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" -dependencies = [ - "log", -] - [[package]] name = "errno" version = "0.3.2" @@ -509,15 +471,6 @@ dependencies = [ "slab", ] -[[package]] -name = "fxhash" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" -dependencies = [ - "byteorder", -] - [[package]] name = "getrandom" version = "0.2.10" @@ -546,9 +499,6 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash", -] [[package]] name = "heck" @@ -591,18 +541,6 @@ dependencies = [ "yaml-rust", ] -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", - "js-sys", - "wasm-bindgen", - "web-sys", -] - [[package]] name = "inventory" version = "0.3.11" @@ -1108,7 +1046,7 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quil-py" -version = "0.5.0" +version = "0.5.1-rc.1" dependencies = [ "ndarray", "numpy", @@ -1127,7 +1065,6 @@ dependencies = [ "clap", "criterion", "dot-writer", - "egg", "indexmap", "insta", "itertools 0.11.0", @@ -1484,22 +1421,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "symbol_table" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32bf088d1d7df2b2b6711b06da3471bc86677383c57b27251e18c56df8deac14" -dependencies = [ - "ahash", - "hashbrown", -] - -[[package]] -name = "symbolic_expressions" -version = "5.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71" - [[package]] name = "syn" version = "0.15.44" @@ -1629,12 +1550,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - [[package]] name = "wait-timeout" version = "0.2.0" diff --git a/quil-rs/Cargo.toml b/quil-rs/Cargo.toml index a7c177b0..42d9af6b 100644 --- a/quil-rs/Cargo.toml +++ b/quil-rs/Cargo.toml @@ -12,7 +12,6 @@ categories = ["parser-implementations", "science", "compilers", "emulators"] [dependencies] approx = { version = "0.5.1", features = ["num-complex"] } dot-writer = { version = "0.1.2", optional = true } -egg = { version = "0.9.4", features = ["deterministic"] } indexmap = "1.6.1" itertools = "0.11.0" lexical = "6.1.1" @@ -40,7 +39,7 @@ rstest = "0.18.1" # These are described in the crate README.md [features] graphviz-dot = ["dot-writer"] -wasm-bindgen = ["egg/wasm-bindgen"] +wasm-bindgen = [] [[bench]] name = "parser" diff --git a/quil-rs/src/expression/mod.rs b/quil-rs/src/expression/mod.rs index 3a6b5aa4..311832f6 100644 --- a/quil-rs/src/expression/mod.rs +++ b/quil-rs/src/expression/mod.rs @@ -259,11 +259,7 @@ impl Expression { Expression::PiConstant => { *self = Expression::Number(Complex64::from(PI)); } - _ => { - if let Ok(simpler) = simplification::run(self) { - *self = simpler; - } - } + _ => *self = simplification::run(self), } } @@ -666,11 +662,19 @@ impl fmt::Display for InfixOperator { #[allow(clippy::arc_with_non_send_sync)] mod tests { use super::*; - use crate::hash::hash_to_u64; use crate::reserved::ReservedToken; use proptest::prelude::*; + use std::collections::hash_map::DefaultHasher; use std::collections::HashSet; + /// Hash value helper: turn a hashable thing into a u64. + #[inline] + fn hash_to_u64(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() + } + #[test] fn simplify_and_evaluate() { use Expression::*; @@ -790,7 +794,7 @@ mod tests { // Better behaved than the auto-derived version for names fn arb_name() -> impl Strategy { r"[a-z][a-zA-Z0-9]{1,10}".prop_filter("Exclude reserved tokens", |t| { - ReservedToken::from_str(t).is_err() + ReservedToken::from_str(t).is_err() && !t.to_lowercase().starts_with("nan") }) } @@ -835,12 +839,10 @@ mod tests { right: Box::new(r) }) ), - (any::(), expr).prop_map(|(operator, e)| Prefix( - PrefixExpression { - operator, - expression: Box::new(e) - } - )) + (expr).prop_map(|e| Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + expression: Box::new(e) + })) ] }, ) diff --git a/quil-rs/src/expression/simplification.rs b/quil-rs/src/expression/simplification.rs deleted file mode 100644 index e1e059a1..00000000 --- a/quil-rs/src/expression/simplification.rs +++ /dev/null @@ -1,588 +0,0 @@ -/// Complex machinery for simplifying [`Expression`]s. -use crate::{ - expression::{ - format_complex, is_small, Expression, ExpressionFunction, FunctionCallExpression, - InfixExpression, InfixOperator, MemoryReference, PrefixExpression, PrefixOperator, - }, - hash::{hash_f64, hash_to_u64}, - imag, - quil::Quil, - real, -}; -use egg::{define_language, rewrite as rw, Id, Language, RecExpr}; -use once_cell::sync::Lazy; -use std::{ - cmp::Ordering, - f64::consts::PI, - hash::{Hash, Hasher}, - ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, - str::FromStr, -}; - -/// Simplify an [`Expression`]: -/// - turn it into a [`RecExpr`], -/// - let [`egg`] simplify the recursive expression as best as it can, -/// - and turn that back into an [`Expression`] -pub(super) fn run(expression: &Expression) -> Result { - let recexpr = expression_to_recexpr(expression); - let runner = egg::Runner::default().with_expr(&recexpr).run(&(*RULES)); - let root = runner.roots[0]; - let (_, best) = egg::Extractor::new(&runner.egraph, egg::AstSize).find_best(root); - recexpr_to_expression(best) -} - -/// All the myriad ways simplifying an [`Expression`] can fail. -#[derive(Debug, thiserror::Error)] -pub enum SimplificationError { - #[error("Invalid string for a complex number: {0}")] - ComplexParsingError(#[from] num_complex::ParseComplexError), - #[error("Expected a valid index: {0}")] - IndexExpected(#[from] std::num::ParseIntError), - #[error("Invalid string for a memory reference: {0}")] - MemoryReferenceSyntax(#[from] ::Err), -} - -/// An [`egg`]-friendly complex number. -/// We can't just use `num_complex::Complex64`, because we need `Ord` and `Hash`. -/// -/// Fun fact, there is no natural ordering on the complex numbers; however, the implementations -/// here are good enough for our purposes. -/// -/// https://en.wikipedia.org/wiki/Complex_number#Ordering -#[derive(Debug, Default, Clone, Copy)] -struct Complex(num_complex::Complex64); - -impl Hash for Complex { - fn hash(&self, state: &mut H) { - // Skip zero values (akin to `format_complex`). - // Also, since f64 isn't hashable, use the u64 binary representation. - // The docs claim this is rather portable: https://doc.rust-lang.org/std/primitive.f64.html#method.to_bits - if self.0.re.abs() > 0f64 { - hash_f64(self.0.re, state) - } - if self.0.im.abs() > 0f64 { - hash_f64(self.0.im, state) - } - } -} - -impl PartialEq for Complex { - // Partial equality by hash value - fn eq(&self, other: &Self) -> bool { - hash_to_u64(self) == hash_to_u64(other) - } -} - -impl Eq for Complex {} - -impl PartialOrd for Complex { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -/// Typical ordering, but with NAN as the biggest value; borrowed the idea from ordered-float -#[inline(always)] -fn _fcmp(x: f64, y: f64) -> Ordering { - if let Some(ordering) = x.partial_cmp(&y) { - ordering - } else { - match (x.is_nan(), y.is_nan()) { - (true, true) => Ordering::Equal, - (false, true) => Ordering::Less, - (true, false) => Ordering::Greater, - (false, false) => unreachable!("These floats should be partially comparable"), - } - } -} - -/// lexicographic ordering with NAN as the biggest value -impl Ord for Complex { - fn cmp(&self, other: &Self) -> Ordering { - match (_fcmp(self.0.re, other.0.re), _fcmp(self.0.im, other.0.im)) { - (Ordering::Less, _) => Ordering::Less, - (Ordering::Greater, _) => Ordering::Greater, - (Ordering::Equal, other) => other, - } - } -} - -impl std::str::FromStr for Complex { - type Err = (); - fn from_str(s: &str) -> Result { - num_complex::Complex64::from_str(s) - .map(Self) - .map_err(|_| ()) - } -} - -impl std::fmt::Display for Complex { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.write_str(&format_complex(&self.0)) - } -} - -macro_rules! impl_via_inner { - ($function:ident) => { - fn $function(self) -> Self { - Self(self.0.$function()) - } - }; - ($trait:ident, $function:ident) => { - impl $trait for Complex { - type Output = Self; - fn $function(self) -> Self::Output { - Self(self.0.$function()) - } - } - }; - ($trait:ident, $trait_assign:ident, $function:ident, $function_assign:ident) => { - impl $trait for Complex { - type Output = Self; - fn $function(self, other: Self) -> Self { - Self(self.0.$function(other.0)) - } - } - impl $trait_assign for Complex { - fn $function_assign(&mut self, other: Self) { - *self = Self(self.0.$function(other.0)) - } - } - }; -} - -impl_via_inner!(Neg, neg); -impl_via_inner!(Add, AddAssign, add, add_assign); -impl_via_inner!(Sub, SubAssign, sub, sub_assign); -impl_via_inner!(Mul, MulAssign, mul, mul_assign); -impl_via_inner!(Div, DivAssign, div, div_assign); - -impl Complex { - const ZERO: Self = Self(real!(0.0)); - const PI: Self = Self(real!(std::f64::consts::PI)); - fn close(&self, other: Self) -> bool { - is_small((*self - other).abs()) - } - fn abs(self) -> f64 { - self.0.norm() - } - fn cis(self) -> Self { - // num_complex::Complex64::cis takes a float :-( - Self(self.0.cos() + imag!(1.0) * self.0.sin()) - } - fn pow(self, other: Self) -> Self { - Self(self.0.powc(other.0)) - } - impl_via_inner!(cos); - impl_via_inner!(exp); - impl_via_inner!(sin); - impl_via_inner!(sqrt); -} - -define_language! { - /// An [`egg`]-friendly version of [`Expression`]s, this language allows us to manipulate - /// and simplify terms. - enum Expr { - // Numbers - "pi" = Pi, - Number(Complex), - // Functions - "cis" = Cis(Id), - "cos" = Cos(Id), - "exp" = Exp(Id), - "sin" = Sin(Id), - "sqrt" = Sqrt(Id), - // Prefix arithmetic - "pos" = Pos(Id), - "neg" = Neg(Id), - // Infix arithmetic - "^" = Pow([Id; 2]), - "*" = Mul([Id; 2]), - "/" = Div([Id; 2]), - "+" = Add([Id; 2]), - "-" = Sub([Id; 2]), - // Variables and Addresses - Symbol(egg::Symbol), - } -} - -/// Parse the [`Expression`] into a [`RecExpr`] -fn expression_to_recexpr(expression: &Expression) -> RecExpr { - fn helper(e: &Expression, r: &mut RecExpr) -> Id { - match e { - Expression::Address(m) => { - let expr = Expr::Symbol(m.to_quil_or_debug().into()); - r.add(expr) - } - Expression::FunctionCall(FunctionCallExpression { - function, - expression, - }) => { - let id = helper(expression, r); - let expr = match function { - ExpressionFunction::Cis => Expr::Cis(id), - ExpressionFunction::Cosine => Expr::Cos(id), - ExpressionFunction::Exponent => Expr::Exp(id), - ExpressionFunction::Sine => Expr::Sin(id), - ExpressionFunction::SquareRoot => Expr::Sqrt(id), - }; - r.add(expr) - } - Expression::Infix(InfixExpression { - left, - operator, - right, - }) => { - let ids = [helper(left, r), helper(right, r)]; - let expr = match operator { - InfixOperator::Caret => Expr::Pow(ids), - InfixOperator::Plus => Expr::Add(ids), - InfixOperator::Minus => Expr::Sub(ids), - InfixOperator::Slash => Expr::Div(ids), - InfixOperator::Star => Expr::Mul(ids), - }; - r.add(expr) - } - Expression::Number(x) => r.add(Expr::Number(Complex(*x))), - Expression::Prefix(PrefixExpression { - operator, - expression, - }) => { - let id = helper(expression, r); - let expr = match operator { - PrefixOperator::Plus => Expr::Pos(id), - PrefixOperator::Minus => Expr::Neg(id), - }; - r.add(expr) - } - Expression::PiConstant => r.add(Expr::Pi), - Expression::Variable(v) => r.add(Expr::Symbol(format!("%{v}").into())), - } - } - let mut r = RecExpr::default(); - helper(expression, &mut r); - r -} - -/// Parse the [`RecExpr`] back into an [`Expression`] -/// -/// This returns a [`Result`] rather than just the expression due to some `FromStr` usage in the -/// very last case. -fn recexpr_to_expression(recexpr: RecExpr) -> Result { - fn helper(nodes: &[Expr], i: usize) -> Result { - match nodes[i] { - Expr::Pi => Ok(Expression::Number(PI.into())), - Expr::Number(x) => Ok(Expression::Number(x.0)), - Expr::Cis(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::FunctionCall(FunctionCallExpression { - function: ExpressionFunction::Cis, - expression, - })) - } - Expr::Cos(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::FunctionCall(FunctionCallExpression { - function: ExpressionFunction::Cosine, - expression, - })) - } - Expr::Exp(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::FunctionCall(FunctionCallExpression { - function: ExpressionFunction::Exponent, - expression, - })) - } - Expr::Sin(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::FunctionCall(FunctionCallExpression { - function: ExpressionFunction::Sine, - expression, - })) - } - Expr::Sqrt(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::FunctionCall(FunctionCallExpression { - function: ExpressionFunction::SquareRoot, - expression, - })) - } - Expr::Pos(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::Prefix(PrefixExpression { - operator: PrefixOperator::Plus, - expression, - })) - } - Expr::Neg(id) => { - let expression = Box::new(helper(nodes, id.into())?); - Ok(Expression::Prefix(PrefixExpression { - operator: PrefixOperator::Minus, - expression, - })) - } - Expr::Pow([left_id, right_id]) => { - let left = Box::new(helper(nodes, left_id.into())?); - let right = Box::new(helper(nodes, right_id.into())?); - Ok(Expression::Infix(InfixExpression { - operator: InfixOperator::Caret, - left, - right, - })) - } - Expr::Mul([left_id, right_id]) => { - let left = Box::new(helper(nodes, left_id.into())?); - let right = Box::new(helper(nodes, right_id.into())?); - Ok(Expression::Infix(InfixExpression { - operator: InfixOperator::Star, - left, - right, - })) - } - Expr::Div([left_id, right_id]) => { - let left = Box::new(helper(nodes, left_id.into())?); - let right = Box::new(helper(nodes, right_id.into())?); - Ok(Expression::Infix(InfixExpression { - operator: InfixOperator::Slash, - left, - right, - })) - } - Expr::Add([left_id, right_id]) => { - let left = Box::new(helper(nodes, left_id.into())?); - let right = Box::new(helper(nodes, right_id.into())?); - Ok(Expression::Infix(InfixExpression { - operator: InfixOperator::Plus, - left, - right, - })) - } - Expr::Sub([left_id, right_id]) => { - let left = Box::new(helper(nodes, left_id.into())?); - let right = Box::new(helper(nodes, right_id.into())?); - Ok(Expression::Infix(InfixExpression { - operator: InfixOperator::Minus, - left, - right, - })) - } - Expr::Symbol(sym) => { - let s = sym.to_string(); - match s { - ref x if x.starts_with('%') => Ok(Expression::Variable(s[1..].to_string())), - ref x if x.contains('[') => { - Ok(Expression::Address(MemoryReference::from_str(x)?)) - } - _ => num_complex::Complex64::from_str(&s) - .map(Expression::Number) - .map_err(SimplificationError::ComplexParsingError), - } - } - } - } - let nodes = recexpr.as_ref(); - helper(nodes, nodes.len() - 1) -} - -/// Our analysis will perform arithmetic simplification (largely, constant folding) on our -/// language. -#[derive(Default)] -struct Arithmetic; -type EGraph = egg::EGraph; - -/// Our analysis will perform constant folding on our language. -impl egg::Analysis for Arithmetic { - /// Constant values - type Data = Option; - - /// Pull the (possible) [`Self::Data`] from the given expression. - fn make(egraph: &EGraph, enode: &Expr) -> Self::Data { - let x = |id: &Id| egraph[*id].data.as_ref(); - match enode { - Expr::Pi => Some(Complex::PI), - Expr::Number(c) => Some(*c), - Expr::Cis(id) => Some(x(id)?.cis()), - Expr::Cos(id) => Some(x(id)?.cos()), - Expr::Exp(id) => Some(x(id)?.exp()), - Expr::Sin(id) => Some(x(id)?.sin()), - Expr::Sqrt(id) => Some(x(id)?.sqrt()), - Expr::Pos(id) => Some(*x(id)?), - Expr::Neg(id) => Some(-*x(id)?), - Expr::Pow([base, power]) => Some(x(base)?.pow(*x(power)?)), - Expr::Mul([left, right]) => Some(*x(left)? * *x(right)?), - Expr::Div([left, right]) => Some(*x(left)? / *x(right)?), - Expr::Add([left, right]) => Some(*x(left)? + *x(right)?), - Expr::Sub([left, right]) => Some(*x(left)? - *x(right)?), - Expr::Symbol(_) => None, - } - } - - /// Merge two pieces of data with the same value. - fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> egg::DidMerge { - egg::merge_option(to, from, |_a, _b| egg::DidMerge(false, false)) - } - - /// Update the graph to equate and simplify constant values. - fn modify(egraph: &mut EGraph, id: Id) { - if let Some(c) = egraph[id].data { - let value = if c.close(Complex::PI) { - Expr::Pi - } else { - Expr::Number(c) - }; - let added = egraph.add(value); - egraph.union(id, added); - egraph[id].nodes.retain(|n| n.is_leaf()); - } - } -} - -/// Is the variable equivalent to zero in the given circumstances? -fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &egg::Subst) -> bool { - let key = var.parse().unwrap(); - move |egraph, _, subst| { - egraph[subst[key]] - .data - .as_ref() - .map(|value| !value.close(Complex::ZERO)) - .unwrap_or(false) - } -} - -fn is_number(var: &str) -> impl Fn(&mut EGraph, Id, &egg::Subst) -> bool { - let key = var.parse().unwrap(); - move |egraph, _, subst| egraph[subst[key]].data.as_ref().is_some() -} - -/// Rewrite terms of our [`Expr`] language by reducing with our [`Arithmetic`] analysis. -type Rewrite = egg::Rewrite; - -/// Instantiate our rewrite rules for simplifying [`Expr`] terms. -static RULES: Lazy> = Lazy::new(|| { - vec![ - // Largely copied from https://github.com/egraphs-good/egg/blob/82c00e970f0bc1fbfe90ce6dc3c3c79ee919c933/tests/math.rs - // and https://github.com/herbie-fp/herbie/blob/2052806f2ffe0d46bc2e151dd096b127c39e12bd/egg-herbie/src/rules.rs - - // addition & subtraction - rw!("add-zero" ; "(+ ?a 0)" => "?a"), - rw!("zero-add" ; "(+ 0 ?a)" => "?a"), - rw!("cancel-sub" ; "(- ?a ?a)" => "0"), - // multiplication & division - rw!("zero-mul" ; "(* 0 ?a)" => "0"), - rw!("one-mul" ; "(* 1 ?a)" => "?a"), - rw!("one-div" ; "(/ ?a 1)" => "?a"), - rw!("cancel-div" ; "(/ ?a ?a)" => "1" if is_not_zero("?a")), - // + & * - rw!("distribute" ; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), - rw!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), - rw!("associate-add" ; "(+ ?a ?b)" => "(+ ?b ?a)"), - rw!("associate-mul" ; "(* ?a ?b)" => "(* ?b ?a)"), - rw!("commute-add" ; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"), - rw!("commute-mul" ; "(* (* ?a ?b) ?c)" => "(* ?a (* ?b ?c))"), - // pow & sqrt - rw!("pow0" ; "(^ ?a 0)" => "1" if is_not_zero("?a")), - rw!("pow1" ; "(^ ?a 1)" => "?a"), - rw!("pow2" ; "(^ ?a 2)" => "(* ?a ?a)"), - rw!("pow2-neg" ; "(^ (neg ?a) 2)" => "(* ?a ?a)"), - rw!("pow2-sqrt" ; "(^ (sqrt ?a) 2)" => "?a"), - rw!("sqrt-pow2" ; "(sqrt (^ ?a 2))" => "?a"), - rw!("pow1/2" ; "(^ ?a 0.5)" => "(sqrt ?a)"), - rw!("div-mul" ; "(/ ?a (* ?b ?a))" => "?b" if is_not_zero("?a")), - rw!("div-mul-2" ; "(/ (* ?b ?a) ?a)" => "?b" if is_not_zero("?a")), - rw!("mul-div" ; "(* ?a (/ ?b ?a))" => "?b" if is_not_zero("?a")), - rw!("mul-div-2" ; "(* (/ ?b ?a) ?a)" => "?b" if is_not_zero("?a")), - rw!("pow-mul" ; "(* (^ ?a ?b) (^ ?a ?c))" => "(^ ?a (+ ?b ?c))"), - rw!("mul-pow" ; "(^ ?a (+ ?b ?c))" => "(* (^ ?a ?b) (^ ?a ?c))"), - // pos and neg - rw!("pos-canon" ; "(pos ?a)" => "?a"), - rw!("sub-neg" ; "(- ?a (neg ?b))" => "(+ ?a ?b)"), - rw!("neg-canon" ; "(neg ?a)" => "-?a" if is_number("?a")), - // exp - rw!("exp-zero" ; "(exp 0)" => "1"), - rw!("exp-neg" ; "(exp (neg ?a))" => "(/ 1 (exp ?a))"), - ] -}); - -#[cfg(test)] -mod tests { - use super::*; - - egg::test_fn! { - docstring_example, - &RULES, - "(+ (cos (* 2 pi)) 2)" => "3" - } - - egg::test_fn! { - issue_208_1, - &RULES, - "(* 0 theta)" => "0" - } - - egg::test_fn! { - issue_208_2, - &RULES, - "(/ theta 1)" => "theta" - } - - egg::test_fn! { - issue_208_3, - &RULES, - "(/ (* theta 5) 5)" => "theta" - } - - egg::test_fn! { - memory_ref, - &RULES, - "theta[0]" => "theta[0]" - } - - egg::test_fn! { - var, - &RULES, - "%foo" => "%foo" - } - - egg::test_fn! { - prefix_neg, - &RULES, - "(neg -1)" => "1" - } - - egg::test_fn! { - neg_sub, - &RULES, - "(neg (- 1 2))" => "1" - } - - egg::test_fn! { - neg_imag, - &RULES, - "(neg 9.48e42i)" => "-9.48e42i" - } - - egg::test_fn! { - pow_neg_address, - &RULES, - "(^ (neg 9.48e42i) A[9])" => "(^ -9.48e42i A[9]))" - } - - egg::test_fn! { - fold_constant_mul, - &RULES, - "(* 2 pi)" => "6.283185307179586" - } - - egg::test_fn! { - fold_constant_mul_div, - &RULES, - "(/ (* 2 pi) 6.283185307179586)" => "1" - } - - egg::test_fn! { - fold_constant_mul_div_with_ref, - &RULES, - "(/ (* (* a[0] 2) pi) 6.283185307179586)" => "a[0]" - } -} diff --git a/quil-rs/src/expression/simplification/by_hand.rs b/quil-rs/src/expression/simplification/by_hand.rs new file mode 100644 index 00000000..75e49637 --- /dev/null +++ b/quil-rs/src/expression/simplification/by_hand.rs @@ -0,0 +1,868 @@ +/// Complex machinery for simplifying [`Expression`]s. +use crate::expression::{ + imag, real, Expression, ExpressionFunction, FunctionCallExpression, InfixExpression, + InfixOperator, PrefixExpression, PrefixOperator, +}; +use std::cmp::min_by_key; + +/// Simplify an [`Expression`]. +pub(super) fn run(expression: &Expression) -> Expression { + simplify(expression, LIMIT) +} + +/// Keep stack sizes under control +/// +/// Note(@genos): If this limit is allowed to be too large (100, in local testing on my laptop), +/// the recursive nature of `simplify` and friends (below) will build up large callstacks and then +/// crash with an "I've overflowed my stack" error. Except for exceedingly large expressions +/// (`the_big_one` test case in `mod.rs`, for example), a larger limit here doesn't seem to be of +/// practical value in anecdotal testing. +const LIMIT: u64 = 10; + +/// Recursively simplify an [`Expression`] by hand, breaking into cases to make things more +/// manageable. +fn simplify(e: &Expression, limit: u64) -> Expression { + if limit == 0 { + // bail + e.clone() + } else { + match e { + Expression::Address(_) | Expression::Number(_) | Expression::Variable(_) => e.clone(), + Expression::FunctionCall(FunctionCallExpression { + function, + expression, + }) => simplify_function_call(*function, expression, limit - 1), + Expression::PiConstant => Expression::Number(std::f64::consts::PI.into()), + Expression::Infix(InfixExpression { + left, + operator, + right, + }) => simplify_infix(left, *operator, right, limit - 1), + Expression::Prefix(PrefixExpression { + operator, + expression, + }) => simplify_prefix(*operator, expression, limit - 1), + } + } +} + +const PI: num_complex::Complex64 = real!(std::f64::consts::PI); +const ZERO: num_complex::Complex64 = real!(0.0); +const ONE: num_complex::Complex64 = real!(1.0); +const TWO: num_complex::Complex64 = real!(2.0); + +/// Simplify a function call inside an `Expression`, terminating the recursion if `limit` has reached zero. +fn simplify_function_call(func: ExpressionFunction, expr: &Expression, limit: u64) -> Expression { + if limit == 0 { + // bail + Expression::FunctionCall(FunctionCallExpression { + function: func, + expression: expr.clone().into(), + }) + } else { + // Evaluate numbers and π + // Pass through otherwise + match (func, simplify(expr, limit - 1)) { + (ExpressionFunction::Cis, Expression::Number(x)) => { + // num_complex::Complex64::cis only accepts f64 + Expression::Number(x.cos() + imag!(1.0) * x.sin()) + } + (ExpressionFunction::Cis, Expression::PiConstant) => Expression::Number(-ONE), + (ExpressionFunction::Cosine, Expression::Number(x)) => Expression::Number(x.cos()), + (ExpressionFunction::Cosine, Expression::PiConstant) => Expression::Number(-ONE), + (ExpressionFunction::Exponent, Expression::Number(x)) => Expression::Number(x.exp()), + (ExpressionFunction::Exponent, Expression::PiConstant) => Expression::Number(PI.exp()), + (ExpressionFunction::Sine, Expression::Number(x)) => Expression::Number(x.sin()), + (ExpressionFunction::Sine, Expression::PiConstant) => Expression::Number(PI.sin()), + (ExpressionFunction::SquareRoot, Expression::Number(x)) => Expression::Number(x.sqrt()), + (ExpressionFunction::SquareRoot, Expression::PiConstant) => { + Expression::Number(PI.sqrt()) + } + (function, expression) => Expression::FunctionCall(FunctionCallExpression { + function, + expression: expression.into(), + }), + } + } +} + +#[inline] +fn is_zero(x: num_complex::Complex64) -> bool { + x.norm() < 1e-10 +} + +#[inline] +fn is_one(x: num_complex::Complex64) -> bool { + is_zero(x - 1.0) +} + +/// Helper: in simplification, we'll bias towards smaller expressions +fn size(expr: &Expression) -> usize { + match expr { + Expression::Address(_) + | Expression::Number(_) + | Expression::PiConstant + | Expression::Variable(_) => 1, + Expression::FunctionCall(FunctionCallExpression { + function: _, + expression, + }) => 1 + size(expression), + Expression::Infix(InfixExpression { + left, + operator: _, + right, + }) => 1 + size(left) + size(right), + Expression::Prefix(PrefixExpression { + operator: _, + expression, + }) => 1 + size(expression), + } +} + +// It's verbose to go alone! Take this. +macro_rules! infix { + ($left:expr, $op:expr, $right:expr) => { + Expression::Infix(InfixExpression { + left: $left.into(), + operator: $op, + right: $right.into(), + }) + }; +} +macro_rules! add { + ($left:expr, $right:expr) => { + infix!($left, InfixOperator::Plus, $right) + }; +} +macro_rules! sub { + ($left:expr, $right:expr) => { + infix!($left, InfixOperator::Minus, $right) + }; +} +macro_rules! mul { + ($left:expr, $right:expr) => { + infix!($left, InfixOperator::Star, $right) + }; +} +macro_rules! div { + ($left:expr, $right:expr) => { + infix!($left, InfixOperator::Slash, $right) + }; +} + +/// Check if both arguments are of the form "something * x" for the _same_ x. +fn mul_matches(left_ax: &Expression, right_ax: &Expression) -> bool { + match (left_ax, right_ax) { + ( + Expression::Infix(InfixExpression { + left: ref ll, + operator: InfixOperator::Star, + right: ref lr, + }), + Expression::Infix(InfixExpression { + left: ref rl, + operator: InfixOperator::Star, + right: ref rr, + }), + ) => ll == rl || ll == rr || lr == rl || lr == rr, + _ => false, + } +} + +/// Simplify an infix expression inside an `Expression`, terminating the recursion if `limit` has reached +/// zero. +fn simplify_infix(l: &Expression, op: InfixOperator, r: &Expression, limit: u64) -> Expression { + if limit == 0 { + // bail + Expression::Infix(InfixExpression { + left: l.clone().into(), + operator: op, + right: r.clone().into(), + }) + } else { + // There are … many cases here + match (simplify(l, limit - 1), op, simplify(r, limit - 1)) { + //---------------------------------------------------------------- + // First: only diving one deep, pattern matching on the operation + // (Constant folding and cancellations, mostly) + //---------------------------------------------------------------- + + // Addition and Subtraction + + // Adding with zero + (Expression::Number(x), InfixOperator::Plus, other) + | (other, InfixOperator::Plus, Expression::Number(x)) + if is_zero(x) => + { + other + } + // Adding numbers or π + (Expression::Number(x), InfixOperator::Plus, Expression::Number(y)) => { + Expression::Number(x + y) + } + (Expression::Number(x), InfixOperator::Plus, Expression::PiConstant) + | (Expression::PiConstant, InfixOperator::Plus, Expression::Number(x)) => { + Expression::Number(PI + x) + } + (Expression::PiConstant, InfixOperator::Plus, Expression::PiConstant) => { + Expression::Number(2.0 * PI) + } + + // Subtracting with zero + (Expression::Number(x), InfixOperator::Minus, right) if is_zero(x) => { + simplify_prefix(PrefixOperator::Minus, &right, limit - 1) + } + (left, InfixOperator::Minus, Expression::Number(y)) if is_zero(y) => left, + // Subtracting self + (left, InfixOperator::Minus, right) if left == right => Expression::Number(ZERO), + // Subtracting numbers or π (π - π already covered) + (Expression::Number(x), InfixOperator::Minus, Expression::Number(y)) => { + Expression::Number(x - y) + } + (Expression::Number(x), InfixOperator::Minus, Expression::PiConstant) => { + Expression::Number(x - PI) + } + (Expression::PiConstant, InfixOperator::Minus, Expression::Number(y)) => { + Expression::Number(PI - y) + } + + // Multiplication and Division + + // Multiplication with zero + (Expression::Number(x), InfixOperator::Star, _) + | (_, InfixOperator::Star, Expression::Number(x)) + if is_zero(x) => + { + Expression::Number(ZERO) + } + // Multiplication with one + (Expression::Number(x), InfixOperator::Star, other) + | (other, InfixOperator::Star, Expression::Number(x)) + if is_one(x) => + { + other + } + // Multiplying with numbers or π + (Expression::Number(x), InfixOperator::Star, Expression::Number(y)) => { + Expression::Number(x * y) + } + (Expression::Number(x), InfixOperator::Star, Expression::PiConstant) + | (Expression::PiConstant, InfixOperator::Star, Expression::Number(x)) => { + Expression::Number(PI * x) + } + (Expression::PiConstant, InfixOperator::Star, Expression::PiConstant) => { + Expression::Number(PI * PI) + } + + // Division with zero + (Expression::Number(x), InfixOperator::Slash, _) if is_zero(x) => { + Expression::Number(ZERO) + } + (_, InfixOperator::Slash, Expression::Number(y)) if is_zero(y) => { + Expression::Number(real!(f64::NAN)) + } + // Division with one + (left, InfixOperator::Slash, Expression::Number(y)) if is_one(y) => left, + // Division with self + (left, InfixOperator::Slash, right) if left == right => Expression::Number(ONE), + // Division with numbers or π (π / π already covered) + (Expression::Number(x), InfixOperator::Slash, Expression::Number(y)) => { + Expression::Number(x / y) + } + (Expression::Number(x), InfixOperator::Slash, Expression::PiConstant) => { + Expression::Number(x / PI) + } + (Expression::PiConstant, InfixOperator::Slash, Expression::Number(y)) => { + Expression::Number(PI / y) + } + + // Exponentiation + + // Exponentiation with zero + (Expression::Number(x), InfixOperator::Caret, _) if is_zero(x) => { + Expression::Number(ZERO) + } + (_, InfixOperator::Caret, Expression::Number(y)) if is_zero(y) => { + Expression::Number(ONE) + } + // Exponentiation with one + (Expression::Number(x), InfixOperator::Caret, _) if is_one(x) => { + Expression::Number(ONE) + } + (left, InfixOperator::Caret, Expression::Number(y)) if is_one(y) => left, + // Exponentiation with numbers or π + (Expression::Number(x), InfixOperator::Caret, Expression::Number(y)) => { + Expression::Number(x.powc(y)) + } + (Expression::Number(x), InfixOperator::Caret, Expression::PiConstant) => { + Expression::Number(x.powc(PI)) + } + (Expression::PiConstant, InfixOperator::Caret, Expression::Number(y)) => { + Expression::Number(PI.powc(y)) + } + (Expression::PiConstant, InfixOperator::Caret, Expression::PiConstant) => { + Expression::Number(PI.powc(PI)) + } + + //---------------------------------------------------------------- + // Second: dealing with negation in subexpressions + //---------------------------------------------------------------- + + // Addition with negation + // a + (-b) = (-b) + a = a - b + ( + ref other, + InfixOperator::Plus, + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + ) + | ( + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + InfixOperator::Plus, + ref other, + ) => simplify_infix(other, InfixOperator::Minus, expression, limit - 1), + + // Subtraction with negation + + // a - (-b) = a + b + ( + ref left, + InfixOperator::Minus, + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + ) => simplify_infix(left, InfixOperator::Plus, expression, limit - 1), + + // -expression - right = smaller of [(-expression) - right, -(expression + right)] + ( + ref left @ Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + InfixOperator::Minus, + ref right, + ) => { + let original = sub!(left.clone(), right.clone()); + let new = simplify_prefix( + PrefixOperator::Minus, + &simplify_infix(expression, InfixOperator::Plus, right, limit - 1), + limit - 1, + ); + min_by_key(original, new, size) + } + + // Multiplication with negation + + // Double negative: (-a) * (-b) = a * b + ( + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + expression: ref a, + }), + InfixOperator::Star, + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + expression: ref b, + }), + ) => simplify_infix(a, InfixOperator::Star, b, limit - 1), + + // a * (-b) = (-a) * b, pick the shorter + ( + ref left, + InfixOperator::Star, + ref right @ Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + ) => { + let original = mul!(left.clone(), right.clone()); + let neg_left = simplify_prefix(PrefixOperator::Minus, left, limit - 1); + let new = simplify_infix(&neg_left, InfixOperator::Star, expression, limit - 1); + min_by_key(original, new, size) + } + // (-a) * b = a * (-b), pick the shorter + ( + ref left @ Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + InfixOperator::Star, + ref right, + ) => { + let original = mul!(left.clone(), right.clone()); + let neg_right = simplify_prefix(PrefixOperator::Minus, right, limit - 1); + let new = simplify_infix(expression, InfixOperator::Star, &neg_right, limit - 1); + min_by_key(original, new, size) + } + + // Division with negation + + // Double negative: (-a) / (-b) = a / b + ( + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + expression: ref a, + }), + InfixOperator::Slash, + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + expression: ref b, + }), + ) => simplify_infix(a, InfixOperator::Slash, b, limit - 1), + + // (-a) / a = a / (-a) = -1 + ( + ref other, + InfixOperator::Slash, + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + ) + | ( + Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + InfixOperator::Slash, + ref other, + ) if *other == **expression => Expression::Number(-ONE), + + // a / (-b) = (-a) / b, pick the shorter + ( + ref left, + InfixOperator::Slash, + ref right @ Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + ) => { + let original = div!(left.clone(), right.clone()); + let neg_left = simplify_prefix(PrefixOperator::Minus, left, limit - 1); + let new = simplify_infix(&neg_left, InfixOperator::Slash, expression, limit - 1); + min_by_key(original, new, size) + } + + // (-a) / b = a / (-b), pick the shorter + ( + ref left @ Expression::Prefix(PrefixExpression { + operator: PrefixOperator::Minus, + ref expression, + }), + InfixOperator::Slash, + ref right, + ) => { + let original = div!(left.clone(), right.clone()); + let neg_right = simplify_prefix(PrefixOperator::Minus, right, limit - 1); + let new = simplify_infix(expression, InfixOperator::Slash, &neg_right, limit - 1); + min_by_key(original, new, size) + } + + //---------------------------------------------------------------- + // Third: Affine relationships + //---------------------------------------------------------------- + + // (a1 * x + b1) + (a2 * x + b2) = (a1 + a2) * x + (b1 + b2) + // + // Apologies for this one; I couldn't get the compiler to let me match two levels deep in a + // recursive data type, and `if let` in a match guard isn't stabilized. + ( + Expression::Infix(InfixExpression { + left: ref left_ax, + operator: InfixOperator::Plus, + right: ref left_b, + }), + InfixOperator::Plus, + Expression::Infix(InfixExpression { + left: ref right_ax, + operator: InfixOperator::Plus, + right: ref right_b, + }), + ) if mul_matches(left_ax, right_ax) => { + let &Expression::Infix(InfixExpression { + left: ref ll, + operator: InfixOperator::Star, + right: ref lr, + }) = &**left_ax + else { + unreachable!("This is handled by mul_matches") + }; + let &Expression::Infix(InfixExpression { + left: ref rl, + operator: InfixOperator::Star, + right: ref rr, + }) = &**right_ax + else { + unreachable!("This is handled by mul_matches") + }; + let (left_a, right_a, x) = if **ll == **rl { + (lr, rr, ll) + } else if **ll == **rr { + (lr, rl, ll) + } else if **lr == **rl { + (ll, rr, lr) + } else { + (ll, rl, rr) + }; + simplify_infix( + &simplify_infix( + &simplify_infix(left_a, InfixOperator::Plus, right_a, limit - 1), + InfixOperator::Star, + x, + limit - 1, + ), + InfixOperator::Plus, + &simplify_infix(left_b, InfixOperator::Plus, right_b, limit - 1), + limit - 1, + ) + } + + // (a1 * x) + (a2 * x) = (a1 + a2) * x + ( + Expression::Infix(InfixExpression { + left: ref left_a, + operator: InfixOperator::Star, + right: ref left_x, + }), + InfixOperator::Plus, + Expression::Infix(InfixExpression { + left: ref right_a, + operator: InfixOperator::Star, + right: ref right_x, + }), + ) if left_x == right_x => simplify_infix( + &simplify_infix(left_a, InfixOperator::Plus, right_a, limit - 1), + InfixOperator::Star, + left_x, + limit - 1, + ), + + // (x + b1) + (x + b2) = x + (b1 + b2) + ( + Expression::Infix(InfixExpression { + left: ref left_x, + operator: InfixOperator::Plus, + right: ref left_b, + }), + InfixOperator::Plus, + Expression::Infix(InfixExpression { + left: ref right_x, + operator: InfixOperator::Plus, + right: ref right_b, + }), + ) if left_x == right_x => simplify_infix( + &simplify_infix( + &Expression::Number(TWO), + InfixOperator::Star, + left_x, + limit - 1, + ), + InfixOperator::Plus, + &simplify_infix(left_b, InfixOperator::Plus, right_b, limit - 1), + limit - 1, + ), + + //---------------------------------------------------------------- + // Fourth: commutation, association, distribution + //---------------------------------------------------------------- + + // Addition associative, right: a + (b + c) = (a + b) + c, pick the shorter + ( + ref a, + InfixOperator::Plus, + ref right @ Expression::Infix(InfixExpression { + left: ref b, + operator: InfixOperator::Plus, + right: ref c, + }), + ) => { + let original = add!(a.clone(), right.clone()); + let new_ab = simplify_infix(a, InfixOperator::Plus, b, limit - 1); + let new = simplify_infix(&new_ab, InfixOperator::Plus, c, limit - 1); + min_by_key(original, new, size) + } + + // Addition associative, left: (a + b) + c = a + (b + c), pick the shorter + ( + ref left @ Expression::Infix(InfixExpression { + left: ref a, + operator: InfixOperator::Plus, + right: ref b, + }), + InfixOperator::Plus, + ref c, + ) => { + let original = add!(left.clone(), c.clone()); + let bc = simplify_infix(b, InfixOperator::Plus, c, limit - 1); + let new = simplify_infix(a, InfixOperator::Plus, &bc, limit - 1); + min_by_key(original, new, size) + } + + // Multiplication associative, right: a * (b * c) = (a * b) * c, pick the shorter + ( + ref a, + InfixOperator::Star, + ref right @ Expression::Infix(InfixExpression { + left: ref b, + operator: InfixOperator::Star, + right: ref c, + }), + ) => { + let original = mul!(a.clone(), right.clone()); + let ab = simplify_infix(a, InfixOperator::Star, b, limit - 1); + let new = simplify_infix(&ab, InfixOperator::Star, c, limit - 1); + min_by_key(original, new, size) + } + + // Multiplication associative, left: (a * b) * c = a * (b * c), pick the shorter + ( + ref left @ Expression::Infix(InfixExpression { + left: ref a, + operator: InfixOperator::Star, + right: ref b, + }), + InfixOperator::Star, + ref c, + ) => { + let original = mul!(left.clone(), c.clone()); + let bc = simplify_infix(b, InfixOperator::Star, c, limit - 1); + let new = simplify_infix(a, InfixOperator::Star, &bc, limit - 1); + min_by_key(original, new, size) + } + + // Subtraction "associative" (not really), right: a - (b - c) = (a + c) - b + ( + ref a, + InfixOperator::Minus, + ref right @ Expression::Infix(InfixExpression { + left: ref b, + operator: InfixOperator::Minus, + right: ref c, + }), + ) => { + let original = sub!(a.clone(), right.clone()); + let ac = simplify_infix(a, InfixOperator::Plus, c, limit - 1); + let new = simplify_infix(&ac, InfixOperator::Minus, b, limit - 1); + min_by_key(original, new, size) + } + + // Division "associative" (not really), right: a / (b / c) = (a * c) / b + ( + ref a, + InfixOperator::Slash, + ref right @ Expression::Infix(InfixExpression { + left: ref b, + operator: InfixOperator::Slash, + right: ref c, + }), + ) => { + let original = div!(a.clone(), right.clone()); + let ac = simplify_infix(a, InfixOperator::Star, c, limit - 1); + let new = simplify_infix(&ac, InfixOperator::Slash, b, limit - 1); + min_by_key(original, new, size) + } + + // Division "associative" (not really), left: (a / b) / c = a / (b * c) + ( + ref left @ Expression::Infix(InfixExpression { + left: ref a, + operator: InfixOperator::Slash, + right: ref b, + }), + InfixOperator::Slash, + ref c, + ) => { + let original = div!(left.clone(), c.clone()); + let bc = simplify_infix(b, InfixOperator::Star, c, limit - 1); + let new = simplify_infix(a, InfixOperator::Slash, &bc, limit - 1); + min_by_key(original, new, size) + } + + // Right distribution: a * (b + c) = (a * b) + (a * c) + ( + ref a, + InfixOperator::Star, + ref right @ Expression::Infix(InfixExpression { + left: ref b, + operator: InfixOperator::Plus, + right: ref c, + }), + ) => { + let original = mul!(a.clone(), right.clone()); + let ab = simplify_infix(a, InfixOperator::Star, b, limit - 1); + let ac = simplify_infix(a, InfixOperator::Star, c, limit - 1); + let new = simplify_infix(&ab, InfixOperator::Plus, &ac, limit - 1); + min_by_key(original, new, size) + } + + // Left distribution: (a + b) * c = (a * c) + (a * b) + ( + ref left @ Expression::Infix(InfixExpression { + left: ref a, + operator: InfixOperator::Plus, + right: ref b, + }), + InfixOperator::Star, + ref c, + ) => { + let original = mul!(left.clone(), c.clone()); + let ac = simplify_infix(a, InfixOperator::Star, c, limit - 1); + let bc = simplify_infix(b, InfixOperator::Star, c, limit - 1); + let new = simplify_infix(&ac, InfixOperator::Plus, &bc, limit - 1); + min_by_key(original, new, size) + } + + //---------------------------------------------------------------- + // Fifth: other parenthesis manipulation + //---------------------------------------------------------------- + + // Mul inside Div on left with cancellation + ( + Expression::Infix(InfixExpression { + left: ref same_1, + operator: InfixOperator::Star, + right: ref other, + }), + InfixOperator::Slash, + ref same_2, + ) + | ( + Expression::Infix(InfixExpression { + left: ref other, + operator: InfixOperator::Star, + right: ref same_1, + }), + InfixOperator::Slash, + ref same_2, + ) if **same_1 == *same_2 => simplify(other, limit - 1), + + // Mul inside Div on right with cancellation + ( + ref same_1, + InfixOperator::Slash, + Expression::Infix(InfixExpression { + left: ref same_2, + operator: InfixOperator::Star, + right: ref other, + }), + ) + | ( + ref same_1, + InfixOperator::Slash, + Expression::Infix(InfixExpression { + left: ref other, + operator: InfixOperator::Star, + right: ref same_2, + }), + ) if *same_1 == **same_2 => simplify_infix( + &Expression::Number(ONE), + InfixOperator::Slash, + other, + limit - 1, + ), + + // Mul inside Div on left + ( + ref numerator @ Expression::Infix(InfixExpression { + left: ref multiplier, + operator: InfixOperator::Star, + right: ref multiplicand, + }), + InfixOperator::Slash, + ref denominator, + ) => { + let original = div!(numerator.clone(), denominator.clone()); + let new_multiplicand = + simplify_infix(multiplicand, InfixOperator::Slash, denominator, limit - 1); + let new = simplify_infix( + multiplier, + InfixOperator::Star, + &new_multiplicand, + limit - 1, + ); + min_by_key(original, new, size) + } + + // Mul inside Div on right + ( + ref numerator, + InfixOperator::Slash, + ref denominator @ Expression::Infix(InfixExpression { + left: ref multiplier, + operator: InfixOperator::Star, + right: ref multiplicand, + }), + ) => { + let original = div!(numerator.clone(), denominator.clone()); + let new_multiplier = + simplify_infix(numerator, InfixOperator::Slash, multiplier, limit - 1); + let new = simplify_infix( + &new_multiplier, + InfixOperator::Star, + multiplicand, + limit - 1, + ); + min_by_key(original, new, size) + } + + // Div inside Mul with cancellation + ( + Expression::Infix(InfixExpression { + left: ref other, + operator: InfixOperator::Slash, + right: ref same_1, + }), + InfixOperator::Star, + ref same_2, + ) + | ( + ref same_2, + InfixOperator::Star, + Expression::Infix(InfixExpression { + left: ref other, + operator: InfixOperator::Slash, + right: ref same_1, + }), + ) if **same_1 == *same_2 => simplify(other, limit - 1), + + //---------------------------------------------------------------- + // Sixth: catch-all if no other patterns match + //---------------------------------------------------------------- + (left, operator, right) => Expression::Infix(InfixExpression { + left: left.into(), + operator, + right: right.into(), + }), + } + } +} + +/// Simplify a prefix expression inside an `Expression`, terminating the recursion if `limit` has reached zero. +fn simplify_prefix(op: PrefixOperator, expr: &Expression, limit: u64) -> Expression { + if limit == 0 { + // bail + Expression::Prefix(PrefixExpression { + operator: op, + expression: expr.clone().into(), + }) + } else { + // Remove + + // Push - into numbers & π + // Pass through otherwise + match (op, simplify(expr, limit - 1)) { + (PrefixOperator::Plus, expression) => expression, + (PrefixOperator::Minus, Expression::Number(x)) => Expression::Number(-x), + (PrefixOperator::Minus, Expression::PiConstant) => Expression::Number(-PI), + (operator, expression) => Expression::Prefix(PrefixExpression { + operator, + expression: expression.into(), + }), + } + } +} diff --git a/quil-rs/src/expression/simplification/mod.rs b/quil-rs/src/expression/simplification/mod.rs new file mode 100644 index 00000000..205a8494 --- /dev/null +++ b/quil-rs/src/expression/simplification/mod.rs @@ -0,0 +1,345 @@ +use crate::expression::Expression; + +mod by_hand; + +/// Simplify an [`Expression`]. +pub(super) fn run(expression: &Expression) -> Expression { + by_hand::run(expression) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + macro_rules! test_simplify { + ($name:ident, $input:expr, $expected:expr) => { + #[test] + fn $name() { + let parsed_input = Expression::from_str($input) + .unwrap_or_else(|error| panic!("Parsing input `{}` failed: {error}", $input)); + let parsed_expected = Expression::from_str($expected).unwrap_or_else(|error| { + panic!( + "Parsing expected expression `{}` failed: {error}", + $expected + ) + }); + let computed = run(&parsed_input); + assert_eq!(parsed_expected, computed); + } + }; + } + + test_simplify! { + function_cis, + "cis(0)", + "1" + } + + test_simplify! { + function_cos, + "cis(0)", + "1" + } + + test_simplify! { + function_exp, + "exp(1)", + "2.718281828459045" + } + + test_simplify! { + function_sin, + "sin(0)", + "0" + } + + test_simplify! { + function_sqrt, + "sqrt(9)", + "3" + } + + test_simplify! { + infix_add_0_r, + "x + 0", + "x" + } + + test_simplify! { + infix_add_0_l, + "0 + x", + "x" + } + + test_simplify! { + infix_add, + "1 + 2", + "3" + } + + test_simplify! { + infix_sub_0_r, + "x - 0", + "x" + } + + test_simplify! { + infix_sub_self, + "x - x", + "0" + } + + test_simplify! { + infix_mul_0_r, + "x * 0", + "0" + } + + test_simplify! { + infix_mul_0_l, + "0 * x", + "0" + } + + test_simplify! { + infix_mul_1_r, + "x * 1", + "x" + } + + test_simplify! { + infix_mul_1_l, + "1 * x", + "x" + } + + test_simplify! { + infix_div_0_l, + "0 / x", + "0" + } + + test_simplify! { + infix_div_1_r, + "x / 1", + "x" + } + + test_simplify! { + infix_div_self, + "x / x", + "1" + } + + test_simplify! { + infix_exp_0_r, + "0^x", + "0" + } + + test_simplify! { + infix_exp_0_l, + "x^0", + "1" + } + + test_simplify! { + infix_sub_neg, + "x - (-y)", + "x + y" + } + + test_simplify! { + infix_mul_double_neg, + "(-x) * (-y)", + "x * y" + } + + test_simplify! { + infix_div_double_neg, + "(-x) / (-y)", + "x / y" + } + + test_simplify! { + infix_affine_full, + "(a1 * x + b1) + (a2 * x + b2)", + "(a1 + a2) * x + (b1 + b2)" + } + + test_simplify! { + infix_affine_coeffs, + "(a1 * x) + (a2 * x)", + "(a1 + a2) * x" + } + + test_simplify! { + infix_affine_constants, + "(x + b1) + (x + b2)", + "(2 * x) + (b1 + b2)" + } + + test_simplify! { + infix_mul_div_ll, + "(y * x) / x", + "y" + } + + test_simplify! { + infix_mul_div_lr, + "(x * y) / x", + "y" + } + + test_simplify! { + infix_mul_div_rl, + "x / (y * x)", + "1 / y" + } + + test_simplify! { + infix_mul_div_rr, + "x / (x * y)", + "1 / y" + } + + test_simplify! { + infix_div_mul_l, + "(x / y) * y", + "x" + } + + test_simplify! { + infix_div_mul_r, + "y * (x / y)", + "x" + } + + test_simplify! { + docstring_example, + "cos(2 * pi) + 2", + "3" + } + + test_simplify! { + issue_208_1, + "0 * theta[0]", + "0" + } + + test_simplify! { + issue_208_2, + "theta[0] / 1", + "theta[0]" + } + + test_simplify! { + issue_208_3, + "(theta[0] * 5) / 5", + "theta" + } + + test_simplify! { + memory_ref, + "theta[0]", + "theta[0]" + } + + test_simplify! { + var, + "%foo", + "%foo" + } + + test_simplify! { + prefix_neg, + "-(-1)", + "1" + } + + test_simplify! { + sub_neg, + "2 - (-1)", + "3" + } + + test_simplify! { + neg_sub, + "-(1 - 2)", + "1" + } + + test_simplify! { + fold_constant_mul, + "2 * pi", + "6.283185307179586" + } + + test_simplify! { + fold_constant_mul_div, + "(2 * pi) / 6.283185307179586", + "1" + } + + test_simplify! { + fold_constant_mul_div_2, + "2 * (pi / 6.283185307179586)", + "1" + } + + test_simplify! { + fold_constant_mul_div_with_ref, + "((a[0] * 2) * pi) / 6.283185307179586", + "a[0]" + } + + test_simplify! { + fold_constant_mul_div_with_ref_2, + "a[0] * (2 * pi) / 6.283185307179586", + "a[0]" + } + + test_simplify! { + fold_constant_mul_div_with_ref_3, + "a[0] * (2 * (pi / 6.283185307179586))", + "a[0]" + } + + test_simplify! { + affine, + "(2 * x[0] + 3) + (4 * x[0] + 5)", + "6 * x[0] + 8" + } + + test_simplify! { + affine_2, + "2 * x[0] + (4 * x[0] + 5)", + "6 * x[0] + 5" + } + + test_simplify! { + affine_3, + "2 * x[0] + 4 * x[0]", + "6 * x[0]" + } + + test_simplify! { + affine_4, + "(x[0] + 3) + (x[0] + 5)", + "2 * x[0] + 8" + } + + test_simplify! { + double_subtraction, + "3 - 2 - 1", + "0" + } + + // TODO doesn't fully simplify in a reasonable amount of recursion + // test_simplify! { + // the_big_one, + // "(6.283185307179586*(-((-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(-((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+(1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)+1830.4305845069357))+2997.220957806505) - ((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+3082.921997445349)+-((-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+(-((-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(-((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+(1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)+1830.4305845069357))+2997.220957806505) - ((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+3082.921997445349)+(1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+-((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527))+3552.7822825370968) - ((-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(-((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+(1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)+1830.4305845069357))+2997.220957806505) - ((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+3082.921997445349)+(((1827.142690137572+-(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702) - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - ((-3.141592653589793+gamma[0]*-1.3670709112264738)/6.283185307179586+1293.2884354900702)+1830.4305845069357))+(2404.366183299857+(-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+(1292.2571023206997 - (-3.141592653589793+(gamma[0]*-1.3670709112264738)/6.283185307179586)+1293.2884354900702)) - 2473.4667568746527)+3654.308518679512+(-3.141592653589793+gamma[0]*-1.4598346220303238)/6.283185307179586)))+0.4345210910722077)/6.283185307179586", + // "-0.637964476122525*gamma[0] - 14553.9199845484" + // } +} diff --git a/quil-rs/src/hash.rs b/quil-rs/src/hash.rs index 7fbea4e3..2a8221ad 100644 --- a/quil-rs/src/hash.rs +++ b/quil-rs/src/hash.rs @@ -1,15 +1,5 @@ -use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, -}; +use std::hash::{Hash, Hasher}; -/// Hash value helper: turn a hashable thing into a u64. -#[inline] -pub(crate) fn hash_to_u64(t: &T) -> u64 { - let mut s = DefaultHasher::new(); - t.hash(&mut s); - s.finish() -} /// Hashes a f64 using its u64 representation. /// /// Notes: