Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Rework ifs transform #245

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 177 additions & 133 deletions src/transformer/ifs.rs
Original file line number Diff line number Diff line change
@@ -1,131 +1,143 @@
use anyhow::Result;

Check warning on line 1 in src/transformer/ifs.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `anyhow::Result`

Check warning on line 1 in src/transformer/ifs.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `anyhow::Result`
use num_traits::Zero;

use crate::compiler::{Constraint, ConstraintSet, Expression, Intrinsic, Node};

use super::flatten_list;

Check warning on line 6 in src/transformer/ifs.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `super::flatten_list`

Check warning on line 6 in src/transformer/ifs.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `super::flatten_list`

/// Expand if conditions, assuming they are roughly in "top-most"
/// positions. That is, we can have arbitrary nested if `List` and
/// `IfZero` / `IfNotZero` but nothing else. The simplest example is
/// something like this:
/// Lower an expression by eliminating if conditionals. The simplest
/// example is something like this:
///
/// ```
/// (if (vanishes! A) B C)
/// ```
///
/// Which is translated into a list of two constraints:
/// Which is translated into a list of two lowered constraints:
///
/// ```
/// {
/// (1 - NORM(A)) * B
/// A * C
/// }
/// ```
fn do_expand_ifs(e: &mut Node) -> Result<()> {
match e.e_mut() {
fn lower_expr(node: &Node) -> Node {
match node.e() {
Expression::List(es) => {
for e in es.iter_mut() {
do_expand_ifs(e)?;
let mut nes = Vec::new();
// Lower each expression in turn
for e in es {
let le = lower_expr(e);
if !is_zero(Some(&le)) {
nes.push(le);
}
}
// Fold back into a list
Expression::List(nes).into()
}
Expression::Funcall { func, args, .. } => {
for e in args.iter_mut() {
do_expand_ifs(e)?;
}
if matches!(func, Intrinsic::IfZero | Intrinsic::IfNotZero) {
let cond = args[0].clone();
let if_not_zero = matches!(func, Intrinsic::IfNotZero);

// If the condition reduces to a constant, we can determine the result
if let Ok(constant_cond) = cond.pure_eval() {
if if_not_zero {
if !constant_cond.is_zero() {
*e = args[1].clone();
} else {
*e = flatten_list(args.get(2).cloned().unwrap_or_else(Node::zero));
}
} else {
if constant_cond.is_zero() {
*e = args[1].clone();
} else {
*e = flatten_list(args.get(2).cloned().unwrap_or_else(Node::zero));
}
}
} else {
// Construct condition for then branch, and
// condition for else branch.
let conds = {
// Multiplier for if-non-zero branch.
let cond_not_zero = cond.clone();
// Multiplier for if-zero branch.
let cond_zero = Intrinsic::Sub.unchecked_call(&[
Node::one(),
Intrinsic::Normalize.unchecked_call(&[cond.clone()])?,
])?;
// Set ordering based on function itself.
if if_not_zero {
[cond_not_zero, cond_zero]
} else {
[cond_zero, cond_not_zero]
}
};
// Apply condition to body.
let then_else: Node = match (args.get(1), args.get(2)) {
(Some(e), None) => {
let then_cond = conds[0].clone();
Intrinsic::Mul
.unchecked_call(&[then_cond, e.clone()])
.unwrap()
}
(None, Some(e)) => {
let else_cond = conds[1].clone();
Intrinsic::Mul
.unchecked_call(&[else_cond, e.clone()])
.unwrap()
}
(_, _) => unreachable!(),
};
// Finally, replace existing node.
*e = then_else.clone();
};
_ => {
let body = extract_body(node);
// Construct lowered expression
match extract_condition(node) {
None => body,
Some(cond) => {
// Construct cond * body
mul2(Some(cond), Some(body)).unwrap()
}
}
}
_ => (),
}

Ok(())
}

/// Pull `if` conditionals out of nested positions and into top-most
/// positions. Specifically, something like this:
/// Extract the _condition_ of an expression. Every expression can be
/// view as a conditional constraint of the form `if c then e`, where
/// `c` is the condition. This is allowed to return `None` if the
/// body is unconditional. For example, consider this:
///
/// ```lisp
/// (defconstraint test () (+ (if A B) C))
/// ```
///
/// Has the nested `if` raised into the following position:
/// Then, the extracted condition is `A`. Likewise, for this case:
///
/// ```lisp
/// (defconstraint test () (if A (+ B C)))
/// (defconstraint test () (+ (if A (if B C)) D))
/// ```
///
/// The purpose of this is to sanitize the structure of `if`
/// conditions to make their subsequent translation easier.
///
/// **NOTE:** the algorithm implemented here is not particular
/// efficient, and can result in unnecessary cloning of expressions.
fn raise_ifs(mut e: Node) -> Node {
match e.e_mut() {
Expression::Funcall { func, ref mut args } => {
*args = args.iter_mut().map(|a| raise_ifs(a.clone())).collect();
// This is a sanity check, though I'm not sure how it can
// arise.
assert!(args
.iter()
.fold(true, |b, e| b && !matches!(e.e(), Expression::Void)));
//
/// Then, the extracted condition is `A * B`.
fn extract_condition(node: &Node) -> Option<Node> {
match node.e() {
Expression::Funcall { func, args } => {
match func {
Intrinsic::Neg | Intrinsic::Inv | Intrinsic::Normalize => {
assert_eq!(args.len(), 1);
extract_condition(&args[0])
}
Intrinsic::Add
| Intrinsic::Sub
| Intrinsic::Mul
| Intrinsic::VectorAdd
| Intrinsic::VectorSub
| Intrinsic::VectorMul
| Intrinsic::Exp => {
let mut r = None;
// Extract condition for each term
for n in args {
r = mul2(r, extract_condition(n));
}
//
r
}
Intrinsic::IfZero => {
assert_eq!(args.len(), 2);
extract_condition_if(true, &args[0], &args[1])
}
Intrinsic::IfNotZero => {
assert_eq!(args.len(), 2);
extract_condition_if(false, &args[0], &args[1])
}
Intrinsic::Begin => {
// Should be unreachable here since this function should only
// never be called with a list, or a node containing a list.
unreachable!()
}
}
}
Expression::List(_) => {
// Should be unreachable here since this function should only
// never be called with a list, or a node containing a list.
unreachable!()
}
_ => None, // unconditional
}
}

fn extract_condition_if(sign: bool, cond: &Node, body: &Node) -> Option<Node> {
let cc = extract_condition(cond);
let mut cb = extract_body(cond);
// Account for true branch
if sign {
// 1 - X
let args = &[
Node::one(),
Intrinsic::Normalize.unchecked_call(&[cb]).unwrap(),
];
cb = Intrinsic::Sub.unchecked_call(args).unwrap();
}
//
let bc = extract_condition(body);
//
mul3(cc, Some(cb), bc)
}

/// Translate the _body_ of an expression. Every expression can be
/// viewed as a conditional constraint of the form `if c then e`,
/// where `e` is the constraint.
fn extract_body(node: &Node) -> Node {
match node.e() {
Expression::Funcall { func, args } => {
match func {
Intrinsic::IfZero => extract_body(&args[1]),
Intrinsic::IfNotZero => extract_body(&args[1]),
Intrinsic::Neg
| Intrinsic::Inv
| Intrinsic::Normalize
Expand All @@ -136,47 +148,84 @@
| Intrinsic::VectorAdd
| Intrinsic::VectorSub
| Intrinsic::VectorMul => {
for (i, a) in args.iter().enumerate() {
if let Expression::Funcall {
func: func_if @ (Intrinsic::IfZero | Intrinsic::IfNotZero),
args: args_if,
} = a.e()
{
let cond = args_if[0].clone();
// Pull out true-branch:
// (func a b (if cond c d) e)
// ==> (if cond (func a b c e))
let mut then_args = args.clone();
then_args[i] = args_if[1].clone();
let new_then = func.unchecked_call(&then_args).unwrap();
let mut new_args = vec![cond, new_then];
// Pull out false branch (if applicable):
// (func a b (if cond c d) e)
// ==> (if !cond (func a b d e))
if let Some(arg_else) = args_if.get(2).cloned() {
let mut else_args = args.clone();
else_args[i] = arg_else;
new_args.push(func.unchecked_call(&else_args).unwrap());
}
// Repeat this until ifs pulled out
// from all argument positions.
return raise_ifs(
func_if.unchecked_call(&new_args).unwrap().with_type(a.t()),
);
}
let mut bodies = Vec::new();
// Extract bodies from each term
for n in args {
bodies.push(extract_body(n));
}
e
// Combine back together
func.unchecked_call(&bodies).unwrap()
}
Intrinsic::Begin => {
// Should be unreachable here since this function should only
// never be called with a list, or a node containing a list.
unreachable!()
}
Intrinsic::IfZero | Intrinsic::IfNotZero | Intrinsic::Begin => e,
}
}
Expression::List(xs) => {
for x in xs.iter_mut() {
*x = raise_ifs(x.clone());
Expression::List(_) => {
// Should be unreachable here since this function should only
// never be called with a list, or a node containing a list.
unreachable!()
}
_ => node.clone(),
}
}

/// Multiply two optional nodes together, whilst performing some
/// simplistic optimisations when possible.
fn mul2(lhs: Option<Node>, rhs: Option<Node>) -> Option<Node> {
if is_zero(lhs.as_ref()) || is_zero(rhs.as_ref()) {
Some(Node::zero())
} else if is_not_zero(lhs.as_ref()) {
rhs
} else if is_not_zero(rhs.as_ref()) {
lhs
} else {
match (lhs, rhs) {
(None, r) => r,
(l, None) => l,
(Some(l), Some(r)) => Some(Intrinsic::Mul.unchecked_call(&[l, r]).unwrap()),
}
}
}

/// Multiply three optional nodes together.
fn mul3(lhs: Option<Node>, mhs: Option<Node>, rhs: Option<Node>) -> Option<Node> {
mul2(lhs, mul2(mhs, rhs))
}

/// Determine whether a given expression definitely evaluates to `0`.
/// Note that, if this returns `false`, it may still be that the
/// expression will always evaluate to `0` --- but this cannot be
/// easily determined.
fn is_zero(node: Option<&Node>) -> bool {
match node {
Some(n) => {
if let Ok(constant) = n.pure_eval() {
constant.is_zero()
} else {
false
}
}
_ => false,
}
}

/// Determine whether a given expression definitely does not evaluate
/// to `0`. Note that, if this returns `false`, it may still be that
/// the expression will never evaluate to `0` --- but this cannot be
/// easily determined.
fn is_not_zero(node: Option<&Node>) -> bool {
match node {
Some(n) => {
if let Ok(constant) = n.pure_eval() {
!constant.is_zero()
} else {
false
}
e
}
_ => e,
_ => false,
}
}

Expand Down Expand Up @@ -328,14 +377,9 @@
// Raise ifs
for c in cs.constraints.iter_mut() {
if let Constraint::Vanishes { expr, .. } = c {
let nexpr = raise_ifs(*expr.clone());
// Replace old expression with new
let nexpr = lower_expr(expr);
// Done
*expr = Box::new(nexpr);
}
}
for c in cs.constraints.iter_mut() {
if let Constraint::Vanishes { expr: e, .. } = c {
do_expand_ifs(e).unwrap();
}
}
}
11 changes: 11 additions & 0 deletions tests/issue241_c.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(defcolumns X ST)

(defconstraint c0 ()
(if-not-zero ST (vanishes!
(if (is-zero (if (is-zero 1) 0 0))
X
(~and! 1 1)))))

(defconstraint c1 () (if-not-zero ST (vanishes! (if (is-zero 0) X (~and! 1 1)))))

(defconstraint c2 () (if-not-zero ST (vanishes! X)))
10 changes: 10 additions & 0 deletions tests/issue241_d.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
(defcolumns X ST)

(defconstraint c0 ()
(if-not-zero ST (vanishes!
(if (is-zero (if (is-zero 1) 0 0)) X (~and! 1 1)))))

(defconstraint c1 ()
(if-not-zero ST (vanishes! (if (is-zero 0) X (~and! 1 1)))))

(defconstraint c2 () (if-not-zero ST (vanishes! X)))
2 changes: 2 additions & 0 deletions tests/issue241_e.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
(defcolumns ST)
(defconstraint c0 () (is-not-zero! (~and! (if (is-zero 1) 1 1) (if (is-zero 1) 1 1))))
Loading
Loading